cgarciae/nnx

Optimal way to handle Variable metadata

cgarciae opened this issue · 2 comments

💡

Hi Cristian, love the work you're doing with this library 🔥

Quick question related to this issue: is currently possible to access Variable metadata from within a nnx.Module (e.g. inside __call__)? If not, do you plan to add this feature?

Thanks @frazane!

There is no official API but you can access Variables from a Module by accessing the attribute from __dict__ directly e.g via vars:

module = nnx.Linear(32, 10, ctx=nnx.context(0))
kernel: nnx.Param = vars(module)["kernel"]
# access Variable's metadata (e.g. sharding)
sharding = kernel.sharding