[DOC] Update mnist.py example
orion160 opened this issue · 10 comments
orion160 commented
Update example at https://github.com/pytorch/examples/blob/main/mnist/main.py to use torch.compile features
orion160 commented
Proposal:
diff --git a/mnist/main.py b/mnist/main.py
index 184dc47..a3cffd1 100644
--- a/mnist/main.py
+++ b/mnist/main.py
@@ -3,13 +3,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
-from torchvision import datasets, transforms
+from torchvision import datasets
+from torchvision.transforms import v2 as transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
def __init__(self):
- super(Net, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ class Net(nn.Module):
return output
-def train(args, model, device, train_loader, optimizer, epoch):
+def train_amp(args, model, device, train_loader, opt, epoch, scaler):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
+ data, target = data.to(device, memory_format=torch.channels_last), target.to(
+ device
+ )
+ opt.zero_grad()
+ with torch.autocast(device_type=device.type):
+ output = model(data)
+ loss = F.nll_loss(output, target)
+ scaler.scale(loss).backward()
+ scaler.step(opt)
+ scaler.update()
+ if batch_idx % args.log_interval == 0:
+ print(
+ f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+ )
+ if args.dry_run:
+ break
+
+
+def train(args, model, device, train_loader, opt, epoch):
+ model.train()
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device, memory_format=torch.channels_last), target.to(
+ device
+ )
+ opt.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
- optimizer.step()
+ opt.step()
if batch_idx % args.log_interval == 0:
- print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- epoch, batch_idx * len(data), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item()))
+ print(
+ f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+ )
if args.dry_run:
break
@@ -58,43 +82,125 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
- test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
- pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
- print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
- test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
+ print(
+ f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
+ )
-def main():
+def parse_args():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=14, metavar='N',
- help='number of epochs to train (default: 14)')
- parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
- help='learning rate (default: 1.0)')
- parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
- help='Learning rate step gamma (default: 0.7)')
- parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
- parser.add_argument('--no-mps', action='store_true', default=False,
- help='disables macOS GPU training')
- parser.add_argument('--dry-run', action='store_true', default=False,
- help='quickly check a single pass')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--log-interval', type=int, default=10, metavar='N',
- help='how many batches to wait before logging training status')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=14,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--no-mps",
+ action="store_true",
+ default=False,
+ help="disables macOS GPU training",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--use-amp",
+ type=bool,
+ default=False,
+ help="use automatic mixed precision",
+ )
+ parser.add_argument(
+ "--compile-backend",
+ type=str,
+ default="inductor",
+ metavar="BACKEND",
+ help="backend to compile the model with",
+ )
+ parser.add_argument(
+ "--compile-mode",
+ type=str,
+ default="default",
+ metavar="MODE",
+ help="compilation mode",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--data-dir",
+ type=str,
+ default="../data",
+ metavar="DIR",
+ help="path to the data directory",
+ )
args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()
@@ -107,32 +213,43 @@ def main():
else:
device = torch.device("cpu")
- train_kwargs = {'batch_size': args.batch_size}
- test_kwargs = {'batch_size': args.test_batch_size}
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
- cuda_kwargs = {'num_workers': 1,
- 'pin_memory': True,
- 'shuffle': True}
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
- train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+ transform = transforms.Compose(
+ [
+ transforms.ToImage(),
+ transforms.ToDtype(torch.float32, scale=True),
+ transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
+ ]
+ )
+
+ data_dir = args.data_dir
+
+ dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+ dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
- model = Net().to(device)
- optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+ model = Net().to(device, memory_format=torch.channels_last)
+ model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
+ optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+ scaler = None
+ if use_cuda and args.use_amp:
+ scaler = torch.GradScaler(device=device)
+
for epoch in range(1, args.epochs + 1):
- train(args, model, device, train_loader, optimizer, epoch)
+ if scaler is None:
+ train(args, model, device, train_loader, optimizer, epoch)
+ else:
+ train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
test(model, device, test_loader)
scheduler.step()
@@ -140,5 +257,5 @@ def main():
torch.save(model.state_dict(), "mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
doshi-kevin commented
Hey, can I work on this issue ?
doshi-kevin commented