/batchrenorm

Batch Renormalization in Pytorch

Primary LanguagePythonMIT LicenseMIT

Batch Renormalization

https://github.com/ludvb/batchrenorm/workflows/build/badge.svg?branch=master https://codecov.io/gh/ludvb/batchrenorm/branch/master/graph/badge.svg

Pytorch implementation of Batch Renormalization, introduced in https://arxiv.org/abs/1702.03275.

Installation

Requires Python 3.7. To install the latest version of this package, run

pip install git+https://github.com/ludvb/batchrenorm@master

Usage

import torch
from batchrenorm import BatchRenorm2d

# Create batch renormalization layer
br = BatchRenorm2d(3)

# Create some example data with dimensions N x C x H x W
x = torch.randn(1, 3, 10, 10)

# Batch renormalize the data
x = br(x)