Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): fixes custom recorder race condition with servlet filter #53

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public AWSXRayServletFilter(String fixedSegmentName) {
}

public AWSXRayServletFilter(SegmentNamingStrategy segmentNamingStrategy) {
this(segmentNamingStrategy, AWSXRay.getGlobalRecorder());
this(segmentNamingStrategy, null);
}

public AWSXRayServletFilter(SegmentNamingStrategy segmentNamingStrategy, AWSXRayRecorder recorder) {
Expand Down Expand Up @@ -245,6 +245,7 @@ private String getSegmentName(HttpServletRequest httpServletRequest) {
}

private SamplingResponse fromSamplingStrategy(HttpServletRequest httpServletRequest) {
AWSXRayRecorder recorder = getRecorder();
SamplingRequest samplingRequest = new SamplingRequest(getSegmentName(httpServletRequest), getHost(httpServletRequest).orElse(null), httpServletRequest.getRequestURI(), httpServletRequest.getMethod(), recorder.getOrigin());
SamplingResponse sample = recorder.getSamplingStrategy().shouldTrace(samplingRequest);
return sample;
Expand All @@ -260,7 +261,15 @@ private SampleDecision getSampleDecision(SamplingResponse sample) {
}
}

private AWSXRayRecorder getRecorder() {
if (recorder == null) {
recorder = AWSXRay.getGlobalRecorder();
}
return recorder;
}

public Segment preFilter(ServletRequest request, ServletResponse response) {
AWSXRayRecorder recorder = getRecorder();
Segment created = null;
HttpServletRequest httpServletRequest = castServletRequest(request);
if (null == httpServletRequest) {
Expand Down Expand Up @@ -348,6 +357,7 @@ public Segment preFilter(ServletRequest request, ServletResponse response) {
}

public void postFilter(ServletRequest request, ServletResponse response) {
AWSXRayRecorder recorder = getRecorder();
Segment segment = recorder.getCurrentSegment();
if (null != segment) {
HttpServletResponse httpServletResponse = castServletResponse(response);
Expand Down Expand Up @@ -397,7 +407,15 @@ public AWSXRayServletAsyncListener(AWSXRayServletFilter filter, AWSXRayRecorder
this.recorder = recorder;
}

private AWSXRayRecorder getRecorder() {
if (recorder == null) {
recorder = AWSXRay.getGlobalRecorder();
}
return recorder;
}

private void processEvent(AsyncEvent event) throws IOException {
AWSXRayRecorder recorder = getRecorder();
Entity prior = recorder.getTraceEntity();
try {
Entity entity = (Entity) event.getSuppliedRequest().getAttribute(ENTITY_ATTRIBUTE_KEY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.amazonaws.xray.AWSXRayRecorder;
import com.amazonaws.xray.strategy.FixedSegmentNamingStrategy;
import com.amazonaws.xray.strategy.sampling.LocalizedSamplingStrategy;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -39,12 +41,16 @@ public class AWSXRayServletFilterTest {

@Before
public void setupAWSXRay() {
AWSXRay.setGlobalRecorder(getMockRecorder());
AWSXRay.clearTraceEntity();
}

private AWSXRayRecorder getMockRecorder() {
Emitter blankEmitter = Mockito.mock(Emitter.class);
LocalizedSamplingStrategy defaultSamplingStrategy = new LocalizedSamplingStrategy();
Mockito.doReturn(true).when(blankEmitter).sendSegment(Mockito.anyObject());
Mockito.doReturn(true).when(blankEmitter).sendSubsegment(Mockito.anyObject());
AWSXRay.setGlobalRecorder(AWSXRayRecorderBuilder.standard().withEmitter(blankEmitter).withSamplingStrategy(defaultSamplingStrategy).build());
AWSXRay.clearTraceEntity();
return AWSXRayRecorderBuilder.standard().withEmitter(blankEmitter).withSamplingStrategy(defaultSamplingStrategy).build();
}

@Test
Expand All @@ -68,6 +74,73 @@ public void testAsyncServletRequestHasListenerAdded() throws IOException, Servle
Mockito.verify(asyncContext, Mockito.times(1)).addListener(Mockito.any());
}

@Test
public void testServletLazilyLoadsRecorder() throws IOException, ServletException {
AWSXRayServletFilter servletFilter = new AWSXRayServletFilter("test");

AsyncContext asyncContext = Mockito.mock(AsyncContext.class);
AWSXRayRecorder customRecorder = getMockRecorder();
Mockito.spy(customRecorder);
AWSXRay.setGlobalRecorder(customRecorder);

HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
Mockito.when(request.getRequestURL()).thenReturn(new StringBuffer("test_url"));
Mockito.when(request.getMethod()).thenReturn("TEST_METHOD");
Mockito.when(request.isAsyncStarted()).thenReturn(true);
Mockito.when(request.getAsyncContext()).thenReturn(asyncContext);

HttpServletResponse response = Mockito.mock(HttpServletResponse.class);

FilterChain chain = Mockito.mock(FilterChain.class);

AsyncEvent event = Mockito.mock(AsyncEvent.class);
Mockito.when(event.getSuppliedRequest()).thenReturn(request);
Mockito.when(event.getSuppliedResponse()).thenReturn(response);

servletFilter.doFilter(request, response, chain);

Entity currentEntity = AWSXRay.getTraceEntity();
Mockito.when(request.getAttribute("com.amazonaws.xray.entities.Entity")).thenReturn(currentEntity);

AWSXRayServletAsyncListener listener = (AWSXRayServletAsyncListener) Whitebox.getInternalState(servletFilter, "listener");
listener.onComplete(event);

Mockito.verify(customRecorder.getEmitter(), Mockito.times(1)).sendSegment(Mockito.any());
}

@Test
public void testServletUsesPassedInRecorder() throws IOException, ServletException {
AWSXRayRecorder customRecorder = getMockRecorder();
Mockito.spy(customRecorder);
AWSXRayServletFilter servletFilter = new AWSXRayServletFilter(new FixedSegmentNamingStrategy("test"), customRecorder);

AsyncContext asyncContext = Mockito.mock(AsyncContext.class);

HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
Mockito.when(request.getRequestURL()).thenReturn(new StringBuffer("test_url"));
Mockito.when(request.getMethod()).thenReturn("TEST_METHOD");
Mockito.when(request.isAsyncStarted()).thenReturn(true);
Mockito.when(request.getAsyncContext()).thenReturn(asyncContext);

HttpServletResponse response = Mockito.mock(HttpServletResponse.class);

FilterChain chain = Mockito.mock(FilterChain.class);

AsyncEvent event = Mockito.mock(AsyncEvent.class);
Mockito.when(event.getSuppliedRequest()).thenReturn(request);
Mockito.when(event.getSuppliedResponse()).thenReturn(response);

servletFilter.doFilter(request, response, chain);

Entity currentEntity = AWSXRay.getTraceEntity();
Mockito.when(request.getAttribute("com.amazonaws.xray.entities.Entity")).thenReturn(currentEntity);

AWSXRayServletAsyncListener listener = (AWSXRayServletAsyncListener) Whitebox.getInternalState(servletFilter, "listener");
listener.onComplete(event);

Mockito.verify(customRecorder.getEmitter(), Mockito.times(1)).sendSegment(Mockito.any());
}

@Test
public void testAWSXRayServletAsyncListenerEmitsSegmentWhenProcessingEvent() throws IOException, ServletException {
AWSXRayServletFilter servletFilter = new AWSXRayServletFilter("test");
Expand Down