MadryLab/journey-TRAK

unexpected keyword argument 'grad_wrt' when using TRAKer

Opened this issue · 1 comments

I'm following the MSCOCO demo script in the folder examples. I got an error like this:

traker = TRAKer(
  File "/home/user_name/.conda/envs/sd38/lib/python3.8/site-packages/trak/traker.py", line 177, in __init__
    self.gradient_computer = gradient_computer(
TypeError: __init__() got an unexpected keyword argument 'grad_wrt'

I didn't change the initialization:

traker = TRAKer(
    model=model,
    task=task,
    gradient_computer=DiffusionGradientComputer,
    proj_dim=2048,
    train_set_size=len(loader_train.dataset),
    save_dir='./MSCOCO_trak_results',
    device='cuda'
)

I'm using traker 0.3.2. Can you help me locate the bug?

###############################################################

I also got this warning, and I put it here though it seems to be unrelated to the error.

The configuration file of this scheduler: DDPMScheduler {
  "_class_name": "DDPMScheduler",
  "_diffusers_version": "0.15.1",
  "beta_end": 0.02,
  "beta_schedule": "linear",
  "beta_start": 0.0001,
  "clip_sample": true,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "thresholding": false,
  "trained_betas": null,
  "variance_type": "fixed_small"
}

 has not set the configuration `clip_sample`. `clip_sample` should be set to False in the configuration file. Please make sure to update the config accordingly as not setting `clip_sample` in the config might lead to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` file

  deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)

Thanks!

Looking back after several months, using traker==0.2.2 diffusers==0.15.1 makes it work, although according to the setup.py traker 0.3.2 should also work.