replicate/replicate-python

Feature request: access to logs in the python API

johny-b opened this issue · 4 comments

Right now run function returns only prediction.output (

def run(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.
"""
version, owner, name, version_id = identifier._resolve(ref)
if version_id is not None:
prediction = client.predictions.create(
version=version_id, input=input or {}, **params
)
elif owner and name:
prediction = client.models.predictions.create(
model=(owner, name), input=input or {}, **params
)
else:
raise ValueError(
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
)
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)
if version and (iterator := _make_output_iterator(version, prediction)):
return iterator
prediction.wait()
if prediction.status == "failed":
raise ModelError(prediction.error)
return prediction.output
).

I would like to also access prediction.logs. One trivial solution is to have something like create_prediction that would have all the logic in run except for waiting and returning only output, and run would internally call create_prediction.

Does that make sense? If yes, would you be open to a PR that adds this?

I also want to get more information, like metrics. So I think maybe just returning all responses is a good choice.

I also want to get more information, like metrics. So I think maybe just returning all responses is a good choice.

I got it. Use replicate.predictions.create instead of replicate.run.

Yeah, I can do this, but is this supposed to be a stable/public interface?

Hi @johny-b. You can access a prediction's logs through it's logs property. From the README

>>> prediction = replicate.predictions.create(
    version="(your model version)",
    input={"prompt":"Watercolor painting of an underwater submarine"})

>>> prediction
Prediction(...)

>>> prediction.wait()

>>> prediction.status
'succeeded'

>>> print(prediction.logs)
iteration: 0, render:loss: -0.6171875
iteration: 10, render:loss: -0.92236328125
iteration: 20, render:loss: -1.197265625
iteration: 30, render:loss: -1.3994140625

>>> prediction.output
'https://.../output.png'