From 3174e77b8ac1c07f1384666325fd84e830dbb3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Thu, 26 Oct 2023 14:56:31 +0200 Subject: [PATCH 1/4] Gateway: accept HTTP request for producing messages --- .../google/GoogleAuthenticationProvider.java | 6 +- .../jwt/admin/HttpAuthenticationProvider.java | 3 +- .../jwt/admin/JwtAuthenticationProvider.java | 3 +- langstream-api-gateway/pom.xml | 10 + .../apigateway/gateways/ConsumeGateway.java | 215 ++++++++ .../gateways/GatewayRequestHandler.java | 280 ++++++++++ .../GatewayRequestHandlerFactory.java | 21 + .../apigateway/gateways/ProduceGateway.java | 202 +++++++ .../apigateway/http/GatewayResource.java | 91 ++++ .../http/ResourceErrorsHandler.java | 45 ++ .../langstream/apigateway/util/HttpUtil.java | 30 ++ .../websocket/AuthenticationInterceptor.java | 158 +----- .../apigateway/websocket/WebSocketConfig.java | 5 +- .../websocket/handlers/AbstractHandler.java | 494 +++--------------- .../websocket/handlers/ChatHandler.java | 89 ++-- .../websocket/handlers/ConsumeHandler.java | 57 +- .../websocket/handlers/ProduceHandler.java | 50 +- ...uthenticatedGatewayRequestContextImpl.java | 1 + .../apigateway/http/GatewayResourceTest.java | 413 +++++++++++++++ .../http/KafkaGatewayResourceTest.java | 98 ++++ .../http/PulsarGatewayResourceTest.java | 71 +++ .../TestGatewayAuthenticationProvider.java | 2 +- .../handlers/ProduceConsumeHandlerTest.java | 85 ++- .../cli/commands/gateway/BaseGatewayCmd.java | 26 +- .../cli/commands/gateway/ChatGatewayCmd.java | 9 +- .../commands/gateway/ConsumeGatewayCmd.java | 3 +- .../commands/gateway/ProduceGatewayCmd.java | 42 +- 27 files changed, 1774 insertions(+), 735 deletions(-) create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/ResourceErrorsHandler.java create mode 100644 langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java create mode 100644 langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java create mode 100644 langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java create mode 100644 langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java diff --git a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java index f2e8bce5e..76ddc5404 100644 --- a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java @@ -59,7 +59,11 @@ public void initialize(Map configuration) { @Override public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { try { - GoogleIdToken idToken = verifier.verify(context.credentials()); + final String credentials = context.credentials(); + if (credentials == null) { + return GatewayAuthenticationResult.authenticationFailed("Token not found."); + } + GoogleIdToken idToken = verifier.verify(credentials); if (idToken != null) { final GoogleIdToken.Payload payload = idToken.getPayload(); Map result = new HashMap<>(); diff --git a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java index cb484c89c..2e1f9dc1c 100644 --- a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java @@ -64,7 +64,8 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { final HttpRequest.Builder builder = HttpRequest.newBuilder().uri(URI.create(url)); httpConfiguration.getHeaders().forEach(builder::header); - builder.header("Authorization", "Bearer " + context.credentials()); + final String credentials = context.credentials(); + builder.header("Authorization", "Bearer " + (credentials == null ? "" : credentials)); final HttpRequest request = builder.GET().build(); final HttpResponse response; diff --git a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java index 9577bd1b9..e669d9e1d 100644 --- a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java @@ -66,7 +66,8 @@ public void initialize(Map configuration) { public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { String role; try { - role = authenticationProviderToken.authenticate(context.credentials()); + final String credentials = context.credentials(); + role = authenticationProviderToken.authenticate(credentials == null ? "" : credentials); } catch (AuthenticationProviderToken.AuthenticationException ex) { return GatewayAuthenticationResult.authenticationFailed(ex.getMessage()); } diff --git a/langstream-api-gateway/pom.xml b/langstream-api-gateway/pom.xml index b465cb315..d13c3c8e9 100644 --- a/langstream-api-gateway/pom.xml +++ b/langstream-api-gateway/pom.xml @@ -70,6 +70,16 @@ org.springframework.boot spring-boot-starter-websocket + + org.springdoc + springdoc-openapi-starter-webmvc-ui + ${springdoc-openapi-starter-webmvc.version} + + + org.springdoc + springdoc-openapi-starter-webmvc-api + ${springdoc-openapi-starter-webmvc.version} + ch.qos.logback logback-core diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java new file mode 100644 index 000000000..b7686988b --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -0,0 +1,215 @@ +package ai.langstream.apigateway.gateways; + +import ai.langstream.api.model.Gateway; +import ai.langstream.api.model.StreamingCluster; +import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.topics.TopicConnectionsRuntime; +import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; +import ai.langstream.api.runner.topics.TopicOffsetPosition; +import ai.langstream.api.runner.topics.TopicReadResult; +import ai.langstream.api.runner.topics.TopicReader; +import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import ai.langstream.apigateway.websocket.api.ConsumePushMessage; +import ai.langstream.apigateway.websocket.api.ProduceResponse; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.Closeable; +import java.util.Base64; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class ConsumeGateway implements Closeable { + + protected static final ObjectMapper mapper = new ObjectMapper(); + + @Getter + public static class ProduceException extends Exception { + + private final ProduceResponse.Status status; + + public ProduceException(String message, ProduceResponse.Status status) { + super(message); + this.status = status; + } + } + + public static class ProduceGatewayRequestValidator implements GatewayRequestHandler.GatewayRequestValidator { + @Override + public List getAllRequiredParameters(Gateway gateway) { + return gateway.getParameters(); + } + + @Override + public void validateOptions(Map options) { + for (Map.Entry option : options.entrySet()) { + switch (option.getKey()) { + default -> throw new IllegalArgumentException("Unknown option " + option.getKey()); + } + } + } + } + + + private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; + private TopicReader reader; + private AuthenticatedGatewayRequestContext requestContext; + private List> filters; + + public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) { + this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry; + } + + @SneakyThrows + public void setup(String topic, List> filters, AuthenticatedGatewayRequestContext requestContext) { + this.requestContext = requestContext; + this.filters = filters == null ? List.of() : filters; + + final StreamingCluster streamingCluster = requestContext.application().getInstance().streamingCluster(); + final TopicConnectionsRuntime topicConnectionsRuntime = + topicConnectionsRuntimeRegistry + .getTopicConnectionsRuntime(streamingCluster) + .asTopicConnectionsRuntime(); + + topicConnectionsRuntime.init(streamingCluster); + + final String positionParameter = requestContext.options().getOrDefault("position", "latest"); + TopicOffsetPosition position = + switch (positionParameter) { + case "latest" -> TopicOffsetPosition.LATEST; + case "earliest" -> TopicOffsetPosition.EARLIEST; + default -> TopicOffsetPosition.absolute( + Base64.getDecoder().decode(positionParameter)); + }; + reader = + topicConnectionsRuntime.createReader( + streamingCluster, Map.of("topic", topic), position); + reader.start(); + } + + + public CompletableFuture startReading(Executor executor, Supplier stop, Consumer onMessage) { + if (requestContext == null || reader == null) { + throw new IllegalStateException("Not initialized"); + } + return CompletableFuture.runAsync( + () -> { + + try { + final String tenant = requestContext.tenant(); + + final String gatewayId = requestContext.gateway().getId(); + final String applicationId = requestContext.applicationId(); + log.info( + "Started reader for gateway {}/{}/{}", + tenant, + applicationId, + gatewayId); + readMessages(stop, onMessage); + } catch (InterruptedException | CancellationException ex) { + // ignore + } catch (Throwable ex) { + log.error(ex.getMessage(), ex); + } finally { + closeReader(); + } + }, + executor); + } + + + protected void readMessages(Supplier stop, Consumer onMessage) + throws Exception { + while (true) { + if (Thread.currentThread().isInterrupted()) { + return; + } + if (stop.get()) { + return; + } + final TopicReadResult readResult = reader.read(); + final List records = readResult.records(); + for (Record record : records) { + log.debug("Received record {}", record); + boolean skip = false; + if (filters != null) { + for (Function filter : filters) { + if (!filter.apply(record)) { + skip = true; + log.debug("Skipping record {}", requestContext, record); + break; + } + } + } + if (!skip) { + final Map messageHeaders = computeMessageHeaders(record); + final String offset = computeOffset(readResult); + + final ConsumePushMessage message = + new ConsumePushMessage( + new ConsumePushMessage.Record( + record.key(), record.value(), messageHeaders), + offset); + final String jsonMessage = mapper.writeValueAsString(message); + onMessage.accept(jsonMessage); + + } + } + } + } + + + + + private static Map computeMessageHeaders(Record record) { + final Collection
headers = record.headers(); + final Map messageHeaders; + if (headers == null) { + messageHeaders = Map.of(); + } else { + messageHeaders = new HashMap<>(); + headers.forEach(h -> messageHeaders.put(h.key(), h.valueAsString())); + } + return messageHeaders; + } + + private static String computeOffset(TopicReadResult readResult) { + final byte[] offset = readResult.offset(); + if (offset == null) { + return null; + } + return Base64.getEncoder().encodeToString(offset); + } + + + + + private void closeReader() { + if (reader != null) { + try { + reader.close(); + } catch (Exception e) { + log.error("error closing reader", e); + } + } + } + + + + @Override + public void close() { + closeReader(); + } + +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java new file mode 100644 index 000000000..53cf1cdd2 --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java @@ -0,0 +1,280 @@ +package ai.langstream.apigateway.gateways; + +import ai.langstream.api.gateway.GatewayAuthenticationProvider; +import ai.langstream.api.gateway.GatewayAuthenticationProviderRegistry; +import ai.langstream.api.gateway.GatewayAuthenticationResult; +import ai.langstream.api.gateway.GatewayRequestContext; +import ai.langstream.api.model.Application; +import ai.langstream.api.model.ApplicationSpecs; +import ai.langstream.api.model.Gateway; +import ai.langstream.api.model.Gateways; +import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import ai.langstream.apigateway.websocket.AuthenticationInterceptor; +import ai.langstream.apigateway.websocket.impl.AuthenticatedGatewayRequestContextImpl; +import ai.langstream.apigateway.websocket.impl.GatewayRequestContextImpl; +import ai.langstream.impl.common.ApplicationPlaceholderResolver; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.codec.digest.DigestUtils; +import org.springframework.util.StringUtils; + +@Slf4j +public class GatewayRequestHandler { + + public static class AuthFailedException extends Exception { + public AuthFailedException(String message) { + super(message); + } + } + + public interface GatewayRequestValidator { + List getAllRequiredParameters(Gateway gateway); + void validateOptions(Map options); + } + + + private final ApplicationStore applicationStore; + private final GatewayAuthenticationProvider authTestProvider; + + + + public GatewayRequestHandler( + ApplicationStore applicationStore, + GatewayTestAuthenticationProperties testAuthenticationProperties) { + this.applicationStore = applicationStore; + if (testAuthenticationProperties.getType() != null) { + authTestProvider = + GatewayAuthenticationProviderRegistry.loadProvider( + testAuthenticationProperties.getType(), + testAuthenticationProperties.getConfiguration()); + log.info( + "Loaded test authentication provider {}", + authTestProvider.getClass().getName()); + } else { + authTestProvider = null; + log.info("No test authentication provider configured"); + } + } + + + public GatewayRequestContext validateRequest( + String tenant, + String applicationId, + String gatewayId, + Gateway.GatewayType expectedGatewayType, + Map queryString, + Map httpHeaders, + GatewayRequestValidator validator) { + Map options = new HashMap<>(); + Map userParameters = new HashMap<>(); + + final String credentials = queryString.remove("credentials"); + final String testCredentials = queryString.remove("test-credentials"); + + for (Map.Entry entry : queryString.entrySet()) { + if (entry.getKey().startsWith("option:")) { + options.put(entry.getKey().substring("option:".length()), entry.getValue()); + } else if (entry.getKey().startsWith("param:")) { + userParameters.put(entry.getKey().substring("param:".length()), entry.getValue()); + } else { + throw new IllegalArgumentException( + "invalid query parameter " + + entry.getKey() + + ". " + + "To specify a gateway parameter, use the format param:." + + "To specify a option, use the format option:."); + } + } + + final Application application = getResolvedApplication(tenant, applicationId); + final Gateway gateway = extractGateway(gatewayId, application, expectedGatewayType); + + final List requiredParameters = validator.getAllRequiredParameters(gateway); + Set allUserParameterKeys = new HashSet<>(userParameters.keySet()); + if (requiredParameters != null) { + for (String requiredParameter : requiredParameters) { + final String value = userParameters.get(requiredParameter); + if (!StringUtils.hasText(value)) { + throw new IllegalArgumentException( + formatErrorMessage( + tenant, + applicationId, + gateway, + "missing required parameter " + + requiredParameter + + ". Required parameters: " + + requiredParameters)); + } + allUserParameterKeys.remove(requiredParameter); + } + } + if (!allUserParameterKeys.isEmpty()) { + throw new IllegalArgumentException( + formatErrorMessage( + tenant, + applicationId, + gateway, + "unknown parameters: " + allUserParameterKeys)); + } + validator.validateOptions(options); + + if (credentials != null && testCredentials != null) { + throw new IllegalArgumentException( + formatErrorMessage( + tenant, + applicationId, + gateway, + "credentials and test-credentials cannot be used together")); + } + return GatewayRequestContextImpl.builder() + .tenant(tenant) + .applicationId(applicationId) + .application(application) + .credentials(credentials) + .testCredentials(testCredentials) + .httpHeaders(httpHeaders) + .options(options) + .userParameters(userParameters) + .gateway(gateway) + .build(); + } + + private static String formatErrorMessage( + String tenant, String applicationId, Gateway gateway, String error) { + return "Error for gateway %s (tenant: %s, appId: %s): %s" + .formatted(gateway.getId(), tenant, applicationId, error); + } + + private Application getResolvedApplication(String tenant, String applicationId) { + final ApplicationSpecs applicationSpecs = applicationStore.getSpecs(tenant, applicationId); + if (applicationSpecs == null) { + throw new IllegalArgumentException("application " + applicationId + " not found"); + } + final Application application = applicationSpecs.getApplication(); + application.setSecrets(applicationStore.getSecrets(tenant, applicationId)); + return ApplicationPlaceholderResolver.resolvePlaceholders(application); + } + + + + private Gateway extractGateway( + String gatewayId, Application application, Gateway.GatewayType type) { + final Gateways gatewaysObj = application.getGateways(); + if (gatewaysObj == null) { + throw new IllegalArgumentException("no gateways defined for the application"); + } + final List gateways = gatewaysObj.gateways(); + if (gateways == null) { + throw new IllegalArgumentException("no gateways defined for the application"); + } + + Gateway selectedGateway = null; + + for (Gateway gateway : gateways) { + if (gateway.getId().equals(gatewayId) && type == gateway.getType()) { + selectedGateway = gateway; + break; + } + } + if (selectedGateway == null) { + throw new IllegalArgumentException( + "gateway " + + gatewayId + + " of type " + + type + + " is not defined in the application"); + } + return selectedGateway; + } + + + public AuthenticatedGatewayRequestContext authenticate(GatewayRequestContext gatewayRequestContext) + throws AuthFailedException { + + final Gateway.Authentication authentication = + gatewayRequestContext.gateway().getAuthentication(); + + if (authentication == null) { + return getAuthenticatedGatewayRequestContext(gatewayRequestContext, Map.of(), new HashMap<>()); + } + + final GatewayAuthenticationResult result; + if (gatewayRequestContext.isTestMode()) { + if (!authentication.isAllowTestMode()) { + throw new AuthFailedException( + "Gateway " + + gatewayRequestContext.gateway().getId() + + " of tenant " + + gatewayRequestContext.tenant() + + " does not allow test mode."); + } + if (authTestProvider == null) { + throw new AuthFailedException("No test auth provider specified"); + } + result = authTestProvider.authenticate(gatewayRequestContext); + } else { + final String provider = authentication.getProvider(); + final GatewayAuthenticationProvider authProvider = + GatewayAuthenticationProviderRegistry.loadProvider( + provider, authentication.getConfiguration()); + result = authProvider.authenticate(gatewayRequestContext); + } + if (result == null) { + throw new AuthFailedException("Authentication provider returned null"); + } + if (!result.authenticated()) { + throw new AuthFailedException(result.reason()); + } + final Map principalValues = getPrincipalValues(result, gatewayRequestContext); + return getAuthenticatedGatewayRequestContext(gatewayRequestContext, principalValues, new HashMap<>()); + } + + private Map getPrincipalValues( + GatewayAuthenticationResult result, GatewayRequestContext context) { + if (!context.isTestMode()) { + final Map values = result.principalValues(); + if (values == null) { + return Map.of(); + } + return values; + } else { + final Map values = new HashMap<>(); + final String principalSubject = DigestUtils.sha256Hex(context.credentials()); + final int principalNumericId = principalSubject.hashCode(); + final String principalEmail = "%s@locahost".formatted(principalSubject); + + // google + values.putIfAbsent("subject", principalSubject); + values.putIfAbsent("email", principalEmail); + values.putIfAbsent("name", principalSubject); + + // github + values.putIfAbsent("login", principalSubject); + values.putIfAbsent("id", principalNumericId + ""); + return values; + } + } + + private AuthenticatedGatewayRequestContext getAuthenticatedGatewayRequestContext( + GatewayRequestContext gatewayRequestContext, + Map principalValues, + Map attributes) { + + return AuthenticatedGatewayRequestContextImpl.builder() + .gatewayRequestContext(gatewayRequestContext) + .attributes(attributes) + .principalValues(principalValues) + .build(); + } + + + + +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java new file mode 100644 index 000000000..321d3f7d2 --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java @@ -0,0 +1,21 @@ +package ai.langstream.apigateway.gateways; + +import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.api.storage.ApplicationStoreRegistry; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.config.StorageProperties; +import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; +import java.util.Objects; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class GatewayRequestHandlerFactory { + + @Bean + public GatewayRequestHandler gatewayRequestHandler(ApplicationStore applicationStore, + GatewayTestAuthenticationProperties testAuthenticationProperties) { + return new GatewayRequestHandler(applicationStore, testAuthenticationProperties); + } + +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java new file mode 100644 index 000000000..d26a5ba19 --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -0,0 +1,202 @@ +package ai.langstream.apigateway.gateways; + +import ai.langstream.api.model.Gateway; +import ai.langstream.api.model.StreamingCluster; +import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.SimpleRecord; +import ai.langstream.api.runner.topics.TopicConnectionsRuntime; +import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; +import ai.langstream.api.runner.topics.TopicProducer; +import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import ai.langstream.apigateway.websocket.api.ProduceRequest; +import ai.langstream.apigateway.websocket.api.ProduceResponse; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class ProduceGateway implements Closeable { + + protected static final ObjectMapper mapper = new ObjectMapper(); + + @Getter + public static class ProduceException extends Exception { + + private final ProduceResponse.Status status; + + public ProduceException(String message, ProduceResponse.Status status) { + super(message); + this.status = status; + } + } + + public static class ProduceGatewayRequestValidator implements GatewayRequestHandler.GatewayRequestValidator { + @Override + public List getAllRequiredParameters(Gateway gateway) { + return gateway.getParameters(); + } + + @Override + public void validateOptions(Map options) { + for (Map.Entry option : options.entrySet()) { + switch (option.getKey()) { + default -> throw new IllegalArgumentException("Unknown option " + option.getKey()); + } + } + } + } + + + private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; + private TopicProducer producer; + private List
commonHeaders; + + public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) { + this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry; + } + + public void start(String topic, List
commonHeaders, AuthenticatedGatewayRequestContext requestContext) { + this.commonHeaders = commonHeaders == null ? List.of() : commonHeaders; + + setupProducer( + topic, + requestContext.application().getInstance().streamingCluster(), + requestContext.tenant(), + requestContext.applicationId(), + requestContext.gateway().getId()); + } + + protected void setupProducer( + String topic, + StreamingCluster streamingCluster, + final String tenant, + final String applicationId, + final String gatewayId) { + final TopicConnectionsRuntime topicConnectionsRuntime = + topicConnectionsRuntimeRegistry + .getTopicConnectionsRuntime(streamingCluster) + .asTopicConnectionsRuntime(); + + topicConnectionsRuntime.init(streamingCluster); + + producer = + topicConnectionsRuntime.createProducer( + null, streamingCluster, Map.of("topic", topic)); + producer.start(); + log.info( + "Started producer for gateway {}/{}/{} on topic {}", + tenant, + applicationId, + gatewayId, + topic); + } + + public void produceMessage(String payload) throws ProduceException { + final ProduceRequest produceRequest; + try { + produceRequest = mapper.readValue(payload, ProduceRequest.class); + } catch (JsonProcessingException err) { + throw new ProduceException(err.getMessage(), ProduceResponse.Status.BAD_REQUEST); + } + produceMessage(produceRequest); + } + + + public void produceMessage(ProduceRequest produceRequest) throws ProduceException { + if (produceRequest.value() == null && produceRequest.key() == null) { + throw new ProduceException("Either key or value must be set.", ProduceResponse.Status.BAD_REQUEST); + } + if (producer == null) { + throw new ProduceException("Producer not initialized", ProduceResponse.Status.PRODUCER_ERROR); + } + + final Collection
headers = new ArrayList<>(commonHeaders); + if (produceRequest.headers() != null) { + final Set configuredHeaders = + headers.stream().map(Header::key).collect(Collectors.toSet()); + for (Map.Entry messageHeader : produceRequest.headers().entrySet()) { + if (configuredHeaders.contains(messageHeader.getKey())) { + throw new ProduceException("Header " + + messageHeader.getKey() + + " is configured as parameter-level header.", ProduceResponse.Status.BAD_REQUEST); + } + headers.add( + SimpleRecord.SimpleHeader.of( + messageHeader.getKey(), messageHeader.getValue())); + } + } + try { + final SimpleRecord record = + SimpleRecord.builder() + .key(produceRequest.key()) + .value(produceRequest.value()) + .headers(headers) + .build(); + producer.write(record).get(); + log.info("Produced record {}", record); + } catch (Throwable tt) { + throw new ProduceException(tt.getMessage(), ProduceResponse.Status.PRODUCER_ERROR); + } + } + + @Override + public void close() { + + if (producer != null) { + try { + producer.close(); + } catch (Exception e) { + log.error("error closing producer", e); + } + } + + } + + public static List
getProducerCommonHeaders(Gateway.ProduceOptions produceOptions, + AuthenticatedGatewayRequestContext context) { + if (produceOptions != null) { + return getProducerCommonHeaders( + produceOptions.headers(), + context.userParameters(), + context.principalValues()); + } + return null; + } + public static List
getProducerCommonHeaders( + List headerFilters, + Map passedParameters, + Map principalValues) { + final List
headers = new ArrayList<>(); + if (headerFilters == null) { + return headers; + } + for (Gateway.KeyValueComparison mapping : headerFilters) { + if (mapping.key() == null || mapping.key().isEmpty()) { + throw new IllegalArgumentException("Header key cannot be empty"); + } + String value = mapping.value(); + if (value == null && mapping.valueFromParameters() != null) { + value = passedParameters.get(mapping.valueFromParameters()); + } + if (value == null && mapping.valueFromAuthentication() != null) { + value = principalValues.get(mapping.valueFromAuthentication()); + } + if (value == null) { + throw new IllegalArgumentException("header " + mapping.key() + " cannot be empty"); + } + headers.add(SimpleRecord.SimpleHeader.of(mapping.key(), value)); + } + return headers; + } +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java new file mode 100644 index 000000000..a37e8a9be --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java @@ -0,0 +1,91 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.http; + +import ai.langstream.api.gateway.GatewayRequestContext; +import ai.langstream.api.model.Gateway; +import ai.langstream.api.runner.code.Header; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; +import ai.langstream.apigateway.gateways.ProduceGateway; +import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; +import ai.langstream.apigateway.util.HttpUtil; +import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import ai.langstream.apigateway.websocket.api.ProduceRequest; +import ai.langstream.apigateway.websocket.api.ProduceResponse; + +import jakarta.validation.constraints.NotBlank; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.request.WebRequest; +import org.springframework.web.server.ResponseStatusException; + +@RestController +@RequestMapping("/api/gateways") +@Slf4j +@AllArgsConstructor +public class GatewayResource { + + private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider; + private final GatewayRequestHandler gatewayRequestHandler; + + @PostMapping(value = "/produce/{tenant}/{application}/{gateway}", consumes = MediaType.APPLICATION_JSON_VALUE) + ProduceResponse produce( + WebRequest request, + @NotBlank @PathVariable("tenant") String tenant, + @NotBlank @PathVariable("application") String application, + @NotBlank @PathVariable("gateway") String gateway, + @RequestBody ProduceRequest produceRequest) throws ProduceGateway.ProduceException { + + final Map queryString = request.getParameterMap().keySet().stream() + .collect(Collectors.toMap(k -> k, k -> request.getParameter(k))); + final Map headers = new HashMap<>(); + request.getHeaderNames().forEachRemaining(name -> headers.put(name, request.getHeader(name))); + final GatewayRequestContext context = + gatewayRequestHandler.validateRequest(tenant, application, gateway, Gateway.GatewayType.produce, + queryString, + headers, + new ProduceGateway.ProduceGatewayRequestValidator()); + final AuthenticatedGatewayRequestContext authContext; + try { + authContext = gatewayRequestHandler.authenticate(context); + } catch (GatewayRequestHandler.AuthFailedException e) { + throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage()); + } + + final ProduceGateway produceGateway = + new ProduceGateway(topicConnectionsRuntimeRegistryProvider.getTopicConnectionsRuntimeRegistry()); + final List
commonHeaders = + ProduceGateway.getProducerCommonHeaders(context.gateway().getProduceOptions(), authContext); + produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext); + produceGateway.produceMessage(produceRequest); + return ProduceResponse.OK; + + } +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/ResourceErrorsHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/ResourceErrorsHandler.java new file mode 100644 index 000000000..90e5f93f9 --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/ResourceErrorsHandler.java @@ -0,0 +1,45 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.http; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.http.HttpStatus; +import org.springframework.http.ProblemDetail; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.server.ResponseStatusException; + +@ControllerAdvice +@Order(Ordered.LOWEST_PRECEDENCE) +@Slf4j +public class ResourceErrorsHandler { + + @ExceptionHandler(Throwable.class) + ProblemDetail handleAll(Throwable exception) { + if (exception instanceof final ResponseStatusException rs) { + return ProblemDetail.forStatusAndDetail(rs.getStatusCode(), rs.getMessage()); + } + if (exception instanceof IllegalArgumentException) { + log.error("Bad request", exception); + return ProblemDetail.forStatusAndDetail(HttpStatus.BAD_REQUEST, exception.getMessage()); + } + log.error("Internal error", exception); + return ProblemDetail.forStatusAndDetail( + HttpStatus.INTERNAL_SERVER_ERROR, exception.getMessage()); + } +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java new file mode 100644 index 000000000..40e756f6f --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java @@ -0,0 +1,30 @@ +package ai.langstream.apigateway.util; + +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +public class HttpUtil { + + public static Map parseQuerystring(String queryString) { + Map map = new HashMap<>(); + if (queryString == null || queryString.isBlank()) { + return map; + } + String[] params = queryString.split("&"); + for (String param : params) { + String[] keyValuePair = param.split("=", 2); + String name = URLDecoder.decode(keyValuePair[0], StandardCharsets.UTF_8); + if ("".equals(name)) { + continue; + } + String value = + keyValuePair.length > 1 + ? URLDecoder.decode(keyValuePair[1], StandardCharsets.UTF_8) + : ""; + map.put(name, value); + } + return map; + } +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java index 49efcf746..adb5618e7 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java @@ -21,12 +21,15 @@ import ai.langstream.api.gateway.GatewayRequestContext; import ai.langstream.api.model.Gateway; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; +import ai.langstream.apigateway.util.HttpUtil; import ai.langstream.apigateway.websocket.handlers.AbstractHandler; import ai.langstream.apigateway.websocket.impl.AuthenticatedGatewayRequestContextImpl; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; +import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.digest.DigestUtils; import org.springframework.http.HttpStatus; @@ -40,38 +43,23 @@ import org.springframework.web.socket.server.HandshakeInterceptor; @Slf4j +@AllArgsConstructor public class AuthenticationInterceptor implements HandshakeInterceptor { - private final GatewayAuthenticationProvider authTestProvider; - - public AuthenticationInterceptor( - GatewayTestAuthenticationProperties testAuthenticationProperties) { - if (testAuthenticationProperties.getType() != null) { - authTestProvider = - GatewayAuthenticationProviderRegistry.loadProvider( - testAuthenticationProperties.getType(), - testAuthenticationProperties.getConfiguration()); - log.info( - "Loaded test authentication provider {}", - authTestProvider.getClass().getName()); - } else { - authTestProvider = null; - log.info("No test authentication provider configured"); - } - } + private final GatewayRequestHandler gatewayRequestHandler; @Override public boolean beforeHandshake( ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, - Map attributes) + Map sessionAttributes) throws Exception { final ServletServerHttpRequest httpRequest = (ServletServerHttpRequest) request; final ServletServerHttpResponse httpResponse = (ServletServerHttpResponse) response; try { final String queryString = httpRequest.getServletRequest().getQueryString(); - final Map querystring = parseQuerystring(queryString); + final Map querystring = HttpUtil.parseQuerystring(queryString); final WebSocketHandler delegate = ((ExceptionWebSocketHandlerDecorator) wsHandler).getLastHandler(); @@ -82,28 +70,32 @@ public boolean beforeHandshake( final Map vars = antPathMatcher.extractUriTemplateVariables(handler.path(), path); final GatewayRequestContext gatewayRequestContext = - handler.validateRequest( - vars, querystring, request.getHeaders().toSingleValueMap()); - - final Map principalValues; + gatewayRequestHandler.validateRequest( + handler.tenantFromPath(vars, querystring), + handler.applicationIdFromPath(vars, querystring), + handler.gatewayFromPath(vars, querystring), + handler.gatewayType(), + querystring, + request.getHeaders().toSingleValueMap(), + handler.validator()); + + final AuthenticatedGatewayRequestContext authenticatedGatewayRequestContext; try { - principalValues = authenticate(gatewayRequestContext); - } catch (AuthFailedException authFailedException) { + authenticatedGatewayRequestContext = gatewayRequestHandler.authenticate(gatewayRequestContext); + } catch (GatewayRequestHandler.AuthFailedException authFailedException) { log.info("Authentication failed {}", authFailedException.getMessage()); String error = authFailedException.getMessage(); if (error == null || error.isEmpty()) { error = "unknown"; } - httpResponse.getServletResponse().sendError(HttpStatus.FORBIDDEN.value(), error); + httpResponse.getServletResponse().sendError(HttpStatus.UNAUTHORIZED.value(), error); return false; } - log.info("Authentication passed!"); + log.debug("Authentication OK"); - final AuthenticatedGatewayRequestContext authenticatedGatewayRequestContext = - getAuthenticatedGatewayRequestContext( - gatewayRequestContext, principalValues, attributes); - attributes.put("context", authenticatedGatewayRequestContext); - handler.onBeforeHandshakeCompleted(authenticatedGatewayRequestContext, attributes); + sessionAttributes.put("context", authenticatedGatewayRequestContext); + handler.onBeforeHandshakeCompleted(authenticatedGatewayRequestContext, + authenticatedGatewayRequestContext.attributes()); return true; } catch (Throwable error) { log.info("Internal error {}", error.getMessage(), error); @@ -114,89 +106,7 @@ public boolean beforeHandshake( } } - private static class AuthFailedException extends Exception { - public AuthFailedException(String message) { - super(message); - } - } - private Map authenticate(GatewayRequestContext gatewayRequestContext) - throws AuthFailedException { - - final Gateway.Authentication authentication = - gatewayRequestContext.gateway().getAuthentication(); - - if (authentication == null) { - return Map.of(); - } - - final GatewayAuthenticationResult result; - if (gatewayRequestContext.isTestMode()) { - if (!authentication.isAllowTestMode()) { - throw new AuthFailedException( - "Gateway " - + gatewayRequestContext.gateway().getId() - + " of tenant " - + gatewayRequestContext.tenant() - + " does not allow test mode."); - } - if (authTestProvider == null) { - throw new AuthFailedException("No test auth provider specified"); - } - result = authTestProvider.authenticate(gatewayRequestContext); - } else { - final String provider = authentication.getProvider(); - final GatewayAuthenticationProvider authProvider = - GatewayAuthenticationProviderRegistry.loadProvider( - provider, authentication.getConfiguration()); - result = authProvider.authenticate(gatewayRequestContext); - } - if (result == null) { - throw new AuthFailedException("Authentication provider returned null"); - } - if (!result.authenticated()) { - throw new AuthFailedException(result.reason()); - } - return getPrincipalValues(result, gatewayRequestContext); - } - - private Map getPrincipalValues( - GatewayAuthenticationResult result, GatewayRequestContext context) { - if (!context.isTestMode()) { - final Map values = result.principalValues(); - if (values == null) { - return Map.of(); - } - return values; - } else { - final Map values = new HashMap<>(); - final String principalSubject = DigestUtils.sha256Hex(context.credentials()); - final int principalNumericId = principalSubject.hashCode(); - final String principalEmail = "%s@locahost".formatted(principalSubject); - - // google - values.putIfAbsent("subject", principalSubject); - values.putIfAbsent("email", principalEmail); - values.putIfAbsent("name", principalSubject); - - // github - values.putIfAbsent("login", principalSubject); - values.putIfAbsent("id", principalNumericId + ""); - return values; - } - } - - private AuthenticatedGatewayRequestContext getAuthenticatedGatewayRequestContext( - GatewayRequestContext gatewayRequestContext, - Map principalValues, - Map attributes) { - - return AuthenticatedGatewayRequestContextImpl.builder() - .gatewayRequestContext(gatewayRequestContext) - .attributes(attributes) - .principalValues(principalValues) - .build(); - } @Override public void afterHandshake( @@ -205,24 +115,4 @@ public void afterHandshake( WebSocketHandler wsHandler, Exception exception) {} - private static Map parseQuerystring(String queryString) { - Map map = new HashMap<>(); - if (queryString == null || queryString.isBlank()) { - return map; - } - String[] params = queryString.split("&"); - for (String param : params) { - String[] keyValuePair = param.split("=", 2); - String name = URLDecoder.decode(keyValuePair[0], StandardCharsets.UTF_8); - if ("".equals(name)) { - continue; - } - String value = - keyValuePair.length > 1 - ? URLDecoder.decode(keyValuePair[1], StandardCharsets.UTF_8) - : ""; - map.put(name, value); - } - return map; - } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java index 58add2d30..a85f0c853 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java @@ -18,6 +18,7 @@ import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; import ai.langstream.apigateway.websocket.handlers.ChatHandler; import ai.langstream.apigateway.websocket.handlers.ConsumeHandler; @@ -47,7 +48,7 @@ public class WebSocketConfig implements WebSocketConfigurer { private final ApplicationStore applicationStore; private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider; - private final GatewayTestAuthenticationProperties adminAuthenticationProperties; + private final GatewayRequestHandler gatewayRequestHandler; private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool(); @Override @@ -72,7 +73,7 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { .setAllowedOrigins("*") .addInterceptors( new HttpSessionHandshakeInterceptor(), - new AuthenticationInterceptor(adminAuthenticationProperties)); + new AuthenticationInterceptor(gatewayRequestHandler)); } @Bean diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 003f6aadc..391116eff 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -18,48 +18,34 @@ import ai.langstream.api.events.EventRecord; import ai.langstream.api.events.EventSources; import ai.langstream.api.events.GatewayEventData; -import ai.langstream.api.gateway.GatewayRequestContext; -import ai.langstream.api.model.Application; -import ai.langstream.api.model.ApplicationSpecs; import ai.langstream.api.model.Gateway; -import ai.langstream.api.model.Gateways; import ai.langstream.api.model.StreamingCluster; import ai.langstream.api.runner.code.Header; import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; -import ai.langstream.api.runner.topics.TopicOffsetPosition; import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.api.runner.topics.TopicReadResult; import ai.langstream.api.runner.topics.TopicReader; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.ConsumeGateway; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; +import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; import ai.langstream.apigateway.websocket.api.ConsumePushMessage; -import ai.langstream.apigateway.websocket.api.ProduceRequest; import ai.langstream.apigateway.websocket.api.ProduceResponse; -import ai.langstream.apigateway.websocket.impl.GatewayRequestContextImpl; -import ai.langstream.impl.common.ApplicationPlaceholderResolver; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.util.ArrayList; -import java.util.Base64; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; -import java.util.stream.Collectors; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; -import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; @@ -68,6 +54,8 @@ @Slf4j public abstract class AbstractHandler extends TextWebSocketHandler { protected static final ObjectMapper mapper = new ObjectMapper(); + protected static final String ATTRIBUTE_PRODUCE_GATEWAY = "__produce_gateway"; + protected static final String ATTRIBUTE_CONSUME_GATEWAY = "__consume_gateway"; protected final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; protected final ApplicationStore applicationStore; @@ -80,20 +68,24 @@ public AbstractHandler( public abstract String path(); - abstract Gateway.GatewayType gatewayType(); + public abstract Gateway.GatewayType gatewayType(); - abstract String tenantFromPath(Map parsedPath, Map queryString); + public abstract String tenantFromPath(Map parsedPath, Map queryString); - abstract String applicationIdFromPath( + public abstract String applicationIdFromPath( Map parsedPath, Map queryString); - abstract String gatewayFromPath( + public abstract String gatewayFromPath( Map parsedPath, Map queryString); + + public abstract GatewayRequestHandler.GatewayRequestValidator validator(); + public void onBeforeHandshakeCompleted( AuthenticatedGatewayRequestContext gatewayRequestContext, Map attributes) - throws Exception {} + throws Exception { + } abstract void onOpen( WebSocketSession webSocketSession, @@ -112,8 +104,6 @@ abstract void onClose( CloseStatus status) throws Exception; - abstract void validateOptions(Map options); - @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { super.afterConnectionEstablished(session); @@ -137,7 +127,7 @@ private void closeSession(WebSocketSession session, Throwable throwable) throws try { session.close(status.withReason(throwable.getMessage())); } finally { - closeCloseableResources(session); + callHandlerOnClose(session, status); } } @@ -156,13 +146,16 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { super.afterConnectionClosed(session, status); + callHandlerOnClose(session, status); + sendClientDisconnectedEvent(getContext(session)); + } + + private void callHandlerOnClose(WebSocketSession session, CloseStatus status) { try { onClose(session, getContext(session), status); } catch (Throwable throwable) { log.error("[{}] error while closing websocket", session.getId(), throwable); } - closeCloseableResources(session); - sendClientDisconnectedEvent(getContext(session)); } @Override @@ -170,163 +163,6 @@ public boolean supportsPartialMessages() { return true; } - private Application getResolvedApplication(String tenant, String applicationId) { - final ApplicationSpecs applicationSpecs = applicationStore.getSpecs(tenant, applicationId); - if (applicationSpecs == null) { - throw new IllegalArgumentException("application " + applicationId + " not found"); - } - final Application application = applicationSpecs.getApplication(); - application.setSecrets(applicationStore.getSecrets(tenant, applicationId)); - return ApplicationPlaceholderResolver.resolvePlaceholders(application); - } - - private Gateway extractGateway( - String gatewayId, Application application, Gateway.GatewayType type) { - final Gateways gatewaysObj = application.getGateways(); - if (gatewaysObj == null) { - throw new IllegalArgumentException("no gateways defined for the application"); - } - final List gateways = gatewaysObj.gateways(); - if (gateways == null) { - throw new IllegalArgumentException("no gateways defined for the application"); - } - - Gateway selectedGateway = null; - - for (Gateway gateway : gateways) { - if (gateway.getId().equals(gatewayId) && type == gateway.getType()) { - selectedGateway = gateway; - break; - } - } - if (selectedGateway == null) { - throw new IllegalArgumentException( - "gateway " - + gatewayId - + " of type " - + type - + " is not defined in the application"); - } - return selectedGateway; - } - - public GatewayRequestContext validateRequest( - Map pathVars, - Map queryString, - Map httpHeaders) { - Map options = new HashMap<>(); - Map userParameters = new HashMap<>(); - - final String credentials = queryString.remove("credentials"); - final String testCredentials = queryString.remove("test-credentials"); - - for (Map.Entry entry : queryString.entrySet()) { - if (entry.getKey().startsWith("option:")) { - options.put(entry.getKey().substring("option:".length()), entry.getValue()); - } else if (entry.getKey().startsWith("param:")) { - userParameters.put(entry.getKey().substring("param:".length()), entry.getValue()); - } else { - throw new IllegalArgumentException( - "invalid query parameter " - + entry.getKey() - + ". " - + "To specify a gateway parameter, use the format param:." - + "To specify a option, use the format option:."); - } - } - - final String tenant = tenantFromPath(pathVars, queryString); - final String applicationId = applicationIdFromPath(pathVars, queryString); - final String gatewayId = gatewayFromPath(pathVars, queryString); - - final Application application = getResolvedApplication(tenant, applicationId); - final Gateway.GatewayType type = gatewayType(); - final Gateway gateway = extractGateway(gatewayId, application, type); - - final List requiredParameters = getAllRequiredParameters(gateway); - Set allUserParameterKeys = new HashSet<>(userParameters.keySet()); - if (requiredParameters != null) { - for (String requiredParameter : requiredParameters) { - final String value = userParameters.get(requiredParameter); - if (!StringUtils.hasText(value)) { - throw new IllegalArgumentException( - formatErrorMessage( - tenant, - applicationId, - gateway, - "missing required parameter " - + requiredParameter - + ". Required parameters: " - + requiredParameters)); - } - allUserParameterKeys.remove(requiredParameter); - } - } - if (!allUserParameterKeys.isEmpty()) { - throw new IllegalArgumentException( - formatErrorMessage( - tenant, - applicationId, - gateway, - "unknown parameters: " + allUserParameterKeys)); - } - validateOptions(options); - - if (credentials != null && testCredentials != null) { - throw new IllegalArgumentException( - formatErrorMessage( - tenant, - applicationId, - gateway, - "credentials and test-credentials cannot be used together")); - } - return GatewayRequestContextImpl.builder() - .tenant(tenant) - .applicationId(applicationId) - .application(application) - .credentials(credentials) - .testCredentials(testCredentials) - .httpHeaders(httpHeaders) - .options(options) - .userParameters(userParameters) - .gateway(gateway) - .build(); - } - - private static String formatErrorMessage( - String tenant, String applicationId, Gateway gateway, String error) { - return "Error for gateway %s (tenant: %s, appId: %s): %s" - .formatted(gateway.getId(), tenant, applicationId, error); - } - - protected abstract List getAllRequiredParameters(Gateway gateway); - - protected static void recordCloseableResource( - Map attributes, AutoCloseable... closeables) { - List currentCloseable = (List) attributes.get("closeables"); - - if (currentCloseable == null) { - currentCloseable = new ArrayList<>(); - } - Collections.addAll(currentCloseable, closeables); - attributes.put("closeables", currentCloseable); - } - - private void closeCloseableResources(WebSocketSession webSocketSession) { - List currentCloseable = - (List) webSocketSession.getAttributes().get("closeables"); - - if (currentCloseable != null) { - for (AutoCloseable autoCloseable : currentCloseable) { - try { - autoCloseable.close(); - } catch (Throwable e) { - log.error("error while closing resource", e); - } - } - } - } - protected void sendClientConnectedEvent(AuthenticatedGatewayRequestContext context) { sendEvent(EventRecord.Types.ClientConnected, context); } @@ -355,10 +191,10 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont topicConnectionsRuntime.init(streamingCluster); try (final TopicProducer producer = - topicConnectionsRuntime.createProducer( - "langstream-events", - streamingCluster, - Map.of("topic", gateway.getEventsTopic()))) { + topicConnectionsRuntime.createProducer( + "langstream-events", + streamingCluster, + Map.of("topic", gateway.getEventsTopic()))) { producer.start(); final EventSources.GatewaySource source = @@ -392,111 +228,20 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont } } - protected static void startReadingMessages( - WebSocketSession session, - AuthenticatedGatewayRequestContext context, - Executor executor) { - final CompletableFuture future = - CompletableFuture.runAsync( - () -> { - final Map attributes = session.getAttributes(); - TopicReader reader = (TopicReader) attributes.get("topicReader"); - final String tenant = context.tenant(); - final String gatewayId = context.gateway().getId(); - final String applicationId = context.applicationId(); - try { - log.info( - "[{}] Started reader for gateway {}/{}/{}", - session.getId(), - tenant, - applicationId, - gatewayId); - readMessages( - session, - (List>) - attributes.get("consumeFilters"), - reader); - } catch (InterruptedException | CancellationException ex) { - // ignore - } catch (Throwable ex) { - log.error(ex.getMessage(), ex); - } finally { - closeReader(reader); - } - }, - executor); - session.getAttributes().put("future", future); - } - - protected static void readMessages( - WebSocketSession session, List> filters, TopicReader reader) - throws Exception { - while (true) { - if (Thread.currentThread().isInterrupted()) { - return; - } - if (!session.isOpen()) { - return; - } - final TopicReadResult readResult = reader.read(); - final List records = readResult.records(); - for (Record record : records) { - log.debug("[{}] Received record {}", session.getId(), record); - boolean skip = false; - if (filters != null) { - for (Function filter : filters) { - if (!filter.apply(record)) { - skip = true; - log.debug("[{}] Skipping record {}", session.getId(), record); - break; - } + protected void startReadingMessages(WebSocketSession webSocketSession, Executor executor) { + final AuthenticatedGatewayRequestContext context = getContext(webSocketSession); + final ConsumeGateway consumeGateway = (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); + final CompletableFuture future = consumeGateway.startReading( + executor, + () -> !webSocketSession.isOpen(), + message -> { + try { + webSocketSession.sendMessage(new TextMessage(message)); + } catch (IOException ex) { + throw new RuntimeException(ex); } - } - if (!skip) { - final Map messageHeaders = computeMessageHeaders(record); - final String offset = computeOffset(readResult); - - final ConsumePushMessage message = - new ConsumePushMessage( - new ConsumePushMessage.Record( - record.key(), record.value(), messageHeaders), - offset); - final String jsonMessage = mapper.writeValueAsString(message); - session.sendMessage(new TextMessage(jsonMessage)); - } - } - } - } - - private static void closeReader(TopicReader reader) { - if (reader == null) { - return; - } - try { - reader.close(); - } catch (Exception e) { - log.error("error closing reader", e); - } - } - - private static Map computeMessageHeaders(Record record) { - final Collection
headers = record.headers(); - final Map messageHeaders; - if (headers == null) { - messageHeaders = Map.of(); - } else { - messageHeaders = new HashMap<>(); - headers.forEach(h -> messageHeaders.put(h.key(), h.valueAsString())); - } - return messageHeaders; - } - - private static String computeOffset(TopicReadResult readResult) { - final byte[] offset = readResult.offset(); - if (offset == null) { - return null; - } - return Base64.getEncoder().encodeToString(offset); + }); + webSocketSession.getAttributes().put("future", future); } protected static List> createMessageFilters( @@ -538,34 +283,14 @@ record -> { } protected void setupReader( - Map sessionAttributes, String topic, - StreamingCluster streamingCluster, List> filters, - Map options) + AuthenticatedGatewayRequestContext context) throws Exception { - sessionAttributes.put("consumeFilters", filters); + final ConsumeGateway consumeGateway = new ConsumeGateway(topicConnectionsRuntimeRegistry); + context.attributes().put(ATTRIBUTE_CONSUME_GATEWAY, consumeGateway); + consumeGateway.setup(topic, filters, context); - final TopicConnectionsRuntime topicConnectionsRuntime = - topicConnectionsRuntimeRegistry - .getTopicConnectionsRuntime(streamingCluster) - .asTopicConnectionsRuntime(); - - topicConnectionsRuntime.init(streamingCluster); - - final String positionParameter = options.getOrDefault("position", "latest"); - TopicOffsetPosition position = - switch (positionParameter) { - case "latest" -> TopicOffsetPosition.LATEST; - case "earliest" -> TopicOffsetPosition.EARLIEST; - default -> TopicOffsetPosition.absolute( - Base64.getDecoder().decode(positionParameter)); - }; - TopicReader reader = - topicConnectionsRuntime.createReader( - streamingCluster, Map.of("topic", topic), position); - reader.start(); - sessionAttributes.put("topicReader", reader); } protected void stopReadingMessages(WebSocketSession webSocketSession) { @@ -576,122 +301,38 @@ protected void stopReadingMessages(WebSocketSession webSocketSession) { } } - protected void setupProducer( - Map sessionAttributes, - String topic, - StreamingCluster streamingCluster, - final List
commonHeaders, - final String tenant, - final String applicationId, - final String gatewayId) { - final TopicConnectionsRuntime topicConnectionsRuntime = - topicConnectionsRuntimeRegistry - .getTopicConnectionsRuntime(streamingCluster) - .asTopicConnectionsRuntime(); - topicConnectionsRuntime.init(streamingCluster); - - final TopicProducer producer = - topicConnectionsRuntime.createProducer( - null, streamingCluster, Map.of("topic", topic)); - recordCloseableResource(sessionAttributes, producer); - producer.start(); - - sessionAttributes.put("producer", producer); - sessionAttributes.put( - "headers", - commonHeaders == null ? List.of() : Collections.unmodifiableList(commonHeaders)); - log.info( - "Started produced for gateway {}/{}/{} on topic {}", - tenant, - applicationId, - gatewayId, - topic); + protected void setupProducer(String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) { + final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry); + context.attributes().put(ATTRIBUTE_PRODUCE_GATEWAY, produceGateway); + produceGateway.start(topic, commonHeaders, context); } - protected static List
getProducerCommonHeaders( - List headerFilters, - Map passedParameters, - Map principalValues) { - final List
headers = new ArrayList<>(); - if (headerFilters == null) { - return headers; - } - for (Gateway.KeyValueComparison mapping : headerFilters) { - if (mapping.key() == null || mapping.key().isEmpty()) { - throw new IllegalArgumentException("Header key cannot be empty"); - } - String value = mapping.value(); - if (value == null && mapping.valueFromParameters() != null) { - value = passedParameters.get(mapping.valueFromParameters()); - } - if (value == null && mapping.valueFromAuthentication() != null) { - value = principalValues.get(mapping.valueFromAuthentication()); - } - if (value == null) { - throw new IllegalArgumentException("header " + mapping.key() + " cannot be empty"); - } - headers.add(SimpleRecord.SimpleHeader.of(mapping.key(), value)); - } - return headers; - } - protected static void produceMessage(WebSocketSession webSocketSession, TextMessage message) + protected void produceMessage(WebSocketSession webSocketSession, + TextMessage message) throws IOException { - final TopicProducer topicProducer = getTopicProducer(webSocketSession, true); - final ProduceRequest produceRequest; try { - produceRequest = mapper.readValue(message.getPayload(), ProduceRequest.class); - } catch (JsonProcessingException err) { - sendResponse(webSocketSession, ProduceResponse.Status.BAD_REQUEST, err.getMessage()); - return; - } - if (produceRequest.value() == null && produceRequest.key() == null) { - sendResponse( - webSocketSession, - ProduceResponse.Status.BAD_REQUEST, - "Either key or value must be set."); - return; + final AuthenticatedGatewayRequestContext context = getContext(webSocketSession); + final ProduceGateway produceGateway = (ProduceGateway) context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); + produceGateway.produceMessage(message.getPayload()); + webSocketSession.sendMessage( + new TextMessage(mapper.writeValueAsString(ProduceResponse.OK))); + } catch (ProduceGateway.ProduceException exception) { + sendResponse(webSocketSession, exception.getStatus(), exception.getMessage()); } + } - final Collection
headers = - new ArrayList<>((List
) webSocketSession.getAttributes().get("headers")); - if (produceRequest.headers() != null) { - final Set configuredHeaders = - headers.stream().map(Header::key).collect(Collectors.toSet()); - for (Map.Entry messageHeader : produceRequest.headers().entrySet()) { - if (configuredHeaders.contains(messageHeader.getKey())) { - sendResponse( - webSocketSession, - ProduceResponse.Status.BAD_REQUEST, - "Header " - + messageHeader.getKey() - + " is configured as parameter-level header."); - return; - } - headers.add( - SimpleRecord.SimpleHeader.of( - messageHeader.getKey(), messageHeader.getValue())); - } - } - try { - final SimpleRecord record = - SimpleRecord.builder() - .key(produceRequest.key()) - .value(produceRequest.value()) - .headers(headers) - .build(); - topicProducer.write(record).get(); - log.info("[{}] Produced record {}", webSocketSession.getId(), record); - } catch (Throwable tt) { - sendResponse(webSocketSession, ProduceResponse.Status.PRODUCER_ERROR, tt.getMessage()); + protected void closeProduceGateway(WebSocketSession webSocketSession) { + final ProduceGateway produceGateway = (ProduceGateway) getContext(webSocketSession).attributes().get( + ATTRIBUTE_PRODUCE_GATEWAY); + if (produceGateway == null) { return; } - - webSocketSession.sendMessage( - new TextMessage(mapper.writeValueAsString(ProduceResponse.OK))); + produceGateway.close(); } + private static void sendResponse( WebSocketSession webSocketSession, ProduceResponse.Status status, String reason) throws IOException { @@ -699,17 +340,4 @@ private static void sendResponse( new TextMessage(mapper.writeValueAsString(new ProduceResponse(status, reason)))); } - private static TopicProducer getTopicProducer( - WebSocketSession webSocketSession, boolean throwIfNotFound) { - final TopicProducer topicProducer = - (TopicProducer) webSocketSession.getAttributes().get("producer"); - if (topicProducer == null) { - if (throwIfNotFound) { - log.error("No producer found for session {}", webSocketSession.getId()); - throw new IllegalStateException( - "No producer found for session " + webSocketSession.getId()); - } - } - return topicProducer; - } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java index 06e9b9e8f..3d6999951 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java @@ -22,6 +22,8 @@ import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; +import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; import java.util.ArrayList; import java.util.List; @@ -53,39 +55,59 @@ public String path() { } @Override - Gateway.GatewayType gatewayType() { + public Gateway.GatewayType gatewayType() { return Gateway.GatewayType.chat; } @Override - String tenantFromPath(Map parsedPath, Map queryString) { + public String tenantFromPath(Map parsedPath, Map queryString) { return parsedPath.get("tenant"); } @Override - String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath(Map parsedPath, Map queryString) { return parsedPath.get("application"); } @Override - String gatewayFromPath(Map parsedPath, Map queryString) { + public String gatewayFromPath(Map parsedPath, Map queryString) { return parsedPath.get("gateway"); } @Override - protected List getAllRequiredParameters(Gateway gateway) { - List parameters = gateway.getParameters(); - if (parameters == null) { - parameters = new ArrayList<>(); - } - if (gateway.getChatOptions() != null && gateway.getChatOptions().getHeaders() != null) { - for (Gateway.KeyValueComparison header : gateway.getChatOptions().getHeaders()) { - if (header.valueFromParameters() != null) { - parameters.add(header.valueFromParameters()); + public GatewayRequestHandler.GatewayRequestValidator validator() { + return new GatewayRequestHandler.GatewayRequestValidator() { + @Override + public List getAllRequiredParameters(Gateway gateway) { + List parameters = gateway.getParameters(); + if (parameters == null) { + parameters = new ArrayList<>(); + } + if (gateway.getChatOptions() != null && gateway.getChatOptions().getHeaders() != null) { + for (Gateway.KeyValueComparison header : gateway.getChatOptions().getHeaders()) { + if (header.valueFromParameters() != null) { + parameters.add(header.valueFromParameters()); + } + } } + return parameters; } - } - return parameters; + + @Override + public void validateOptions(Map options) { + for (Map.Entry option : options.entrySet()) { + switch (option.getKey()) { + case "position": + if (!StringUtils.hasText(option.getValue())) { + throw new IllegalArgumentException("'position' cannot be blank"); + } + break; + default: + throw new IllegalArgumentException("Unknown option " + option.getKey()); + } + } + } + }; } @Override @@ -110,16 +132,10 @@ private void setupProducer(AuthenticatedGatewayRequestContext context) { } } final List
commonHeaders = - getProducerCommonHeaders( - headerConfig, context.userParameters(), context.principalValues()); - setupProducer( - context.attributes(), - chatOptions.getQuestionsTopic(), - context.application().getInstance().streamingCluster(), - commonHeaders, - context.tenant(), - context.applicationId(), - context.gateway().getId()); + ProduceGateway.getProducerCommonHeaders(headerConfig, context.userParameters(), + context.principalValues()); + + setupProducer(chatOptions.getQuestionsTopic(), commonHeaders, context); } private void setupReader(AuthenticatedGatewayRequestContext context) throws Exception { @@ -136,18 +152,15 @@ private void setupReader(AuthenticatedGatewayRequestContext context) throws Exce createMessageFilters( headerFilters, context.userParameters(), context.principalValues()); - setupReader( - context.attributes(), - chatOptions.getAnswersTopic(), - context.application().getInstance().streamingCluster(), + setupReader(chatOptions.getAnswersTopic(), messageFilters, - context.options()); + context); } @Override public void onOpen( WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context) { - startReadingMessages(webSocketSession, context, executor); + startReadingMessages(webSocketSession, executor); } @Override @@ -165,18 +178,4 @@ public void onClose( AuthenticatedGatewayRequestContext context, CloseStatus status) {} - @Override - void validateOptions(Map options) { - for (Map.Entry option : options.entrySet()) { - switch (option.getKey()) { - case "position": - if (!StringUtils.hasText(option.getValue())) { - throw new IllegalArgumentException("'position' cannot be blank"); - } - break; - default: - throw new IllegalArgumentException("Unknown option " + option.getKey()); - } - } - } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java index efdcaa8b3..e5a510677 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java @@ -21,6 +21,7 @@ import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; import java.util.List; import java.util.Map; @@ -51,28 +52,48 @@ public String path() { } @Override - Gateway.GatewayType gatewayType() { + public Gateway.GatewayType gatewayType() { return Gateway.GatewayType.consume; } @Override - String tenantFromPath(Map parsedPath, Map queryString) { + public String tenantFromPath(Map parsedPath, Map queryString) { return parsedPath.get("tenant"); } @Override - String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath(Map parsedPath, Map queryString) { return parsedPath.get("application"); } @Override - String gatewayFromPath(Map parsedPath, Map queryString) { + public String gatewayFromPath(Map parsedPath, Map queryString) { return parsedPath.get("gateway"); } @Override - protected List getAllRequiredParameters(Gateway gateway) { - return gateway.getParameters(); + public GatewayRequestHandler.GatewayRequestValidator validator() { + return new GatewayRequestHandler.GatewayRequestValidator() { + @Override + public List getAllRequiredParameters(Gateway gateway) { + return gateway.getParameters(); + } + + @Override + public void validateOptions(Map options) { + for (Map.Entry option : options.entrySet()) { + switch (option.getKey()) { + case "position": + if (!StringUtils.hasText(option.getValue())) { + throw new IllegalArgumentException("'position' cannot be blank"); + } + break; + default: + throw new IllegalArgumentException("Unknown option " + option.getKey()); + } + } + } + }; } @Override @@ -92,19 +113,15 @@ public void onBeforeHandshakeCompleted( } else { messageFilters = null; } - - setupReader( - context.attributes(), - context.gateway().getTopic(), - context.application().getInstance().streamingCluster(), + setupReader(context.gateway().getTopic(), messageFilters, - context.options()); + context); sendClientConnectedEvent(context); } @Override public void onOpen(WebSocketSession session, AuthenticatedGatewayRequestContext context) { - startReadingMessages(session, context, executor); + startReadingMessages(session, executor); } @Override @@ -121,18 +138,4 @@ public void onClose( stopReadingMessages(webSocketSession); } - @Override - void validateOptions(Map options) { - for (Map.Entry option : options.entrySet()) { - switch (option.getKey()) { - case "position": - if (!StringUtils.hasText(option.getValue())) { - throw new IllegalArgumentException("'position' cannot be blank"); - } - break; - default: - throw new IllegalArgumentException("Unknown option " + option.getKey()); - } - } - } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java index 403ed42ee..d39e7903a 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java @@ -21,7 +21,10 @@ import ai.langstream.api.runner.code.Header; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.GatewayRequestHandler; +import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import java.util.ArrayList; import java.util.List; import java.util.Map; import lombok.extern.slf4j.Slf4j; @@ -44,54 +47,38 @@ public String path() { } @Override - Gateway.GatewayType gatewayType() { + public Gateway.GatewayType gatewayType() { return Gateway.GatewayType.produce; } @Override - String tenantFromPath(Map parsedPath, Map queryString) { + public String tenantFromPath(Map parsedPath, Map queryString) { return parsedPath.get("tenant"); } @Override - String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath(Map parsedPath, Map queryString) { return parsedPath.get("application"); } @Override - String gatewayFromPath(Map parsedPath, Map queryString) { + public String gatewayFromPath(Map parsedPath, Map queryString) { return parsedPath.get("gateway"); } @Override - protected List getAllRequiredParameters(Gateway gateway) { - return gateway.getParameters(); + public GatewayRequestHandler.GatewayRequestValidator validator() { + return new ProduceGateway.ProduceGatewayRequestValidator(); } @Override public void onBeforeHandshakeCompleted( AuthenticatedGatewayRequestContext context, Map attributes) throws Exception { - final Gateway gateway = context.gateway(); - final Gateway.ProduceOptions produceOptions = gateway.getProduceOptions(); - final List
commonHeaders; - if (produceOptions != null) { - commonHeaders = - getProducerCommonHeaders( - produceOptions.headers(), - context.userParameters(), - context.principalValues()); - } else { - commonHeaders = null; - } - setupProducer( - context.attributes(), - gateway.getTopic(), - context.application().getInstance().streamingCluster(), - commonHeaders, - context.tenant(), - context.applicationId(), - gateway.getId()); + final List
commonHeaders = + ProduceGateway.getProducerCommonHeaders(context.gateway().getProduceOptions(), context); + setupProducer(context.gateway().getTopic(), commonHeaders, context); + sendClientConnectedEvent(context); } @@ -112,14 +99,7 @@ public void onMessage( public void onClose( WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context, - CloseStatus status) {} - - @Override - void validateOptions(Map options) { - for (Map.Entry option : options.entrySet()) { - switch (option.getKey()) { - default -> throw new IllegalArgumentException("Unknown option " + option.getKey()); - } - } + CloseStatus status) { + closeProduceGateway(webSocketSession); } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java index 9c75f47ce..99ac59dd3 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java @@ -25,6 +25,7 @@ @Builder public class AuthenticatedGatewayRequestContextImpl implements AuthenticatedGatewayRequestContext { + private final String sessionId; private final GatewayRequestContext gatewayRequestContext; private final Map attributes; private final Map principalValues; diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java new file mode 100644 index 000000000..7d4485173 --- /dev/null +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -0,0 +1,413 @@ +package ai.langstream.apigateway.http; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import ai.langstream.api.events.EventRecord; +import ai.langstream.api.events.EventSources; +import ai.langstream.api.events.GatewayEventData; +import ai.langstream.api.model.Application; +import ai.langstream.api.model.ApplicationSpecs; +import ai.langstream.api.model.Gateway; +import ai.langstream.api.model.Gateways; +import ai.langstream.api.model.StoredApplication; +import ai.langstream.api.model.StreamingCluster; +import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; +import ai.langstream.api.runtime.ClusterRuntimeRegistry; +import ai.langstream.api.runtime.PluginsRegistry; +import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; +import ai.langstream.apigateway.websocket.api.ConsumePushMessage; +import ai.langstream.apigateway.websocket.api.ProduceRequest; +import ai.langstream.apigateway.websocket.api.ProduceResponse; +import ai.langstream.apigateway.websocket.handlers.TestWebSocketClient; +import ai.langstream.impl.deploy.ApplicationDeployer; +import ai.langstream.impl.parser.ModelBuilder; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import jakarta.websocket.CloseReason; +import jakarta.websocket.DeploymentException; +import jakarta.websocket.Session; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import lombok.Cleanup; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mockito; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; + + +@SpringBootTest( + webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, + properties = { + "spring.main.allow-bean-definition-overriding=true", + }) +@WireMockTest +@Slf4j +abstract class GatewayResourceTest { + + public static final Path agentsDirectory; + protected static final HttpClient CLIENT = HttpClient.newHttpClient(); + + static { + agentsDirectory = Path.of(System.getProperty("user.dir"), "target", "agents"); + log.info("Agents directory is {}", agentsDirectory); + } + + protected static final ObjectMapper MAPPER = new ObjectMapper(); + + static List topics; + static Gateways testGateways; + + protected static ApplicationStore getMockedStore(String instanceYaml) { + ApplicationStore mock = Mockito.mock(ApplicationStore.class); + doAnswer( + invocationOnMock -> { + final StoredApplication storedApplication = new StoredApplication(); + final Application application = buildApp(instanceYaml); + storedApplication.setInstance(application); + return storedApplication; + }) + .when(mock) + .get(anyString(), anyString(), anyBoolean()); + doAnswer( + invocationOnMock -> + ApplicationSpecs.builder() + .application(buildApp(instanceYaml)) + .build()) + .when(mock) + .getSpecs(anyString(), anyString()); + + return mock; + } + + protected static GatewayTestAuthenticationProperties getGatewayTestAuthenticationProperties() { + final GatewayTestAuthenticationProperties props = new GatewayTestAuthenticationProperties(); + props.setType("http"); + props.setConfiguration( + Map.of( + "base-url", + wireMockBaseUrl, + "path-template", + "/auth/{tenant}", + "headers", + Map.of("h1", "v1"))); + return props; + } + + @Autowired + private TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeProvider; + + @NotNull + private static Application buildApp(String instanceYaml) throws Exception { + final Map module = + Map.of( + "module", + "mod1", + "id", + "p", + "topics", + topics.stream() + .map( + t -> + Map.of( + "name", + t, + "creation-mode", + "create-if-not-exists")) + .collect(Collectors.toList())); + + final Application application = + ModelBuilder.buildApplicationInstance( + Map.of( + "module.yaml", + new ObjectMapper(new YAMLFactory()) + .writeValueAsString(module)), + instanceYaml, + null) + .getApplication(); + application.setGateways(testGateways); + return application; + } + + @LocalServerPort + int port; + + @Autowired ApplicationStore store; + + static WireMock wireMock; + static String wireMockBaseUrl; + static AtomicInteger topicCounter = new AtomicInteger(); + + private static String genTopic() { + return "topic" + topicCounter.incrementAndGet(); + } + + + @BeforeAll + public static void beforeAll(WireMockRuntimeInfo wmRuntimeInfo) { + wireMock = wmRuntimeInfo.getWireMock(); + wireMockBaseUrl = wmRuntimeInfo.getHttpBaseUrl(); + } + + @BeforeEach + public void beforeEach(WireMockRuntimeInfo wmRuntimeInfo) { + testGateways = null; + topics = null; + Awaitility.setDefaultTimeout(30, TimeUnit.SECONDS); + } + + @AfterAll + public static void afterAll() { + Awaitility.reset(); + } + + + @SneakyThrows + void produceAndExpectOk(String url, String content) { + final HttpRequest request = + HttpRequest.newBuilder(URI.create(url)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(content)) + .build(); + final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + assertEquals(200, response.statusCode()); + assertEquals(""" + {"status":"OK","reason":null}""", response.body()); + + } + + @SneakyThrows + void produceAndExpectBadRequest(String url, String content, String errorMessage) { + final HttpRequest request = + HttpRequest.newBuilder(URI.create(url)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(content)) + .build(); + final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + assertEquals(400, response.statusCode()); + log.info("Response body: {}", response.body()); + final Map map = new ObjectMapper().readValue(response.body(), Map.class); + String detail = (String)map.get("detail"); + assertTrue(detail.contains(errorMessage)); + + } + + @SneakyThrows + void produceAndExpectUnauthorized(String url, String content) { + final HttpRequest request = + HttpRequest.newBuilder(URI.create(url)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(content)) + .build(); + final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + assertEquals(401, response.statusCode()); + log.info("Response body: {}", response.body()); + + } + + @Test + void testSimpleProduce() throws Exception { + final String topic = genTopic(); + prepareTopicsForTest(topic); + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("produce") + .type(Gateway.GatewayType.produce) + .topic(topic) + .build(), + Gateway.builder() + .id("consume") + .type(Gateway.GatewayType.consume) + .topic(topic) + .build())); + + final String url = + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); + + produceAndExpectOk(url, "{\"value\": \"my-value\"}"); + produceAndExpectOk(url, "{\"key\": \"my-key\"}"); + produceAndExpectOk(url, "{\"key\": \"my-key\", \"headers\": {\"h1\": \"v1\"}}"); + } + + + @Test + void testParametersRequired() throws Exception { + final String topic = genTopic(); + prepareTopicsForTest(topic); + + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("gw") + .type(Gateway.GatewayType.produce) + .topic(topic) + .parameters(List.of("session-id")) + .build())); + + final String baseUrl = + "http://localhost:%d/api/gateways/produce/tenant1/application1/gw".formatted(port); + + final String content = "{\"value\": \"my-value\"}"; + produceAndExpectBadRequest(baseUrl, content, "missing required parameter session-id"); + produceAndExpectBadRequest(baseUrl+ "?param:otherparam=1", content, "missing required parameter session-id"); + produceAndExpectBadRequest(baseUrl+ "?param:session-id=", content, "missing required parameter session-id"); + produceAndExpectBadRequest(baseUrl+ "?param:session-id=ok¶m:another-non-declared=y", content, "unknown parameters: [another-non-declared]"); + produceAndExpectOk(baseUrl+ "?param:session-id=1", content); + produceAndExpectOk(baseUrl+ "?param:session-id=string-value", content); + + } + + + + @Test + void testAuthentication() throws Exception { + final String topic = genTopic(); + prepareTopicsForTest(topic); + + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("produce") + .type(Gateway.GatewayType.produce) + .topic(topic) + .authentication( + new Gateway.Authentication( + "test-auth", Map.of(), true)) + .produceOptions( + new Gateway.ProduceOptions( + List.of( + Gateway.KeyValueComparison + .valueFromAuthentication( + "header1", + "login")))) + .build(), + Gateway.builder() + .id("consume") + .type(Gateway.GatewayType.consume) + .topic(topic) + .authentication( + new Gateway.Authentication( + "test-auth", Map.of(), true)) + .consumeOptions( + new Gateway.ConsumeOptions( + new Gateway.ConsumeOptionsFilters( + List.of( + Gateway.KeyValueComparison + .valueFromAuthentication( + "header1", + "login"))))) + .build())); + + final String baseUrl = + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); + + produceAndExpectUnauthorized(baseUrl, "{\"value\": \"my-value\"}"); + produceAndExpectUnauthorized(baseUrl + "?credentials=", "{\"value\": \"my-value\"}"); + produceAndExpectUnauthorized(baseUrl + "?credentials=error", "{\"value\": \"my-value\"}"); + produceAndExpectOk(baseUrl + "?credentials=test-user-password", "{\"value\": \"my-value\"}"); + } + + @Test + void testTestCredentials() throws Exception { + wireMock.register( + WireMock.get("/auth/tenant1") + .withHeader("Authorization", WireMock.equalTo("Bearer test-user-password")) + .withHeader("h1", WireMock.equalTo("v1")) + .willReturn(WireMock.ok(""))); + final String topic = genTopic(); + prepareTopicsForTest(topic); + + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("produce") + .type(Gateway.GatewayType.produce) + .topic(topic) + .authentication( + new Gateway.Authentication( + "test-auth", Map.of(), true)) + .produceOptions( + new Gateway.ProduceOptions( + List.of( + Gateway.KeyValueComparison + .valueFromAuthentication( + "header1", + "login")))) + .build(), + Gateway.builder() + .id("produce-no-test") + .type(Gateway.GatewayType.produce) + .topic(topic) + .authentication( + new Gateway.Authentication( + "test-auth", Map.of(), false)) + .build())); + + final String baseUrl = + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); + + + produceAndExpectUnauthorized(baseUrl + "?test-credentials=test", "{\"value\": \"my-value\"}"); + produceAndExpectOk(baseUrl + "?test-credentials=test-user-password", "{\"value\": \"my-value\"}"); + produceAndExpectUnauthorized("http://localhost:%d/api/gateways/produce/tenant1/application1/produce-no-test?test-credentials=test-user-password".formatted(port), "{\"value\": \"my-value\"}"); + + } + + + + protected abstract StreamingCluster getStreamingCluster(); + + private void prepareTopicsForTest(String... topic) throws Exception { + topics = List.of(topic); + TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry = + topicConnectionsRuntimeProvider.getTopicConnectionsRuntimeRegistry(); + final ApplicationDeployer deployer = + ApplicationDeployer.builder() + .pluginsRegistry(new PluginsRegistry()) + .registry(new ClusterRuntimeRegistry()) + .topicConnectionsRuntimeRegistry(topicConnectionsRuntimeRegistry) + .build(); + final StreamingCluster streamingCluster = getStreamingCluster(); + topicConnectionsRuntimeRegistry + .getTopicConnectionsRuntime(streamingCluster) + .asTopicConnectionsRuntime() + .deploy( + deployer.createImplementation( + "app", store.get("t", "app", false).getInstance())); + } + +} \ No newline at end of file diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java new file mode 100644 index 000000000..097ecdbcb --- /dev/null +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java @@ -0,0 +1,98 @@ +package ai.langstream.apigateway.http; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import ai.langstream.api.model.Application; +import ai.langstream.api.model.ApplicationSpecs; +import ai.langstream.api.model.Gateway; +import ai.langstream.api.model.Gateways; +import ai.langstream.api.model.StoredApplication; +import ai.langstream.api.model.StreamingCluster; +import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; +import ai.langstream.api.runtime.ClusterRuntimeRegistry; +import ai.langstream.api.runtime.PluginsRegistry; +import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; +import ai.langstream.impl.deploy.ApplicationDeployer; +import ai.langstream.impl.parser.ModelBuilder; +import ai.langstream.kafka.extensions.KafkaContainerExtension; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.Mockito; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; + + +class KafkaGatewayResourceTest extends GatewayResourceTest { + + @RegisterExtension + static KafkaContainerExtension kafkaContainer = new KafkaContainerExtension(); + + @Override + protected StreamingCluster getStreamingCluster() { + return new StreamingCluster( + "kafka", + Map.of( + "admin", + Map.of( + "bootstrap.servers", + kafkaContainer.getBootstrapServers(), + "default.api.timeout.ms", + 5000))); + } + + @TestConfiguration + public static class WebSocketTestConfig { + + @Bean + @Primary + public ApplicationStore store() { + String instanceYaml = + """ + instance: + streamingCluster: + type: "kafka" + configuration: + admin: + bootstrap.servers: "%s" + computeCluster: + type: "none" + """ + .formatted(kafkaContainer.getBootstrapServers()); + return getMockedStore(instanceYaml); + } + + @Bean + @Primary + public GatewayTestAuthenticationProperties gatewayTestAuthenticationProperties() { + return getGatewayTestAuthenticationProperties(); + } + } +} \ No newline at end of file diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java new file mode 100644 index 000000000..ab4987fb4 --- /dev/null +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java @@ -0,0 +1,71 @@ +package ai.langstream.apigateway.http; + +import ai.langstream.api.model.StreamingCluster; +import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; +import ai.langstream.apigateway.websocket.handlers.PulsarContainerExtension; +import ai.langstream.kafka.extensions.KafkaContainerExtension; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; + + +class PulsarGatewayResourceTest extends GatewayResourceTest { + + @RegisterExtension + static PulsarContainerExtension pulsarContainer = new PulsarContainerExtension(); + + @Override + protected StreamingCluster getStreamingCluster() { + return new StreamingCluster( + "pulsar", + Map.of( + "admin", + Map.of("serviceUrl", pulsarContainer.getHttpServiceUrl()), + "service", + Map.of("serviceUrl", pulsarContainer.getBrokerUrl()), + "default-tenant", + "public", + "default-namespace", + "default")); + } + + @TestConfiguration + public static class WebSocketTestConfig { + + @Bean + @Primary + public ApplicationStore store() { + String instanceYaml = + """ + instance: + streamingCluster: + type: "pulsar" + configuration: + admin: + serviceUrl: "%s" + service: + serviceUrl: "%s" + default-tenant: "public" + default-namespace: "default" + computeCluster: + type: "none" + """ + .formatted( + pulsarContainer.getHttpServiceUrl(), + pulsarContainer.getBrokerUrl()); + return getMockedStore(instanceYaml); + } + + @Bean + @Primary + public GatewayTestAuthenticationProperties gatewayTestAuthenticationProperties() { + return getGatewayTestAuthenticationProperties(); + } + } +} \ No newline at end of file diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java index c3782954b..85dc521c4 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java @@ -35,7 +35,7 @@ public void initialize(Map configuration) {} @Override public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { log.info("Authenticating {}", context.credentials()); - if (context.credentials().startsWith("test-user-password")) { + if (context.credentials() != null && context.credentials().startsWith("test-user-password")) { return GatewayAuthenticationResult.authenticationSuccessful( Map.of("login", context.credentials())); } else { diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java index 0f99b7fc8..4f027e753 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java @@ -291,34 +291,22 @@ void testParametersRequired(String type) throws Exception { .topic(topic) .parameters(List.of("session-id")) .build())); - connectAndExpectClose( - URI.create("ws://localhost:%d/v1/%s/tenant1/application1/gw".formatted(port, type)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); - connectAndExpectClose( + connectAndExpectHttpError( + URI.create("ws://localhost:%d/v1/%s/tenant1/application1/gw".formatted(port, type)), 500); + connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/%s/tenant1/application1/gw?param:otherparam=1" - .formatted(port, type)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); - connectAndExpectClose( + .formatted(port, type)), 500); + connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/%s/tenant1/application1/gw?param:session-id=" - .formatted(port, type)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); + .formatted(port, type)), 500); - connectAndExpectClose( + connectAndExpectHttpError( URI.create( ("ws://localhost:%d/v1/%s/tenant1/application1/gw?param:session-id=ok¶m:another-non" + "-declared=y") - .formatted(port, type)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "unknown parameters: [another-non-declared]")); + .formatted(port, type)), 500); connectAndExpectRunning( URI.create( @@ -474,27 +462,18 @@ void testAuthentication() throws Exception { "login"))))) .build())); - connectAndExpectClose( + connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce" - .formatted(port)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); - connectAndExpectClose( + .formatted(port)), 401); + connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=" - .formatted(port)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); - connectAndExpectClose( + .formatted(port)), 401); + connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=error" - .formatted(port)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "missing required parameter session-id")); + .formatted(port)), 401); connectAndExpectRunning( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=test-user-password" @@ -622,21 +601,17 @@ void testTestCredentials() throws Exception { "9d75ff199d33e051209b59702de27d1e470eafb58ac6d8865788bf23b48e6818"))), user1Messages)); - connectAndExpectClose( + connectAndExpectHttpError( URI.create( - ("ws://localhost:%d/v1/consume/tenant1/application1/consume-no-admin?test-credentials=test" + ("ws://localhost:%d/v1/consume/tenant1/application1/consume-no-test?test-credentials=test" + "-user-password") - .formatted(port)), - new CloseReason( - CloseReason.CloseCodes.VIOLATED_POLICY, - "Gateway consume-no-test of tenant tenant1 does not allow test mode.")); + .formatted(port)), 401); - connectAndExpectClose( + connectAndExpectHttpError( URI.create( ("ws://localhost:%d/v1/produce/tenant1/application1/produce?test-credentials=test-user" + "-password-but-wrong") - .formatted(port)), - new CloseReason(CloseReason.CloseCodes.VIOLATED_POLICY, "Invalid credentials")); + .formatted(port)), 401); } private record MsgRecord(Object key, Object value, Map headers) {} @@ -1109,8 +1084,17 @@ private static String genTopic() { return "topic" + topicCounter.incrementAndGet(); } - @SneakyThrows + private void connectAndExpectClose(URI connectTo, CloseReason expectedCloseReason) { + connectAndExpectClose(connectTo, expectedCloseReason, -1); + + } + private void connectAndExpectHttpError(URI connectTo, int code) { + connectAndExpectClose(connectTo, null, code); + + } + @SneakyThrows + private void connectAndExpectClose(URI connectTo, CloseReason expectedCloseReason, int code) { CountDownLatch countDownLatch = new CountDownLatch(1); AtomicReference closeReason = new AtomicReference<>(); @@ -1137,11 +1121,20 @@ public void onError(Throwable throwable) { .connect(connectTo)) { Thread.sleep(5000); countDownLatch.await(); + if (expectedCloseReason == null) { + throw new RuntimeException("close reason not expected"); + } assertEquals( expectedCloseReason.getReasonPhrase(), closeReason.get().getReasonPhrase()); assertEquals(expectedCloseReason.getCloseCode(), closeReason.get().getCloseCode()); } catch (DeploymentException e) { - // ok + if (code > 0) { + if (e.getMessage().contains("[" + code + "]")) { + return; + } + throw new RuntimeException("expected http error code " + code, e); + } + throw new RuntimeException(e); } } diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java index 8a15d3da9..7164ceebb 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java @@ -34,6 +34,9 @@ public abstract class BaseGatewayCmd extends BaseCmd { protected static final ObjectMapper messageMapper = new ObjectMapper(); + protected enum Protocols { + ws, http; + } @CommandLine.ParentCommand private RootGatewayCmd cmd; @@ -83,7 +86,8 @@ protected String validateGatewayAndGetUrl( Map params, Map options, String credentials, - String testCredentials) { + String testCredentials, + Protocols protocol) { validateGateway( applicationId, gatewayId, type, params, options, credentials, testCredentials); @@ -94,6 +98,21 @@ protected String validateGatewayAndGetUrl( if (testCredentials != null) { systemParams.put("test-credentials", testCredentials); } + if (protocol == Protocols.http) { + if (!type.equals("produce")) { + throw new IllegalArgumentException("HTTP protocol is only supported for produce"); + } + return String.format( + "%s/api/gateways/%s/%s/%s/%s?%s", + getApiGatewayUrlHttp(), + type, + getTenant(), + applicationId, + gatewayId, + computeQueryString(systemParams, params, options)); + + } + return String.format( "%s/v1/%s/%s/%s/%s?%s", getApiGatewayUrl(), @@ -112,6 +131,11 @@ private String getApiGatewayUrl() { return getCurrentProfile().getApiGatewayUrl(); } + + private String getApiGatewayUrlHttp() { + return getApiGatewayUrl().replace("wss://", "https://").replace("ws://", "http://"); + } + @SneakyThrows protected void validateGateway( String application, diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java index 4ee0b00c5..98c1ae2ae 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java @@ -238,7 +238,8 @@ private GatewayConnection createGatewayConnection() { finalParams, consumeGatewayOptions, credentials, - testCredentials); + testCredentials, + Protocols.ws); return new ChatGatewayConnection(url, connectTimeout); } @@ -250,7 +251,8 @@ private GatewayConnection createGatewayConnection() { params, consumeGatewayOptions, credentials, - testCredentials); + testCredentials, + Protocols.ws); final String producePath = validateGatewayAndGetUrl( applicationId, @@ -259,7 +261,8 @@ private GatewayConnection createGatewayConnection() { params, Map.of(), credentials, - testCredentials); + testCredentials, + Protocols.ws); return new ProduceConsumeGatewaysConnection(consumePath, producePath, connectTimeout); } diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java index d09dc839f..3a8970079 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java @@ -83,7 +83,8 @@ public void run() { params, options, credentials, - testCredentials); + testCredentials, + Protocols.ws); final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java index 4627b2332..8e62c8bf6 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java @@ -19,6 +19,9 @@ import ai.langstream.cli.websocket.WebSocketClient; import jakarta.websocket.CloseReason; import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -82,6 +85,11 @@ static class ProduceRequest { description = "Test credentials for the gateway.") private String testCredentials; + @CommandLine.Option( + names = {"--protocol"}, + description = "Protocol to use: http or ws", defaultValue = "ws") + private Protocols protocol = Protocols.ws; + @Override @SneakyThrows public void run() { @@ -93,9 +101,23 @@ public void run() { params, Map.of(), credentials, - testCredentials); + testCredentials, + protocol); final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; + + final ProduceRequest produceRequest = + new ProduceRequest(messageKey, messageValue, headers); + final String json = messageMapper.writeValueAsString(produceRequest); + + if (protocol == Protocols.http) { + produceHttp(producePath, connectTimeout, json); + } else { + produceWebSocket(producePath, connectTimeout, json); + } + } + + private void produceWebSocket(String producePath, Duration connectTimeout, String json) throws Exception { CountDownLatch countDownLatch = new CountDownLatch(1); try (final WebSocketClient client = new WebSocketClient( @@ -128,11 +150,23 @@ public void onError(Throwable throwable) { } }) .connect(URI.create(producePath), connectTimeout)) { - final ProduceRequest produceRequest = - new ProduceRequest(messageKey, messageValue, headers); - final String json = messageMapper.writeValueAsString(produceRequest); + client.send(json); countDownLatch.await(); } } + + private void produceHttp(String producePath, Duration connectTimeout, String json) throws Exception { + final HttpRequest.Builder builder = HttpRequest.newBuilder(URI.create(producePath)) + .header("Content-Type", "application/json") + .version(HttpClient.Version.HTTP_1_1) + .POST(HttpRequest.BodyPublishers.ofString(json)); + if (connectTimeout != null) { + builder.timeout(connectTimeout); + } + final HttpRequest request = builder.build(); + final HttpResponse response = + getClient().getHttpClientFacade().http(request, HttpResponse.BodyHandlers.ofString()); + log(response.body()); + } } From 7f586ea62a976c682cdce31db1bd359329dbebd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Thu, 26 Oct 2023 14:57:50 +0200 Subject: [PATCH 2/4] fix --- .../apigateway/gateways/ConsumeGateway.java | 93 +++++++------ .../gateways/GatewayRequestHandler.java | 72 +++++----- .../GatewayRequestHandlerFactory.java | 25 +++- .../apigateway/gateways/ProduceGateway.java | 57 +++++--- .../apigateway/http/GatewayResource.java | 34 +++-- .../langstream/apigateway/util/HttpUtil.java | 15 ++ .../websocket/AuthenticationInterceptor.java | 19 +-- .../apigateway/websocket/WebSocketConfig.java | 1 - .../websocket/handlers/AbstractHandler.java | 62 ++++----- .../websocket/handlers/ChatHandler.java | 20 +-- .../websocket/handlers/ConsumeHandler.java | 8 +- .../websocket/handlers/ProduceHandler.java | 7 +- .../apigateway/http/GatewayResourceTest.java | 130 +++++++++--------- .../http/KafkaGatewayResourceTest.java | 57 +++----- .../http/PulsarGatewayResourceTest.java | 22 ++- .../TestGatewayAuthenticationProvider.java | 3 +- .../handlers/ProduceConsumeHandlerTest.java | 32 +++-- .../cli/commands/gateway/BaseGatewayCmd.java | 6 +- .../commands/gateway/ProduceGatewayCmd.java | 25 ++-- 19 files changed, 365 insertions(+), 323 deletions(-) diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index b7686988b..1cff2b3f9 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -1,3 +1,18 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.gateways; import ai.langstream.api.model.Gateway; @@ -45,7 +60,8 @@ public ProduceException(String message, ProduceResponse.Status status) { } } - public static class ProduceGatewayRequestValidator implements GatewayRequestHandler.GatewayRequestValidator { + public static class ProduceGatewayRequestValidator + implements GatewayRequestHandler.GatewayRequestValidator { @Override public List getAllRequiredParameters(Gateway gateway) { return gateway.getParameters(); @@ -55,13 +71,13 @@ public List getAllRequiredParameters(Gateway gateway) { public void validateOptions(Map options) { for (Map.Entry option : options.entrySet()) { switch (option.getKey()) { - default -> throw new IllegalArgumentException("Unknown option " + option.getKey()); + default -> throw new IllegalArgumentException( + "Unknown option " + option.getKey()); } } } } - private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; private TopicReader reader; private AuthenticatedGatewayRequestContext requestContext; @@ -72,11 +88,15 @@ public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeReg } @SneakyThrows - public void setup(String topic, List> filters, AuthenticatedGatewayRequestContext requestContext) { + public void setup( + String topic, + List> filters, + AuthenticatedGatewayRequestContext requestContext) { this.requestContext = requestContext; this.filters = filters == null ? List.of() : filters; - final StreamingCluster streamingCluster = requestContext.application().getInstance().streamingCluster(); + final StreamingCluster streamingCluster = + requestContext.application().getInstance().streamingCluster(); final TopicConnectionsRuntime topicConnectionsRuntime = topicConnectionsRuntimeRegistry .getTopicConnectionsRuntime(streamingCluster) @@ -84,7 +104,8 @@ public void setup(String topic, List> filters, Authent topicConnectionsRuntime.init(streamingCluster); - final String positionParameter = requestContext.options().getOrDefault("position", "latest"); + final String positionParameter = + requestContext.options().getOrDefault("position", "latest"); TopicOffsetPosition position = switch (positionParameter) { case "latest" -> TopicOffsetPosition.LATEST; @@ -98,37 +119,35 @@ public void setup(String topic, List> filters, Authent reader.start(); } - - public CompletableFuture startReading(Executor executor, Supplier stop, Consumer onMessage) { + public CompletableFuture startReading( + Executor executor, Supplier stop, Consumer onMessage) { if (requestContext == null || reader == null) { throw new IllegalStateException("Not initialized"); } return CompletableFuture.runAsync( - () -> { - - try { - final String tenant = requestContext.tenant(); - - final String gatewayId = requestContext.gateway().getId(); - final String applicationId = requestContext.applicationId(); - log.info( - "Started reader for gateway {}/{}/{}", - tenant, - applicationId, - gatewayId); - readMessages(stop, onMessage); - } catch (InterruptedException | CancellationException ex) { - // ignore - } catch (Throwable ex) { - log.error(ex.getMessage(), ex); - } finally { - closeReader(); - } - }, - executor); + () -> { + try { + final String tenant = requestContext.tenant(); + + final String gatewayId = requestContext.gateway().getId(); + final String applicationId = requestContext.applicationId(); + log.info( + "Started reader for gateway {}/{}/{}", + tenant, + applicationId, + gatewayId); + readMessages(stop, onMessage); + } catch (InterruptedException | CancellationException ex) { + // ignore + } catch (Throwable ex) { + log.error(ex.getMessage(), ex); + } finally { + closeReader(); + } + }, + executor); } - protected void readMessages(Supplier stop, Consumer onMessage) throws Exception { while (true) { @@ -163,15 +182,11 @@ protected void readMessages(Supplier stop, Consumer onMessage) offset); final String jsonMessage = mapper.writeValueAsString(message); onMessage.accept(jsonMessage); - } } } } - - - private static Map computeMessageHeaders(Record record) { final Collection
headers = record.headers(); final Map messageHeaders; @@ -192,9 +207,6 @@ private static String computeOffset(TopicReadResult readResult) { return Base64.getEncoder().encodeToString(offset); } - - - private void closeReader() { if (reader != null) { try { @@ -205,11 +217,8 @@ private void closeReader() { } } - - @Override - public void close() { + public void close() { closeReader(); } - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java index 53cf1cdd2..d5d8b10bc 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandler.java @@ -1,3 +1,18 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.gateways; import ai.langstream.api.gateway.GatewayAuthenticationProvider; @@ -11,7 +26,6 @@ import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.AuthenticationInterceptor; import ai.langstream.apigateway.websocket.impl.AuthenticatedGatewayRequestContextImpl; import ai.langstream.apigateway.websocket.impl.GatewayRequestContextImpl; import ai.langstream.impl.common.ApplicationPlaceholderResolver; @@ -20,7 +34,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.digest.DigestUtils; import org.springframework.util.StringUtils; @@ -36,15 +49,13 @@ public AuthFailedException(String message) { public interface GatewayRequestValidator { List getAllRequiredParameters(Gateway gateway); + void validateOptions(Map options); } - private final ApplicationStore applicationStore; private final GatewayAuthenticationProvider authTestProvider; - - public GatewayRequestHandler( ApplicationStore applicationStore, GatewayTestAuthenticationProperties testAuthenticationProperties) { @@ -63,7 +74,6 @@ public GatewayRequestHandler( } } - public GatewayRequestContext validateRequest( String tenant, String applicationId, @@ -86,10 +96,10 @@ public GatewayRequestContext validateRequest( } else { throw new IllegalArgumentException( "invalid query parameter " - + entry.getKey() - + ". " - + "To specify a gateway parameter, use the format param:." - + "To specify a option, use the format option:."); + + entry.getKey() + + ". " + + "To specify a gateway parameter, use the format param:." + + "To specify a option, use the format option:."); } } @@ -108,9 +118,9 @@ public GatewayRequestContext validateRequest( applicationId, gateway, "missing required parameter " - + requiredParameter - + ". Required parameters: " - + requiredParameters)); + + requiredParameter + + ". Required parameters: " + + requiredParameters)); } allUserParameterKeys.remove(requiredParameter); } @@ -162,8 +172,6 @@ private Application getResolvedApplication(String tenant, String applicationId) return ApplicationPlaceholderResolver.resolvePlaceholders(application); } - - private Gateway extractGateway( String gatewayId, Application application, Gateway.GatewayType type) { final Gateways gatewaysObj = application.getGateways(); @@ -186,23 +194,23 @@ private Gateway extractGateway( if (selectedGateway == null) { throw new IllegalArgumentException( "gateway " - + gatewayId - + " of type " - + type - + " is not defined in the application"); + + gatewayId + + " of type " + + type + + " is not defined in the application"); } return selectedGateway; } - - public AuthenticatedGatewayRequestContext authenticate(GatewayRequestContext gatewayRequestContext) - throws AuthFailedException { + public AuthenticatedGatewayRequestContext authenticate( + GatewayRequestContext gatewayRequestContext) throws AuthFailedException { final Gateway.Authentication authentication = gatewayRequestContext.gateway().getAuthentication(); if (authentication == null) { - return getAuthenticatedGatewayRequestContext(gatewayRequestContext, Map.of(), new HashMap<>()); + return getAuthenticatedGatewayRequestContext( + gatewayRequestContext, Map.of(), new HashMap<>()); } final GatewayAuthenticationResult result; @@ -210,10 +218,10 @@ public AuthenticatedGatewayRequestContext authenticate(GatewayRequestContext gat if (!authentication.isAllowTestMode()) { throw new AuthFailedException( "Gateway " - + gatewayRequestContext.gateway().getId() - + " of tenant " - + gatewayRequestContext.tenant() - + " does not allow test mode."); + + gatewayRequestContext.gateway().getId() + + " of tenant " + + gatewayRequestContext.tenant() + + " does not allow test mode."); } if (authTestProvider == null) { throw new AuthFailedException("No test auth provider specified"); @@ -232,8 +240,10 @@ public AuthenticatedGatewayRequestContext authenticate(GatewayRequestContext gat if (!result.authenticated()) { throw new AuthFailedException(result.reason()); } - final Map principalValues = getPrincipalValues(result, gatewayRequestContext); - return getAuthenticatedGatewayRequestContext(gatewayRequestContext, principalValues, new HashMap<>()); + final Map principalValues = + getPrincipalValues(result, gatewayRequestContext); + return getAuthenticatedGatewayRequestContext( + gatewayRequestContext, principalValues, new HashMap<>()); } private Map getPrincipalValues( @@ -273,8 +283,4 @@ private AuthenticatedGatewayRequestContext getAuthenticatedGatewayRequestContext .principalValues(principalValues) .build(); } - - - - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java index 321d3f7d2..2a4832012 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/GatewayRequestHandlerFactory.java @@ -1,11 +1,22 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.gateways; import ai.langstream.api.storage.ApplicationStore; -import ai.langstream.api.storage.ApplicationStoreRegistry; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; -import ai.langstream.apigateway.config.StorageProperties; -import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; -import java.util.Objects; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -13,9 +24,9 @@ public class GatewayRequestHandlerFactory { @Bean - public GatewayRequestHandler gatewayRequestHandler(ApplicationStore applicationStore, - GatewayTestAuthenticationProperties testAuthenticationProperties) { + public GatewayRequestHandler gatewayRequestHandler( + ApplicationStore applicationStore, + GatewayTestAuthenticationProperties testAuthenticationProperties) { return new GatewayRequestHandler(applicationStore, testAuthenticationProperties); } - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index d26a5ba19..0cea54761 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -1,3 +1,18 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.gateways; import ai.langstream.api.model.Gateway; @@ -13,15 +28,12 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.Closeable; -import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; @@ -41,7 +53,8 @@ public ProduceException(String message, ProduceResponse.Status status) { } } - public static class ProduceGatewayRequestValidator implements GatewayRequestHandler.GatewayRequestValidator { + public static class ProduceGatewayRequestValidator + implements GatewayRequestHandler.GatewayRequestValidator { @Override public List getAllRequiredParameters(Gateway gateway) { return gateway.getParameters(); @@ -51,13 +64,13 @@ public List getAllRequiredParameters(Gateway gateway) { public void validateOptions(Map options) { for (Map.Entry option : options.entrySet()) { switch (option.getKey()) { - default -> throw new IllegalArgumentException("Unknown option " + option.getKey()); + default -> throw new IllegalArgumentException( + "Unknown option " + option.getKey()); } } } } - private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; private TopicProducer producer; private List
commonHeaders; @@ -66,7 +79,10 @@ public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeReg this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry; } - public void start(String topic, List
commonHeaders, AuthenticatedGatewayRequestContext requestContext) { + public void start( + String topic, + List
commonHeaders, + AuthenticatedGatewayRequestContext requestContext) { this.commonHeaders = commonHeaders == null ? List.of() : commonHeaders; setupProducer( @@ -112,13 +128,14 @@ public void produceMessage(String payload) throws ProduceException { produceMessage(produceRequest); } - public void produceMessage(ProduceRequest produceRequest) throws ProduceException { if (produceRequest.value() == null && produceRequest.key() == null) { - throw new ProduceException("Either key or value must be set.", ProduceResponse.Status.BAD_REQUEST); + throw new ProduceException( + "Either key or value must be set.", ProduceResponse.Status.BAD_REQUEST); } if (producer == null) { - throw new ProduceException("Producer not initialized", ProduceResponse.Status.PRODUCER_ERROR); + throw new ProduceException( + "Producer not initialized", ProduceResponse.Status.PRODUCER_ERROR); } final Collection
headers = new ArrayList<>(commonHeaders); @@ -127,9 +144,11 @@ public void produceMessage(ProduceRequest produceRequest) throws ProduceExceptio headers.stream().map(Header::key).collect(Collectors.toSet()); for (Map.Entry messageHeader : produceRequest.headers().entrySet()) { if (configuredHeaders.contains(messageHeader.getKey())) { - throw new ProduceException("Header " - + messageHeader.getKey() - + " is configured as parameter-level header.", ProduceResponse.Status.BAD_REQUEST); + throw new ProduceException( + "Header " + + messageHeader.getKey() + + " is configured as parameter-level header.", + ProduceResponse.Status.BAD_REQUEST); } headers.add( SimpleRecord.SimpleHeader.of( @@ -151,7 +170,7 @@ public void produceMessage(ProduceRequest produceRequest) throws ProduceExceptio } @Override - public void close() { + public void close() { if (producer != null) { try { @@ -160,19 +179,17 @@ public void close() { log.error("error closing producer", e); } } - } - public static List
getProducerCommonHeaders(Gateway.ProduceOptions produceOptions, - AuthenticatedGatewayRequestContext context) { + public static List
getProducerCommonHeaders( + Gateway.ProduceOptions produceOptions, AuthenticatedGatewayRequestContext context) { if (produceOptions != null) { return getProducerCommonHeaders( - produceOptions.headers(), - context.userParameters(), - context.principalValues()); + produceOptions.headers(), context.userParameters(), context.principalValues()); } return null; } + public static List
getProducerCommonHeaders( List headerFilters, Map passedParameters, diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java index a37e8a9be..c96b10396 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java @@ -21,11 +21,9 @@ import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; -import ai.langstream.apigateway.util.HttpUtil; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; import ai.langstream.apigateway.websocket.api.ProduceRequest; import ai.langstream.apigateway.websocket.api.ProduceResponse; - import jakarta.validation.constraints.NotBlank; import java.util.HashMap; import java.util.List; @@ -35,11 +33,8 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -55,20 +50,29 @@ public class GatewayResource { private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider; private final GatewayRequestHandler gatewayRequestHandler; - @PostMapping(value = "/produce/{tenant}/{application}/{gateway}", consumes = MediaType.APPLICATION_JSON_VALUE) + @PostMapping( + value = "/produce/{tenant}/{application}/{gateway}", + consumes = MediaType.APPLICATION_JSON_VALUE) ProduceResponse produce( WebRequest request, @NotBlank @PathVariable("tenant") String tenant, @NotBlank @PathVariable("application") String application, @NotBlank @PathVariable("gateway") String gateway, - @RequestBody ProduceRequest produceRequest) throws ProduceGateway.ProduceException { + @RequestBody ProduceRequest produceRequest) + throws ProduceGateway.ProduceException { - final Map queryString = request.getParameterMap().keySet().stream() - .collect(Collectors.toMap(k -> k, k -> request.getParameter(k))); + final Map queryString = + request.getParameterMap().keySet().stream() + .collect(Collectors.toMap(k -> k, k -> request.getParameter(k))); final Map headers = new HashMap<>(); - request.getHeaderNames().forEachRemaining(name -> headers.put(name, request.getHeader(name))); + request.getHeaderNames() + .forEachRemaining(name -> headers.put(name, request.getHeader(name))); final GatewayRequestContext context = - gatewayRequestHandler.validateRequest(tenant, application, gateway, Gateway.GatewayType.produce, + gatewayRequestHandler.validateRequest( + tenant, + application, + gateway, + Gateway.GatewayType.produce, queryString, headers, new ProduceGateway.ProduceGatewayRequestValidator()); @@ -80,12 +84,14 @@ ProduceResponse produce( } final ProduceGateway produceGateway = - new ProduceGateway(topicConnectionsRuntimeRegistryProvider.getTopicConnectionsRuntimeRegistry()); + new ProduceGateway( + topicConnectionsRuntimeRegistryProvider + .getTopicConnectionsRuntimeRegistry()); final List
commonHeaders = - ProduceGateway.getProducerCommonHeaders(context.gateway().getProduceOptions(), authContext); + ProduceGateway.getProducerCommonHeaders( + context.gateway().getProduceOptions(), authContext); produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext); produceGateway.produceMessage(produceRequest); return ProduceResponse.OK; - } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java index 40e756f6f..2fe1b93e5 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/util/HttpUtil.java @@ -1,3 +1,18 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.util; import java.net.URLDecoder; diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java index adb5618e7..1d22f2555 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java @@ -15,23 +15,13 @@ */ package ai.langstream.apigateway.websocket; -import ai.langstream.api.gateway.GatewayAuthenticationProvider; -import ai.langstream.api.gateway.GatewayAuthenticationProviderRegistry; -import ai.langstream.api.gateway.GatewayAuthenticationResult; import ai.langstream.api.gateway.GatewayRequestContext; -import ai.langstream.api.model.Gateway; -import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.util.HttpUtil; import ai.langstream.apigateway.websocket.handlers.AbstractHandler; -import ai.langstream.apigateway.websocket.impl.AuthenticatedGatewayRequestContextImpl; -import java.net.URLDecoder; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.codec.digest.DigestUtils; import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -81,7 +71,8 @@ public boolean beforeHandshake( final AuthenticatedGatewayRequestContext authenticatedGatewayRequestContext; try { - authenticatedGatewayRequestContext = gatewayRequestHandler.authenticate(gatewayRequestContext); + authenticatedGatewayRequestContext = + gatewayRequestHandler.authenticate(gatewayRequestContext); } catch (GatewayRequestHandler.AuthFailedException authFailedException) { log.info("Authentication failed {}", authFailedException.getMessage()); String error = authFailedException.getMessage(); @@ -94,7 +85,8 @@ public boolean beforeHandshake( log.debug("Authentication OK"); sessionAttributes.put("context", authenticatedGatewayRequestContext); - handler.onBeforeHandshakeCompleted(authenticatedGatewayRequestContext, + handler.onBeforeHandshakeCompleted( + authenticatedGatewayRequestContext, authenticatedGatewayRequestContext.attributes()); return true; } catch (Throwable error) { @@ -106,13 +98,10 @@ public boolean beforeHandshake( } } - - @Override public void afterHandshake( ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {} - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java index a85f0c853..0100210e8 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java @@ -17,7 +17,6 @@ import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; -import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; import ai.langstream.apigateway.websocket.handlers.ChatHandler; diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 391116eff..871bb2080 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -26,21 +26,17 @@ import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.runner.topics.TopicProducer; -import ai.langstream.api.runner.topics.TopicReadResult; -import ai.langstream.api.runner.topics.TopicReader; import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.gateways.ConsumeGateway; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.api.ConsumePushMessage; import ai.langstream.apigateway.websocket.api.ProduceResponse; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; @@ -70,7 +66,8 @@ public AbstractHandler( public abstract Gateway.GatewayType gatewayType(); - public abstract String tenantFromPath(Map parsedPath, Map queryString); + public abstract String tenantFromPath( + Map parsedPath, Map queryString); public abstract String applicationIdFromPath( Map parsedPath, Map queryString); @@ -78,14 +75,12 @@ public abstract String applicationIdFromPath( public abstract String gatewayFromPath( Map parsedPath, Map queryString); - public abstract GatewayRequestHandler.GatewayRequestValidator validator(); public void onBeforeHandshakeCompleted( AuthenticatedGatewayRequestContext gatewayRequestContext, Map attributes) - throws Exception { - } + throws Exception {} abstract void onOpen( WebSocketSession webSocketSession, @@ -191,10 +186,10 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont topicConnectionsRuntime.init(streamingCluster); try (final TopicProducer producer = - topicConnectionsRuntime.createProducer( - "langstream-events", - streamingCluster, - Map.of("topic", gateway.getEventsTopic()))) { + topicConnectionsRuntime.createProducer( + "langstream-events", + streamingCluster, + Map.of("topic", gateway.getEventsTopic()))) { producer.start(); final EventSources.GatewaySource source = @@ -230,17 +225,19 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont protected void startReadingMessages(WebSocketSession webSocketSession, Executor executor) { final AuthenticatedGatewayRequestContext context = getContext(webSocketSession); - final ConsumeGateway consumeGateway = (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); - final CompletableFuture future = consumeGateway.startReading( - executor, - () -> !webSocketSession.isOpen(), - message -> { - try { - webSocketSession.sendMessage(new TextMessage(message)); - } catch (IOException ex) { - throw new RuntimeException(ex); - } - }); + final ConsumeGateway consumeGateway = + (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); + final CompletableFuture future = + consumeGateway.startReading( + executor, + () -> !webSocketSession.isOpen(), + message -> { + try { + webSocketSession.sendMessage(new TextMessage(message)); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); webSocketSession.getAttributes().put("future", future); } @@ -290,7 +287,6 @@ protected void setupReader( final ConsumeGateway consumeGateway = new ConsumeGateway(topicConnectionsRuntimeRegistry); context.attributes().put(ATTRIBUTE_CONSUME_GATEWAY, consumeGateway); consumeGateway.setup(topic, filters, context); - } protected void stopReadingMessages(WebSocketSession webSocketSession) { @@ -301,20 +297,19 @@ protected void stopReadingMessages(WebSocketSession webSocketSession) { } } - - protected void setupProducer(String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) { + protected void setupProducer( + String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) { final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry); context.attributes().put(ATTRIBUTE_PRODUCE_GATEWAY, produceGateway); produceGateway.start(topic, commonHeaders, context); } - - protected void produceMessage(WebSocketSession webSocketSession, - TextMessage message) + protected void produceMessage(WebSocketSession webSocketSession, TextMessage message) throws IOException { try { final AuthenticatedGatewayRequestContext context = getContext(webSocketSession); - final ProduceGateway produceGateway = (ProduceGateway) context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); + final ProduceGateway produceGateway = + (ProduceGateway) context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); produceGateway.produceMessage(message.getPayload()); webSocketSession.sendMessage( new TextMessage(mapper.writeValueAsString(ProduceResponse.OK))); @@ -324,20 +319,19 @@ protected void produceMessage(WebSocketSession webSocketSession, } protected void closeProduceGateway(WebSocketSession webSocketSession) { - final ProduceGateway produceGateway = (ProduceGateway) getContext(webSocketSession).attributes().get( - ATTRIBUTE_PRODUCE_GATEWAY); + final ProduceGateway produceGateway = + (ProduceGateway) + getContext(webSocketSession).attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); if (produceGateway == null) { return; } produceGateway.close(); } - private static void sendResponse( WebSocketSession webSocketSession, ProduceResponse.Status status, String reason) throws IOException { webSocketSession.sendMessage( new TextMessage(mapper.writeValueAsString(new ProduceResponse(status, reason)))); } - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java index 3d6999951..91b9e98bd 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java @@ -60,12 +60,13 @@ public Gateway.GatewayType gatewayType() { } @Override - public String tenantFromPath(Map parsedPath, Map queryString) { + public String tenantFromPath(Map parsedPath, Map queryString) { return parsedPath.get("tenant"); } @Override - public String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath( + Map parsedPath, Map queryString) { return parsedPath.get("application"); } @@ -83,8 +84,10 @@ public List getAllRequiredParameters(Gateway gateway) { if (parameters == null) { parameters = new ArrayList<>(); } - if (gateway.getChatOptions() != null && gateway.getChatOptions().getHeaders() != null) { - for (Gateway.KeyValueComparison header : gateway.getChatOptions().getHeaders()) { + if (gateway.getChatOptions() != null + && gateway.getChatOptions().getHeaders() != null) { + for (Gateway.KeyValueComparison header : + gateway.getChatOptions().getHeaders()) { if (header.valueFromParameters() != null) { parameters.add(header.valueFromParameters()); } @@ -132,8 +135,8 @@ private void setupProducer(AuthenticatedGatewayRequestContext context) { } } final List
commonHeaders = - ProduceGateway.getProducerCommonHeaders(headerConfig, context.userParameters(), - context.principalValues()); + ProduceGateway.getProducerCommonHeaders( + headerConfig, context.userParameters(), context.principalValues()); setupProducer(chatOptions.getQuestionsTopic(), commonHeaders, context); } @@ -152,9 +155,7 @@ private void setupReader(AuthenticatedGatewayRequestContext context) throws Exce createMessageFilters( headerFilters, context.userParameters(), context.principalValues()); - setupReader(chatOptions.getAnswersTopic(), - messageFilters, - context); + setupReader(chatOptions.getAnswersTopic(), messageFilters, context); } @Override @@ -177,5 +178,4 @@ public void onClose( WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context, CloseStatus status) {} - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java index e5a510677..03e25ea85 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java @@ -62,7 +62,8 @@ public String tenantFromPath(Map parsedPath, Map } @Override - public String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath( + Map parsedPath, Map queryString) { return parsedPath.get("application"); } @@ -113,9 +114,7 @@ public void onBeforeHandshakeCompleted( } else { messageFilters = null; } - setupReader(context.gateway().getTopic(), - messageFilters, - context); + setupReader(context.gateway().getTopic(), messageFilters, context); sendClientConnectedEvent(context); } @@ -137,5 +136,4 @@ public void onClose( CloseStatus closeStatus) { stopReadingMessages(webSocketSession); } - } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java index d39e7903a..38a217dc4 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java @@ -24,7 +24,6 @@ import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import java.util.ArrayList; import java.util.List; import java.util.Map; import lombok.extern.slf4j.Slf4j; @@ -57,7 +56,8 @@ public String tenantFromPath(Map parsedPath, Map } @Override - public String applicationIdFromPath(Map parsedPath, Map queryString) { + public String applicationIdFromPath( + Map parsedPath, Map queryString) { return parsedPath.get("application"); } @@ -76,7 +76,8 @@ public void onBeforeHandshakeCompleted( AuthenticatedGatewayRequestContext context, Map attributes) throws Exception { final List
commonHeaders = - ProduceGateway.getProducerCommonHeaders(context.gateway().getProduceOptions(), context); + ProduceGateway.getProducerCommonHeaders( + context.gateway().getProduceOptions(), context); setupProducer(context.gateway().getTopic(), commonHeaders, context); sendClientConnectedEvent(context); diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index 7d4485173..c876060ff 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -1,12 +1,25 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.http; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; -import ai.langstream.api.events.EventRecord; -import ai.langstream.api.events.EventSources; -import ai.langstream.api.events.GatewayEventData; + import ai.langstream.api.model.Application; import ai.langstream.api.model.ApplicationSpecs; import ai.langstream.api.model.Gateway; @@ -19,36 +32,23 @@ import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; -import ai.langstream.apigateway.websocket.api.ConsumePushMessage; -import ai.langstream.apigateway.websocket.api.ProduceRequest; -import ai.langstream.apigateway.websocket.api.ProduceResponse; -import ai.langstream.apigateway.websocket.handlers.TestWebSocketClient; import ai.langstream.impl.deploy.ApplicationDeployer; import ai.langstream.impl.parser.ModelBuilder; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.github.tomakehurst.wiremock.client.WireMock; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import jakarta.websocket.CloseReason; -import jakarta.websocket.DeploymentException; -import jakarta.websocket.Session; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.file.Path; -import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import lombok.Cleanup; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.awaitility.Awaitility; @@ -57,18 +57,15 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.server.LocalServerPort; - @SpringBootTest( webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, properties = { - "spring.main.allow-bean-definition-overriding=true", + "spring.main.allow-bean-definition-overriding=true", }) @WireMockTest @Slf4j @@ -90,19 +87,19 @@ abstract class GatewayResourceTest { protected static ApplicationStore getMockedStore(String instanceYaml) { ApplicationStore mock = Mockito.mock(ApplicationStore.class); doAnswer( - invocationOnMock -> { - final StoredApplication storedApplication = new StoredApplication(); - final Application application = buildApp(instanceYaml); - storedApplication.setInstance(application); - return storedApplication; - }) + invocationOnMock -> { + final StoredApplication storedApplication = new StoredApplication(); + final Application application = buildApp(instanceYaml); + storedApplication.setInstance(application); + return storedApplication; + }) .when(mock) .get(anyString(), anyString(), anyBoolean()); doAnswer( - invocationOnMock -> - ApplicationSpecs.builder() - .application(buildApp(instanceYaml)) - .build()) + invocationOnMock -> + ApplicationSpecs.builder() + .application(buildApp(instanceYaml)) + .build()) .when(mock) .getSpecs(anyString(), anyString()); @@ -123,8 +120,7 @@ protected static GatewayTestAuthenticationProperties getGatewayTestAuthenticatio return props; } - @Autowired - private TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeProvider; + @Autowired private TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeProvider; @NotNull private static Application buildApp(String instanceYaml) throws Exception { @@ -158,8 +154,7 @@ private static Application buildApp(String instanceYaml) throws Exception { return application; } - @LocalServerPort - int port; + @LocalServerPort int port; @Autowired ApplicationStore store; @@ -171,7 +166,6 @@ private static String genTopic() { return "topic" + topicCounter.incrementAndGet(); } - @BeforeAll public static void beforeAll(WireMockRuntimeInfo wmRuntimeInfo) { wireMock = wmRuntimeInfo.getWireMock(); @@ -190,7 +184,6 @@ public static void afterAll() { Awaitility.reset(); } - @SneakyThrows void produceAndExpectOk(String url, String content) { final HttpRequest request = @@ -198,11 +191,11 @@ void produceAndExpectOk(String url, String content) { .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(content)) .build(); - final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + final HttpResponse response = + CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(200, response.statusCode()); assertEquals(""" {"status":"OK","reason":null}""", response.body()); - } @SneakyThrows @@ -212,13 +205,13 @@ void produceAndExpectBadRequest(String url, String content, String errorMessage) .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(content)) .build(); - final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + final HttpResponse response = + CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(400, response.statusCode()); log.info("Response body: {}", response.body()); final Map map = new ObjectMapper().readValue(response.body(), Map.class); - String detail = (String)map.get("detail"); + String detail = (String) map.get("detail"); assertTrue(detail.contains(errorMessage)); - } @SneakyThrows @@ -228,10 +221,10 @@ void produceAndExpectUnauthorized(String url, String content) { .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(content)) .build(); - final HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + final HttpResponse response = + CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(401, response.statusCode()); log.info("Response body: {}", response.body()); - } @Test @@ -253,14 +246,14 @@ void testSimpleProduce() throws Exception { .build())); final String url = - "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" + .formatted(port); produceAndExpectOk(url, "{\"value\": \"my-value\"}"); produceAndExpectOk(url, "{\"key\": \"my-key\"}"); produceAndExpectOk(url, "{\"key\": \"my-key\", \"headers\": {\"h1\": \"v1\"}}"); } - @Test void testParametersRequired() throws Exception { final String topic = genTopic(); @@ -281,16 +274,18 @@ void testParametersRequired() throws Exception { final String content = "{\"value\": \"my-value\"}"; produceAndExpectBadRequest(baseUrl, content, "missing required parameter session-id"); - produceAndExpectBadRequest(baseUrl+ "?param:otherparam=1", content, "missing required parameter session-id"); - produceAndExpectBadRequest(baseUrl+ "?param:session-id=", content, "missing required parameter session-id"); - produceAndExpectBadRequest(baseUrl+ "?param:session-id=ok¶m:another-non-declared=y", content, "unknown parameters: [another-non-declared]"); - produceAndExpectOk(baseUrl+ "?param:session-id=1", content); - produceAndExpectOk(baseUrl+ "?param:session-id=string-value", content); - + produceAndExpectBadRequest( + baseUrl + "?param:otherparam=1", content, "missing required parameter session-id"); + produceAndExpectBadRequest( + baseUrl + "?param:session-id=", content, "missing required parameter session-id"); + produceAndExpectBadRequest( + baseUrl + "?param:session-id=ok¶m:another-non-declared=y", + content, + "unknown parameters: [another-non-declared]"); + produceAndExpectOk(baseUrl + "?param:session-id=1", content); + produceAndExpectOk(baseUrl + "?param:session-id=string-value", content); } - - @Test void testAuthentication() throws Exception { final String topic = genTopic(); @@ -332,12 +327,14 @@ void testAuthentication() throws Exception { .build())); final String baseUrl = - "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" + .formatted(port); produceAndExpectUnauthorized(baseUrl, "{\"value\": \"my-value\"}"); produceAndExpectUnauthorized(baseUrl + "?credentials=", "{\"value\": \"my-value\"}"); produceAndExpectUnauthorized(baseUrl + "?credentials=error", "{\"value\": \"my-value\"}"); - produceAndExpectOk(baseUrl + "?credentials=test-user-password", "{\"value\": \"my-value\"}"); + produceAndExpectOk( + baseUrl + "?credentials=test-user-password", "{\"value\": \"my-value\"}"); } @Test @@ -378,17 +375,19 @@ void testTestCredentials() throws Exception { .build())); final String baseUrl = - "http://localhost:%d/api/gateways/produce/tenant1/application1/produce".formatted(port); - - - produceAndExpectUnauthorized(baseUrl + "?test-credentials=test", "{\"value\": \"my-value\"}"); - produceAndExpectOk(baseUrl + "?test-credentials=test-user-password", "{\"value\": \"my-value\"}"); - produceAndExpectUnauthorized("http://localhost:%d/api/gateways/produce/tenant1/application1/produce-no-test?test-credentials=test-user-password".formatted(port), "{\"value\": \"my-value\"}"); - + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" + .formatted(port); + + produceAndExpectUnauthorized( + baseUrl + "?test-credentials=test", "{\"value\": \"my-value\"}"); + produceAndExpectOk( + baseUrl + "?test-credentials=test-user-password", "{\"value\": \"my-value\"}"); + produceAndExpectUnauthorized( + "http://localhost:%d/api/gateways/produce/tenant1/application1/produce-no-test?test-credentials=test-user-password" + .formatted(port), + "{\"value\": \"my-value\"}"); } - - protected abstract StreamingCluster getStreamingCluster(); private void prepareTopicsForTest(String... topic) throws Exception { @@ -409,5 +408,4 @@ private void prepareTopicsForTest(String... topic) throws Exception { deployer.createImplementation( "app", store.get("t", "app", false).getInstance())); } - -} \ No newline at end of file +} diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java index 097ecdbcb..4e40c18e7 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/KafkaGatewayResourceTest.java @@ -1,55 +1,30 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.http; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doAnswer; -import ai.langstream.api.model.Application; -import ai.langstream.api.model.ApplicationSpecs; -import ai.langstream.api.model.Gateway; -import ai.langstream.api.model.Gateways; -import ai.langstream.api.model.StoredApplication; import ai.langstream.api.model.StreamingCluster; -import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; -import ai.langstream.api.runtime.ClusterRuntimeRegistry; -import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; -import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; -import ai.langstream.impl.deploy.ApplicationDeployer; -import ai.langstream.impl.parser.ModelBuilder; import ai.langstream.kafka.extensions.KafkaContainerExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.github.tomakehurst.wiremock.client.WireMock; -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.file.Path; -import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.awaitility.Awaitility; -import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import org.mockito.Mockito; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.TestConfiguration; -import org.springframework.boot.test.web.server.LocalServerPort; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; - class KafkaGatewayResourceTest extends GatewayResourceTest { @RegisterExtension @@ -95,4 +70,4 @@ public GatewayTestAuthenticationProperties gatewayTestAuthenticationProperties() return getGatewayTestAuthenticationProperties(); } } -} \ No newline at end of file +} diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java index ab4987fb4..1c5e49797 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/PulsarGatewayResourceTest.java @@ -1,20 +1,30 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ai.langstream.apigateway.http; import ai.langstream.api.model.StreamingCluster; import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.websocket.handlers.PulsarContainerExtension; -import ai.langstream.kafka.extensions.KafkaContainerExtension; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; import java.util.Map; -import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.extension.RegisterExtension; -import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.TestConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; - class PulsarGatewayResourceTest extends GatewayResourceTest { @RegisterExtension @@ -68,4 +78,4 @@ public GatewayTestAuthenticationProperties gatewayTestAuthenticationProperties() return getGatewayTestAuthenticationProperties(); } } -} \ No newline at end of file +} diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java index 85dc521c4..26fbbda5b 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java @@ -35,7 +35,8 @@ public void initialize(Map configuration) {} @Override public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { log.info("Authenticating {}", context.credentials()); - if (context.credentials() != null && context.credentials().startsWith("test-user-password")) { + if (context.credentials() != null + && context.credentials().startsWith("test-user-password")) { return GatewayAuthenticationResult.authenticationSuccessful( Map.of("login", context.credentials())); } else { diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java index 4f027e753..4c4b60784 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java @@ -292,21 +292,25 @@ void testParametersRequired(String type) throws Exception { .parameters(List.of("session-id")) .build())); connectAndExpectHttpError( - URI.create("ws://localhost:%d/v1/%s/tenant1/application1/gw".formatted(port, type)), 500); + URI.create("ws://localhost:%d/v1/%s/tenant1/application1/gw".formatted(port, type)), + 500); connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/%s/tenant1/application1/gw?param:otherparam=1" - .formatted(port, type)), 500); + .formatted(port, type)), + 500); connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/%s/tenant1/application1/gw?param:session-id=" - .formatted(port, type)), 500); + .formatted(port, type)), + 500); connectAndExpectHttpError( URI.create( ("ws://localhost:%d/v1/%s/tenant1/application1/gw?param:session-id=ok¶m:another-non" + "-declared=y") - .formatted(port, type)), 500); + .formatted(port, type)), + 500); connectAndExpectRunning( URI.create( @@ -465,15 +469,18 @@ void testAuthentication() throws Exception { connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce" - .formatted(port)), 401); + .formatted(port)), + 401); connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=" - .formatted(port)), 401); + .formatted(port)), + 401); connectAndExpectHttpError( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=error" - .formatted(port)), 401); + .formatted(port)), + 401); connectAndExpectRunning( URI.create( "ws://localhost:%d/v1/produce/tenant1/application1/produce?credentials=test-user-password" @@ -605,13 +612,15 @@ void testTestCredentials() throws Exception { URI.create( ("ws://localhost:%d/v1/consume/tenant1/application1/consume-no-test?test-credentials=test" + "-user-password") - .formatted(port)), 401); + .formatted(port)), + 401); connectAndExpectHttpError( URI.create( ("ws://localhost:%d/v1/produce/tenant1/application1/produce?test-credentials=test-user" + "-password-but-wrong") - .formatted(port)), 401); + .formatted(port)), + 401); } private record MsgRecord(Object key, Object value, Map headers) {} @@ -1084,15 +1093,14 @@ private static String genTopic() { return "topic" + topicCounter.incrementAndGet(); } - private void connectAndExpectClose(URI connectTo, CloseReason expectedCloseReason) { connectAndExpectClose(connectTo, expectedCloseReason, -1); - } + private void connectAndExpectHttpError(URI connectTo, int code) { connectAndExpectClose(connectTo, null, code); - } + @SneakyThrows private void connectAndExpectClose(URI connectTo, CloseReason expectedCloseReason, int code) { CountDownLatch countDownLatch = new CountDownLatch(1); diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java index 7164ceebb..b15fb23a9 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java @@ -34,8 +34,10 @@ public abstract class BaseGatewayCmd extends BaseCmd { protected static final ObjectMapper messageMapper = new ObjectMapper(); + protected enum Protocols { - ws, http; + ws, + http; } @CommandLine.ParentCommand private RootGatewayCmd cmd; @@ -110,7 +112,6 @@ protected String validateGatewayAndGetUrl( applicationId, gatewayId, computeQueryString(systemParams, params, options)); - } return String.format( @@ -131,7 +132,6 @@ private String getApiGatewayUrl() { return getCurrentProfile().getApiGatewayUrl(); } - private String getApiGatewayUrlHttp() { return getApiGatewayUrl().replace("wss://", "https://").replace("ws://", "http://"); } diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java index 8e62c8bf6..a46c7d578 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java @@ -87,7 +87,8 @@ static class ProduceRequest { @CommandLine.Option( names = {"--protocol"}, - description = "Protocol to use: http or ws", defaultValue = "ws") + description = "Protocol to use: http or ws", + defaultValue = "ws") private Protocols protocol = Protocols.ws; @Override @@ -106,8 +107,7 @@ public void run() { final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; - final ProduceRequest produceRequest = - new ProduceRequest(messageKey, messageValue, headers); + final ProduceRequest produceRequest = new ProduceRequest(messageKey, messageValue, headers); final String json = messageMapper.writeValueAsString(produceRequest); if (protocol == Protocols.http) { @@ -117,7 +117,8 @@ public void run() { } } - private void produceWebSocket(String producePath, Duration connectTimeout, String json) throws Exception { + private void produceWebSocket(String producePath, Duration connectTimeout, String json) + throws Exception { CountDownLatch countDownLatch = new CountDownLatch(1); try (final WebSocketClient client = new WebSocketClient( @@ -156,17 +157,21 @@ public void onError(Throwable throwable) { } } - private void produceHttp(String producePath, Duration connectTimeout, String json) throws Exception { - final HttpRequest.Builder builder = HttpRequest.newBuilder(URI.create(producePath)) - .header("Content-Type", "application/json") - .version(HttpClient.Version.HTTP_1_1) - .POST(HttpRequest.BodyPublishers.ofString(json)); + private void produceHttp(String producePath, Duration connectTimeout, String json) + throws Exception { + final HttpRequest.Builder builder = + HttpRequest.newBuilder(URI.create(producePath)) + .header("Content-Type", "application/json") + .version(HttpClient.Version.HTTP_1_1) + .POST(HttpRequest.BodyPublishers.ofString(json)); if (connectTimeout != null) { builder.timeout(connectTimeout); } final HttpRequest request = builder.build(); final HttpResponse response = - getClient().getHttpClientFacade().http(request, HttpResponse.BodyHandlers.ofString()); + getClient() + .getHttpClientFacade() + .http(request, HttpResponse.BodyHandlers.ofString()); log(response.body()); } } From fee02b6cbaea70be82c2f3e7fa450e2b0196cf01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Fri, 27 Oct 2023 09:35:58 +0200 Subject: [PATCH 3/4] comments and leaks --- .../google/GoogleAuthenticationProvider.java | 4 +- langstream-api-gateway/pom.xml | 7 ++ .../apigateway/gateways/ConsumeGateway.java | 77 +++++++++++-------- .../apigateway/gateways/ProduceGateway.java | 25 +++--- .../apigateway/http/GatewayResource.java | 16 ++-- .../websocket/handlers/AbstractHandler.java | 49 ++++++++---- .../websocket/handlers/ChatHandler.java | 22 +++++- .../websocket/handlers/ConsumeHandler.java | 12 ++- .../apigateway/http/GatewayResourceTest.java | 2 +- langstream-kafka-runtime/pom.xml | 2 +- langstream-pulsar-runtime/pom.xml | 2 +- ...PulsarTopicConnectionsRuntimeProvider.java | 9 ++- 12 files changed, 148 insertions(+), 79 deletions(-) diff --git a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java index 76ddc5404..a2219371f 100644 --- a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java @@ -61,7 +61,7 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { try { final String credentials = context.credentials(); if (credentials == null) { - return GatewayAuthenticationResult.authenticationFailed("Token not found."); + return GatewayAuthenticationResult.authenticationFailed("Credentials not provided."); } GoogleIdToken idToken = verifier.verify(credentials); if (idToken != null) { @@ -73,7 +73,7 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { result.put(FIELD_LOCALE, (String) payload.get("locale")); return GatewayAuthenticationResult.authenticationSuccessful(result); } else { - return GatewayAuthenticationResult.authenticationFailed("Invalid token."); + return GatewayAuthenticationResult.authenticationFailed("Invalid credentials."); } } catch (Exception e) { throw new RuntimeException(e); diff --git a/langstream-api-gateway/pom.xml b/langstream-api-gateway/pom.xml index d13c3c8e9..ef2dd74af 100644 --- a/langstream-api-gateway/pom.xml +++ b/langstream-api-gateway/pom.xml @@ -61,6 +61,13 @@ runtime + + ${project.groupId} + langstream-pulsar-runtime + ${project.version} + provided + + org.springframework.boot spring-boot-configuration-processor diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index 1cff2b3f9..8e7934732 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -36,12 +36,14 @@ import java.util.Map; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import lombok.Getter; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -79,7 +81,11 @@ public void validateOptions(Map options) { } private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; - private TopicReader reader; + + private volatile TopicReader reader; + private volatile boolean interrupted; + private volatile String logRef; + private CompletableFuture readerFuture; private AuthenticatedGatewayRequestContext requestContext; private List> filters; @@ -87,11 +93,14 @@ public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeReg this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry; } - @SneakyThrows public void setup( String topic, List> filters, - AuthenticatedGatewayRequestContext requestContext) { + AuthenticatedGatewayRequestContext requestContext) throws Exception { + this.logRef = "%s/%s/%s".formatted( + requestContext.tenant(), + requestContext.applicationId(), + requestContext.gateway().getId()); this.requestContext = requestContext; this.filters = filters == null ? List.of() : filters; @@ -119,39 +128,32 @@ public void setup( reader.start(); } - public CompletableFuture startReading( + public void startReadingAsync( Executor executor, Supplier stop, Consumer onMessage) { if (requestContext == null || reader == null) { throw new IllegalStateException("Not initialized"); } - return CompletableFuture.runAsync( - () -> { - try { - final String tenant = requestContext.tenant(); - - final String gatewayId = requestContext.gateway().getId(); - final String applicationId = requestContext.applicationId(); - log.info( - "Started reader for gateway {}/{}/{}", - tenant, - applicationId, - gatewayId); - readMessages(stop, onMessage); - } catch (InterruptedException | CancellationException ex) { - // ignore - } catch (Throwable ex) { - log.error(ex.getMessage(), ex); - } finally { - closeReader(); - } - }, - executor); + if (readerFuture != null) { + throw new IllegalStateException("Already started"); + } + readerFuture = CompletableFuture.runAsync( + () -> { + try { + log.debug("[{}] Started reader", logRef); + readMessages(stop, onMessage); + } catch (Throwable ex) { + throw new RuntimeException(ex); + } finally { + closeReader(); + } + }, + executor); } protected void readMessages(Supplier stop, Consumer onMessage) throws Exception { while (true) { - if (Thread.currentThread().isInterrupted()) { + if (interrupted) { return; } if (stop.get()) { @@ -160,13 +162,13 @@ protected void readMessages(Supplier stop, Consumer onMessage) final TopicReadResult readResult = reader.read(); final List records = readResult.records(); for (Record record : records) { - log.debug("Received record {}", record); + log.debug("[{}] Received record {}", logRef, record); boolean skip = false; if (filters != null) { for (Function filter : filters) { if (!filter.apply(record)) { skip = true; - log.debug("Skipping record {}", requestContext, record); + log.debug("[{}] Skipping record {}", logRef, record); break; } } @@ -212,13 +214,24 @@ private void closeReader() { try { reader.close(); } catch (Exception e) { - log.error("error closing reader", e); + log.warn("error closing reader", e); } } } @Override public void close() { - closeReader(); + if (readerFuture != null) { + + interrupted = true; + try { + // reader.close must be done by the same thread that started the consumer + readerFuture.get(10, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.debug("error waiting for reader to stop", e); + } + } else { + closeReader(); + } } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index 0cea54761..30eb2038e 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -51,6 +51,11 @@ public ProduceException(String message, ProduceResponse.Status status) { super(message); this.status = status; } + + public ProduceException(String message, ProduceResponse.Status status, Throwable tt) { + super(message, tt); + this.status = status; + } } public static class ProduceGatewayRequestValidator @@ -74,6 +79,7 @@ public void validateOptions(Map options) { private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; private TopicProducer producer; private List
commonHeaders; + private String logRef; public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) { this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry; @@ -83,6 +89,10 @@ public void start( String topic, List
commonHeaders, AuthenticatedGatewayRequestContext requestContext) { + this.logRef = "%s/%s/%s".formatted( + requestContext.tenant(), + requestContext.applicationId(), + requestContext.gateway().getId()); this.commonHeaders = commonHeaders == null ? List.of() : commonHeaders; setupProducer( @@ -110,12 +120,7 @@ protected void setupProducer( topicConnectionsRuntime.createProducer( null, streamingCluster, Map.of("topic", topic)); producer.start(); - log.info( - "Started producer for gateway {}/{}/{} on topic {}", - tenant, - applicationId, - gatewayId, - topic); + log.debug("[{}] Started producer on topic {}",logRef, topic); } public void produceMessage(String payload) throws ProduceException { @@ -163,20 +168,20 @@ public void produceMessage(ProduceRequest produceRequest) throws ProduceExceptio .headers(headers) .build(); producer.write(record).get(); - log.info("Produced record {}", record); + log.debug("[{}] Produced record {}",logRef, record); } catch (Throwable tt) { - throw new ProduceException(tt.getMessage(), ProduceResponse.Status.PRODUCER_ERROR); + log.error("[{}] Error producing message: {}", logRef, tt.getMessage(), tt); + throw new ProduceException(tt.getMessage(), ProduceResponse.Status.PRODUCER_ERROR, tt); } } @Override public void close() { - if (producer != null) { try { producer.close(); } catch (Exception e) { - log.error("error closing producer", e); + log.debug("[{}] Error closing producer: {}", logRef, e.getMessage(), e); } } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java index c96b10396..d8cca19df 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java @@ -87,11 +87,15 @@ ProduceResponse produce( new ProduceGateway( topicConnectionsRuntimeRegistryProvider .getTopicConnectionsRuntimeRegistry()); - final List
commonHeaders = - ProduceGateway.getProducerCommonHeaders( - context.gateway().getProduceOptions(), authContext); - produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext); - produceGateway.produceMessage(produceRequest); - return ProduceResponse.OK; + try { + final List
commonHeaders = + ProduceGateway.getProducerCommonHeaders( + context.gateway().getProduceOptions(), authContext); + produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext); + produceGateway.produceMessage(produceRequest); + return ProduceResponse.OK; + } finally { + produceGateway.close(); + } } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 871bb2080..258ef01d0 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -227,8 +227,7 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor final AuthenticatedGatewayRequestContext context = getContext(webSocketSession); final ConsumeGateway consumeGateway = (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); - final CompletableFuture future = - consumeGateway.startReading( + consumeGateway.startReadingAsync( executor, () -> !webSocketSession.isOpen(), message -> { @@ -238,7 +237,6 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor throw new RuntimeException(ex); } }); - webSocketSession.getAttributes().put("future", future); } protected static List> createMessageFilters( @@ -285,23 +283,29 @@ protected void setupReader( AuthenticatedGatewayRequestContext context) throws Exception { final ConsumeGateway consumeGateway = new ConsumeGateway(topicConnectionsRuntimeRegistry); + try { + consumeGateway.setup(topic, filters, context); + } catch (Exception ex) { + log.error(ex.getMessage(), ex); + consumeGateway.close(); + throw ex; + } context.attributes().put(ATTRIBUTE_CONSUME_GATEWAY, consumeGateway); - consumeGateway.setup(topic, filters, context); - } - protected void stopReadingMessages(WebSocketSession webSocketSession) { - final CompletableFuture future = - (CompletableFuture) webSocketSession.getAttributes().get("future"); - if (future != null && !future.isDone()) { - future.cancel(true); - } } protected void setupProducer( - String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) { + String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) throws Exception { final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry); + + try { + produceGateway.start(topic, commonHeaders, context); + } catch (Exception ex) { + log.error(ex.getMessage(), ex); + produceGateway.close(); + throw ex; + } context.attributes().put(ATTRIBUTE_PRODUCE_GATEWAY, produceGateway); - produceGateway.start(topic, commonHeaders, context); } protected void produceMessage(WebSocketSession webSocketSession, TextMessage message) @@ -318,10 +322,27 @@ protected void produceMessage(WebSocketSession webSocketSession, TextMessage mes } } + protected void closeConsumeGateway(WebSocketSession webSocketSession) { + closeConsumeGateway(getContext(webSocketSession)); + } + protected void closeConsumeGateway(AuthenticatedGatewayRequestContext context) { + final ConsumeGateway consumeGateway = + (ConsumeGateway) + context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); + if (consumeGateway == null) { + return; + } + consumeGateway.close(); + } + protected void closeProduceGateway(WebSocketSession webSocketSession) { + closeProduceGateway(getContext(webSocketSession)); + } + + protected void closeProduceGateway(AuthenticatedGatewayRequestContext context) { final ProduceGateway produceGateway = (ProduceGateway) - getContext(webSocketSession).attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); + context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); if (produceGateway == null) { return; } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java index 91b9e98bd..c224fc240 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java @@ -118,13 +118,24 @@ public void onBeforeHandshakeCompleted( AuthenticatedGatewayRequestContext context, Map attributes) throws Exception { - setupReader(context); - setupProducer(context); + try { + setupReader(context); + } catch (Exception ex) { + log.error("Error setting up reader", ex); + throw ex; + } + try { + setupProducer(context); + } catch (Exception ex) { + log.error("Error setting up producer", ex); + closeConsumeGateway(context); + throw ex; + } sendClientConnectedEvent(context); } - private void setupProducer(AuthenticatedGatewayRequestContext context) { + private void setupProducer(AuthenticatedGatewayRequestContext context) throws Exception { final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions(); List headerConfig = new ArrayList<>(); @@ -177,5 +188,8 @@ public void onMessage( public void onClose( WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context, - CloseStatus status) {} + CloseStatus status) { + closeConsumeGateway(webSocketSession); + closeProduceGateway(webSocketSession); + } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java index 03e25ea85..efcdc77e9 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java @@ -99,8 +99,7 @@ public void validateOptions(Map options) { @Override public void onBeforeHandshakeCompleted( - AuthenticatedGatewayRequestContext context, Map attributes) - throws Exception { + AuthenticatedGatewayRequestContext context, Map attributes) { final Gateway.ConsumeOptions consumeOptions = context.gateway().getConsumeOptions(); @@ -114,7 +113,12 @@ public void onBeforeHandshakeCompleted( } else { messageFilters = null; } - setupReader(context.gateway().getTopic(), messageFilters, context); + try { + setupReader(context.gateway().getTopic(), messageFilters, context); + } catch (Exception ex) { + log.error("Error setting up reader", ex); + throw new RuntimeException(ex); + } sendClientConnectedEvent(context); } @@ -134,6 +138,6 @@ public void onClose( WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context, CloseStatus closeStatus) { - stopReadingMessages(webSocketSession); + closeConsumeGateway(webSocketSession); } } diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index c876060ff..db2f5d2db 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -249,7 +249,7 @@ void testSimpleProduce() throws Exception { "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" .formatted(port); - produceAndExpectOk(url, "{\"value\": \"my-value\"}"); + produceAndExpectOk(url, "{\"key\": \"my-key\", \"value\": \"my-value\"}"); produceAndExpectOk(url, "{\"key\": \"my-key\"}"); produceAndExpectOk(url, "{\"key\": \"my-key\", \"headers\": {\"h1\": \"v1\"}}"); } diff --git a/langstream-kafka-runtime/pom.xml b/langstream-kafka-runtime/pom.xml index 8986700c1..b7145f56b 100644 --- a/langstream-kafka-runtime/pom.xml +++ b/langstream-kafka-runtime/pom.xml @@ -25,7 +25,7 @@ 4.0.0 langstream-kafka-runtime jar - LangStream - Kafka Implementation + LangStream - Kafka runtime ${project.build.directory} diff --git a/langstream-pulsar-runtime/pom.xml b/langstream-pulsar-runtime/pom.xml index 89c32c40f..afeb366db 100644 --- a/langstream-pulsar-runtime/pom.xml +++ b/langstream-pulsar-runtime/pom.xml @@ -25,7 +25,7 @@ 4.0.0 langstream-pulsar-runtime jar - LangStream - Pulsar Implementation + LangStream - Pulsar runtime ${project.build.directory} diff --git a/langstream-pulsar-runtime/src/main/java/ai/langstream/pulsar/runner/PulsarTopicConnectionsRuntimeProvider.java b/langstream-pulsar-runtime/src/main/java/ai/langstream/pulsar/runner/PulsarTopicConnectionsRuntimeProvider.java index 7400e5b6f..0faf90c3a 100644 --- a/langstream-pulsar-runtime/src/main/java/ai/langstream/pulsar/runner/PulsarTopicConnectionsRuntimeProvider.java +++ b/langstream-pulsar-runtime/src/main/java/ai/langstream/pulsar/runner/PulsarTopicConnectionsRuntimeProvider.java @@ -667,11 +667,12 @@ public CompletableFuture write(Record r) { totalIn.addAndGet(1); if (schema == null) { try { - if (r.value() == null) { - throw new IllegalStateException( - "Cannot infer schema because value is null"); + final Schema valueSchema; + if (r.value() != null) { + valueSchema = getSchema(r.value().getClass()); + } else { + valueSchema = Schema.BYTES; } - Schema valueSchema = getSchema(r.value().getClass()); if (r.key() != null) { Schema keySchema = getSchema(r.key().getClass()); schema = From 5a565a45a0ed8bb4b5a48fde1079ff7ddd989f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Fri, 27 Oct 2023 09:36:18 +0200 Subject: [PATCH 4/4] fix --- .../google/GoogleAuthenticationProvider.java | 3 +- .../apigateway/gateways/ConsumeGateway.java | 17 ++++++----- .../apigateway/gateways/ProduceGateway.java | 14 +++++---- .../websocket/handlers/AbstractHandler.java | 30 +++++++++---------- 4 files changed, 34 insertions(+), 30 deletions(-) diff --git a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java index a2219371f..6fa74fd65 100644 --- a/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-google-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/google/GoogleAuthenticationProvider.java @@ -61,7 +61,8 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { try { final String credentials = context.credentials(); if (credentials == null) { - return GatewayAuthenticationResult.authenticationFailed("Credentials not provided."); + return GatewayAuthenticationResult.authenticationFailed( + "Credentials not provided."); } GoogleIdToken idToken = verifier.verify(credentials); if (idToken != null) { diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index 8e7934732..827bdc489 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -34,7 +34,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -96,11 +95,14 @@ public ConsumeGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeReg public void setup( String topic, List> filters, - AuthenticatedGatewayRequestContext requestContext) throws Exception { - this.logRef = "%s/%s/%s".formatted( - requestContext.tenant(), - requestContext.applicationId(), - requestContext.gateway().getId()); + AuthenticatedGatewayRequestContext requestContext) + throws Exception { + this.logRef = + "%s/%s/%s" + .formatted( + requestContext.tenant(), + requestContext.applicationId(), + requestContext.gateway().getId()); this.requestContext = requestContext; this.filters = filters == null ? List.of() : filters; @@ -136,7 +138,8 @@ public void startReadingAsync( if (readerFuture != null) { throw new IllegalStateException("Already started"); } - readerFuture = CompletableFuture.runAsync( + readerFuture = + CompletableFuture.runAsync( () -> { try { log.debug("[{}] Started reader", logRef); diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index 30eb2038e..0f6871425 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -89,10 +89,12 @@ public void start( String topic, List
commonHeaders, AuthenticatedGatewayRequestContext requestContext) { - this.logRef = "%s/%s/%s".formatted( - requestContext.tenant(), - requestContext.applicationId(), - requestContext.gateway().getId()); + this.logRef = + "%s/%s/%s" + .formatted( + requestContext.tenant(), + requestContext.applicationId(), + requestContext.gateway().getId()); this.commonHeaders = commonHeaders == null ? List.of() : commonHeaders; setupProducer( @@ -120,7 +122,7 @@ protected void setupProducer( topicConnectionsRuntime.createProducer( null, streamingCluster, Map.of("topic", topic)); producer.start(); - log.debug("[{}] Started producer on topic {}",logRef, topic); + log.debug("[{}] Started producer on topic {}", logRef, topic); } public void produceMessage(String payload) throws ProduceException { @@ -168,7 +170,7 @@ public void produceMessage(ProduceRequest produceRequest) throws ProduceExceptio .headers(headers) .build(); producer.write(record).get(); - log.debug("[{}] Produced record {}",logRef, record); + log.debug("[{}] Produced record {}", logRef, record); } catch (Throwable tt) { log.error("[{}] Error producing message: {}", logRef, tt.getMessage(), tt); throw new ProduceException(tt.getMessage(), ProduceResponse.Status.PRODUCER_ERROR, tt); diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 258ef01d0..7f5615a14 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -37,7 +37,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; import lombok.SneakyThrows; @@ -228,15 +227,15 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor final ConsumeGateway consumeGateway = (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); consumeGateway.startReadingAsync( - executor, - () -> !webSocketSession.isOpen(), - message -> { - try { - webSocketSession.sendMessage(new TextMessage(message)); - } catch (IOException ex) { - throw new RuntimeException(ex); - } - }); + executor, + () -> !webSocketSession.isOpen(), + message -> { + try { + webSocketSession.sendMessage(new TextMessage(message)); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); } protected static List> createMessageFilters( @@ -291,11 +290,11 @@ protected void setupReader( throw ex; } context.attributes().put(ATTRIBUTE_CONSUME_GATEWAY, consumeGateway); - } protected void setupProducer( - String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) throws Exception { + String topic, List
commonHeaders, AuthenticatedGatewayRequestContext context) + throws Exception { final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry); try { @@ -325,10 +324,10 @@ protected void produceMessage(WebSocketSession webSocketSession, TextMessage mes protected void closeConsumeGateway(WebSocketSession webSocketSession) { closeConsumeGateway(getContext(webSocketSession)); } + protected void closeConsumeGateway(AuthenticatedGatewayRequestContext context) { final ConsumeGateway consumeGateway = - (ConsumeGateway) - context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); + (ConsumeGateway) context.attributes().get(ATTRIBUTE_CONSUME_GATEWAY); if (consumeGateway == null) { return; } @@ -341,8 +340,7 @@ protected void closeProduceGateway(WebSocketSession webSocketSession) { protected void closeProduceGateway(AuthenticatedGatewayRequestContext context) { final ProduceGateway produceGateway = - (ProduceGateway) - context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); + (ProduceGateway) context.attributes().get(ATTRIBUTE_PRODUCE_GATEWAY); if (produceGateway == null) { return; }