Skip to content

Commit

Permalink
allow partition http header on CORS requests
Browse files Browse the repository at this point in the history
  • Loading branch information
timcoffman committed Aug 27, 2024
1 parent b53ad92 commit 73c8542
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
import ca.uhn.fhir.jpa.model.config.PartitionSettings.CrossPartitionReferenceMode;
import ca.uhn.fhir.jpa.searchparam.matcher.AuthorizationSearchParamMatcher;
import ca.uhn.fhir.jpa.searchparam.matcher.SearchParamMatcher;
import ca.uhn.fhir.jpa.starter.AppProperties;
import ca.uhn.fhir.jpa.starter.annotations.OnCorsPresent;
import ca.uhn.fhir.rest.server.RestfulServer;
import ca.uhn.fhir.rest.server.interceptor.CorsInterceptor;
import ca.uhn.fhir.rest.server.interceptor.auth.IAuthorizationSearchParamMatcher;
import ca.uhn.fhir.rest.server.interceptor.consent.ConsentInterceptor;
import ca.uhn.fhir.rest.server.interceptor.consent.RuleFilteringConsentService;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.cors.CorsConfiguration;

import java.util.Optional;

import javax.annotation.PostConstruct;
import javax.inject.Inject;
Expand All @@ -26,6 +34,9 @@ public class SupplementalDataStorePartitioningConfig {
@Inject
PartitionSettings partitionSettings;

@Inject()
Optional<CorsInterceptor> corsInterceptor;

@Inject
SupplementalDataStoreAuthorizationInterceptor authorizationInterceptor;

Expand All @@ -49,6 +60,11 @@ public void configurePartitioning() {
partitionSettings.setPartitioningEnabled(true);
partitionSettings.setAllowReferencesAcrossPartitions(CrossPartitionReferenceMode.ALLOWED_UNQUALIFIED);
server.registerInterceptor(partitionInterceptor);

corsInterceptor.ifPresent( cors -> {
CorsConfiguration config = cors.getConfig() ;
config.addAllowedHeader( sdsProperties.getPartition().getHttpHeader() );
}) ;
}

@PostConstruct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@

import static java.util.stream.Collectors.toList;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.Calendar;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

import junit.framework.AssertionFailedError;

import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpUriRequest;
import org.hl7.fhir.instance.model.api.IBaseResource;
import org.hl7.fhir.instance.model.api.IIdType;
import org.hl7.fhir.r4.model.Bundle;
Expand All @@ -28,7 +39,9 @@
import ca.uhn.fhir.jpa.starter.Application;
import ca.uhn.fhir.jpa.starter.JpaStarterWebsocketDispatcherConfig;
import ca.uhn.fhir.parser.IParser;
import ca.uhn.fhir.rest.client.apache.ApacheRestfulClientFactory;
import ca.uhn.fhir.rest.client.api.IGenericClient;
import ca.uhn.fhir.rest.client.api.IRestfulClientFactory;
import ca.uhn.fhir.rest.client.api.ServerValidationModeEnum;
import ca.uhn.fhir.rest.client.interceptor.BearerTokenAuthInterceptor;
import ca.uhn.fhir.util.BundleBuilder;
Expand Down Expand Up @@ -82,7 +95,7 @@ protected String fhirServerlBase() {

private int testSpecificIdCount ;
private String testSpecificIdBase ;

protected String createTestSpecificId() {
return String.format( "%1$s-%2$03d", testSpecificIdBase, testSpecificIdCount++ ) ;
}
Expand All @@ -99,7 +112,40 @@ public void resetTestSpecificIdComponents(TestInfo testInfo) {
testNameHashCode
) ;
}

protected <T extends HttpUriRequest> HttpResponse executeRequest( Function<URI,T> requestFactory, String relativePath ) {
return executeRequest( requestFactory, URI.create(relativePath) ) ;
}

protected <T extends HttpUriRequest> HttpResponse executeRequest( Function<URI,T> requestFactory, String relativePath, Function<T,T> configurer ) {
return executeRequest( requestFactory, URI.create(relativePath), configurer ) ;
}

protected <T extends HttpUriRequest> HttpResponse executeRequest( Function<URI,T> requestFactory, URI relativeUri ) {
return executeRequest( requestFactory, relativeUri, q -> q ) ;
}

protected <T extends HttpUriRequest> HttpResponse executeRequest( Function<URI,T> requestFactory, URI relativeUri, Function<T,T> configurer ) {
URI baseUri = URI.create(ourServerBase) ;
URI uri = baseUri.resolve( relativeUri ) ;
T request = requestFactory.apply( uri ) ;
T configuredRequest = configurer.apply( request ) ;
try {
HttpResponse response = nativeHttpClient().execute( configuredRequest );
return response ;
} catch (IOException ex) {
throw new AssertionFailedError( "request failed: " + configuredRequest.getMethod() + " " + configuredRequest.getURI() + "\n" + ex.getMessage() ) ;
}
}

private HttpClient nativeHttpClient() {
IRestfulClientFactory restfulClientFactory = ctx.getRestfulClientFactory();
if ( !(restfulClientFactory instanceof ApacheRestfulClientFactory) )
throw new AssertionFailedError() ;
ApacheRestfulClientFactory apacheRestfulClientFactory = (ApacheRestfulClientFactory)restfulClientFactory ;
return apacheRestfulClientFactory.getNativeHttpClient() ;
}

protected IGenericClient client() { return ctx.newRestfulGenericClient(ourServerBase) ; }

protected IGenericClient clientTargetingPartition( String partitionName ) {
Expand Down
141 changes: 141 additions & 0 deletions src/test/java/edu/ohsu/cmp/ecp/sds/CorsOptionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package edu.ohsu.cmp.ecp.sds;

import static java.util.stream.Collectors.toList;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasItem;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

import org.apache.http.Header;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpOptions;
import org.apache.http.client.methods.HttpUriRequest;
import org.hl7.fhir.instance.model.api.IIdType;
import org.hl7.fhir.r4.model.Bundle ;
import org.hl7.fhir.r4.model.CapabilityStatement;
import org.hl7.fhir.r4.model.IdType;
import org.hl7.fhir.r4.model.Linkage ;
import org.hl7.fhir.r4.model.Patient ;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.context.ActiveProfiles;

import ca.uhn.fhir.jpa.starter.AppTestMockPrincipalRegistry;
import ca.uhn.fhir.rest.api.MethodOutcome;

@ActiveProfiles( "auth-aware-test")
public class CorsOptionTest extends BaseSuppplementalDataStoreTest {

private static final String FOREIGN_PARTITION_NAME = "http://my.ehr.org/fhir/R4/" ;

@Autowired
AppTestMockPrincipalRegistry mockPrincipalRegistry ;

@Autowired
private SupplementalDataStoreProperties sdsProperties ;

private IIdType authorizedPatientId ;

private String authToken ;

@BeforeEach
public void setupAuthorization() {
authorizedPatientId = new IdType( FOREIGN_PARTITION_NAME, "Patient", createTestSpecificId(), null ) ;
authToken = mockPrincipalRegistry.register().principal( "MyPatient", authorizedPatientId.toString() ).token() ;
}

private <T extends HttpUriRequest> T authorize( T request ) {
request.addHeader( HttpHeaders.AUTHORIZATION, "Bearer " + authToken ) ;
return request ;
}

@Test
public void doesIncludePartitionHeaderInCorsOptionsForCapabilityStatement() {
CapabilityStatement cap = client().capabilities().ofType( CapabilityStatement.class ).execute() ;
assertThat( cap, notNullValue() ) ;

HttpResponse resp2 = executeRequest( HttpGet::new, "metadata" ) ;
assertThat( resp2.getStatusLine().getStatusCode(), equalTo( 200 ) ) ;

doesIncludePartitionHeaderInCorsOptions( "GET", "metadata" );
}

@Test
public void doesIncludePartitionHeaderInCorsOptionsForLinkages() {
Bundle linkageBundle =
authenticatingClient( authToken )
.search()
.forResource( Linkage.class )
.where( Linkage.ITEM.hasId( authorizedPatientId.toUnqualifiedVersionless() ) )
.returnBundle( Bundle.class )
.execute()
;
assertThat( linkageBundle, notNullValue() ) ;

doesIncludePartitionHeaderInCorsOptions( "GET", "Linkage?item=" + authorizedPatientId.toUnqualifiedVersionless() );
}

@Test
public void doesIncludePartitionHeaderInCorsOptionsForLinkagesInForeignPartition() {
Bundle linkageBundle =
authenticatingClientTargetingPartition( authToken, FOREIGN_PARTITION_NAME )
.search()
.forResource( Linkage.class )
.where( Linkage.ITEM.hasId( authorizedPatientId.toUnqualifiedVersionless() ) )
.returnBundle( Bundle.class )
.execute()
;
assertThat( linkageBundle, notNullValue() ) ;

doesIncludePartitionHeaderInCorsOptions( "GET", "Linkage?item=" + authorizedPatientId.toUnqualifiedVersionless() );
}

@Test
public void doesIncludePartitionHeaderInCorsOptionsForPatientInForeignPartition() {
MethodOutcome outcome =
authenticatingClientTargetingPartition( authToken, FOREIGN_PARTITION_NAME )
.update()
.resource( new Patient().setId( authorizedPatientId.toUnqualifiedVersionless() ) )
.execute()
;
assertThat( outcome, notNullValue() ) ;

doesIncludePartitionHeaderInCorsOptions( "PUT", "Patient/" + authorizedPatientId.getIdPart() );
}

private void doesIncludePartitionHeaderInCorsOptions( String requestMethod, String relativePath ) {
final String sdsPartitionHeaderName = sdsProperties.getPartition().getHttpHeader();

HttpResponse resp = executeRequest( corsRequest(requestMethod), relativePath ) ;
assertThat( resp.getStatusLine().getStatusCode(), equalTo( 200 ) ) ;
List<String> valuesOfAccessControlAllowHeaders =
Arrays.stream( resp.getHeaders( "Access-Control-Allow-Headers" ) )
.map( Header::getValue )
.map( v -> v.split(",") )
.flatMap( Arrays::stream )
.map( String::trim )
.map( String::toLowerCase )
.collect( toList() ) ;
assertThat( valuesOfAccessControlAllowHeaders, hasItem( sdsPartitionHeaderName.toLowerCase() ) ) ;
}

private Function<URI,HttpOptions> corsRequest( String requestMethod ) {
final String sdsPartitionHeaderName = sdsProperties.getPartition().getHttpHeader();

return uri -> {
HttpOptions optionsRequest = new HttpOptions( uri ) ;
optionsRequest.addHeader( "Origin", "https://10.0.1.1" ) ;
optionsRequest.addHeader( "Access-Control-Request-Method", requestMethod ) ;
optionsRequest.addHeader( "Access-Control-Request-Headers", "content-type, authorization, " + sdsPartitionHeaderName ) ;
return optionsRequest ;
} ;
}
}
8 changes: 4 additions & 4 deletions src/test/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ hapi:
# partitioning:
# allow_references_across_partitions: false
# partitioning_include_in_search_hashes: false
#cors:
# allow_Credentials: true
cors:
allow_Credentials: true
# These are allowed_origin patterns, see: https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/cors/CorsConfiguration.html#setAllowedOriginPatterns-java.util.List-
# allowed_origin:
# - '*'
allowed_origin:
- '*'

# Search coordinator thread pool sizes
search-coord-core-pool-size: 20
Expand Down

0 comments on commit 73c8542

Please sign in to comment.