发新帖

报错TypeError: slice indices must be integers or None or have an __index__ ...

[复制链接]
90 0

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
您好,有幸读到您的手册,我有一点小小的疑问,求解答!
我的运行环境:
pycharm2019.2.3
python 3.7
TensorFlow 2.0
代码如下
  1. import tensorflow as tf
  2. import numpy as np
  3. class DataLoader():
  4.     def __init__(self):
  5.         path = tf.keras.utils.get_file('nietzsche.txt',
  6.             origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
  7.         with open(path, encoding='utf-8') as f:
  8.             self.raw_text = f.read().lower()
  9.         self.chars = sorted(list(set(self.raw_text)))
  10.         self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
  11.         self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
  12.         self.text = [self.char_indices[c] for c in self.raw_text]

  13.     def get_batch(self, seq_length, batch_size):
  14.         seq = []
  15.         next_char = []
  16.         for i in range(batch_size):
  17.             index = np.random.randint(0, len(self.text) - seq_length)
  18.             seq.append(self.text[index:index+seq_length])
  19.             next_char.append(self.text[index+seq_length])
  20.         return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]


  21. class RNN(tf.keras.Model):
  22.     def __init__(self, num_chars, batch_size, seq_length):
  23.         super().__init__()
  24.         self.num_chars = num_chars
  25.         self.seq_length = seq_length
  26.         self.batch_size = batch_size
  27.         self.cell = tf.keras.layers.LSTMCell(units=256)
  28.         self.dense = tf.keras.layers.Dense(units=self.num_chars)

  29.     def call(self, inputs, from_logits=False):
  30.         inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
  31.         state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
  32.         for t in range(self.seq_length):
  33.             output, state = self.cell(inputs[:, t, :], state)
  34.         logits = self.dense(output)
  35.         if from_logits:
  36.             return logits
  37.         else:
  38.             return tf.nn.softmax(logits)
  39. num_batches = 10
  40. seq_length = 40
  41. batch_size = 50
  42. learning_rate = 1e-3
  43. data_loader = DataLoader()
  44. model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
  45. optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  46. for batch_index in range(num_batches):
  47.     X, y = data_loader.get_batch(seq_length, batch_size)
  48.     with tf.GradientTape() as tape:
  49.         y_pred = model(X)
  50.         loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
  51.         loss = tf.reduce_mean(loss)
  52.         print("batch %d: loss %f" % (batch_index, loss.numpy()))
  53.     grads = tape.gradient(loss, model.variables)
  54.     optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

  55. def predict(self, inputs, temperature=1.):
  56.     batch_size, _ = tf.shape(inputs)
  57.     logits = self(inputs, from_logits=True)
  58.     print(logits)
  59.     print(temperature)
  60.     prob = tf.nn.softmax(logits / temperature).numpy()
  61.     return np.array([np.random.choice(self.num_chars, p=prob[i, :])
  62.                      for i in range(batch_size.numpy())])

  63. X_, _ = data_loader.get_batch(seq_length, 1)

  64. for diversity in [0.2, 0.5, 1.0, 1.2]:
  65.     X = X_
  66.     print("diversity %f:" % diversity)
  67.     for t in range(400):
  68.         y_pred = model.predict(X, diversity)
  69.         print(data_loader.indices_char[y_pred[0]], end='', flush=True)
  70.         X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)
  71.     print("\n")
复制代码
报错:
  1. runfile('F:/pyth/pj3/study3.py', wdir='F:/pyth/pj3')
  2. batch 0: loss 4.044459
  3. batch 1: loss 4.025946
  4. batch 2: loss 4.001545
  5. batch 3: loss 3.980800
  6. batch 4: loss 3.945248
  7. batch 5: loss 3.867068
  8. batch 6: loss 3.684950
  9. batch 7: loss 3.236459
  10. batch 8: loss 3.574704
  11. batch 9: loss 3.551273
  12. diversity 0.200000:
  13. Traceback (most recent call last):
  14.   File "<input>", line 1, in <module>
  15.   File "D:\Program Files\JetBrains\PyCharm 2019.2.3\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
  16.     pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  17.   File "D:\Program Files\JetBrains\PyCharm 2019.2.3\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
  18.     exec(compile(contents+"\n", file, 'exec'), glob, loc)
  19.   File "F:/pyth/pj3/study3.py", line 93, in <module>
  20.     y_pred = model.predict(X, diversity)
  21.   File "D:\ProgramData\Anaconda3\envs\kingtf2\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 909, in predict
  22.     use_multiprocessing=use_multiprocessing)
  23.   File "D:\ProgramData\Anaconda3\envs\kingtf2\lib\site-packages\tensorflow_core\python\keras\engine\training_arrays.py", line 722, in predict
  24.     callbacks=callbacks)
  25.   File "D:\ProgramData\Anaconda3\envs\kingtf2\lib\site-packages\tensorflow_core\python\keras\engine\training_arrays.py", line 362, in model_iteration
  26.     batch_ids = index_array[batch_start:batch_end]
  27. TypeError: slice indices must be integers or None or have an __index__ method
复制代码


您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

主题

帖子

4

积分
快速回复 返回顶部 返回列表