From 9dadc7ebf4d0de077adc05f81cf6fc32425e9b96 Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Tue, 31 Oct 2023 13:48:48 +0100 Subject: [PATCH] [agents] Add Agent to call LangServe (langserve-invoke) (#673) --- .../applications/langserve-invoke/README.md | 68 +++ .../applications/langserve-invoke/example.py | 41 ++ .../langserve-invoke/gateways.yaml | 25 + .../langserve-invoke/pipeline.yaml | 41 ++ .../langserve-service}/.langstreamignore | 0 .../langserve-service}/README.md | 4 +- .../langserve-service}/gateways.yaml | 0 .../langserve-service}/pipeline.yaml | 0 .../langserve-service}/python/example.py | 0 .../{source => }/AzureBlobStorageSource.java | 2 +- .../AzureBlobStorageSourceCodeProvider.java | 2 +- ...ngstream.api.runner.code.AgentCodeProvider | 2 +- .../AzureBlobStorageSourceTest.java | 2 +- .../langstream-agent-http-request/pom.xml | 5 + .../source => http}/HttpRequestAgent.java | 8 +- .../HttpRequestAgentProvider.java | 8 +- .../agents/http/LangServeInvokeAgent.java | 477 ++++++++++++++++++ .../META-INF/ai.langstream.agents.index | 3 +- ...ngstream.api.runner.code.AgentCodeProvider | 2 +- .../agents/http/LangServeInvokeAgentTest.java | 295 +++++++++++ .../src/test/resources/logback-test.xml | 34 ++ .../ai/agents/GenAIToolKitAgent.java | 8 +- .../k8s/agents/HttpRequestAgentProvider.java | 186 ++++++- .../kafka/runner/KafkaProducerWrapper.java | 4 +- .../kafka/runner/KafkaReaderWrapper.java | 4 +- .../kafka/LangServeInvokeAgentRunnerIT.java | 206 ++++++++ .../src/main/assemble/logback.xml | 1 + pom.xml | 2 +- 28 files changed, 1410 insertions(+), 20 deletions(-) create mode 100644 examples/applications/langserve-invoke/README.md create mode 100644 examples/applications/langserve-invoke/example.py create mode 100644 examples/applications/langserve-invoke/gateways.yaml create mode 100644 examples/applications/langserve-invoke/pipeline.yaml rename examples/applications/{langserve => python/langserve-service}/.langstreamignore (100%) rename examples/applications/{langserve => python/langserve-service}/README.md (80%) rename examples/applications/{langserve => python/langserve-service}/gateways.yaml (100%) rename examples/applications/{langserve => python/langserve-service}/pipeline.yaml (100%) rename examples/applications/{langserve => python/langserve-service}/python/example.py (100%) rename langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/{source => }/AzureBlobStorageSource.java (99%) rename langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/{source => }/AzureBlobStorageSourceCodeProvider.java (95%) rename langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/{source => }/AzureBlobStorageSourceTest.java (97%) rename langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/{azureblobstorage/source => http}/HttpRequestAgent.java (96%) rename langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/{azureblobstorage/source => http}/HttpRequestAgentProvider.java (84%) create mode 100644 langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/LangServeInvokeAgent.java create mode 100644 langstream-agents/langstream-agent-http-request/src/test/java/ai/langstream/agents/http/LangServeInvokeAgentTest.java create mode 100644 langstream-agents/langstream-agent-http-request/src/test/resources/logback-test.xml create mode 100644 langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/LangServeInvokeAgentRunnerIT.java diff --git a/examples/applications/langserve-invoke/README.md b/examples/applications/langserve-invoke/README.md new file mode 100644 index 000000000..80ff126c2 --- /dev/null +++ b/examples/applications/langserve-invoke/README.md @@ -0,0 +1,68 @@ +# Invoking a LangServe service + +This sample application explains how to invoke a LangServe service and leverage streaming capabilities. + +## Set up your LangServe environment + +Start you LangServe application, the example below is using the LangServe sample [application](https://github.com/langchain-ai/langserve) + +```python +#!/usr/bin/env python +from fastapi import FastAPI +from langchain.prompts import ChatPromptTemplate +from langchain.chat_models import ChatAnthropic, ChatOpenAI +from langserve import add_routes + + +app = FastAPI( + title="LangChain Server", + version="1.0", + description="A simple api server using Langchain's Runnable interfaces", +) + +model = ChatOpenAI() +prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}") +add_routes( + app, + prompt | model, + path="/chain", +) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8000) +``` + +## Configure you OpenAI API Key and run the application + +```bash +export OPENAI_API_KEY=... +pip install fastapi langserve langchain openai sse_starlette uvicorn +python example.py +``` + +The sample application is exposing a chain at http://localhost:8000/chain/stream and http://localhost:8000/chain/invoke. + +The application, running in docker, connects to http://host.docker.internal:8000/chain/stream + +LangStream sends an input like this: + +```json +{ + "input": { + "topic": "cats" + } +} +``` + +When "topic" is the topic of the joke you want to generate and it is taken from the user input. + +## Deploy the LangStream application +``` +./bin/langstream docker run test -app examples/applications/langserve-invoke +``` + +## Interact with the application + +You can now interact with the application using the UI opening your browser at http://localhost:8092/ diff --git a/examples/applications/langserve-invoke/example.py b/examples/applications/langserve-invoke/example.py new file mode 100644 index 000000000..859382c75 --- /dev/null +++ b/examples/applications/langserve-invoke/example.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from fastapi import FastAPI +from langchain.prompts import ChatPromptTemplate +from langchain.chat_models import ChatOpenAI +from langserve import add_routes + + +app = FastAPI( + title="LangChain Server", + version="1.0", + description="A simple api server using Langchain's Runnable interfaces", +) + +model = ChatOpenAI() +prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}") +add_routes( + app, + prompt | model, + path="/chain", + ) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8000) \ No newline at end of file diff --git a/examples/applications/langserve-invoke/gateways.yaml b/examples/applications/langserve-invoke/gateways.yaml new file mode 100644 index 000000000..902f7f24c --- /dev/null +++ b/examples/applications/langserve-invoke/gateways.yaml @@ -0,0 +1,25 @@ +# +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +gateways: + - id: chat + type: chat + chat-options: + answers-topic: streaming-answers-topic + questions-topic: input-topic + headers: + - value-from-parameters: session-id \ No newline at end of file diff --git a/examples/applications/langserve-invoke/pipeline.yaml b/examples/applications/langserve-invoke/pipeline.yaml new file mode 100644 index 000000000..d8fe2051d --- /dev/null +++ b/examples/applications/langserve-invoke/pipeline.yaml @@ -0,0 +1,41 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +topics: + - name: "input-topic" + creation-mode: create-if-not-exists + - name: "output-topic" + creation-mode: create-if-not-exists + - name: "streaming-answers-topic" + creation-mode: create-if-not-exists +pipeline: + - type: "langserve-invoke" + input: input-topic + output: output-topic + id: step1 + configuration: + output-field: value.answer + stream-to-topic: streaming-answers-topic + stream-response-field: value + min-chunks-per-message: 10 + debug: false + method: POST + allow-redirects: true + handle-cookies: false + url: "http://host.docker.internal:8000/chain/stream" + fields: + - name: topic + expression: "value" diff --git a/examples/applications/langserve/.langstreamignore b/examples/applications/python/langserve-service/.langstreamignore similarity index 100% rename from examples/applications/langserve/.langstreamignore rename to examples/applications/python/langserve-service/.langstreamignore diff --git a/examples/applications/langserve/README.md b/examples/applications/python/langserve-service/README.md similarity index 80% rename from examples/applications/langserve/README.md rename to examples/applications/python/langserve-service/README.md index cd60733ca..e1e9101c3 100644 --- a/examples/applications/langserve/README.md +++ b/examples/applications/python/langserve-service/README.md @@ -10,7 +10,7 @@ Export to the ENV the access key to OpenAI export OPEN_AI_ACCESS_KEY=... ``` -The default [secrets file](../../secrets/secrets.yaml) reads from the ENV. Check out the file to learn more about +The default [secrets file](../../../secrets/secrets.yaml) reads from the ENV. Check out the file to learn more about the default settings, you can change them by exporting other ENV variables. @@ -26,7 +26,7 @@ export LANGSMITH_APIKEY=xxxxx ## Deploy the LangStream application ``` -./bin/langstream docker run test -app examples/applications/langserve -s examples/secrets/secrets.yaml --start-broker=false +./bin/langstream docker run test -app examples/applications/python/langserve-service -s examples/secrets/secrets.yaml --start-broker=false ``` ## Interact with the application diff --git a/examples/applications/langserve/gateways.yaml b/examples/applications/python/langserve-service/gateways.yaml similarity index 100% rename from examples/applications/langserve/gateways.yaml rename to examples/applications/python/langserve-service/gateways.yaml diff --git a/examples/applications/langserve/pipeline.yaml b/examples/applications/python/langserve-service/pipeline.yaml similarity index 100% rename from examples/applications/langserve/pipeline.yaml rename to examples/applications/python/langserve-service/pipeline.yaml diff --git a/examples/applications/langserve/python/example.py b/examples/applications/python/langserve-service/python/example.py similarity index 100% rename from examples/applications/langserve/python/example.py rename to examples/applications/python/langserve-service/python/example.py diff --git a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSource.java b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSource.java similarity index 99% rename from langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSource.java rename to langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSource.java index 87087a00d..c4d26511a 100644 --- a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSource.java +++ b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSource.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.agents.azureblobstorage.source; +package ai.langstream.agents.azureblobstorage; import ai.langstream.api.runner.code.AbstractAgentCode; import ai.langstream.api.runner.code.AgentSource; diff --git a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceCodeProvider.java b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceCodeProvider.java similarity index 95% rename from langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceCodeProvider.java rename to langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceCodeProvider.java index 36182aa9f..2ac425eee 100644 --- a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceCodeProvider.java +++ b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceCodeProvider.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.agents.azureblobstorage.source; +package ai.langstream.agents.azureblobstorage; import ai.langstream.api.runner.code.AgentCode; import ai.langstream.api.runner.code.AgentCodeProvider; diff --git a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider index 4f4a7838b..6ee309d73 100644 --- a/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider +++ b/langstream-agents/langstream-agent-azure-blob-storage-source/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider @@ -1 +1 @@ -ai.langstream.agents.azureblobstorage.source.AzureBlobStorageSourceCodeProvider \ No newline at end of file +ai.langstream.agents.azureblobstorage.AzureBlobStorageSourceCodeProvider \ No newline at end of file diff --git a/langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceTest.java b/langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceTest.java similarity index 97% rename from langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceTest.java rename to langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceTest.java index 70c88c1b5..cf83d214a 100644 --- a/langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/source/AzureBlobStorageSourceTest.java +++ b/langstream-agents/langstream-agent-azure-blob-storage-source/src/test/java/ai/langstream/agents/azureblobstorage/AzureBlobStorageSourceTest.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.agents.azureblobstorage.source; +package ai.langstream.agents.azureblobstorage; import static org.junit.jupiter.api.Assertions.*; diff --git a/langstream-agents/langstream-agent-http-request/pom.xml b/langstream-agents/langstream-agent-http-request/pom.xml index a03ef4310..43641e87d 100644 --- a/langstream-agents/langstream-agent-http-request/pom.xml +++ b/langstream-agents/langstream-agent-http-request/pom.xml @@ -65,6 +65,11 @@ junit-jupiter test + + com.github.tomakehurst + wiremock + test + diff --git a/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgent.java b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgent.java similarity index 96% rename from langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgent.java rename to langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgent.java index 8a40b130b..d41d30b15 100644 --- a/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgent.java +++ b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgent.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.agents.azureblobstorage.source; +package ai.langstream.agents.http; import ai.langstream.ai.agents.commons.JsonRecord; import ai.langstream.ai.agents.commons.MutableRecord; @@ -191,6 +191,12 @@ public void processRecord(Record record, RecordSink recordSink) { recordSink.emit( new SourceRecordAndResult(record, List.of(), e)); } + }) + .exceptionally( + error -> { + log.error("Error processing record: {}", record, error); + recordSink.emit(new SourceRecordAndResult(record, null, error)); + return null; }); } catch (Throwable error) { log.error("Error processing record: {}", record, error); diff --git a/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgentProvider.java b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgentProvider.java similarity index 84% rename from langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgentProvider.java rename to langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgentProvider.java index b8a228543..045fdbec7 100644 --- a/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/azureblobstorage/source/HttpRequestAgentProvider.java +++ b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/HttpRequestAgentProvider.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.agents.azureblobstorage.source; +package ai.langstream.agents.http; import ai.langstream.api.runner.code.AgentCodeProvider; import ai.langstream.api.runner.code.AgentProcessor; @@ -25,7 +25,11 @@ public class HttpRequestAgentProvider implements AgentCodeProvider { private static final Map> FACTORIES = - Map.of("http-request", HttpRequestAgent::new); + Map.of( + "http-request", + HttpRequestAgent::new, + "langserve-invoke", + LangServeInvokeAgent::new); @Override public boolean supports(String agentType) { diff --git a/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/LangServeInvokeAgent.java b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/LangServeInvokeAgent.java new file mode 100644 index 000000000..3f7996a04 --- /dev/null +++ b/langstream-agents/langstream-agent-http-request/src/main/java/ai/langstream/agents/http/LangServeInvokeAgent.java @@ -0,0 +1,477 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.http; + +import ai.langstream.ai.agents.commons.JsonRecord; +import ai.langstream.ai.agents.commons.MutableRecord; +import ai.langstream.ai.agents.commons.jstl.JstlEvaluator; +import ai.langstream.api.runner.code.AbstractAgentCode; +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.AgentProcessor; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.RecordSink; +import ai.langstream.api.runner.topics.TopicProducer; +import ai.langstream.api.runtime.ComponentType; +import ai.langstream.api.util.ConfigurationUtils; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.samskivert.mustache.Mustache; +import com.samskivert.mustache.Template; +import java.io.IOException; +import java.io.StringWriter; +import java.net.CookieManager; +import java.net.CookiePolicy; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.extern.slf4j.Slf4j; +import org.apache.avro.Schema; + +@Slf4j +public class LangServeInvokeAgent extends AbstractAgentCode implements AgentProcessor { + + record FieldDefinition(String name, JstlEvaluator expressionEvaluator) {} + + private final List fields = new ArrayList<>(); + private TopicProducer topicProducer; + private String streamToTopic; + private int minChunksPerMessage; + private String contentField; + private boolean debug; + static final ObjectMapper mapper = new ObjectMapper(); + private final Map avroValueSchemaCache = new ConcurrentHashMap<>(); + + private final Map avroKeySchemaCache = new ConcurrentHashMap<>(); + + private AgentContext agentContext; + private ExecutorService executor; + private HttpClient httpClient; + private String url; + private String method; + private Map headersTemplates; + private String outputFieldName; + private String streamResponseCompletionField; + + @SuppressWarnings("unchecked") + @Override + public void init(Map configuration) { + url = + ConfigurationUtils.requiredNonEmptyField( + configuration, "url", () -> "langserve-invoke agent"); + + outputFieldName = + ConfigurationUtils.requiredNonEmptyField( + configuration, "output-field", () -> "langserve-invoke agent"); + streamToTopic = ConfigurationUtils.getString("stream-to-topic", "", configuration); + streamResponseCompletionField = + ConfigurationUtils.getString( + "stream-response-field", outputFieldName, configuration); + + minChunksPerMessage = + ConfigurationUtils.getInteger("min-chunks-per-message", 20, configuration); + contentField = ConfigurationUtils.getString("content-field", "content", configuration); + + debug = ConfigurationUtils.getBoolean("debug", false, configuration); + method = ConfigurationUtils.getString("method", "POST", configuration); + List> fields = + (List>) configuration.getOrDefault("fields", List.of()); + fields.forEach( + r -> { + String name = ConfigurationUtils.getString("name", "", r); + String expression = ConfigurationUtils.getString("expression", "", r); + log.info("Sending field with name {} computed as {}", name, expression); + JstlEvaluator expressionEvaluator = + new JstlEvaluator<>("${" + expression + "}", Object.class); + this.fields.add(new FieldDefinition(name, expressionEvaluator)); + }); + final boolean allowRedirects = + ConfigurationUtils.getBoolean("allow-redirects", true, configuration); + final boolean handleCookies = + ConfigurationUtils.getBoolean("handle-cookies", true, configuration); + + final Map headers = + ConfigurationUtils.getMap("headers", new HashMap<>(), configuration); + headersTemplates = new HashMap<>(); + for (Map.Entry entry : headers.entrySet()) { + headersTemplates.put( + entry.getKey(), Mustache.compiler().compile(entry.getValue().toString())); + } + + executor = Executors.newCachedThreadPool(); + CookieManager cookieManager = new CookieManager(); + cookieManager.setCookiePolicy( + handleCookies ? CookiePolicy.ACCEPT_ALL : CookiePolicy.ACCEPT_NONE); + httpClient = + HttpClient.newBuilder() + .followRedirects( + allowRedirects + ? HttpClient.Redirect.NORMAL + : HttpClient.Redirect.NEVER) + .cookieHandler(cookieManager) + .executor(executor) + .build(); + } + + @Override + public void setContext(AgentContext context) throws Exception { + this.agentContext = context; + } + + @Override + public void process(List records, RecordSink recordSink) { + for (Record record : records) { + processRecord(record, recordSink); + } + } + + @Override + public ComponentType componentType() { + return ComponentType.PROCESSOR; + } + + public void processRecord(Record record, RecordSink recordSink) { + try { + MutableRecord context = MutableRecord.recordToMutableRecord(record, true); + final JsonRecord jsonRecord = context.toJsonRecord(); + + final URI uri = URI.create(url); + final String body = buildBody(context); + + final HttpRequest.BodyPublisher bodyPublisher = + HttpRequest.BodyPublishers.ofString(body); + + final HttpRequest.Builder requestBuilder = + HttpRequest.newBuilder() + .uri(uri) + .version(HttpClient.Version.HTTP_1_1) + .method(this.method, bodyPublisher); + requestBuilder.header("Content-Type", "application/json"); + headersTemplates.forEach( + (key, value) -> requestBuilder.header(key, value.execute(jsonRecord))); + final HttpRequest request = requestBuilder.build(); + if (debug) { + log.info("Sending request {}", request); + log.info("Body {}", body); + } + + if (url.endsWith("/invoke")) { + invoke(record, recordSink, request, context); + } else if (url.endsWith("/stream")) { + stream(record, recordSink, request, context); + } else { + recordSink.emitError( + record, new UnsupportedOperationException("Unsupported url: " + url)); + } + } catch (Throwable error) { + log.error("Error processing record: {}", record, error); + recordSink.emit(new SourceRecordAndResult(record, null, error)); + } + } + + private void invoke( + Record record, RecordSink recordSink, HttpRequest request, MutableRecord context) { + httpClient + .sendAsync(request, HttpResponse.BodyHandlers.ofString()) + .thenAccept( + response -> { + if (debug) { + log.info("Response {}", response); + log.info("Response body {}", response.body()); + } + try { + if (response.statusCode() >= 400) { + throw new RuntimeException( + "Error processing record: " + + record + + " with response: " + + response); + } + final Object responseBody = + parseResponseBody(response.body(), false); + applyResultFieldToContext(context, responseBody.toString(), false); + Optional recordResult = + MutableRecord.mutableRecordToRecord(context); + if (log.isDebugEnabled()) { + log.debug("recordResult {}", recordResult); + } + recordSink.emit( + new SourceRecordAndResult( + record, List.of(recordResult.orElseThrow()), null)); + } catch (Exception e) { + log.error("Error processing record: {}", record, e); + recordSink.emitError(record, e); + } + }) + .exceptionally( + error -> { + log.error("Error processing record: {}", record, error); + recordSink.emit(new SourceRecordAndResult(record, null, error)); + return null; + }); + } + + enum EventType { + data, + end, + emptyLine + } + + private void stream( + Record record, RecordSink recordSink, HttpRequest request, MutableRecord context) { + StreamResponseProcessor streamResponseProcessor = + new StreamResponseProcessor( + minChunksPerMessage, + new StreamingChunksConsumer() { + @Override + public void consumeChunk( + String answerId, int index, String chunk, boolean last) { + if (topicProducer == null) { + // no streaming output + return; + } + MutableRecord copy = context.copy(); + applyResultFieldToContext(copy, chunk, true); + copy.getProperties().put("stream-id", answerId); + copy.getProperties().put("stream-index", index + ""); + copy.getProperties().put("stream-last-message", last + ""); + Optional recordResult = + MutableRecord.mutableRecordToRecord(copy); + if (log.isDebugEnabled()) { + log.debug("recordResult {}", recordResult); + } + topicProducer + .write(recordResult.orElseThrow()) + .exceptionally( + e -> { + log.error("Error writing chunk to topic", e); + return null; + }); + } + }); + + httpClient.sendAsync( + request, HttpResponse.BodyHandlers.fromLineSubscriber(streamResponseProcessor)); + + streamResponseProcessor.whenComplete( + (r, e) -> { + if (e != null) { + log.error("Error processing record: {}", record, e); + recordSink.emitError(record, e); + } else { + MutableRecord copy = context.copy(); + applyResultFieldToContext( + copy, streamResponseProcessor.buildTotalAnswerMessage(), false); + Optional recordResult = MutableRecord.mutableRecordToRecord(copy); + recordSink.emitSingleResult(record, recordResult.orElseThrow()); + } + }); + } + + private static EventType parseEventType(String body) { + if (body == null) { + return EventType.end; + } + if (body.startsWith("event: end")) { + return EventType.end; + } else if (body.startsWith("event: data")) { + return EventType.data; + } else if (body.isEmpty()) { + return EventType.emptyLine; + } else { + return null; + } + } + + private String buildBody(MutableRecord context) throws IOException { + Map values = new HashMap<>(); + for (FieldDefinition field : fields) { + values.put(field.name, field.expressionEvaluator.evaluate(context)); + } + Map request = Map.of("input", values); + return mapper.writeValueAsString(request); + } + + private Object parseResponseBody(String body, boolean streaming) { + try { + Map map = + mapper.readValue(body, new TypeReference>() {}); + if (!streaming) { + map = (Map) map.get("output"); + } + if (contentField.isEmpty()) { + return map; + } else { + return map.get(contentField); + } + } catch (JsonProcessingException ex) { + log.debug("Not able to parse response to json: {}, {}", body, ex); + } + return body; + } + + @Override + public void start() throws Exception { + if (!streamToTopic.isEmpty()) { + log.info("Streaming answers to topic {}", streamToTopic); + topicProducer = + agentContext + .getTopicConnectionProvider() + .createProducer( + agentContext.getGlobalAgentId(), streamToTopic, Map.of()); + topicProducer.start(); + } + } + + @Override + public void close() throws Exception { + if (topicProducer != null) { + topicProducer.close(); + topicProducer = null; + } + if (executor != null) { + executor.shutdownNow(); + } + } + + private class StreamResponseProcessor extends CompletableFuture + implements Flow.Subscriber { + + Flow.Subscription subscription; + + private final StringWriter totalAnswer = new StringWriter(); + + private final StringWriter writer = new StringWriter(); + private final AtomicInteger numberOfChunks = new AtomicInteger(); + private final int minChunksPerMessage; + + private final AtomicInteger currentChunkSize = new AtomicInteger(1); + private final AtomicInteger index = new AtomicInteger(); + + private final StreamingChunksConsumer streamingChunksConsumer; + + private final String answerId = java.util.UUID.randomUUID().toString(); + + public StreamResponseProcessor( + int minChunksPerMessage, StreamingChunksConsumer streamingChunksConsumer) { + this.minChunksPerMessage = minChunksPerMessage; + this.streamingChunksConsumer = streamingChunksConsumer; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public synchronized void onNext(String body) { + EventType eventType = parseEventType(body); + if (eventType == null || eventType == EventType.end) { + boolean last = false; + String content; + if (body.startsWith("data: ")) { + body = body.substring("data: ".length()); + final Object responseBody = parseResponseBody(body, true); + if (responseBody == null) { + content = ""; + } else { + content = responseBody.toString(); + } + } else if (eventType == EventType.end) { + content = ""; + last = true; + } else { + content = ""; + } + + if (!content.isEmpty()) { + writer.write(content); + totalAnswer.write(content); + numberOfChunks.incrementAndGet(); + } + + // start from 1 chunk, then double the size until we reach the minChunksPerMessage + // this gives better latencies for the first message + int currentMinChunksPerMessage = currentChunkSize.get(); + + if (numberOfChunks.get() >= currentMinChunksPerMessage || last) { + currentChunkSize.set( + Math.min(currentMinChunksPerMessage * 2, minChunksPerMessage)); + String chunk = writer.toString(); + streamingChunksConsumer.consumeChunk( + answerId, index.incrementAndGet(), chunk, last); + writer.getBuffer().setLength(0); + numberOfChunks.set(0); + } + if (last) { + this.complete(null); + } + } + subscription.request(1); + } + + @Override + public void onError(Throwable error) { + log.error("IO Error while calling LangServe", error); + this.completeExceptionally(error); + } + + @Override + public void onComplete() { + if (!this.isDone()) { + this.complete(null); + } + } + + public String buildTotalAnswerMessage() { + return totalAnswer.toString(); + } + } + + interface StreamingChunksConsumer { + void consumeChunk(String answerId, int index, String chunk, boolean last); + } + + private void applyResultFieldToContext( + MutableRecord mutableRecord, String content, boolean streamingAnswer) { + String fieldName = outputFieldName; + + // maybe we want a different field in the streaming answer + // typically you want to directly stream the answer as the whole "value" + if (streamingAnswer) { + fieldName = streamResponseCompletionField; + } + mutableRecord.setResultField( + content, + fieldName, + Schema.create(Schema.Type.STRING), + avroKeySchemaCache, + avroValueSchemaCache); + } +} diff --git a/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/ai.langstream.agents.index b/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/ai.langstream.agents.index index 1ddcb4e59..a7778ca5b 100644 --- a/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/ai.langstream.agents.index +++ b/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/ai.langstream.agents.index @@ -1 +1,2 @@ -query-http \ No newline at end of file +http-request +langserve-invoke \ No newline at end of file diff --git a/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider b/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider index afa4d7185..1a2acf0aa 100644 --- a/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider +++ b/langstream-agents/langstream-agent-http-request/src/main/resources/META-INF/services/ai.langstream.api.runner.code.AgentCodeProvider @@ -1 +1 @@ -ai.langstream.agents.azureblobstorage.source.HttpRequestAgentProvider \ No newline at end of file +ai.langstream.agents.http.HttpRequestAgentProvider \ No newline at end of file diff --git a/langstream-agents/langstream-agent-http-request/src/test/java/ai/langstream/agents/http/LangServeInvokeAgentTest.java b/langstream-agents/langstream-agent-http-request/src/test/java/ai/langstream/agents/http/LangServeInvokeAgentTest.java new file mode 100644 index 000000000..9dfd38c45 --- /dev/null +++ b/langstream-agents/langstream-agent-http-request/src/test/java/ai/langstream/agents/http/LangServeInvokeAgentTest.java @@ -0,0 +1,295 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.http; + +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.ok; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static org.junit.jupiter.api.Assertions.*; + +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.AgentProcessor; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.RecordSink; +import ai.langstream.api.runner.code.SimpleRecord; +import ai.langstream.api.runner.topics.TopicAdmin; +import ai.langstream.api.runner.topics.TopicConnectionProvider; +import ai.langstream.api.runner.topics.TopicConsumer; +import ai.langstream.api.runner.topics.TopicProducer; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; + +@WireMockTest +@Slf4j +class LangServeInvokeAgentTest { + + @Test + public void testInvoke(WireMockRuntimeInfo wireMockRuntimeInfo) throws Exception { + stubFor( + post("/chain/invoke") + .withRequestBody( + equalTo(""" + {"input":{"topic":"cats"}}""")) + .willReturn( + okJson( + """ + {"output":{"content":"Why don't cats play poker in the wild? Too many cheetahs!","additional_kwargs":{},"type":"ai","example":false},"callback_events":[]} + """))); + Map configuration = + Map.of( + "fields", + List.of(Map.of("name", "topic", "expression", "value.foo")), + "url", + wireMockRuntimeInfo.getHttpBaseUrl() + "/chain/invoke", + "output-field", + "value", + "debug", + true); + try (LangServeInvokeAgent agent = new LangServeInvokeAgent(); ) { + agent.init(configuration); + agent.start(); + List records = new CopyOnWriteArrayList<>(); + RecordSink sink = (records::add); + + SimpleRecord input = + SimpleRecord.of( + null, + """ + { + "foo": "cats" + } + """); + agent.processRecord(input, sink); + + Awaitility.await() + .untilAsserted( + () -> { + assertEquals(1, records.size()); + }); + + Record record = records.get(0).resultRecords().get(0); + assertEquals( + """ + Why don't cats play poker in the wild? Too many cheetahs!""", + record.value()); + } + } + + @Test + public void testStreamingOutput(WireMockRuntimeInfo wireMockRuntimeInfo) throws Exception { + String response = + """ + event: data + data: {"content": "", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "Why", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " don", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "'t", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " cats", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " play", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " poker", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " in", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " the", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " wild", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "?\\n\\n", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "Too", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " many", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " che", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "et", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "ah", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "s", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "!", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: end"""; + + stubFor( + post("/chain/stream") + .withRequestBody( + equalTo(""" + {"input":{"topic":"cats"}}""")) + .willReturn(ok(response))); + + Map configuration = + Map.of( + "fields", + List.of(Map.of("name", "topic", "expression", "value.foo")), + "url", + wireMockRuntimeInfo.getHttpBaseUrl() + "/chain/stream", + "output-field", + "value", + "debug", + true, + "stream-to-topic", + "some-topic"); + + List streamingAnswers = new ArrayList<>(); + + try (LangServeInvokeAgent agent = new LangServeInvokeAgent(); ) { + agent.init(configuration); + + setupMockTopicProducer(streamingAnswers, agent); + agent.start(); + List records = new CopyOnWriteArrayList<>(); + RecordSink sink = (records::add); + + SimpleRecord input = + SimpleRecord.of( + null, + """ + { + "foo": "cats" + } + """); + agent.processRecord(input, sink); + + Awaitility.await() + .atMost(1, TimeUnit.DAYS) + .untilAsserted( + () -> { + assertEquals(1, records.size()); + }); + + streamingAnswers.forEach( + record -> { + log.info("Answer {}", record); + }); + + Record record = records.get(0).resultRecords().get(0); + log.info("Main answer: {}", record); + assertEquals( + """ + Why don't cats play poker in the wild? + + Too many cheetahs!""", + record.value()); + + assertEquals("Why", streamingAnswers.get(0).value()); + assertEquals(" don't", streamingAnswers.get(1).value()); + assertEquals(" cats play poker in", streamingAnswers.get(2).value()); + assertEquals(" the wild?\n\nToo many cheetah", streamingAnswers.get(3).value()); + assertEquals("s!", streamingAnswers.get(4).value()); + assertEquals(5, streamingAnswers.size()); + } + } + + private static void setupMockTopicProducer( + List streamingAnswers, LangServeInvokeAgent agent) throws Exception { + TopicProducer topicProducer = + new TopicProducer() { + @Override + public CompletableFuture write(Record record) { + streamingAnswers.add(record); + return CompletableFuture.completedFuture(null); + } + + @Override + public long getTotalIn() { + return 0; + } + }; + + TopicConnectionProvider topicConnectionProvider = + new TopicConnectionProvider() { + @Override + public TopicProducer createProducer( + String agentId, String topic, Map config) { + assertEquals("some-topic", topic); + return topicProducer; + } + }; + + agent.setContext( + new AgentContext() { + @Override + public TopicConsumer getTopicConsumer() { + return null; + } + + @Override + public TopicProducer getTopicProducer() { + return null; + } + + @Override + public String getGlobalAgentId() { + return null; + } + + @Override + public TopicAdmin getTopicAdmin() { + return null; + } + + @Override + public TopicConnectionProvider getTopicConnectionProvider() { + return topicConnectionProvider; + } + + @Override + public Path getCodeDirectory() { + return null; + } + }); + } +} diff --git a/langstream-agents/langstream-agent-http-request/src/test/resources/logback-test.xml b/langstream-agents/langstream-agent-http-request/src/test/resources/logback-test.xml new file mode 100644 index 000000000..fdfe741f8 --- /dev/null +++ b/langstream-agents/langstream-agent-http-request/src/test/resources/logback-test.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n + + + + + + + diff --git a/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/GenAIToolKitAgent.java b/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/GenAIToolKitAgent.java index c74048e64..bca4210bb 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/GenAIToolKitAgent.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/GenAIToolKitAgent.java @@ -218,7 +218,13 @@ public void streamAnswerChunk( last, record); } - topicProducer.write(record.get()).join(); + topicProducer + .write(record.get()) + .exceptionally( + e -> { + log.error("Error writing chunk to topic", e); + return null; + }); } } diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/HttpRequestAgentProvider.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/HttpRequestAgentProvider.java index c44454eef..faff9cf2f 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/HttpRequestAgentProvider.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/HttpRequestAgentProvider.java @@ -17,9 +17,16 @@ import ai.langstream.api.doc.AgentConfig; import ai.langstream.api.doc.ConfigProperty; +import ai.langstream.api.doc.ExtendedValidationType; import ai.langstream.api.model.AgentConfiguration; +import ai.langstream.api.model.Module; +import ai.langstream.api.model.Pipeline; import ai.langstream.api.runtime.ComponentType; +import ai.langstream.api.runtime.ComputeClusterRuntime; +import ai.langstream.api.runtime.ExecutionPlan; +import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.impl.agents.AbstractComposableAgentProvider; +import ai.langstream.impl.uti.ClassConfigValidator; import ai.langstream.runtime.impl.k8s.KubernetesClusterRuntime; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; @@ -31,7 +38,8 @@ @Slf4j public class HttpRequestAgentProvider extends AbstractComposableAgentProvider { - private static final Set SUPPORTED_AGENT_TYPES = Set.of("http-request"); + private static final Set SUPPORTED_AGENT_TYPES = + Set.of("http-request", "langserve-invoke"); public HttpRequestAgentProvider() { super(SUPPORTED_AGENT_TYPES, List.of(KubernetesClusterRuntime.CLUSTER_TYPE, "none")); @@ -44,7 +52,47 @@ protected ComponentType getComponentType(AgentConfiguration agentConfiguration) @Override protected Class getAgentConfigModelClass(String type) { - return Config.class; + return switch (type) { + case "http-request" -> HttpRequestConfig.class; + case "langserve-invoke" -> LangServeInvokeConfig.class; + default -> throw new IllegalArgumentException("Unknown agent type: " + type); + }; + } + + @Override + protected Map computeAgentConfiguration( + AgentConfiguration agentConfiguration, + Module module, + Pipeline pipeline, + ExecutionPlan executionPlan, + ComputeClusterRuntime clusterRuntime, + PluginsRegistry pluginsRegistry) { + if (agentConfiguration.getType().equals("langserve-invoke")) { + LangServeInvokeConfig config = + ClassConfigValidator.convertValidatedConfiguration( + agentConfiguration.getConfiguration(), LangServeInvokeConfig.class); + if (config.getStreamToTopic() != null && !config.getStreamToTopic().isEmpty()) { + log.info("Validating topic reference {}", config.getStreamToTopic()); + module.resolveTopic(config.getStreamToTopic()); + } + String url = config.getUrl(); + if (url == null || url.isEmpty()) { + throw new IllegalArgumentException("Invalid empty url for langserve-invoke"); + } + if (!url.endsWith("/stream") && !url.endsWith("/invoke")) { + throw new IllegalArgumentException( + "Invalid url " + + url + + " for langserve-invoke, must end with /stream or /invoke"); + } + } + return super.computeAgentConfiguration( + agentConfiguration, + module, + pipeline, + executionPlan, + clusterRuntime, + pluginsRegistry); } @AgentConfig( @@ -54,7 +102,7 @@ protected Class getAgentConfigModelClass(String type) { Agent for enriching data with an HTTP request. """) @Data - public static class Config { + public static class HttpRequestConfig { @ConfigProperty( description = """ @@ -121,4 +169,136 @@ public static class Config { @JsonProperty("handle-cookies") private boolean handleCookies; } + + @AgentConfig( + name = "Invoke LangServe", + description = + """ + Agent for invoking LangServe based applications + """) + @Data + public static class LangServeInvokeConfig { + @ConfigProperty( + description = + """ + Url to send the request to. For adding query string parameters, use the `query-string` field. + """, + required = true) + private String url; + + @ConfigProperty( + description = + """ + The field that will hold the results, it can be the same as "field" to override it. + """, + required = true, + defaultValue = "value") + @JsonProperty("output-field") + private String outputFieldName = "value"; + + @ConfigProperty( + description = + """ + Field in the response that will be used as the content of the record. + """, + required = false, + defaultValue = "content") + @JsonProperty("content-field") + private String contentFieldName = "content"; + + @ConfigProperty( + description = + """ + Enable streaming of the results. If enabled, the results are streamed to the specified topic in small chunks. The entire messages will be sent to the output topic instead. + """) + @JsonProperty(value = "stream-to-topic") + private String streamToTopic; + + @ConfigProperty( + description = + """ + Field to use to store the completion results in the stream-to-topic topic. Use "value" to write the result without a structured schema. Use "value." to write the result in a specific field. + """) + @JsonProperty(value = "stream-response-field") + private String streamResponseCompletionField; + + @ConfigProperty( + description = + """ + Minimum number of chunks to send to the stream-to-topic topic. The chunks are sent as soon as they are available. + The chunks are sent in the order they are received from the AI Service. + To improve the TTFB (Time-To-First-Byte), the chunk size starts from 1 and doubles until it reaches the max-chunks-per-message value. + """, + defaultValue = "20") + @JsonProperty(value = "min-chunks-per-message") + private int minChunksPerMessage = 20; + + @ConfigProperty( + description = + """ + Field in the response that will be used as the content of the record. + """) + @JsonProperty("debug") + private boolean debug; + + @ConfigProperty( + description = + """ + Http method to use for the request. + """, + defaultValue = "POST") + private String method = "POST"; + + @ConfigProperty( + description = + """ + Headers to send with the request. You can use the Mustache syntax to inject value from the context. + """) + private Map headers; + + @ConfigProperty( + description = + """ + Whether or not to follow redirects. + """, + defaultValue = "true") + @JsonProperty("allow-redirects") + private boolean allowRedirects; + + @ConfigProperty( + description = + """ + Whether or not to handle cookies during the redirects. + """, + defaultValue = "true") + @JsonProperty("handle-cookies") + private boolean handleCookies; + + @ConfigProperty( + description = + """ + Fields of the generated records. + """) + List fields; + } + + @Data + public static class FieldConfiguration { + @ConfigProperty( + description = + """ + Name of the field like value.xx, key.xxx, properties.xxx + """, + required = true) + String name; + + @ConfigProperty( + description = + """ + Expression to compute the value of the field. This is a standard EL expression. + """, + required = true, + extendedValidationType = ExtendedValidationType.EL_EXPRESSION) + String expression; + } } diff --git a/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaProducerWrapper.java b/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaProducerWrapper.java index 511e6e060..3d522f18d 100644 --- a/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaProducerWrapper.java +++ b/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaProducerWrapper.java @@ -103,7 +103,7 @@ public KafkaProducerWrapper(Map copy, String topicName) { org.apache.kafka.common.serialization.ByteArraySerializer.class.getName(), copy.get(VALUE_SERIALIZER_CLASS_CONFIG)); if (!forcedKeySerializer) { - log.info( + log.debug( "The Producer to {} is configured without a key serializer, we will use reflection to find the right one", topicName); } else { @@ -113,7 +113,7 @@ public KafkaProducerWrapper(Map copy, String topicName) { copy.get(KEY_SERIALIZER_CLASS_CONFIG)); } if (!forcedValueSerializer) { - log.info( + log.debug( "The Producer to {} is configured without a value serializer, we will use reflection to find the right one", topicName); } else { diff --git a/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaReaderWrapper.java b/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaReaderWrapper.java index 8d53f63f0..1738845ad 100644 --- a/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaReaderWrapper.java +++ b/langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaReaderWrapper.java @@ -119,8 +119,8 @@ public TopicReadResult read() throws JsonProcessingException { records.add(KafkaRecord.fromKafkaConsumerRecord(record)); } final Set assignment = consumer.assignment(); - if (!records.isEmpty()) { - log.info("Received {} records from Kafka topics {}", records.size(), assignment); + if (!records.isEmpty() && log.isDebugEnabled()) { + log.debug("Received {} records from Kafka topics {}", records.size(), assignment); } Map offsets = consumer.endOffsets(assignment); diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/LangServeInvokeAgentRunnerIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/LangServeInvokeAgentRunnerIT.java new file mode 100644 index 000000000..3feb4ae77 --- /dev/null +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/LangServeInvokeAgentRunnerIT.java @@ -0,0 +1,206 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.kafka; + +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.ok; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +@Slf4j +@WireMockTest +class LangServeInvokeAgentRunnerIT extends AbstractKafkaApplicationRunner { + + static WireMockRuntimeInfo wireMockRuntimeInfo; + + @BeforeAll + static void onBeforeAll(WireMockRuntimeInfo info) { + wireMockRuntimeInfo = info; + } + + @Test + void testStreamOuput() throws Exception { + + String response = + """ + event: data + data: {"content": "", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "Why", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " don", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "'t", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " cats", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " play", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " poker", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " in", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " the", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " wild", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "?\\n\\n", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "Too", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " many", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": " che", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "et", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "ah", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "s", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "!", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: data + data: {"content": "", "additional_kwargs": {}, "type": "AIMessageChunk", "example": false} + + event: end"""; + + stubFor( + post("/chain/stream") + .withRequestBody( + equalTo(""" + {"input":{"topic":"cats"}}""")) + .willReturn(ok(response))); + + Map application = + Map.of( + "module.yaml", + """ + topics: + - name: "input-topic" + creation-mode: create-if-not-exists + deletion-mode: delete + - name: "output-topic" + creation-mode: create-if-not-exists + deletion-mode: delete + - name: "streaming-answers-topic" + creation-mode: create-if-not-exists + deletion-mode: delete + pipeline: + - type: "langserve-invoke" + input: input-topic + output: output-topic + id: step1 + configuration: + output-field: value.answer + stream-to-topic: streaming-answers-topic + stream-response-field: value + min-chunks-per-message: 10 + debug: false + method: POST + allow-redirects: true + handle-cookies: false + url: %s/chain/stream + headers: + Authorisation: "Bearer {{secrets.langserve.token}}" + fields: + - name: topic + expression: "value.topic" + """ + .formatted(wireMockRuntimeInfo.getHttpBaseUrl())); + + String tenant = "tenant"; + String[] expectedAgents = {"app-step1"}; + + // write some data + try (ApplicationRuntime applicationRuntime = + deployApplicationWithSecrets( + tenant, + "app", + application, + buildInstanceYaml(), + """ + secrets: + - id: langserve + data: + token: "my-token" + """, + expectedAgents)) { + try (KafkaProducer producer = createProducer(); + KafkaConsumer consumer = createConsumer("output-topic"); + KafkaConsumer consumerStreaming = + createConsumer("streaming-answers-topic")) { + sendMessage("input-topic", "{\"topic\":\"cats\"}", producer); + executeAgentRunners(applicationRuntime); + + waitForMessages( + consumer, + List.of( + "{\"answer\":\"Why don't cats play poker in the wild?\\n\\nToo many cheetahs!\",\"topic\":\"cats\"}")); + + List streamingAnswers = + waitForMessages( + consumerStreaming, + List.of( + "Why", + " don't", + " cats play poker in", + " the wild?\n\nToo many cheetah", + "s!")); + streamingAnswers.forEach( + a -> { + log.info("Record: {}={}", a.key(), a.value()); + a.headers() + .forEach( + h -> { + log.info( + "header: {}={}", + h.key(), + new String(h.value())); + }); + }); + } + } + } +} diff --git a/langstream-runtime/langstream-runtime-tester/src/main/assemble/logback.xml b/langstream-runtime/langstream-runtime-tester/src/main/assemble/logback.xml index 27af08f73..1aa55b6a5 100644 --- a/langstream-runtime/langstream-runtime-tester/src/main/assemble/logback.xml +++ b/langstream-runtime/langstream-runtime-tester/src/main/assemble/logback.xml @@ -34,6 +34,7 @@ + diff --git a/pom.xml b/pom.xml index ac4d81a8f..465a570fe 100644 --- a/pom.xml +++ b/pom.xml @@ -66,7 +66,7 @@ 8.5.4 1.70 1.4.1 - 3.0.0-beta-10 + 3.0.1 0.25.5 0.11.5 1.16.0