quarkiverse/quarkus-langchain4j

Cannot get the service to return anything but a `String`.

FroMage opened this issue · 14 comments

The docs at https://docs.quarkiverse.io/quarkus-langchain4j/dev/ai-services.html#_ai_method_return_type say I should be able to get a structured return type from my service, but it does not seem to support anything but a String.

package util;

import java.util.List;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService( 
		retrievalAugmentor = ScheduleDocumentRetreiver.class
)
public interface ScheduleAI {

    @SystemMessage("You are a computer science conference organiser") 
    @UserMessage("""
    			I want to find the talks from the conference program that match my interests and constraints.
                Give me the list of talks that match my interests and constraints: {topics}
            """)
    List<AITalk> findTalks(String topics);
    
    public static class AITalk {
    	public String title;
    	public String id;
    }
}
package util;

import java.util.function.Supplier;

import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import io.quarkiverse.langchain4j.pgvector.PgVectorEmbeddingStore;
import jakarta.inject.Singleton;

@Singleton
public class ScheduleDocumentRetreiver implements Supplier<RetrievalAugmentor> {

    private final RetrievalAugmentor augmentor;

    ScheduleDocumentRetreiver(PgVectorEmbeddingStore store, EmbeddingModel model) {
        EmbeddingStoreContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingModel(model)
                .embeddingStore(store)
                .maxResults(3)
                .build();
        augmentor = DefaultRetrievalAugmentor
                .builder()
                .contentRetriever(contentRetriever)
                .build();
    }

    @Override
    public RetrievalAugmentor get() {
        return augmentor;
    }

}
@Path("/")
@Blocking
public class Application {

	@CheckedTemplate
	public static class Templates {
		public static native TemplateInstance ai(String results);
	}

    @Inject
    ScheduleAI ai;
    
    @GET
    @Path("/ai")
    public TemplateInstance ai(@RestQuery String topics){
    	StringBuilder results = new StringBuilder();
    	if(topics != null && !topics.isBlank()) {
    		results.append("Results: ");
    		for (AITalk aiTalk : ai.findTalks(topics)) { // CCE EXCEPTION HERE
        		results.append(aiTalk.id);
        		results.append(", ");
			}
    	}
    	return Templates.ai(results.toString());
    }
}

This generates the following exception:

2024-06-13 12:07:03,893 ERROR [io.qua.ver.htt.run.QuarkusErrorHandler] (executor-thread-1) HTTP Request to /ai?csrf-token=dQpPok314ThUaiceewld7g&topics=I+want+to+learn+about+hibernate%2C+and+web+applications failed, error id: a57d74f1-f0a0-4176-a120-25be5321f0d9-1: java.lang.ClassCastException: class java.lang.String cannot be cast to class util.ScheduleAI$AITalk (java.lang.String is in module java.base of loader 'bootstrap'; util.ScheduleAI$AITalk is in unnamed module of loader io.quarkus.bootstrap.classloading.QuarkusClassLoader @49cb1baf)
	at rest.Application.ai(Application.java:518)
	at rest.Application_ClientProxy.ai(Unknown Source)
	at rest.Application$quarkusrestinvoker$ai_c38fa7a6cc502cf9089d9e3dafbbfabcc91df15a.invoke(Unknown Source)
	at org.jboss.resteasy.reactive.server.handlers.InvocationHandler.handle(InvocationHandler.java:29)
	at io.quarkus.resteasy.reactive.server.runtime.QuarkusResteasyReactiveRequestContext.invokeHandler(QuarkusResteasyReactiveRequestContext.java:141)
	at org.jboss.resteasy.reactive.common.core.AbstractResteasyReactiveContext.run(AbstractResteasyReactiveContext.java:147)
	at io.quarkus.vertx.core.runtime.VertxCoreRecorder$14.runWith(VertxCoreRecorder.java:599)
	at org.jboss.threads.EnhancedQueueExecutor$Task.doRunWith(EnhancedQueueExecutor.java:2516)
	at org.jboss.threads.EnhancedQueueExecutor$Task.run(EnhancedQueueExecutor.java:2495)
	at org.jboss.threads.EnhancedQueueExecutor$ThreadBody.run(EnhancedQueueExecutor.java:1521)
	at org.jboss.threads.DelegatingRunnable.run(DelegatingRunnable.java:11)
	at org.jboss.threads.ThreadLocalResettingRunnable.run(ThreadLocalResettingRunnable.java:11)
	at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
	at java.base/java.lang.Thread.run(Thread.java:1583)

So it looks like the generated service interface has String as a method return type.

Also, @geoand told me:

First of all, try returning a domain object, instead of a string
that will result in the prompt being augmented with the proper instructions on how to return the result

But the docs say:

If the prompt defines the JSON response format precisely, you can map the response directly to an object:

So that's contradictory.

If you can attach a sample application that we can use to run and debug things, that would be most helpful

Thanks

I assume I need some data as well?

Nope

I see what the problem is.

dev.langchain4j.service.ServiceOutputParser from upstream LangChain4j does not really handle collections all that well.
Let me see if we can do better on our side

I tried with arrays before and got a funky error as well.

Yeah, they don't work great

Seems like in upstream the methods I need are private, so I'll have to open a PR to open them.

At some point we probably want to write our own JsonSchema generater from a Jandex type (should be fairly easy), but I don't want to do it now since I have a talk next week.

Maybe I can bait you into doing that? :P

The upstream code that tries to do this is here.
What we need is something that takes org.jboss.jandex.Type and returns a String which is the json schema

I wish I had time to do something like that, but I really can't promise anything soon. I also filed #675 which I hope I might be able to work on in the future.

🙏🏼

It would be helpful if the new parser could be customized using the quarkus.langchain4j.response-schema property (today this is a boolean, but can become an object with different values).

For example, today the use of the {response_schema} placeholder is replaced by the default value:

You must respond strictly in the following JSON format: <object_schema>

With the custom parser, we could have control of the prefix, which could be enabled or disabled, with a property like

quarkus.langchain4j.response-schema.prefix={true|false}

This allows the developer to remove the prefix and add something else, for example if they are writing the prompt in a language other than English.

That's a good point.

We just need to do this at some point: I don't envision it to be more than a day's work (likely less) :)