An automatic differentiation library for generalized meta-learning and multilevel optimization
Docs |
Tutorials |
Examples |
Paper |
CASL Project
pip install betty-ml
Introduction
Betty is a PyTorch library for generalized meta-learning (GML) and multilevel optimization (MLO) that provides a unified programming interface for a number of GML/MLO applications including meta-learning, hyperparameter optimization, neural architecture search, data reweighting, adversarial learning, and reinforcement learning.
Benefits
- Easy-to-use and unified programming interface for GML/MLO.
- Various system support for large-scale GML/MLO (e.g., distributed training).
- Flexible support for complex GML/MLO applications beyond two levels.
Implementing generalized meta-learning and multilevel optimization is notoriously complicated. For example, it requires approximating gradients using iterative/implicit differentiation, and writing nested for-loops to handle hierarchical dependencies between multiple levels.
Betty aims to abstract away low-level implementation details behind its API, while allowing users to write only high-level declarative code. Now, users simply need to do two things to implement any GML/MLO program:
- Define each level's optimization problem using the Problem class.
- Define the hierarchical problem structure using the Engine class.
Applications
We provide reference implementations of several GML/MLO applications, including:
- Hyperparameter Optimization
- Neural Architecture Search
- Data Reweighting
- Domain Adaptation for Pretraining & Finetuning
- (Implicit) Model-Agnostic Meta-Learning
While each of above examples traditionally has a distinct implementation style, one should notice that our implementations share the same code structure thanks to Betty. More examples are on the way!
Quick Start
Problem
Basics
Each level problem can be defined with seven components: (1) module, (2) optimizer, (3)
data loader, (4) loss function, (5) problem configuration, (6) name, and (7) other
optional components (e.g. learning rate scheduler). The loss function (4) can be
defined via the training_step
method, while all other components can be provided
through the class constructor. For example, an image classification problem can be
defined as follows:
from betty.problems import ImplicitProblem
from betty.configs import Config
# set up module, optimizer, data loader (i.e. (1)-(3))
cls_module, cls_optimizer, cls_data_loader = setup_classification()
class Classifier(ImplicitProblem):
# set up loss function
def training_step(self, batch):
inputs, labels = batch
outputs = self.module(inputs)
loss = F.cross_entropy(outputs, labels)
return loss
# set up problem configuration
cls_config = Config(type='darts', unroll_steps=1, log_step=100)
# Classifier problem class instantiation
cls_prob = Classifier(name='classifier',
module=cls_module,
optimizer=cls_optimizer,
train_data_loader=cls_data_loader,
config=cls_config)
Interactions between problems
In GML/MLO, each problem will often need to access modules from other problems to
define its loss function. This can be achieved by using the name
attribute as
follows:
class HPO(ImplicitProblem):
def training_step(self, batch):
# set up hyperparameter optimization loss
...
# HPO problem class instantiation
hpo_prob = HPO(name='hpo', module=...)
class Classifier(ImplicitProblem):
def training_step(self, batch):
inputs, labels = batch
outputs = self.module(inputs)
loss = F.cross_entropy(outputs, labels)
"""
accessing weight decay hyperparameter from another problem HPO can be achieved
by its name 'hpo'
"""
weight_decay = self.hpo()
reg_loss = weight_decay * sum([p.norm().pow(2) for p in self.module.parameters()])
return loss + reg_loss
cls_prob = Classifier(name='classifier', module=...)
Engine
Basics
The Engine
class handles the hierarchical dependencies between problems. In GML/MLO,
there are two types of dependencies: upper-to-lower (u2l
) and lower-to-upper (l2u
).
Both types of dependencies can be defined with a Python dictionary, where the key is
the starting node and the value is the list of destination nodes.
from betty import Engine
from betty.configs import EngineConfig
# set up all involved problems
problems = [cls_prob, hpo_prob]
# set up upper-to-lower and lower-to-upper dependencies
u2l = {hpo_prob: [cls_prob]}
l2u = {cls_prob: [hpo_prob]}
dependencies = {'u2l': u2l, 'l2u': l2u}
# set up Engine configuration
engine_config = EngineConfig(train_iters=10000, valid_step=100)
# instantiate Engine class
engine = Engine(problems=problems, dependencies=dependencies, config=engine_config)
# execute multilevel optimization
engine.run()
Since Engine
manages the whole GML/MLO program, you can also perform a global validation
stage within it. All problems that comprise the GML/MLO program can again be accessed with
their names.
class HPOEngine(Engine):
# set up global validation
@torch.no_grad()
def validation(self):
loss = 0
for inputs, labels in test_loader:
outputs = self.classifer(inputs)
loss += F.cross_entropy(outputs, targets)
# Returned dict will be automatically logged after each validation
return {'loss': loss}
...
engine = HPOEngine(problems=problems, dependencies=dependencies, config=engine_config)
engine.run()
Once we define all optimization problems and the hierarchical dependencies between them
with, respectively, the Problem
class and the Engine
class, all complicated internal
mechanisms of GML/MLO such as gradient calculation and optimization execution order will
be handled by Betty. For more details and advanced features, users can check out our
Documentation and
Tutorials.
Happy multilevel optimization programming!
Features
Gradient Approximation Methods
- Implicit Differentiation
- Finite Difference (DARTS: Differentiable Architecture Search)
- Neumann Series (Optimizing Millions of Hyperparameters by Implicit Differentiation)
- Conjugate Gradient (Meta-Learning with Implicit Gradients)
- Iterative Differentiation
- Reverse-mode Automatic Differentiation (Model-Agnostic Meta-Learning (MAML))
Training
- Gradient accumulation
- FP16 training
- non-distributed data-parallel
Logging
Contributing
We welcome contributions from the community! Please see our contributing guidelines for details on how to contribute to Betty.
License
Betty is licensed under the Apache 2.0 License.