poets-ai/elegy

[Bug] Problem with computing metrics

organic-chemistry opened this issue · 6 comments

Describe the bug
Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred.
It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'

Minimal code to reproduce
Small snippet that contains a minimal amount of code.

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import elegy as eg


class eCNN(eg.Module):
    """A simple CNN model."""

    @eg.compact
    def __call__(self, x):
        x=eg.Conv(10,kernel_size=(10,))(x)
        x=jax.nn.relu(x)
        x = eg.Linear(1)(x)
        x=jax.nn.sigmoid(x)
        return x

n=200
X_train = np.random.rand(n*100).reshape(n,100)
y_train = np.random.rand(n).reshape(n,1)
print(X_train.shape)
print(y_train.shape)

model = eg.Model(
    module=eCNN(),
    loss=[
        eg.losses.MeanSquaredError(),
    ],
    metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
    optimizer=optax.rmsprop(1e-3),
)

model.fit(X_train,y_train,
    epochs=10,
    batch_size=20,
    #validation_data=0.1,
    shuffle=False,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
    )

Library Info
Please provide os info and elegy version.

import elegy
print(elegy.__version__) 
# 0.8.4

It is probably related because I have the same error when defining a custom loss: (and replacing model by the following code)

class BCE(eg.Loss): 
    def call(self, y_true, y_pred):
        return -jnp.mean(y_true*jnp.log(y_pred+1e-7) + (1-y_true)*jnp.log(1-y_pred+1e-7))
model = eg.Model(
    module=eCNN(),
    loss=[
 
        BCE(),
    ],
  #  metrics=eg.metrics.MeanSquareError(),
    optimizer=optax.rmsprop(1e-3),
)

Hey @organic-chemistry! I think the use of Losses and Metrics within Elegy needs to be properly documented.

Elegy uses a simple name-based dependency injection system, meaning there is a fix set of names you can use for the arguments of the Loss.call and Metric.update methods:

elegy/elegy/model/model.py

Lines 220 to 233 in 546c504

extended_labels = {
"inputs": inputs,
"preds": preds,
"model": model,
"parameters": model.parameters(),
"batch_stats": model.batch_stats(),
"rngs": model.rngs(),
"model_states": model.model_states(),
"states": model.states(),
"metric_logs": model.metric_logs(),
"loss_logs": model.loss_logs(),
"logs": model.logs(),
**labels,
}

where labels is usually contains the target key. Problem is that metrics.MeanSquareError and metrics.MeanAbsoluteError which come from Treex recently used the unsupported y_true and y_pred names (cgarciae/treex#55), this should be fixed soon.

On the other hand, to fix your BCE loss just change y_true -> target and y_pred -> preds.

BTW: unless its for pedagogical reason, you can use eg.losses.Crossentropy(binary=True) if you want binary cross entropy.

Ok, thank you indeed it worked.
The example was indeed for pedagogical reason.
I created it according to an example from the doc here:
https://poets-ai.github.io/elegy/basic-api/modules-losses-metrics/ (The paragraph about losses).
Thank you,

@organic-chemistry thanks for the report! Recent refactor broke a lot of the documentation, I'll put an issue to remove old links. Some of this now lives in Treex and should be documented there. Sorry for the confusion 😅

Ok. Do you mean that you are going to stop working on elegy, and that the 'new' version is Treex ?
By the way it is not related but thank you for the libraries that you develop and the articles that you write,
I found really interesting the one on quantile regression.

Ok. Do you mean that you are going to stop working on elegy, and that the 'new' version is Treex ?

Oh no, sorry for the confusion. Treex is a low level library that implements Modules, Losses and Metrics, while Elegy is a high-level API. Elegy existed before Treex, but once Treex was stable it made sense to refactor Elegy on top of Treex as it simplified the codebase.

By the way it is not related but thank you for the libraries that you develop and the articles that you write

Thanks ☺