ImportError: cannot import name 'context_models' from 'src.config' when running transformers_predictor.py
zrc0123 opened this issue · 2 comments
Hello,
I got ImportError: cannot import name 'context_models' from 'src.config' when running transformers_predictor.py. I also didn't find the definition of context_models in src folder. Is it a result of changed function names?
Thank you!
Hello,
Now I can run the file with these modifications:
- from transformers import RobertaTokenizerFast
- self.tokenizer = RobertaTokenizerFast.from_pretrained(self.conf.embedder_type, add_prefix_space=True)
- batch_max_scores, batch_max_ids = self.model(subword_input_ids = batch.input_ids.to(self.conf.device),
word_seq_lens = batch.word_seq_len.to(self.conf.device),
orig_to_tok_index = batch.orig_to_tok_index.to(self.conf.device),
attention_mask = batch.attention_mask.to(self.conf.device),
is_train = False)
Threre's my modifications:
-
Establish tokenizer from AutoTokenizer
from transformers import AutoTokenizer
&
self.tokenizer = AutoTokenizer.from_pretrained(self.conf.embedder_type, add_prefix_space=True, use_fast=True) -
Modify the parameters when calling self.model
batch_max_scores, batch_max_ids = self.model(subword_input_ids = batch.input_ids.to(self.conf.device),
word_seq_lens = batch.word_seq_len.to(self.conf.device),
orig_to_tok_index = batch.orig_to_tok_index.to(self.conf.device),
attention_mask = batch.attention_mask.to(self.conf.device),
is_train=False)
Then it works well.