diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java index d56946f334..1906a3e40f 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java @@ -16,6 +16,7 @@ import org.opensearch.dataprepper.http.HttpServerConfig; import org.opensearch.dataprepper.http.LogThrottlingRejectHandler; import org.opensearch.dataprepper.http.LogThrottlingStrategy; +import org.opensearch.dataprepper.http.certificate.CertificateProviderFactory; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; @@ -32,7 +33,6 @@ import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; import org.opensearch.dataprepper.plugins.certificate.model.Certificate; import org.opensearch.dataprepper.plugins.codec.CompressionOption; -import org.opensearch.dataprepper.http.certificate.CertificateProviderFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java index c6078c4095..f51666db59 100644 --- a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java @@ -9,6 +9,7 @@ import com.linecorp.armeria.client.ResponseTimeoutException; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; @@ -23,7 +24,6 @@ import io.netty.util.AsciiString; import org.apache.commons.io.IOUtils; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -36,6 +36,7 @@ import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; import org.opensearch.dataprepper.armeria.authentication.HttpBasicAuthenticationConfig; import org.opensearch.dataprepper.http.LogThrottlingRejectHandler; +import org.opensearch.dataprepper.http.certificate.CertificateProviderFactory; import org.opensearch.dataprepper.metrics.MetricNames; import org.opensearch.dataprepper.metrics.MetricsTestUtil; import org.opensearch.dataprepper.metrics.PluginMetrics; @@ -49,12 +50,15 @@ import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.HttpBasicArmeriaHttpAuthenticationProvider; import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; +import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; +import org.opensearch.dataprepper.plugins.certificate.model.Certificate; import org.opensearch.dataprepper.plugins.codec.CompressionOption; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -68,6 +72,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.zip.GZIPOutputStream; @@ -78,6 +83,11 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; @@ -112,6 +122,15 @@ class HTTPSourceTest { @Mock private CompletableFuture completableFuture; + @Mock + private CertificateProviderFactory certificateProviderFactory; + + @Mock + private CertificateProvider certificateProvider; + + @Mock + private Certificate certificate; + private BlockingBuffer> testBuffer; private HTTPSource HTTPSourceUnderTest; private List requestsReceivedMeasurements; @@ -252,29 +271,29 @@ public void testHTTPJsonResponse200() { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); // Then - Assertions.assertFalse(testBuffer.isEmpty()); + assertFalse(testBuffer.isEmpty()); final Map.Entry>, CheckpointState> result = testBuffer.read(100); List> records = new ArrayList<>(result.getKey()); - Assertions.assertEquals(1, records.size()); + assertEquals(1, records.size()); final Record record = records.get(0); - Assertions.assertEquals("somelog", record.getData().get("log", String.class)); + assertEquals("somelog", record.getData().get("log", String.class)); // Verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestReceivedCount.getValue()); + assertEquals(1.0, requestReceivedCount.getValue()); final Measurement successRequestsCount = MetricsTestUtil.getMeasurementFromList( successRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, successRequestsCount.getValue()); + assertEquals(1.0, successRequestsCount.getValue()); final Measurement requestProcessDurationCount = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestProcessDurationCount.getValue()); + assertEquals(1.0, requestProcessDurationCount.getValue()); final Measurement requestProcessDurationMax = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.MAX); - Assertions.assertTrue(requestProcessDurationMax.getValue() > 0); + assertTrue(requestProcessDurationMax.getValue() > 0); final Measurement payloadSizeMax = MetricsTestUtil.getMeasurementFromList( payloadSizeSummaryMeasurements, Statistic.MAX); - Assertions.assertEquals(testPayloadSize, payloadSizeMax.getValue()); + assertEquals(testPayloadSize, payloadSizeMax.getValue()); } @Test @@ -299,26 +318,26 @@ public void testHttpCompressionResponse200() throws IOException { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); // Then - Assertions.assertFalse(testBuffer.isEmpty()); + assertFalse(testBuffer.isEmpty()); final Map.Entry>, CheckpointState> result = testBuffer.read(100); List> records = new ArrayList<>(result.getKey()); - Assertions.assertEquals(1, records.size()); + assertEquals(1, records.size()); final Record record = records.get(0); - Assertions.assertEquals("somelog", record.getData().get("log", String.class)); + assertEquals("somelog", record.getData().get("log", String.class)); // Verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestReceivedCount.getValue()); + assertEquals(1.0, requestReceivedCount.getValue()); final Measurement successRequestsCount = MetricsTestUtil.getMeasurementFromList( successRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, successRequestsCount.getValue()); + assertEquals(1.0, successRequestsCount.getValue()); final Measurement requestProcessDurationCount = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestProcessDurationCount.getValue()); + assertEquals(1.0, requestProcessDurationCount.getValue()); final Measurement requestProcessDurationMax = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.MAX); - Assertions.assertTrue(requestProcessDurationMax.getValue() > 0); + assertTrue(requestProcessDurationMax.getValue() > 0); } @Test @@ -383,14 +402,14 @@ public void testHTTPJsonResponse400() { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.BAD_REQUEST)).join(); // Then - Assertions.assertTrue(testBuffer.isEmpty()); + assertTrue(testBuffer.isEmpty()); // Verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestReceivedCount.getValue()); + assertEquals(1.0, requestReceivedCount.getValue()); final Measurement badRequestsCount = MetricsTestUtil.getMeasurementFromList( badRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, badRequestsCount.getValue()); + assertEquals(1.0, badRequestsCount.getValue()); } @Test @@ -414,26 +433,26 @@ public void testHTTPJsonResponse413() throws InterruptedException { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.REQUEST_ENTITY_TOO_LARGE)).join(); // Then - Assertions.assertTrue(testBuffer.isEmpty()); + assertTrue(testBuffer.isEmpty()); // Verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestReceivedCount.getValue()); + assertEquals(1.0, requestReceivedCount.getValue()); final Measurement successRequestsCount = MetricsTestUtil.getMeasurementFromList( successRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(0.0, successRequestsCount.getValue()); + assertEquals(0.0, successRequestsCount.getValue()); final Measurement requestsTooLargeCount = MetricsTestUtil.getMeasurementFromList( requestsTooLargeMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestsTooLargeCount.getValue()); + assertEquals(1.0, requestsTooLargeCount.getValue()); final Measurement requestProcessDurationCount = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestProcessDurationCount.getValue()); + assertEquals(1.0, requestProcessDurationCount.getValue()); final Measurement requestProcessDurationMax = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.MAX); - Assertions.assertTrue(requestProcessDurationMax.getValue() > 0); + assertTrue(requestProcessDurationMax.getValue() > 0); final Measurement payloadSizeMax = MetricsTestUtil.getMeasurementFromList( payloadSizeSummaryMeasurements, Statistic.MAX); - Assertions.assertEquals(testPayloadSize, payloadSizeMax.getValue()); + assertEquals(testPayloadSize, payloadSizeMax.getValue()); } @Test @@ -473,17 +492,17 @@ public void testHTTPJsonResponse408() { // verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(2.0, requestReceivedCount.getValue()); + assertEquals(2.0, requestReceivedCount.getValue()); final Measurement successRequestsCount = MetricsTestUtil.getMeasurementFromList( successRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, successRequestsCount.getValue()); + assertEquals(1.0, successRequestsCount.getValue()); final Measurement requestTimeoutsCount = MetricsTestUtil.getMeasurementFromList( requestTimeoutsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, requestTimeoutsCount.getValue()); + assertEquals(1.0, requestTimeoutsCount.getValue()); final Measurement requestProcessDurationMax = MetricsTestUtil.getMeasurementFromList( requestProcessDurationMeasurements, Statistic.MAX); final double maxDurationInMillis = 1000 * requestProcessDurationMax.getValue(); - Assertions.assertTrue(maxDurationInMillis > bufferTimeoutInMillis); + assertTrue(maxDurationInMillis > bufferTimeoutInMillis); } @Test @@ -520,7 +539,7 @@ public void testHTTPJsonResponse429() throws InterruptedException { // Set the client timeout to be less than source serverTimeoutInMillis / (testMaxPendingRequests + testThreadCount) WebClient testWebClient = WebClient.builder().responseTimeoutMillis(clientTimeoutInMillis).build(); for (int i = 0; i < testMaxPendingRequests + testThreadCount; i++) { - CompletionException actualException = Assertions.assertThrows( + CompletionException actualException = assertThrows( CompletionException.class, () -> testWebClient.execute(testRequestHeaders, testHttpData).aggregate().join()); assertThat(actualException.getCause(), instanceOf(ResponseTimeoutException.class)); } @@ -532,19 +551,19 @@ public void testHTTPJsonResponse429() throws InterruptedException { // Wait until source server timeout a request processing thread Thread.sleep(serverTimeoutInMillis); // New request should timeout instead of being rejected - CompletionException actualException = Assertions.assertThrows( + CompletionException actualException = assertThrows( CompletionException.class, () -> testWebClient.execute(testRequestHeaders, testHttpData).aggregate().join()); assertThat(actualException.getCause(), instanceOf(ResponseTimeoutException.class)); // verify metrics final Measurement requestReceivedCount = MetricsTestUtil.getMeasurementFromList( requestsReceivedMeasurements, Statistic.COUNT); - Assertions.assertEquals(testMaxPendingRequests + testThreadCount + 2, requestReceivedCount.getValue()); + assertEquals(testMaxPendingRequests + testThreadCount + 2, requestReceivedCount.getValue()); final Measurement successRequestsCount = MetricsTestUtil.getMeasurementFromList( successRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, successRequestsCount.getValue()); + assertEquals(1.0, successRequestsCount.getValue()); final Measurement rejectedRequestsCount = MetricsTestUtil.getMeasurementFromList( rejectedRequestsMeasurements, Statistic.COUNT); - Assertions.assertEquals(1.0, rejectedRequestsCount.getValue()); + assertEquals(1.0, rejectedRequestsCount.getValue()); } @Test @@ -556,7 +575,7 @@ public void testServerConnectionsMetric() throws InterruptedException { // Verify connections metric value is 0 Measurement serverConnectionsMeasurement = MetricsTestUtil.getMeasurementFromList(serverConnectionsMeasurements, Statistic.VALUE); - Assertions.assertEquals(0, serverConnectionsMeasurement.getValue()); + assertEquals(0, serverConnectionsMeasurement.getValue()); final RequestHeaders testRequestHeaders = RequestHeaders.builder().scheme(SessionProtocol.HTTP) .authority("127.0.0.1:2021") @@ -572,7 +591,7 @@ public void testServerConnectionsMetric() throws InterruptedException { // Verify connections metric value is 1 serverConnectionsMeasurement = MetricsTestUtil.getMeasurementFromList(serverConnectionsMeasurements, Statistic.VALUE); - Assertions.assertEquals(1.0, serverConnectionsMeasurement.getValue()); + assertEquals(1.0, serverConnectionsMeasurement.getValue()); } @Test @@ -603,6 +622,74 @@ public void testServerStartCertFileSuccess() throws IOException { } } + @Test + public void testServerStartCertFileMissing() { + when(sourceConfig.isSsl()).thenReturn(true); + when(sourceConfig.getSslCertificateFile()).thenReturn(null); + when(sourceConfig.getSslKeyFile()).thenReturn(null); + HTTPSourceUnderTest = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + assertThrows(NullPointerException.class, () -> HTTPSourceUnderTest.start(testBuffer)); + } + } + + @Test + void testServerStartACMCertSuccess() throws IOException, NoSuchFieldException, IllegalAccessException { + final Path certFilePath = new File(TEST_SSL_CERTIFICATE_FILE).toPath(); + final Path keyFilePath = new File(TEST_SSL_KEY_FILE).toPath(); + final String certAsString = Files.readString(certFilePath); + final String keyAsString = Files.readString(keyFilePath); + + when(certificate.getCertificate()).thenReturn(certAsString); + when(certificate.getPrivateKey()).thenReturn(keyAsString); + when(certificateProvider.getCertificate()).thenReturn(certificate); + when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); + when(sourceConfig.isSsl()).thenReturn(true); + when(server.stop()).thenReturn(completableFuture); + + HTTPSourceUnderTest = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + Field field = HTTPSourceUnderTest.getClass().getDeclaredField("certificateProviderFactory"); + field.setAccessible(true); + field.set(HTTPSourceUnderTest, certificateProviderFactory); + + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + HTTPSourceUnderTest.start(testBuffer); + } + HTTPSourceUnderTest.stop(); + + final ArgumentCaptor certificateIs = ArgumentCaptor.forClass(InputStream.class); + final ArgumentCaptor privateKeyIs = ArgumentCaptor.forClass(InputStream.class); + verify(serverBuilder).tls(certificateIs.capture(), privateKeyIs.capture()); + final String actualCertificate = IOUtils.toString(certificateIs.getValue(), StandardCharsets.UTF_8.name()); + final String actualPrivateKey = IOUtils.toString(privateKeyIs.getValue(), StandardCharsets.UTF_8.name()); + assertThat(actualCertificate, is(certAsString)); + assertThat(actualPrivateKey, is(keyAsString)); + } + + @Test + void testServerStartACMCertNull() throws NoSuchFieldException, IllegalAccessException { + when(certificate.getCertificate()).thenReturn(null); + when(certificateProvider.getCertificate()).thenReturn(certificate); + when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); + when(sourceConfig.isSsl()).thenReturn(true); + + HTTPSourceUnderTest = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + Field field = HTTPSourceUnderTest.getClass().getDeclaredField("certificateProviderFactory"); + field.setAccessible(true); + field.set(HTTPSourceUnderTest, certificateProviderFactory); + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + assertThrows(NullPointerException.class, () -> HTTPSourceUnderTest.start(testBuffer)); + } + } + + + @Test void testHTTPSJsonResponse() { reset(sourceConfig); @@ -632,6 +719,43 @@ void testHTTPSJsonResponse() { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); } + + @Test + void testHTTPRequestWhenSSLRequiredNoResponse() { + reset(sourceConfig); + when(sourceConfig.getPort()).thenReturn(2021); + when(sourceConfig.getPath()).thenReturn(HTTPSourceConfig.DEFAULT_LOG_INGEST_URI); + when(sourceConfig.getThreadCount()).thenReturn(200); + when(sourceConfig.getMaxConnectionCount()).thenReturn(500); + when(sourceConfig.getMaxPendingRequests()).thenReturn(1024); + when(sourceConfig.getRequestTimeoutInMillis()).thenReturn(200); + when(sourceConfig.isSsl()).thenReturn(true); + when(sourceConfig.getSslCertificateFile()).thenReturn(TEST_SSL_CERTIFICATE_FILE); + when(sourceConfig.getSslKeyFile()).thenReturn(TEST_SSL_KEY_FILE); + HTTPSourceUnderTest = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + testBuffer = getBuffer(); + HTTPSourceUnderTest.start(testBuffer); + + CompletableFuture future = WebClient.builder() + .factory(ClientFactory.insecure()) + .build() + .execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:2021") + .method(HttpMethod.POST) + .path("/log/ingest") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.ofUtf8("[{\"log\": \"somelog\"}]")) + .aggregate(); + + ExecutionException exception = assertThrows(ExecutionException.class, + () -> future.get(2, TimeUnit.SECONDS) + ); + assertInstanceOf(ClosedSessionException.class, exception.getCause()); + } + @Test void testHTTPSJsonResponse_with_custom_path_along_with_placeholder() { reset(sourceConfig); @@ -669,13 +793,13 @@ public void testDoubleStart() { // starting server HTTPSourceUnderTest.start(testBuffer); // double start server - Assertions.assertThrows(IllegalStateException.class, () -> HTTPSourceUnderTest.start(testBuffer)); + assertThrows(IllegalStateException.class, () -> HTTPSourceUnderTest.start(testBuffer)); } @Test public void testStartWithEmptyBuffer() { final HTTPSource source = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - Assertions.assertThrows(IllegalStateException.class, () -> source.start(null)); + assertThrows(IllegalStateException.class, () -> source.start(null)); } @Test @@ -687,7 +811,7 @@ public void testStartWithServerExecutionExceptionNoCause() throws ExecutionExcep when(completableFuture.get()).thenThrow(new ExecutionException("", null)); // When/Then - Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); + assertThrows(RuntimeException.class, () -> source.start(testBuffer)); } } @@ -701,8 +825,8 @@ public void testStartWithServerExecutionExceptionWithCause() throws ExecutionExc when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); // When/Then - final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); - Assertions.assertEquals(expCause, ex); + final RuntimeException ex = assertThrows(RuntimeException.class, () -> source.start(testBuffer)); + assertEquals(expCause, ex); } } @@ -715,8 +839,8 @@ public void testStartWithInterruptedException() throws ExecutionException, Inter when(completableFuture.get()).thenThrow(new InterruptedException()); // When/Then - Assertions.assertThrows(RuntimeException.class, () -> source.start(testBuffer)); - Assertions.assertTrue(Thread.interrupted()); + assertThrows(RuntimeException.class, () -> source.start(testBuffer)); + assertTrue(Thread.interrupted()); } } @@ -731,7 +855,7 @@ public void testStopWithServerExecutionExceptionNoCause() throws ExecutionExcept // When/Then when(completableFuture.get()).thenThrow(new ExecutionException("", null)); - Assertions.assertThrows(RuntimeException.class, source::stop); + assertThrows(RuntimeException.class, source::stop); } } @@ -747,8 +871,8 @@ public void testStopWithServerExecutionExceptionWithCause() throws ExecutionExce when(completableFuture.get()).thenThrow(new ExecutionException("", expCause)); // When/Then - final RuntimeException ex = Assertions.assertThrows(RuntimeException.class, source::stop); - Assertions.assertEquals(expCause, ex); + final RuntimeException ex = assertThrows(RuntimeException.class, source::stop); + assertEquals(expCause, ex); } } @@ -763,8 +887,8 @@ public void testStopWithInterruptedException() throws ExecutionException, Interr when(completableFuture.get()).thenThrow(new InterruptedException()); // When/Then - Assertions.assertThrows(RuntimeException.class, source::stop); - Assertions.assertTrue(Thread.interrupted()); + assertThrows(RuntimeException.class, source::stop); + assertTrue(Thread.interrupted()); } } @@ -775,7 +899,7 @@ public void testRunAnotherSourceWithSamePort() { final HTTPSource secondSource = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); //Expect RuntimeException because when port is already in use, BindException is thrown which is not RuntimeException - Assertions.assertThrows(RuntimeException.class, () -> secondSource.start(testBuffer)); + assertThrows(RuntimeException.class, () -> secondSource.start(testBuffer)); } @Test @@ -802,7 +926,7 @@ public void request_that_exceeds_maxRequestLength_returns_413() { .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.REQUEST_ENTITY_TOO_LARGE)).join(); // Then - Assertions.assertTrue(testBuffer.isEmpty()); + assertTrue(testBuffer.isEmpty()); } } diff --git a/data-prepper-plugins/otel-logs-source/src/test/java/org/opensearch/dataprepper/plugins/source/otellogs/OTelLogsSourceTest.java b/data-prepper-plugins/otel-logs-source/src/test/java/org/opensearch/dataprepper/plugins/source/otellogs/OTelLogsSourceTest.java index 2ce4daba91..4bcacdaa99 100644 --- a/data-prepper-plugins/otel-logs-source/src/test/java/org/opensearch/dataprepper/plugins/source/otellogs/OTelLogsSourceTest.java +++ b/data-prepper-plugins/otel-logs-source/src/test/java/org/opensearch/dataprepper/plugins/source/otellogs/OTelLogsSourceTest.java @@ -10,6 +10,7 @@ import com.google.protobuf.util.JsonFormat; import com.linecorp.armeria.client.ClientFactory; import com.linecorp.armeria.client.Clients; +import com.linecorp.armeria.client.UnprocessedRequestException; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.HttpData; @@ -91,6 +92,7 @@ import java.util.StringJoiner; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -105,6 +107,7 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -267,6 +270,41 @@ void testHttpsFullJsonWithNonUnframedRequests() throws InvalidProtocolBufferExce .join(); } + @Test + void testHttpRequestWhenSSLRequiredNoResponse() throws InvalidProtocolBufferException { + final Map settingsMap = new HashMap<>(); + settingsMap.put("request_timeout", 5); + settingsMap.put(SSL, true); + settingsMap.put("useAcmCertForSSL", false); + settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); + settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); + pluginSetting = new PluginSetting("otel_logs", settingsMap); + pluginSetting.setPipelineName("pipeline"); + + oTelLogsSourceConfig = OBJECT_MAPPER.convertValue(pluginSetting.getSettings(), OTelLogsSourceConfig.class); + SOURCE = new OTelLogsSource(oTelLogsSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + SOURCE.start(buffer); + + CompletableFuture future = WebClient.builder() + .factory(ClientFactory.insecure()) + .build() + .execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:2021") + .method(HttpMethod.POST) + .path("/opentelemetry.proto.collector.logs.v1.LogsService/Export") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(LOGS_REQUEST).getBytes())) + .aggregate(); + + ExecutionException exception = assertThrows(ExecutionException.class, + () -> future.get(2, TimeUnit.SECONDS) + ); + assertInstanceOf(UnprocessedRequestException.class, exception.getCause()); + } + @Test void testHttpFullBytesWithNonUnframedRequests() { SOURCE.start(buffer); @@ -770,6 +808,40 @@ void gRPC_with_auth_request_writes_to_buffer_with_successful_response() throws E assertThat(actualBufferWrites, hasSize(1)); } + @Test + void gRPC_with_auth_request_with_different_basic_auth_credentials_does_not_write_to_buffer_with_401_response() throws Exception { + when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); + when(httpBasicAuthenticationConfig.getPassword()).thenReturn(PASSWORD); + final GrpcAuthenticationProvider grpcAuthenticationProvider = new GrpcBasicAuthenticationProvider(httpBasicAuthenticationConfig); + + when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(grpcAuthenticationProvider); + when(oTelLogsSourceConfig.enableUnframedRequests()).thenReturn(true); + when(oTelLogsSourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", + Map.of( + "username", USERNAME, + "password", PASSWORD + ))); + configureObjectUnderTest(); + SOURCE.start(buffer); + + final String wrongCredentials = Base64.getEncoder() + .encodeToString(String.format("%s:%s", "wrong Username", "wrong Password").getBytes(StandardCharsets.UTF_8)); + + final LogsServiceGrpc.LogsServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .addHeader("Authorization", "Basic " + wrongCredentials) + .build(LogsServiceGrpc.LogsServiceBlockingStub.class); + + StatusRuntimeException exception = assertThrows( + StatusRuntimeException.class, + () -> client.export(createExportLogsRequest()) + ); + + assertEquals(Status.Code.UNAUTHENTICATED, exception.getStatus().getCode()); + + verify(buffer, never()).writeAll(any(), anyInt()); + } + @Test void gRPC_request_with_custom_path_throws_when_written_to_default_path() { when(oTelLogsSourceConfig.getPath()).thenReturn(TEST_PATH);