This file will become your README and also the index of your documentation.
pip install my_haiku
from my_haiku import basics as basics, basics_with_init as haiku_mini
def model_predict(x):
m = haiku_mini.Linear(4)
return m(x)
init, apply = haiku_mini.transform(model_predict)
from jax import numpy as jnp
# initializing the model with just 1's for weights
model_params = init(jnp.ones((5, 5), ))
apply(model_params, jnp.full((1,5), 5))
Array([[-21.917768, -4.325401, -4.116453, 4.437202]], dtype=float32)
model_params = init(
haiku_mini.np.random.normal(size=(5,))
)
apply(model_params, jnp.full((5,5), 5))
Array([[21.342882, -1.453893, -5.462 , 7.012759],
[21.342882, -1.453893, -5.462 , 7.012759],
[21.342882, -1.453893, -5.462 , 7.012759],
[21.342882, -1.453893, -5.462 , 7.012759],
[21.342882, -1.453893, -5.462 , 7.012759]], dtype=float32)