airaria/TextBrewer

random_token_example error

tanyaroosta opened this issue · 5 comments

Hi,
I have pulled the latest repo and compiled the package. When running the random_token_example/distill.py, I run into an issue as shown below. I am not clear how to fix it, as it doesn't seem distill.py takes any user input, and so it should be running without any problems.


AttributeError Traceback (most recent call last)
in
29 with distiller:
30 distiller.train(optimizer, dataloader, num_epochs=num_epochs,
---> 31 scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=callback_fun)

~/opt/anaconda3/lib/python3.7/site-packages/textbrewer/distiller_basic.py in train(self, optimizer, dataloader, num_epochs, scheduler_class, scheduler_args, scheduler, max_grad_norm, num_steps, callback, batch_postprocessor, **args)
281 self.train_with_num_steps(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_steps, callback, batch_postprocessor, **args)
282 else:
--> 283 self.train_with_num_epochs(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, **args)
284
285

~/opt/anaconda3/lib/python3.7/site-packages/textbrewer/distiller_basic.py in train_with_num_epochs(self, optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, **args)
244 if (global_step%train_steps_per_epoch in checkpoints)
245 and ((current_epoch+1)%self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs):
--> 246 self.save_and_callback(global_step, step, current_epoch, callback)
247
248 logger.info(f"Epoch {current_epoch+1} finished")

~/opt/anaconda3/lib/python3.7/site-packages/textbrewer/distiller_general.py in save_and_callback(self, global_step, step, epoch, callback)
63 self.model_T._forward_hooks = OrderedDict()
64
---> 65 super(GeneralDistiller, self).save_and_callback(global_step, step, epoch, callback)
66
67 if self.has_custom_matches:

~/opt/anaconda3/lib/python3.7/site-packages/textbrewer/distiller_basic.py in save_and_callback(self, global_step, step, epoch, callback)
36 if callback is not None:
37 logger.info("Running callback function...")
---> 38 callback(model=self.model_S, step=global_step)
39 self.model_S.train()
40

in predict(model, eval_dataset, step, device)
49 with torch.no_grad():
50 logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)
---> 51 cpu_logits = logits.detach().cpu()
52 for i in range(len(cpu_logits)):
53 pred_logits.append(cpu_logits[i].numpy())

AttributeError: 'str' object has no attribute 'detach'

stale commented

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale commented

Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.