[RFC] Allow Hydra to live in the base class
SeanNaren opened this issue ยท 3 comments
๐ RFC
Proposal
Marry the project with Hydra, allow Hydra to live in a base class TaskTransformer (or something similarly named). Remove complications with the instantiator, make it clear we're using hydra.utils.instantiate
in the code.
Motivation
I've been working on #53 and ran into many API issues, and confusion. The big issue is if I'm running into issues (after developing the code) then a normal user would be even more confused.
it isn't clear what base class to use.
I want to use Hydra conf as it provides niceties like instantiating my optimizer/scheduler, but at the same time our current base class doesn't provide any of that:
class TaskTransformer(LitTransformer):
"""
Base class for task specific transformers
"""
def setup(self, stage: str):
self.configure_metrics(stage)
def configure_metrics(self, stage: str) -> Optional[Any]:
"""
Override to configure metrics for train/validation/test.
This is called on fit start to have access to the data module,
and initialize any data specific metrics.
"""
pass
My proposal is to move configure_optimizers here, so instantiation happens in the base class. If a user doesn't want to use instantiation, fine. override configure_optimizers
. This moves onto the next point.
def __init__(self, optimizer_cfg: Any, scheduler_cfg: Any):
super().__init__()
self.optimizer_cfg = optimizer_cfg
self.scheduler_cfg = scheduler_cfg
def configure_optimizers(self) -> Dict:
self.optimizer = self.optimizer(self.optimizer_cfg, self.model)
# prepare_warmup needs the datamodule to be available when `self.num_training_steps`
# is called that is why this is done here and not in the __init__
self.prepare_warmup(self.scheduler_cfg)
self.scheduler = self.scheduler(self.scheduler_cfg, self.optimizer)
return super().configure_optimizers()
def prepare_warmup(self, cfg: SchedulerConfig):
if cfg.num_training_steps < 0:
# less than 0 specifies to infer number of training steps
cfg.num_training_steps = self.num_training_steps
log.info(f"Inferring number of training steps, set to {cfg.num_training_steps}")
if isinstance(cfg.num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
cfg.num_warmup_steps *= cfg.num_training_steps
log.info(f"Inferring number of warmup steps from ratio, set to {cfg.num_warmup_steps}")
def optimizer(self, cfg: Any, model: torch.nn.Module) -> torch.optim.Optimizer:
no_decay = ["bias", "LayerNorm.weight"]
grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": cfg.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return hydra.utils.instantiate(cfg, grouped_parameters)
def scheduler(self, cfg: Any, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
return hydra.utils.instantiate(cfg, optimizer=optimizer)
when using hydra, what should my class look like?
Unless I look at the HF base class, I have no idea how the instantiator works. Once I look at the HF base class it makes sense:
class HFTransformer(HydraTaskTransformer):
"""
Base class for task specific transformers, wrapping pre-trained language models for downstream tasks.
The API is built on top of AutoModel and AutoConfig, provided by HuggingFace.
see: https://huggingface.co/transformers/model_doc/auto.html
"""
def __init__(
self,
downstream_model_type: str,
backbone: HFBackboneConfig,
optimizer: OptimizerConfig,
scheduler: HFSchedulerConfig,
**config_data_args,
):
Hydra passes these configs which are defined in the task conf. But the fact that there is no class signature to inherit from made this extremely difficult to parse. So my proposal here is to be opinionated, set this class signature in the TaskTransformer
.
Aside from additional Scheduler config, removing the instantiator functions that are now hard coded into the module, these were the changes required to get DALLE in (from the model perspective, data API needs some change but its less controversial).
Notes
Let's say a user doesn't want to pass in a backbone config. What should they do? Should be able to just omit it from your class signature. I think the best way to support this is to have the class signature take an optional backbone config (but in most cases, you will).
What if I don't want my base class to take all these configs? Use the LitTransformer
base class.
Make sense to me !
All instantiation should be done in the base class to simplify people life. And they will still have the ability to override them if needed.
About marrying with Hydra, I am 100 % for it. I think it is chance to show how to properly use both together.
Best,
T.C
My main concern about marrying Hydra, is that we might also be marrying our children to it.
After all, regardless of Hydra's potential, we want to provide transformers utilities at every level. From zero to hero. From the user running a train.py
command in the README to the Lightning (power) user who wants to import a subclass one of our transformers classes.
What about Flash? Say they provide their NLP tasks by importing ours and maybe polishing them. They'll have to install Hydra and pass DictConfigs around!
What if an user chose plain argparse
, do we want to force him to provide a conversion from argparse.Namespace
to DictConfig
?
These are just some examples. It's hard to see the big picture because we are in a very early stage, but this decision will have a large impact down the line (for better or worse).
The decision is to change our base classes to look like this:
from typing import Any, Dict, Optional
import pytorch_lightning as pl
import torch
class LitTransformer(pl.LightningModule):
"""
Base class for transformers.
Provides a few helper functions primarily for optimization.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
super().__init__()
self.model = model
# some optimizers/schedulers need parameters only known dynamically
# allow users to override the getter to instantiate them lazily
self.optimizer = optimizer
self.scheduler = scheduler
def configure_optimizers(self) -> Dict:
"""Prepare optimizer and scheduler"""
return {
"optimizer": self.optimizer,
"lr_scheduler": {"scheduler": self.scheduler, "interval": "step", "frequency": 1},
}
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if self.trainer.max_steps:
return self.trainer.max_steps
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(self.trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.trainer.datamodule.train_dataloader())
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
return (dataset_size // effective_batch_size) * self.trainer.max_epochs
def prepare_warmup(self, num_training_steps, num_warumup_steps):
if cfg.num_training_steps < 0:
# less than 0 specifies to infer number of training steps
cfg.num_training_steps = self.num_training_steps
log.info(f"Inferring number of training steps, set to {cfg.num_training_steps}")
if isinstance(cfg.num_warmup_steps, float):
# Convert float values to percentage of training steps to use as warmup
cfg.num_warmup_steps *= cfg.num_training_steps
log.info(f"Inferring number of warmup steps from ratio, set to {cfg.num_warmup_steps}")
def setup(self, stage: str):
self.configure_metrics(stage)
def configure_metrics(self, stage: str) -> Optional[Any]:
"""
Override to configure metrics for train/validation/test.
This is called on fit start to have access to the data module,
and initialize any data specific metrics.
"""
pass
class TaskTransformer(LitTransformer):
"""
Base class for task specific transformers
"""
def __init__(
self,
optimizer: OptimizerConfig,
scheduler: SchedulerConfig,
instantiator: Optional[Instantiator] = None,
):
super().__init__()
self.instantiator = instantiator
self.optimizer_cfg = optimizer
self.scheduler_cfg = scheduler
def configure_optimizers(self) -> Dict:
self.optimizer = self.instantiator.optimizer(self.model, self.optimizer_cfg)
# prepare_warmup needs the datamodule to be available when `self.num_training_steps`
# is called that is why this is done here and not in the __init__
self.prepare_warmup(self.scheduler_cfg)
self.scheduler = self.instantiator.scheduler(self.scheduler_cfg, self.optimizer)
return super().configure_optimizers()
def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
# Save tokenizer from datamodule for predictions
checkpoint["instantiator"] = self.instantiator
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.instantiator = checkpoint["instantiator"]
Keep instantiator as is, pass through and save in the model to avoid inference issues/finetuning issues.
Keep dataclasses as is!