BlearnerForSummarization with t5-base errors out 'function' object has no attribute 'setup'
Opened this issue · 1 comments
SiddharthPant commented
I tried running the example code in summarization page of doc with 't5-base' model, but it errors out. I have tried using latest release and master of blurr and fastcore but still issue persists. Here's the sample and the error it spits out:
learn = BlearnerForSummarization.from_data(
cnndm_df,
"t5-base",
text_attr="article",
summary_attr="highlights",
max_length=256,
max_target_length=130,
dblock_splitter=RandomSplitter(),
dl_kwargs={"bs": 2},
).to_fp16()
The error I get:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In [7], line 1
----> 1 learn = BlearnerForSummarization.from_data(
2 cnndm_df,
3 "t5-base",
4 text_attr="article",
5 summary_attr="highlights",
6 max_length=256,
7 max_target_length=130,
8 dblock_splitter=RandomSplitter(),
9 dl_kwargs={"bs": 2},
10 ).to_fp16()
File ~/mambaforge/envs/aranya/lib/python3.8/site-packages/blurr/text/modeling/seq2seq/summarization.py:146, in BlearnerForSummarization.from_data(cls, data, pretrained_model_name_or_path, text_attr, summary_attr, max_length, max_target_length, dblock_splitter, hf_tok_kwargs, text_gen_kwargs, dl_kwargs, learner_kwargs)
143 get_y = ItemGetter(summary_attr)
145 if hf_arch == "t5":
--> 146 get_x.add(cls._add_t5_prefix)
148 # define our DataBlock and DataLoaders
149 batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
150 hf_arch,
151 hf_config,
(...)
156 text_gen_kwargs=text_gen_kwargs,
157 )
File ~/mambaforge/envs/aranya/lib/python3.8/site-packages/fastcore/transform.py:204, in Pipeline.add(self, ts, items, train_setup)
202 def add(self,ts, items=None, train_setup=False):
203 if not is_listy(ts): ts=[ts]
--> 204 for t in ts: t.setup(items, train_setup)
205 self.fs+=ts
206 self.fs = self.fs.sorted(key='order')
AttributeError: 'function' object has no attribute 'setup'
SiddharthPant commented
I am able to workaround this by using mid-level API instead of the above high-level API function. But will just like to highlight the issue to devs.