diff --git a/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilter.java b/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilter.java index 3650ae76..bee53517 100644 --- a/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilter.java +++ b/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilter.java @@ -56,7 +56,7 @@ public AWSXRayServletFilter(String fixedSegmentName) { } public AWSXRayServletFilter(SegmentNamingStrategy segmentNamingStrategy) { - this(segmentNamingStrategy, AWSXRay.getGlobalRecorder()); + this(segmentNamingStrategy, null); } public AWSXRayServletFilter(SegmentNamingStrategy segmentNamingStrategy, AWSXRayRecorder recorder) { @@ -245,6 +245,7 @@ private String getSegmentName(HttpServletRequest httpServletRequest) { } private SamplingResponse fromSamplingStrategy(HttpServletRequest httpServletRequest) { + AWSXRayRecorder recorder = getRecorder(); SamplingRequest samplingRequest = new SamplingRequest(getSegmentName(httpServletRequest), getHost(httpServletRequest).orElse(null), httpServletRequest.getRequestURI(), httpServletRequest.getMethod(), recorder.getOrigin()); SamplingResponse sample = recorder.getSamplingStrategy().shouldTrace(samplingRequest); return sample; @@ -260,7 +261,15 @@ private SampleDecision getSampleDecision(SamplingResponse sample) { } } + private AWSXRayRecorder getRecorder() { + if (recorder == null) { + recorder = AWSXRay.getGlobalRecorder(); + } + return recorder; + } + public Segment preFilter(ServletRequest request, ServletResponse response) { + AWSXRayRecorder recorder = getRecorder(); Segment created = null; HttpServletRequest httpServletRequest = castServletRequest(request); if (null == httpServletRequest) { @@ -348,6 +357,7 @@ public Segment preFilter(ServletRequest request, ServletResponse response) { } public void postFilter(ServletRequest request, ServletResponse response) { + AWSXRayRecorder recorder = getRecorder(); Segment segment = recorder.getCurrentSegment(); if (null != segment) { HttpServletResponse httpServletResponse = castServletResponse(response); @@ -397,7 +407,15 @@ public AWSXRayServletAsyncListener(AWSXRayServletFilter filter, AWSXRayRecorder this.recorder = recorder; } + private AWSXRayRecorder getRecorder() { + if (recorder == null) { + recorder = AWSXRay.getGlobalRecorder(); + } + return recorder; + } + private void processEvent(AsyncEvent event) throws IOException { + AWSXRayRecorder recorder = getRecorder(); Entity prior = recorder.getTraceEntity(); try { Entity entity = (Entity) event.getSuppliedRequest().getAttribute(ENTITY_ATTRIBUTE_KEY); diff --git a/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilterTest.java b/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilterTest.java index 37864d87..9a5b4271 100644 --- a/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilterTest.java +++ b/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/javax/servlet/AWSXRayServletFilterTest.java @@ -9,6 +9,8 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import com.amazonaws.xray.AWSXRayRecorder; +import com.amazonaws.xray.strategy.FixedSegmentNamingStrategy; import com.amazonaws.xray.strategy.sampling.LocalizedSamplingStrategy; import org.junit.Assert; import org.junit.Before; @@ -39,12 +41,16 @@ public class AWSXRayServletFilterTest { @Before public void setupAWSXRay() { + AWSXRay.setGlobalRecorder(getMockRecorder()); + AWSXRay.clearTraceEntity(); + } + + private AWSXRayRecorder getMockRecorder() { Emitter blankEmitter = Mockito.mock(Emitter.class); LocalizedSamplingStrategy defaultSamplingStrategy = new LocalizedSamplingStrategy(); Mockito.doReturn(true).when(blankEmitter).sendSegment(Mockito.anyObject()); Mockito.doReturn(true).when(blankEmitter).sendSubsegment(Mockito.anyObject()); - AWSXRay.setGlobalRecorder(AWSXRayRecorderBuilder.standard().withEmitter(blankEmitter).withSamplingStrategy(defaultSamplingStrategy).build()); - AWSXRay.clearTraceEntity(); + return AWSXRayRecorderBuilder.standard().withEmitter(blankEmitter).withSamplingStrategy(defaultSamplingStrategy).build(); } @Test @@ -68,6 +74,73 @@ public void testAsyncServletRequestHasListenerAdded() throws IOException, Servle Mockito.verify(asyncContext, Mockito.times(1)).addListener(Mockito.any()); } + @Test + public void testServletLazilyLoadsRecorder() throws IOException, ServletException { + AWSXRayServletFilter servletFilter = new AWSXRayServletFilter("test"); + + AsyncContext asyncContext = Mockito.mock(AsyncContext.class); + AWSXRayRecorder customRecorder = getMockRecorder(); + Mockito.spy(customRecorder); + AWSXRay.setGlobalRecorder(customRecorder); + + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.getRequestURL()).thenReturn(new StringBuffer("test_url")); + Mockito.when(request.getMethod()).thenReturn("TEST_METHOD"); + Mockito.when(request.isAsyncStarted()).thenReturn(true); + Mockito.when(request.getAsyncContext()).thenReturn(asyncContext); + + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + FilterChain chain = Mockito.mock(FilterChain.class); + + AsyncEvent event = Mockito.mock(AsyncEvent.class); + Mockito.when(event.getSuppliedRequest()).thenReturn(request); + Mockito.when(event.getSuppliedResponse()).thenReturn(response); + + servletFilter.doFilter(request, response, chain); + + Entity currentEntity = AWSXRay.getTraceEntity(); + Mockito.when(request.getAttribute("com.amazonaws.xray.entities.Entity")).thenReturn(currentEntity); + + AWSXRayServletAsyncListener listener = (AWSXRayServletAsyncListener) Whitebox.getInternalState(servletFilter, "listener"); + listener.onComplete(event); + + Mockito.verify(customRecorder.getEmitter(), Mockito.times(1)).sendSegment(Mockito.any()); + } + + @Test + public void testServletUsesPassedInRecorder() throws IOException, ServletException { + AWSXRayRecorder customRecorder = getMockRecorder(); + Mockito.spy(customRecorder); + AWSXRayServletFilter servletFilter = new AWSXRayServletFilter(new FixedSegmentNamingStrategy("test"), customRecorder); + + AsyncContext asyncContext = Mockito.mock(AsyncContext.class); + + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.getRequestURL()).thenReturn(new StringBuffer("test_url")); + Mockito.when(request.getMethod()).thenReturn("TEST_METHOD"); + Mockito.when(request.isAsyncStarted()).thenReturn(true); + Mockito.when(request.getAsyncContext()).thenReturn(asyncContext); + + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + FilterChain chain = Mockito.mock(FilterChain.class); + + AsyncEvent event = Mockito.mock(AsyncEvent.class); + Mockito.when(event.getSuppliedRequest()).thenReturn(request); + Mockito.when(event.getSuppliedResponse()).thenReturn(response); + + servletFilter.doFilter(request, response, chain); + + Entity currentEntity = AWSXRay.getTraceEntity(); + Mockito.when(request.getAttribute("com.amazonaws.xray.entities.Entity")).thenReturn(currentEntity); + + AWSXRayServletAsyncListener listener = (AWSXRayServletAsyncListener) Whitebox.getInternalState(servletFilter, "listener"); + listener.onComplete(event); + + Mockito.verify(customRecorder.getEmitter(), Mockito.times(1)).sendSegment(Mockito.any()); + } + @Test public void testAWSXRayServletAsyncListenerEmitsSegmentWhenProcessingEvent() throws IOException, ServletException { AWSXRayServletFilter servletFilter = new AWSXRayServletFilter("test");