google-deepmind/dm-haiku

More fine-grained mixed-precision configuration

llCurious opened this issue · 2 comments

I noticed that haiku intergrate JMP and supports mixed-precision. An example code for ResNet is as follows:

# Assign mixed precision policies to modules. Note that when training in f16
# we keep BatchNorm in  full precision. When training with bf16 you can often
# use bf16 for BatchNorm.
mp_policy = get_policy()
bn_policy = get_bn_policy().with_output_dtype(mp_policy.compute_dtype)
hk.mixed_precision.set_policy(hk.BatchNorm, bn_policy)
hk.mixed_precision.set_policy(hk.nets.ResNet50, mp_policy)

The practice is to keep the computation of BN in full precision. I wonder if there is anapproach for me to do something like configuring a designated layer to some precision, instread of by module. For instance,

ln_policy = jmp.get_policy('p=f16,c=f16,o=f16')
special_policy = jmp.get_policy('p=f32,c=f32,o=f32')
model = get_model() # with 3 linear layers
hk.mixed_precision.set_policy(hk.Linear, ln_policy)
hk.mixed_precision.set_policy(the-last-linear-layer, special_policy)

I wonder if:

  1. haiku supports such functionality?
  2. if no, do you have any suggestions to do a workround?

Hi there, instance level mixed precision with jmp is not built into Haiku, however you can implement this in a few ways. I've added them in this colab notebook: https://colab.research.google.com/gist/tomhennigan/874de0420c55f7bd062f24f7ec6b0e51/instance-level-half-precision.ipynb

Approximately the options are:

  1. Create a subclass and use that for instances you want in a different precision.
class LowPrecisionLinear(hk.Linear):
  pass

half_policy = jmp.get_policy('compute=half')
hk.mixed_precision.set_policy(LowPrecisionLinear, half_policy)

def f(x):
  net = hk.Sequential([
      hk.Linear(300), jnp.tan,
      hk.Linear(100), jnp.tan,
      LowPrecisionLinear(10),
  ])
  return net(x)
  1. Create a wrapper function that applies a specific policy before calling the call method.
def wrap_with_policy(mod: hk.Module, policy: jmp.Policy):
  cls = type(mod)
  @functools.wraps(mod.__call__)
  def wrapper(*args, **kwargs):
    old_policy = hk.mixed_precision.get_policy(hk.Linear)
    hk.mixed_precision.set_policy(cls, policy)
    try:
      return mod(*args, **kwargs)
    finally:
      if old_policy is not None:
        hk.mixed_precision.set_policy(cls, old_policy)
      else:
        hk.mixed_precision.clear_policy(cls)
  return wrapper

half_policy = jmp.get_policy('compute=half')

def f(x):
  net = hk.Sequential([
      hk.Linear(300), jnp.tan,
      hk.Linear(100), jnp.tan,
      wrap_with_policy(hk.Linear(10), half_policy),
  ])
  return net(x)

Thank you for the suggested options. This is of great help to me.