Can't save checkpoint to file descriptor.
Opened this issue · 0 comments
GLivshits commented
System Info
safetensors==0.4.5
torch==2.4.0+cu124
Information
- The official example scripts
- My own modified scripts
Reproduction
import io
import torch
from safetensors.torch import save_file as safetensors_save
data = {}
tensor_shape = (64, 128, 256)
num_tensors = 16
for i in range(num_tensors):
data[str(i)] = torch.randn(*tensor_shape)
stream = io.BytesIO()
out = safetensors_save(data, stream)
Error:
TypeError: argument 'filename': expected str, bytes or os.PathLike object, not BytesIO
Expected behavior
I've expected serialized checkpoint bytes to be written to buffer.
It would be convenient to do so for:
- Further uploading from file object.
- Use custom file object, representing database file.
There are only options to serialize checkpoint and get bytes (which is not convenient when checkpoint is large and it is materialized only on master process), and to serialize checkpoint to specified filename.