problem with get_gradients
soumbane opened this issue · 3 comments
Anyone facing problem with the get_gradients?
Traceback (most recent call last):
File "ace_run.py", line 112, in
main(parse_arguments(sys.argv[1:]))
File "ace_run.py", line 68, in main
scores = cd.tcavs(test=False)
File "/gstore/home/baners20/GA_progression/interpretability/ACE/ace.py", line 667, in tcavs
gradients = self._return_gradients(tcav_score_images)
File "/gstore/home/baners20/GA_progression/interpretability/ACE/ace.py", line 622, in _return_gradients
acts[i:i+1], [class_id], bn).reshape(-1)
TypeError: get_gradient() missing 1 required positional argument: 'example'
For future readers who face the same problem:
The solution is that the recent TCAV version is incompatible with the current ACE implementation.
Install version 0.2 of tcav (pip install tcav==0.2
instead of pip install tcav
), and it should work fine.
Hi @MichaelDoron,
Yes, I solved the issue. The positional argument is not required for TCAV. So setting the argument example to None works.
Thanks,
Soumyanil Banerjee