Lightning-Universe/lightning-transformers

[RFC] Allow Hydra to live in the base class

SeanNaren opened this issue ยท 3 comments

๐Ÿš€ RFC

Proposal

Marry the project with Hydra, allow Hydra to live in a base class TaskTransformer (or something similarly named). Remove complications with the instantiator, make it clear we're using hydra.utils.instantiate in the code.

Motivation

I've been working on #53 and ran into many API issues, and confusion. The big issue is if I'm running into issues (after developing the code) then a normal user would be even more confused.

it isn't clear what base class to use.

I want to use Hydra conf as it provides niceties like instantiating my optimizer/scheduler, but at the same time our current base class doesn't provide any of that:

class TaskTransformer(LitTransformer):
    """
    Base class for task specific transformers
    """

    def setup(self, stage: str):
        self.configure_metrics(stage)

    def configure_metrics(self, stage: str) -> Optional[Any]:
        """
        Override to configure metrics for train/validation/test.
        This is called on fit start to have access to the data module,
        and initialize any data specific metrics.
        """
        pass

My proposal is to move configure_optimizers here, so instantiation happens in the base class. If a user doesn't want to use instantiation, fine. override configure_optimizers. This moves onto the next point.

    def __init__(self, optimizer_cfg: Any, scheduler_cfg: Any):
        super().__init__()
        self.optimizer_cfg = optimizer_cfg
        self.scheduler_cfg = scheduler_cfg

    def configure_optimizers(self) -> Dict:
        self.optimizer = self.optimizer(self.optimizer_cfg, self.model)
        # prepare_warmup needs the datamodule to be available when `self.num_training_steps`
        # is called that is why this is done here and not in the __init__
        self.prepare_warmup(self.scheduler_cfg)
        self.scheduler = self.scheduler(self.scheduler_cfg, self.optimizer)
        return super().configure_optimizers()

    def prepare_warmup(self, cfg: SchedulerConfig):
        if cfg.num_training_steps < 0:
            # less than 0 specifies to infer number of training steps
            cfg.num_training_steps = self.num_training_steps
            log.info(f"Inferring number of training steps, set to {cfg.num_training_steps}")

        if isinstance(cfg.num_warmup_steps, float):
            # Convert float values to percentage of training steps to use as warmup
            cfg.num_warmup_steps *= cfg.num_training_steps
            log.info(f"Inferring number of warmup steps from ratio, set to {cfg.num_warmup_steps}")

    def optimizer(self, cfg: Any, model: torch.nn.Module) -> torch.optim.Optimizer:
        no_decay = ["bias", "LayerNorm.weight"]
        grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        return hydra.utils.instantiate(cfg, grouped_parameters)

    def scheduler(self, cfg: Any, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
        return hydra.utils.instantiate(cfg, optimizer=optimizer)

when using hydra, what should my class look like?

Unless I look at the HF base class, I have no idea how the instantiator works. Once I look at the HF base class it makes sense:

class HFTransformer(HydraTaskTransformer):
    """
    Base class for task specific transformers, wrapping pre-trained language models for downstream tasks.
    The API is built on top of AutoModel and AutoConfig, provided by HuggingFace.

    see: https://huggingface.co/transformers/model_doc/auto.html
    """

    def __init__(
        self,
        downstream_model_type: str,
        backbone: HFBackboneConfig,
        optimizer: OptimizerConfig,
        scheduler: HFSchedulerConfig,
        **config_data_args,
    ):

Hydra passes these configs which are defined in the task conf. But the fact that there is no class signature to inherit from made this extremely difficult to parse. So my proposal here is to be opinionated, set this class signature in the TaskTransformer.

Aside from additional Scheduler config, removing the instantiator functions that are now hard coded into the module, these were the changes required to get DALLE in (from the model perspective, data API needs some change but its less controversial).

Notes

Let's say a user doesn't want to pass in a backbone config. What should they do? Should be able to just omit it from your class signature. I think the best way to support this is to have the class signature take an optional backbone config (but in most cases, you will).

What if I don't want my base class to take all these configs? Use the LitTransformer base class.

Make sense to me !

All instantiation should be done in the base class to simplify people life. And they will still have the ability to override them if needed.

About marrying with Hydra, I am 100 % for it. I think it is chance to show how to properly use both together.

Best,
T.C

My main concern about marrying Hydra, is that we might also be marrying our children to it.

After all, regardless of Hydra's potential, we want to provide transformers utilities at every level. From zero to hero. From the user running a train.py command in the README to the Lightning (power) user who wants to import a subclass one of our transformers classes.

What about Flash? Say they provide their NLP tasks by importing ours and maybe polishing them. They'll have to install Hydra and pass DictConfigs around!

What if an user chose plain argparse, do we want to force him to provide a conversion from argparse.Namespace to DictConfig?

These are just some examples. It's hard to see the big picture because we are in a very early stage, but this decision will have a large impact down the line (for better or worse).

The decision is to change our base classes to look like this:

from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch


class LitTransformer(pl.LightningModule):
    """
    Base class for transformers.
    Provides a few helper functions primarily for optimization.
    """

    def __init__(
            self,
            model: torch.nn.Module,
            optimizer: Optional[torch.optim.Optimizer] = None,
            scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    ):
        super().__init__()
        self.model = model
        # some optimizers/schedulers need parameters only known dynamically
        # allow users to override the getter to instantiate them lazily
        self.optimizer = optimizer
        self.scheduler = scheduler

    def configure_optimizers(self) -> Dict:
        """Prepare optimizer and scheduler"""
        return {
            "optimizer": self.optimizer,
            "lr_scheduler": {"scheduler": self.scheduler, "interval": "step", "frequency": 1},
        }

    @property
    def num_training_steps(self) -> int:
        """Total training steps inferred from datamodule and devices."""
        if self.trainer.max_steps:
            return self.trainer.max_steps

        if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
            dataset_size = self.trainer.limit_train_batches
        elif isinstance(self.trainer.limit_train_batches, float):
            # limit_train_batches is a percentage of batches
            dataset_size = len(self.trainer.datamodule.train_dataloader())
            dataset_size = int(dataset_size * self.trainer.limit_train_batches)
        else:
            dataset_size = len(self.trainer.datamodule.train_dataloader())

        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
        if self.trainer.tpu_cores:
            num_devices = max(num_devices, self.trainer.tpu_cores)

        effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
        return (dataset_size // effective_batch_size) * self.trainer.max_epochs

    def prepare_warmup(self, num_training_steps, num_warumup_steps):
        if cfg.num_training_steps < 0:
            # less than 0 specifies to infer number of training steps
            cfg.num_training_steps = self.num_training_steps
            log.info(f"Inferring number of training steps, set to {cfg.num_training_steps}")

        if isinstance(cfg.num_warmup_steps, float):
            # Convert float values to percentage of training steps to use as warmup
            cfg.num_warmup_steps *= cfg.num_training_steps
            log.info(f"Inferring number of warmup steps from ratio, set to {cfg.num_warmup_steps}")

    def setup(self, stage: str):
        self.configure_metrics(stage)

    def configure_metrics(self, stage: str) -> Optional[Any]:
        """
        Override to configure metrics for train/validation/test.
        This is called on fit start to have access to the data module,
        and initialize any data specific metrics.
        """
        pass


class TaskTransformer(LitTransformer):
    """
    Base class for task specific transformers
    """

    def __init__(
            self,
            optimizer: OptimizerConfig,
            scheduler: SchedulerConfig,
            instantiator: Optional[Instantiator] = None,
    ):
        super().__init__()
        self.instantiator = instantiator
        self.optimizer_cfg = optimizer
        self.scheduler_cfg = scheduler

    def configure_optimizers(self) -> Dict:
        self.optimizer = self.instantiator.optimizer(self.model, self.optimizer_cfg)
        # prepare_warmup needs the datamodule to be available when `self.num_training_steps`
        # is called that is why this is done here and not in the __init__
        self.prepare_warmup(self.scheduler_cfg)
        self.scheduler = self.instantiator.scheduler(self.scheduler_cfg, self.optimizer)
        return super().configure_optimizers()

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
        # Save tokenizer from datamodule for predictions
        checkpoint["instantiator"] = self.instantiator

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.instantiator = checkpoint["instantiator"]

Keep instantiator as is, pass through and save in the model to avoid inference issues/finetuning issues.
Keep dataclasses as is!