bitsandbytes-foundation/bitsandbytes

Paged optimizer resuming from checkpoint - attributeError: 'int' object has no attribute 'cpu'

shivam15s opened this issue · 1 comments

System Info

Platform: Linux-5.15.148.2-2.cm2-x86_64-with-glibc2.35
Python version: 3.10.14
Bitsandbytes version: 0.43.1
Safetensors version: 0.4.5
Accelerate version: 0.34.2
Accelerate config: not found
PyTorch version (GPU?): 2.4.0+cu124 (True)
Tensorflow version (GPU?): 2.16.2 (True)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using distributed or parallel set-up in script?: yes
Using GPU in script?: yes
GPU type: NVIDIA A100-SXM4-80GB

Reproduction

from trl import SFTConfig, SFTTrainer
import json
from datasets import load_dataset
import transformers
import torch

dataset = load_dataset("openai/gsm8k", "main", split="train")
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

training_args = SFTConfig(
    dataset_text_field="question",
    per_device_train_batch_size=1,
    eval_steps=4,
    output_dir="tmp1",
    save_steps=4,
    max_steps=4,
    bf16=True,
    fsdp="full_shard auto_wrap",
    optim="paged_adamw_32bit"
)

model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

trainer = SFTTrainer(model=model, train_dataset=dataset, args=training_args)
trainer.train()

# resume from checkpoint for another 4 steps
training_args = SFTConfig(
    dataset_text_field="question",
    per_device_train_batch_size=1,
    eval_steps=4,
    output_dir="tmp1",
    save_steps=4,
    max_steps=8,
    bf16=True,
    fsdp="full_shard auto_wrap",
    optim="paged_adamw_32bit"
)
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
trainer = SFTTrainer(model=model, train_dataset=dataset, args=training_args)
trainer.train(resume_from_checkpoint=True)

Expected behavior

Script works fine, ie. training can resume from checkpoint.
Currently, I get the error: AttributeError: 'int' object has no attribute 'cpu'

@matthewdouglas Would really appreciate any tips here TIA!