diff --git a/src/main/java/edu/ohsu/cmp/ecp/sds/SupplementalDataStorePartitioningConfig.java b/src/main/java/edu/ohsu/cmp/ecp/sds/SupplementalDataStorePartitioningConfig.java index 8d8f7c9..50b4711 100644 --- a/src/main/java/edu/ohsu/cmp/ecp/sds/SupplementalDataStorePartitioningConfig.java +++ b/src/main/java/edu/ohsu/cmp/ecp/sds/SupplementalDataStorePartitioningConfig.java @@ -4,12 +4,20 @@ import ca.uhn.fhir.jpa.model.config.PartitionSettings.CrossPartitionReferenceMode; import ca.uhn.fhir.jpa.searchparam.matcher.AuthorizationSearchParamMatcher; import ca.uhn.fhir.jpa.searchparam.matcher.SearchParamMatcher; +import ca.uhn.fhir.jpa.starter.AppProperties; +import ca.uhn.fhir.jpa.starter.annotations.OnCorsPresent; import ca.uhn.fhir.rest.server.RestfulServer; +import ca.uhn.fhir.rest.server.interceptor.CorsInterceptor; import ca.uhn.fhir.rest.server.interceptor.auth.IAuthorizationSearchParamMatcher; import ca.uhn.fhir.rest.server.interceptor.consent.ConsentInterceptor; import ca.uhn.fhir.rest.server.interceptor.consent.RuleFilteringConsentService; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Conditional; import org.springframework.context.annotation.Configuration; +import org.springframework.web.cors.CorsConfiguration; + +import java.util.Optional; import javax.annotation.PostConstruct; import javax.inject.Inject; @@ -26,6 +34,9 @@ public class SupplementalDataStorePartitioningConfig { @Inject PartitionSettings partitionSettings; + @Inject() + Optional corsInterceptor; + @Inject SupplementalDataStoreAuthorizationInterceptor authorizationInterceptor; @@ -49,6 +60,11 @@ public void configurePartitioning() { partitionSettings.setPartitioningEnabled(true); partitionSettings.setAllowReferencesAcrossPartitions(CrossPartitionReferenceMode.ALLOWED_UNQUALIFIED); server.registerInterceptor(partitionInterceptor); + + corsInterceptor.ifPresent( cors -> { + CorsConfiguration config = cors.getConfig() ; + config.addAllowedHeader( sdsProperties.getPartition().getHttpHeader() ); + }) ; } @PostConstruct diff --git a/src/test/java/edu/ohsu/cmp/ecp/sds/BaseSuppplementalDataStoreTest.java b/src/test/java/edu/ohsu/cmp/ecp/sds/BaseSuppplementalDataStoreTest.java index 5b16006..b84099e 100644 --- a/src/test/java/edu/ohsu/cmp/ecp/sds/BaseSuppplementalDataStoreTest.java +++ b/src/test/java/edu/ohsu/cmp/ecp/sds/BaseSuppplementalDataStoreTest.java @@ -2,13 +2,24 @@ import static java.util.stream.Collectors.toList; +import java.io.IOException; import java.lang.reflect.Method; +import java.net.URI; import java.util.Calendar; import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.Function; +import junit.framework.AssertionFailedError; + +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.HttpResponse; +import org.apache.http.client.ClientProtocolException; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpUriRequest; import org.hl7.fhir.instance.model.api.IBaseResource; import org.hl7.fhir.instance.model.api.IIdType; import org.hl7.fhir.r4.model.Bundle; @@ -28,7 +39,9 @@ import ca.uhn.fhir.jpa.starter.Application; import ca.uhn.fhir.jpa.starter.JpaStarterWebsocketDispatcherConfig; import ca.uhn.fhir.parser.IParser; +import ca.uhn.fhir.rest.client.apache.ApacheRestfulClientFactory; import ca.uhn.fhir.rest.client.api.IGenericClient; +import ca.uhn.fhir.rest.client.api.IRestfulClientFactory; import ca.uhn.fhir.rest.client.api.ServerValidationModeEnum; import ca.uhn.fhir.rest.client.interceptor.BearerTokenAuthInterceptor; import ca.uhn.fhir.util.BundleBuilder; @@ -82,7 +95,7 @@ protected String fhirServerlBase() { private int testSpecificIdCount ; private String testSpecificIdBase ; - + protected String createTestSpecificId() { return String.format( "%1$s-%2$03d", testSpecificIdBase, testSpecificIdCount++ ) ; } @@ -99,7 +112,40 @@ public void resetTestSpecificIdComponents(TestInfo testInfo) { testNameHashCode ) ; } + + protected HttpResponse executeRequest( Function requestFactory, String relativePath ) { + return executeRequest( requestFactory, URI.create(relativePath) ) ; + } + + protected HttpResponse executeRequest( Function requestFactory, String relativePath, Function configurer ) { + return executeRequest( requestFactory, URI.create(relativePath), configurer ) ; + } + protected HttpResponse executeRequest( Function requestFactory, URI relativeUri ) { + return executeRequest( requestFactory, relativeUri, q -> q ) ; + } + + protected HttpResponse executeRequest( Function requestFactory, URI relativeUri, Function configurer ) { + URI baseUri = URI.create(ourServerBase) ; + URI uri = baseUri.resolve( relativeUri ) ; + T request = requestFactory.apply( uri ) ; + T configuredRequest = configurer.apply( request ) ; + try { + HttpResponse response = nativeHttpClient().execute( configuredRequest ); + return response ; + } catch (IOException ex) { + throw new AssertionFailedError( "request failed: " + configuredRequest.getMethod() + " " + configuredRequest.getURI() + "\n" + ex.getMessage() ) ; + } + } + + private HttpClient nativeHttpClient() { + IRestfulClientFactory restfulClientFactory = ctx.getRestfulClientFactory(); + if ( !(restfulClientFactory instanceof ApacheRestfulClientFactory) ) + throw new AssertionFailedError() ; + ApacheRestfulClientFactory apacheRestfulClientFactory = (ApacheRestfulClientFactory)restfulClientFactory ; + return apacheRestfulClientFactory.getNativeHttpClient() ; + } + protected IGenericClient client() { return ctx.newRestfulGenericClient(ourServerBase) ; } protected IGenericClient clientTargetingPartition( String partitionName ) { diff --git a/src/test/java/edu/ohsu/cmp/ecp/sds/CorsOptionTest.java b/src/test/java/edu/ohsu/cmp/ecp/sds/CorsOptionTest.java new file mode 100644 index 0000000..8dfe918 --- /dev/null +++ b/src/test/java/edu/ohsu/cmp/ecp/sds/CorsOptionTest.java @@ -0,0 +1,141 @@ +package edu.ohsu.cmp.ecp.sds; + +import static java.util.stream.Collectors.toList; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import org.apache.http.Header; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpOptions; +import org.apache.http.client.methods.HttpUriRequest; +import org.hl7.fhir.instance.model.api.IIdType; +import org.hl7.fhir.r4.model.Bundle ; +import org.hl7.fhir.r4.model.CapabilityStatement; +import org.hl7.fhir.r4.model.IdType; +import org.hl7.fhir.r4.model.Linkage ; +import org.hl7.fhir.r4.model.Patient ; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ActiveProfiles; + +import ca.uhn.fhir.jpa.starter.AppTestMockPrincipalRegistry; +import ca.uhn.fhir.rest.api.MethodOutcome; + +@ActiveProfiles( "auth-aware-test") +public class CorsOptionTest extends BaseSuppplementalDataStoreTest { + + private static final String FOREIGN_PARTITION_NAME = "http://my.ehr.org/fhir/R4/" ; + + @Autowired + AppTestMockPrincipalRegistry mockPrincipalRegistry ; + + @Autowired + private SupplementalDataStoreProperties sdsProperties ; + + private IIdType authorizedPatientId ; + + private String authToken ; + + @BeforeEach + public void setupAuthorization() { + authorizedPatientId = new IdType( FOREIGN_PARTITION_NAME, "Patient", createTestSpecificId(), null ) ; + authToken = mockPrincipalRegistry.register().principal( "MyPatient", authorizedPatientId.toString() ).token() ; + } + + private T authorize( T request ) { + request.addHeader( HttpHeaders.AUTHORIZATION, "Bearer " + authToken ) ; + return request ; + } + + @Test + public void doesIncludePartitionHeaderInCorsOptionsForCapabilityStatement() { + CapabilityStatement cap = client().capabilities().ofType( CapabilityStatement.class ).execute() ; + assertThat( cap, notNullValue() ) ; + + HttpResponse resp2 = executeRequest( HttpGet::new, "metadata" ) ; + assertThat( resp2.getStatusLine().getStatusCode(), equalTo( 200 ) ) ; + + doesIncludePartitionHeaderInCorsOptions( "GET", "metadata" ); + } + + @Test + public void doesIncludePartitionHeaderInCorsOptionsForLinkages() { + Bundle linkageBundle = + authenticatingClient( authToken ) + .search() + .forResource( Linkage.class ) + .where( Linkage.ITEM.hasId( authorizedPatientId.toUnqualifiedVersionless() ) ) + .returnBundle( Bundle.class ) + .execute() + ; + assertThat( linkageBundle, notNullValue() ) ; + + doesIncludePartitionHeaderInCorsOptions( "GET", "Linkage?item=" + authorizedPatientId.toUnqualifiedVersionless() ); + } + + @Test + public void doesIncludePartitionHeaderInCorsOptionsForLinkagesInForeignPartition() { + Bundle linkageBundle = + authenticatingClientTargetingPartition( authToken, FOREIGN_PARTITION_NAME ) + .search() + .forResource( Linkage.class ) + .where( Linkage.ITEM.hasId( authorizedPatientId.toUnqualifiedVersionless() ) ) + .returnBundle( Bundle.class ) + .execute() + ; + assertThat( linkageBundle, notNullValue() ) ; + + doesIncludePartitionHeaderInCorsOptions( "GET", "Linkage?item=" + authorizedPatientId.toUnqualifiedVersionless() ); + } + + @Test + public void doesIncludePartitionHeaderInCorsOptionsForPatientInForeignPartition() { + MethodOutcome outcome = + authenticatingClientTargetingPartition( authToken, FOREIGN_PARTITION_NAME ) + .update() + .resource( new Patient().setId( authorizedPatientId.toUnqualifiedVersionless() ) ) + .execute() + ; + assertThat( outcome, notNullValue() ) ; + + doesIncludePartitionHeaderInCorsOptions( "PUT", "Patient/" + authorizedPatientId.getIdPart() ); + } + + private void doesIncludePartitionHeaderInCorsOptions( String requestMethod, String relativePath ) { + final String sdsPartitionHeaderName = sdsProperties.getPartition().getHttpHeader(); + + HttpResponse resp = executeRequest( corsRequest(requestMethod), relativePath ) ; + assertThat( resp.getStatusLine().getStatusCode(), equalTo( 200 ) ) ; + List valuesOfAccessControlAllowHeaders = + Arrays.stream( resp.getHeaders( "Access-Control-Allow-Headers" ) ) + .map( Header::getValue ) + .map( v -> v.split(",") ) + .flatMap( Arrays::stream ) + .map( String::trim ) + .map( String::toLowerCase ) + .collect( toList() ) ; + assertThat( valuesOfAccessControlAllowHeaders, hasItem( sdsPartitionHeaderName.toLowerCase() ) ) ; + } + + private Function corsRequest( String requestMethod ) { + final String sdsPartitionHeaderName = sdsProperties.getPartition().getHttpHeader(); + + return uri -> { + HttpOptions optionsRequest = new HttpOptions( uri ) ; + optionsRequest.addHeader( "Origin", "https://10.0.1.1" ) ; + optionsRequest.addHeader( "Access-Control-Request-Method", requestMethod ) ; + optionsRequest.addHeader( "Access-Control-Request-Headers", "content-type, authorization, " + sdsPartitionHeaderName ) ; + return optionsRequest ; + } ; + } +} diff --git a/src/test/resources/application.yaml b/src/test/resources/application.yaml index b25036a..5ba1c54 100644 --- a/src/test/resources/application.yaml +++ b/src/test/resources/application.yaml @@ -122,11 +122,11 @@ hapi: # partitioning: # allow_references_across_partitions: false # partitioning_include_in_search_hashes: false - #cors: - # allow_Credentials: true + cors: + allow_Credentials: true # These are allowed_origin patterns, see: https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/cors/CorsConfiguration.html#setAllowedOriginPatterns-java.util.List- - # allowed_origin: - # - '*' + allowed_origin: + - '*' # Search coordinator thread pool sizes search-coord-core-pool-size: 20