replicate/replicate-python

max_length

syeminpark opened this issue · 2 comments

hello I am trying to use langchain with replicate_python.
https://github.com/replicate/replicate-python

However, I am confused on how to modify the max_new_token for the llm.
To specify

This is a small part of my code.

#main.py
llm = Replicate(
    model="joehoover/falcon-40b-instruct:xxxxxxxx",
model_kwargs={ "max_length":1000},
input= { "max_length":1000})

I put max_length everywhere and still it isn't reflected.
According to the docs in
https://github.com/hwchase17/langchain/blob/master/langchain/llms/replicate.py
you just need to add the following:

  from langchain.llms import Replicate
            replicate = Replicate(model="stability-ai/stable-diffusion: \
                                         27b93a2413e7f36cd83da926f365628\
                                         0b2931564ff050bf9575f1fdf9bcd7478",
                                  input={"image_dimensions": "512x512"})

However, this method is both outdated and not working.

This is the rest of my code. It is quite identical to this code:
https://github.com/hwchase17/langchain/blob/master/langchain/llms/replicate.py

#replicate.py

 
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call to replicate endpoint."""
        try:
            import replicate as replicate_python
        except ImportError:
            raise ImportError(
                "Could not import replicate python package. "
                "Please install it with `pip install replicate`."
            )

        # get the model and version
        model_str, version_str = self.model.split(":")
        model = replicate_python.models.get(model_str)
        version = model.versions.get(version_str)
        # sort through the openapi schema to get the name of the first input
        input_properties = sorted(
            version.openapi_schema["components"]["schemas"]["Input"][
                "properties"
            ].items(),
            key=lambda item: item[1].get("x-order", 0),
        )
        first_input_name = input_properties[0][0]
        print("firstinput",first_input_name)
       
        inputs = {first_input_name: prompt, **self.input}
      

        prediction=replicate_python.predictions.create(version,input={**inputs, **kwargs},kwargs=kwargs)
        print(**kwargs)
        print('status',prediction.status)

        while prediction.status!=  'succeeded':
            prediction.reload()
       
        print('end')
        iterator = replicate_python.run(self.model, input={**inputs, **kwargs})

        print("".join([output for output in iterator]))
        return ''.join(prediction.output)

The reason i want to change the max_length or the max_new_tokens is because i am providing the llm in replicate with
a lot of context e.g. the ConversationalRetrievalChain workflow.

However, the max_length_ seems to give me truncated response because i have large chunk_sizes that are equivalent or bigger than the default max_length, which is 500.

  1. max_length includes both the given input & the output it seems

  2. In the main.py

model="joehoover/falcon-40b-instruct:xxxxxx",
input= {"max_length":1500,"max_new_tokens": 1500})

is the right solution rather than

   model="joehoover/falcon-40b-instruct:xxxxxxx",
model_kwargs={"max_length":1500,"max_new_tokens": 1500},
  1. in replicate.py
    since the input parameter matters,
prediction=replicate_python.predictions.create(version,input={**inputs, **kwargs})
 while prediction.status!=  'succeeded':
            prediction.reload()
return ''.join(prediction.output)

works as well as

 iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
   return ''.join(prediction.output)

don't be fooled by using the worthless **kwargs.

mattt commented

Hi, @syeminpark. If you have questions about Replicate's integration with LangChain, you should check out their Discord. If you think the max_length parameter isn't working as expected, then please open a new issue on that model's repo. Beyond that, there's not much I can do to help from the context of this project, which is scoped to direct interactions with Replicate's API.