bojone/bert4keras

运行example/basic_language_model_gpt2_ml.py生成时报错ValueError: Error when checking model input

nameless0704 opened this issue · 1 comments

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统: Win10
  • 你使用的Python版本: 3.7.9
  • 你使用的Tensorflow版本: 1.14.0
  • 你使用的Keras版本: 2.3.1
  • 你使用的bert4keras版本: 0.11.4
  • 你使用纯keras还是tf.keras: 纯keras
  • 你加载的预训练模型: roberta_zh_L-6-H-768_A-12,来自https://github.com/brightmart/roberta_zh

核心代码

#使用basic_language_model_gpt2_ml.py原文,仅model的model参数改为‘roberta’
class ArticleCompletion(AutoRegressiveDecoder):
    """基于随机采样的文章续写
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        token_ids = np.concatenate([inputs[0], output_ids], 1)
        return self.last_token(model).predict(token_ids)

    def generate(self, text, n=1, topp=0.95):
        token_ids, _ = tokenizer.encode(text)
        results = self.random_sample([token_ids], n, topp=topp)  # 基于随机采样
        return [text + tokenizer.decode(ids) for ids in results]

article_completion = ArticleCompletion(
    start_id=None,
    end_id=511,  # 511是中文句号
    maxlen=256,
    minlen=128
)
print(article_completion.generate(u'今天天气不错'))

输出信息

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[ 791, 1921, 1921, 3698,  679, 7231]])]

自我尝试

看到了 Issue #446 里面写tf2.0有问题,但是我降到了1.15或者1.14试了都还是报错了,所以求救,谢谢。

报错的原因就是roberta不能用来替换gpt,所以【仅model的model参数改为‘roberta’】就是错误原因所在。