/my-haiku

Following the Build your own Haiku tutorial (https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html)

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

my-haiku

This file will become your README and also the index of your documentation.

Install

pip install my_haiku

How to use

from my_haiku import basics as basics, basics_with_init as haiku_mini
def model_predict(x):
    m = haiku_mini.Linear(4)
    return m(x)

init, apply = haiku_mini.transform(model_predict)
from jax import numpy as jnp

# initializing the model with just 1's for weights
model_params = init(jnp.ones((5, 5), ))
apply(model_params, jnp.full((1,5), 5))
Array([[-21.917768,  -4.325401,  -4.116453,   4.437202]], dtype=float32)
model_params = init(
    haiku_mini.np.random.normal(size=(5,))
)
apply(model_params, jnp.full((5,5), 5))
Array([[21.342882, -1.453893, -5.462   ,  7.012759],
       [21.342882, -1.453893, -5.462   ,  7.012759],
       [21.342882, -1.453893, -5.462   ,  7.012759],
       [21.342882, -1.453893, -5.462   ,  7.012759],
       [21.342882, -1.453893, -5.462   ,  7.012759]], dtype=float32)