/pytorch_ema

Tiny PyTorch library for maintaining a moving average of a collection of parameters.

Primary LanguagePythonMIT LicenseMIT

pytorch_ema

A very small library for computing exponential moving averages of model parameters.

This library was written for personal use. Nevertheless, if you run into issues or have suggestions for improvement, feel free to open either a new issue or pull request.

Installation

pip install -U git+https://github.com/fadel/pytorch_ema

Example

import torch
import torch.nn.functional as F

from torch_ema import ExponentialMovingAverage


x_train = torch.rand((100, 10))
y_train = torch.rand(100).round().long()
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)

# Train for a few epochs
model.train()
for _ in range(10):
    logits = model(x_train)
    loss = F.cross_entropy(logits, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    ema.update(model.parameters())

# Compare losses:
# Original
model.eval()
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
print(loss.item())

# With EMA
ema.copy_to(model.parameters())
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
print(loss.item())