coreylowman/dfdx

Question: How to load checkpoints?

rjzak opened this issue · 1 comments

rjzak commented

Are loading of checkpoints from PyTorch or Jax supported? I've only seen examples of saving checkpoints.

Not supported and likely will not support due to complexities with loading pickled objects in rust. There's also a lot of complicated ways pytorch saves their tensors (e.g. sparse tensors).

Instead you should convert pytorch/jax pickles into either .safetensors format (used by huggingface) or a .npz file.