replicate/replicate-python

Set API endpoint on Replicate python client? (e.g. for use with Cloudflare AI Gateway)

JLHasson opened this issue · 3 comments

Cloudflare announced AI Gateway today and have replicate support. They show how to use the REST API, but I didn't find instruction here on how to update the API URL in the python library. Is that possible?

E.g. https://developers.cloudflare.com/ai-gateway/get-started/connecting-applications/#replicate

Looks like you just set

os.environ["REPLICATE_API_BASE_URL"] = "https://gateway.ai.cloudflare.com/v1/.../.../replicate"

based on https://github.com/replicate/replicate-python/blob/main/replicate/client.py#L25

Actually, looks like this may not work?

code:

os.environ["REPLICATE_API_BASE_URL"] = "https://gateway.ai.cloudflare.com/v1/.../.../replicate"
replicate.run(
            "...",
            input={
                "image": url,
                "prompt": "prompt",
            },
        )

error:

File ~/miniconda3/lib/python3.10/site-packages/replicate/client.py:138, in Client.run(self, model_version, **kwargs)
    134     raise ReplicateError(
    135         f"Invalid model_version: {model_version}. Expected format: owner/name:version"
    136     )
    137 model = self.models.get(m.group("model"))
--> 138 version = model.versions.get(m.group("version"))
    139 prediction = self.predictions.create(version=version, **kwargs)
    140 # Return an iterator of the output

File ~/miniconda3/lib/python3.10/site-packages/replicate/version.py:89, in VersionCollection.get(self, id)
     80 def get(self, id: str) -> Version:
     81     """
     82     Get a specific model version.
     83 
   (...)
     87         The model version.
     88     """
---> 89     resp = self._client._request(
     90         "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
     91     )
     92     return self.prepare_model(resp.json())

File ~/miniconda3/lib/python3.10/site-packages/replicate/client.py:80, in Client._request(self, method, path, **kwargs)
     78 if 400 <= resp.status_code < 600:
     79     try:
---> 80         raise ReplicateError(resp.json()["detail"])
     81     except (JSONDecodeError, KeyError):
     82         pass

ReplicateError: Not found.

Hi @JLHasson, following up on this:

The Replicate client will use the REPLICATE_BASE_URL environment variable, but only if its set before the underlying HTTP client is lazily initialized:

base_url = (
base_url or os.environ.get("REPLICATE_BASE_URL") or "https://api.replicate.com"
)

Setting an environment variable at runtime isn't guaranteed to work as intended. A better way would be to create a Client explicitly and use that instead of the default client used by replicate.run et al.

from replicate import Client

replicate = Client(api_token=os.environ["REPLICATE_API_TOKEN"], base_url="https://example.com")

replicate.run(...)