Deepspeed sharding and load from checkpoint with custom lightning module - setup() not called during checkpoint loading
maxzvyagin opened this issue · 2 comments
❓ Questions and Help
Before asking:
- search the issues.
- search the docs.
What is your question?
Hi, I'm doing training from scratch using deepspeed, pytorch lightning, and transformers in a multi node setting, and wanted to know how to setup the code to handle loading from a pytorch checkpoint.
Going off of the docs here, I see that the model is intended to be defined in setup(). However, this doesn't work when loading from a state dict since setup is not called. What's the right way to structure the code here? Does enable_transformers_pretrained_deepspeed_sharding need to be called in setup or can it be called in the constructor?
This has been my potential workaround in the constructor, because it does seem to fail on certain ranks
def __init__(self, config):
# irrelevant constructor things here
try:
enable_transformers_pretrained_deepspeed_sharding(self)
except AttributeError:
pl.utilities.rank_zero.rank_zero_warn(
"Transformers sharding initialization not enabled..."
)
# needed to load from checkpoint
self.model = AutoModelForCausalLM.from_config(self.base_config)
As opposed to:
def setup(self, stage):
if not hasattr(self, "model"):
try:
enable_transformers_pretrained_deepspeed_sharding(self)
### sometimes using ddp for inference so this will fail
except AttributeError:
pl.utilities.rank_zero.rank_zero_warn(
"Transformers sharding initialization not enabled - likely not using DeepSpeed..."
)
self.model = AutoModelForCausalLM.from_config(self.base_config)
Code
What have you tried?
What's your environment?
Linux, conda/pip,
deepspeed==0.7.3
pytorch-lightning==1.6.5
lighting-transformers==0.2.1
- OS: [e.g. iOS, Linux, Win]
- Packaging [e.g. pip, conda]
- Version [e.g. 0.5.2.1]
Thanks in advance for the help!
Hi @maxzvyagin, what I can understand from your question is, performing the enable_
, on the custom Lightning Module. For that, I think one simple and straight strategy would be inheriting the TaskTransformer
class and modifying the initialize_model
method, as mentioned here.
Do correct, me if I did not understand your question correctly.
Regards,
Akarsh
Hi Akarsh, thanks for checking this out! I guess my question is partially what the reason is for calling enable_transformers_pretrained_deepspeed_sharding(self)
in the setup() function vs in the constructor, and if there's a detrimental effect by calling it in the constructor if we're loading from a pre-trained checkpoint. Because otherwise, we're not able to load from our own PyTorch checkpoint file and continue training with a sharded DeepSpeed approach.