kherud/java-llama.cpp

How to reset the context on every turn?

Opened this issue · 1 comments

w.r.t. this example:

while (true) {
                prompt += "\nUser: ";
                System.out.print("\nUser: ");
                String input = reader.readLine();
                prompt += input;
                System.out.print("Llama: ");
                prompt += "\nLlama: ";
				InferenceParameters inferParams = new InferenceParameters(prompt)
						.setTemperature(0.7f)
						.setPenalizeNl(true)
						.setMiroStat(MiroStat.V2)
						.setStopStrings("User:");
                for (LlamaOutput output : model.generate(inferParams)) {
                    System.out.print(output);
                    prompt += output;
                }
            }

How can I reset the context window on every iteration of the while loop? The program becomes very slow as the context starts to fill and one the context window is full, the LLM starts spewing garbage. What is the fix for this? Thanks.

in the above code you are appending to prompt which will only grow. you can create a queue which can have last n messages and as when messages grows beyond 10 the last one will be removed.

something like this

private static final int MAX_HISTORY = 10; // Keep only the last 10 messages

   public static void main(String[] args) throws Exception {
       BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
       Queue<String> conversationHistory = new LinkedList<>(); // FIFO Queue for messages
       LlamaModel model = new LlamaModel(); // Assuming this is your model

       while (true) {
           System.out.print("\nUser: ");
           String input = reader.readLine();

           // Store user input in the queue
           addMessageToQueue(conversationHistory, "User: " + input);

           // Build conversation history into prompt
           StringBuilder promptBuilder = new StringBuilder();
           for (String message : conversationHistory) {
               promptBuilder.append(message).append("\n");
           }

           String prompt = promptBuilder.toString();
           System.out.print("Llama: ");

           // Generate response from Llama
           InferenceParameters inferParams = new InferenceParameters(prompt)
                   .setTemperature(0.7f)
                   .setPenalizeNl(true)
                   .setMiroStat(MiroStat.V2)
                   .setStopStrings("User:");

           StringBuilder llamaResponse = new StringBuilder();
           for (LlamaOutput output : model.generate(inferParams)) {
               System.out.print(output);
               llamaResponse.append(output);
           }

           // Store model response in the queue
           addMessageToQueue(conversationHistory, "Llama: " + llamaResponse.toString());
       }
   }

   // Helper method to manage the queue size
   private static void addMessageToQueue(Queue<String> queue, String message) {
       if (queue.size() >= MAX_HISTORY) {
           queue.poll(); // Remove the oldest message
       }
       queue.add(message); // Add the new message
   }
}