google-deepmind/dm-haiku

Feature request: declare_global_parameter / get_global_parameter

hrbigelow opened this issue · 6 comments

Hi,

I'm pretty sure there is a good reason this doesn't exist which I can't figure out. I'd like to propose an addition to Haiku's parameter management system as follows. The use case is wanting to share the same parameter in two different hk.Module instances, where the instances are actually different derived hk.Module classes. And, in general, something that would not place any burden on the user to mess with module names.

The use case here is to re-use the embedding matrix to both embed and de-embed at the beginning and end of a transformer, as they do in the original "Attention is All You Need" paper.

Apologies, I read through the parameter sharing tutorial but could not figure out how to do this.

Would something like this be possible? Or, what would be a good workaround if not? Any help would be greatly appreciated!

class Embed(hk.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        # it is an error if `global_embed` wasn't declared first
==>     w = hk.get_global_parameter('global_embed')
        return jnp.take(w, x, axis=0)

class DeEmbed(hk.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
==>     w = hk.get_global_parameter('global_embed')
        return jnp.take(jnp.transpose(w), x, axis=0)

class Parent(hk.Module):
    def __init__(self):
        super().__init__()
        init_fn = hk.initializers.RandomNormal(1.0, 0.0) 
        # it is an error to call this twice for same global_name
==>     hk.declare_global_parameter('global_embed', [100, 50], np.float32, init_fn)
        self.embed = Embed()
        self.debed = DeEmbed()

    def __call__(self, x):
        e = self.embed(x)
        d = self.debed(e)
        return d

def make_mod(cls, name):
    def fn(*call_args):
        mod = cls(name)
        return mod(*call_args)
    return fn

def main():
    model = make_mod(Parent)
    key = jax.random.PRNGKey(42)

    params = model.init(key, jnp.empty((5, 100)))
    # params = { 
    # 'Parent': { 
    #   'Embed': { 'global_embed': <Array> }, 
    #   'DeEmbed': { 'global_embed': <Array> },
    # }
    # where <Array> is the same object 

    # perhaps the apply function would check that global parameters were indeed the same objects
    # where they were accessed
    out = model.apply(key, params, input)

Hi @hrbigelow, the simplest way to achieve this is to create a module that contains the parameter and can embed/deembed:

class Embedding(hk.Module):

  def __init__(self, name=None):
    super().__init__(name=name)
    init_fn = hk.initializers.RandomNormal(1.0, 0.0) 
    self.w = hk.get_parameter('embed', [100, 50], np.float32, init_fn)

  def embed(self, x):
    return jnp.take(self.w, x, axis=0)

  def deembed(self, x):
    return jnp.take(jnp.transpose(self.w), x, axis=0)

If this refactoring is too difficult in your actual program, then Haiku does have a "break glass" mechanism for you to take full control over parameter naming:

def get_global_embedding() -> jax.Array:
  with hk.experimental.name_scope(hk.experimental.force_name("global_embed")):
    init_fn = hk.initializers.RandomNormal(1.0, 0.0) 
    return hk.get_parameter('embed', [100, 50], np.float32, init_fn)

Here are examples of using both approaches https://colab.research.google.com/gist/tomhennigan/a2251d80021e445db3fddc71807c84bb/shared-embedding-param-example.ipynb

I think we will likely avoid adding the API you proposed in Haiku itself, while it is possible to implement (and possible to already do this using our public API) I think that using global parameters is generally an anti pattern that makes code harder to maintain (e.g. it would be non trivial to determine the dependency between Embed/DeEmbed) and I would want to avoid encouraging this by making it simpler to use. There are situations (especially in research codebases) where it is pragmatic to use these kind of tricks, but in general I think the first solution I proposed is what I would strongly encourage our users to follow.

That looks great! Actually I somehow made a false association between the hk.transform and the hk.Module's __call__ method, so never considered one module implementing two separate methods. I see now that the tracing mechanism is indifferent to what the method is called.

I do think the 'break glass' approach is still important, because there could be situations in which two separate modules might want to share parameters, but also have their own, different set of locally scoped parameters. In that case, you'd need two separate instances, so the first, dual-method approach would not work. That is, if I understand things correctly.

I do think the 'break glass' approach is still important, because there could be situations in which two separate modules might want to share parameters, but also have their own, different set of locally scoped parameters. In that case, you'd need two separate instances, so the first, dual-method approach would not work. That is, if I understand things correctly.

This is definitely a real use case, but in almost all cases I've seen of this the cleanest way to implement it is to refactor the modules into 3 parts. One module containing any common parameters/logic and the other two use that module and apply whatever other logic/params they need on top.

class Common(hk.Module):
  @property
  def w(self):
    return hk.get_parameter(..)

  def __call__(self, x):
    return something(x, self.w)

class ModuleA(hk.Module):
  def __init__(self, common: Common, name):
    super().__init__(name=name)
    self.common = common

  def __call__(self, x):
    w = self.common.w  # direct reference to `w` if needed.
    return something_else(x, w)

class ModuleB(hk.Module):
  def __init__(self, common: Common, name):
    super().__init__(name=name)
    self.common = common

  def __call__(self, x):
    x = self.common(x)  # Use param and math from common
    return something_else_again(x)

def f(x):
  common = Common()
  mod_a = ModuleA(common)
  mod_b = ModuleB(common)
  return common(x), mod_a(x), mod_b(x)

Typically the break glass mechanism makes sense in a large existing codebase, where the cost of refactoring in order to try a given idea is very high (e.g. as you proposed in your original comment, trying to make two independent modules share a specific parameter).

Would this be a good summary of what you are saying?

Overriding principle:

  1. every parameter instance should be owned by exactly one hk.Module instance (not class)
  2. only the owner of a parameter should directly use the parameter in a computation

In cases where you want to use the same parameter in two different computations (like embed / de-embed)

  • express both computations as member functions of one hk.Module class
  • instantiate the class once
  • use that single instance as a shared sub-module where you want to use one or more of its computations

Apologies for the delay.

every parameter instance should be owned by exactly one hk.Module instance (not class)

Agreed.

only the owner of a parameter should directly use the parameter in a computation

This is more nuanced. Your code will probably be simpler if this is the case so maybe it is a sensible default.

However, I think it is fine create accessors on module instances for parameters, and then use those accessors outside of the module instance. I tried to show this in the Common example above. In one of these uses I call the __call__ method of the module. In the other use I get the parameter directly using common.w. Both are safe.

express both computations as member functions of one hk.Module class
instantiate the class once
use that single instance as a shared sub-module where you want to use one or more of its computations

I think this is a neat way to design things and parameter sharing is very clear here.

Great, thanks so much Tom.

However, I think it is fine create accessors on module instances for parameters, and then use those accessors outside
of the module instance. I tried to show this in the Common example above. In one of these uses I call the __call__
method of the module. In the other use I get the parameter directly using common.w. Both are safe.

Yes - this answers the other question I had - what if you wanted to share a parameter in two different computations, but no part of the computational logic of the parameter is shared. ( missed that in your example at first). Your example above seems to cover all the cases I can think of. And it does so all through composition of modules.

In my opinion, the parameter sharing tutorial would be much improved if it had your above snippet, as it really answers this question most directly.