diff --git a/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderFilter.java b/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderFilter.java index 490b5bcce07..77cd2878431 100644 --- a/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderFilter.java +++ b/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderFilter.java @@ -15,6 +15,7 @@ import com.enonic.xp.portal.idprovider.IdProviderControllerExecutionParams; import com.enonic.xp.portal.idprovider.IdProviderControllerService; import com.enonic.xp.security.auth.AuthenticationInfo; +import com.enonic.xp.web.dispatch.DispatchConstants; import com.enonic.xp.web.filter.OncePerRequestFilter; @Component(immediate = true, service = Filter.class, property = {"connector=xp", "connector=api"}) @@ -48,8 +49,12 @@ protected void doHandle( final HttpServletRequest req, final HttpServletResponse } //Wraps the response to handle 403 errors - final IdProviderResponseWrapper responseWrapper = new IdProviderResponseWrapper( idProviderControllerService, req, res ); + final HttpServletResponse response = DispatchConstants.XP_CONNECTOR.equals( req.getAttribute( DispatchConstants.CONNECTOR_ATTRIBUTE ) ) + ? new IdProviderResponseWrapper( idProviderControllerService, req, res ) + : res; + final IdProviderRequestWrapper requestWrapper = new IdProviderRequestWrapper( req ); - chain.doFilter( requestWrapper, responseWrapper ); + + chain.doFilter( requestWrapper, response ); } } diff --git a/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderResponseWrapper.java b/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderResponseWrapper.java index 3caff8dad94..38c5433df50 100644 --- a/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderResponseWrapper.java +++ b/modules/portal/portal-impl/src/main/java/com/enonic/xp/portal/impl/idprovider/IdProviderResponseWrapper.java @@ -2,7 +2,8 @@ import java.io.IOException; import java.io.PrintWriter; -import java.io.StringWriter; +import java.io.UncheckedIOException; +import java.io.Writer; import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; @@ -13,9 +14,6 @@ import com.enonic.xp.context.ContextAccessor; import com.enonic.xp.portal.idprovider.IdProviderControllerExecutionParams; import com.enonic.xp.portal.idprovider.IdProviderControllerService; -import com.enonic.xp.security.auth.AuthenticationInfo; -import com.enonic.xp.util.Exceptions; - public class IdProviderResponseWrapper extends HttpServletResponseWrapper @@ -40,7 +38,14 @@ public IdProviderResponseWrapper( final IdProviderControllerService idProviderCo @Override public void setStatus( final int sc ) { - handleError( sc ); + try + { + handleError( sc ); + } + catch ( IOException e ) + { + throw new UncheckedIOException( e ); + } if ( !errorHandled ) { @@ -54,7 +59,7 @@ public PrintWriter getWriter() { if ( errorHandled ) { - return new PrintWriter( new StringWriter() ); + return new PrintWriter( Writer.nullWriter() ); } return super.getWriter(); } @@ -124,43 +129,30 @@ public void sendError( final int sc, final String msg ) } private void handleError( final int sc ) + throws IOException { if ( !errorHandled && isUnauthorizedError( sc ) && !isErrorAlreadyHandled() ) { - try + final IdProviderControllerExecutionParams executionParams = IdProviderControllerExecutionParams.create() + .functionName( "handle401" ) + .servletRequest( request ) + .response( response ) + .build(); + final boolean responseSerialized = idProviderControllerService.execute( executionParams ) != null; + if ( responseSerialized ) { - final IdProviderControllerExecutionParams executionParams = IdProviderControllerExecutionParams.create(). - functionName( "handle401" ). - servletRequest( request ). - response( response ). - build(); - final boolean responseSerialized = idProviderControllerService.execute( executionParams ) != null; - if ( responseSerialized ) - { - errorHandled = true; - } - } - catch ( IOException e ) - { - throw Exceptions.unchecked( e ); + errorHandled = true; } } } private boolean isUnauthorizedError( final int sc ) { - return 401 == sc || ( 403 == sc && !isAuthenticated() ); + return 401 == sc || 403 == sc && !ContextAccessor.current().getAuthInfo().isAuthenticated(); } private boolean isErrorAlreadyHandled() { return Boolean.TRUE.equals( request.getAttribute( "error.handled" ) ); } - - - private boolean isAuthenticated() - { - final AuthenticationInfo authInfo = ContextAccessor.current().getAuthInfo(); - return authInfo.isAuthenticated(); - } } diff --git a/modules/web/web-impl/build.gradle b/modules/web/web-impl/build.gradle index e247053fc69..0702611bc47 100644 --- a/modules/web/web-impl/build.gradle +++ b/modules/web/web-impl/build.gradle @@ -3,6 +3,7 @@ dependencies { implementation project( ':core:core-internal' ) testImplementation( testFixtures( project(":web:web-jetty") ) ) + testImplementation( testFixtures( project(":core:core-api") ) ) } jar { diff --git a/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/AuthRequiredFilter.java b/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/AuthRequiredFilter.java new file mode 100644 index 00000000000..8ef9decfb62 --- /dev/null +++ b/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/AuthRequiredFilter.java @@ -0,0 +1,35 @@ +package com.enonic.xp.web.impl.auth; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.annotation.WebFilter; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.osgi.service.component.annotations.Component; + +import com.google.common.net.HttpHeaders; + +import com.enonic.xp.annotation.Order; +import com.enonic.xp.web.filter.OncePerRequestFilter; + +@Component(immediate = true, service = Filter.class, property = {"connector=api"}) +@Order(-20) +@WebFilter("/*") +public class AuthRequiredFilter + extends OncePerRequestFilter +{ + @Override + protected void doHandle( final HttpServletRequest req, final HttpServletResponse res, final FilterChain chain ) + throws Exception + { + if ( req.getUserPrincipal() == null ) + { + res.addHeader( HttpHeaders.WWW_AUTHENTICATE, "Basic"); + res.addHeader( HttpHeaders.WWW_AUTHENTICATE, "Bearer"); + res.sendError( HttpServletResponse.SC_UNAUTHORIZED ); + return; + } + chain.doFilter( req, res ); + } +} diff --git a/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/BasicAuthFilter.java b/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/BasicAuthFilter.java index cac8492e895..c01c8b824a7 100644 --- a/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/BasicAuthFilter.java +++ b/modules/web/web-impl/src/main/java/com/enonic/xp/web/impl/auth/BasicAuthFilter.java @@ -49,6 +49,12 @@ protected void doHandle( final HttpServletRequest req, final HttpServletResponse private void login( final HttpServletRequest req ) { + final AuthenticationInfo authInfo = ContextAccessor.current().getAuthInfo(); + if ( authInfo.isAuthenticated() ) + { + return; + } + final String header = req.getHeader( HttpHeaders.AUTHORIZATION ); if ( header == null ) { @@ -75,8 +81,8 @@ private static String[] parseHeader( final String header ) return null; } - final String type = header.substring( 0, 5 ).toUpperCase(); - if ( !type.equals( HttpServletRequest.BASIC_AUTH ) ) + final String type = header.substring( 0, 5 ); + if ( !type.equalsIgnoreCase( HttpServletRequest.BASIC_AUTH ) ) { return null; } @@ -84,14 +90,14 @@ private static String[] parseHeader( final String header ) final String val = header.substring( 6 ); final String decoded = new String( Base64.getDecoder().decode( val ), StandardCharsets.UTF_8 ); - final String[] parts = decoded.split( ":" ); - if ( parts.length != 2 ) + int pos = decoded.indexOf( ':' ); + if ( pos == -1 ) { return null; } - return parts; + return new String[]{decoded.substring( 0, pos ), decoded.substring( pos + 1 )}; } private AuthenticationInfo authenticate( final String user, final String password ) diff --git a/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/AuthRequiredFilterTest.java b/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/AuthRequiredFilterTest.java new file mode 100644 index 00000000000..070ea154e38 --- /dev/null +++ b/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/AuthRequiredFilterTest.java @@ -0,0 +1,67 @@ +package com.enonic.xp.web.impl.auth; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import com.google.common.net.HttpHeaders; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class AuthRequiredFilterTest +{ + private AuthRequiredFilter authRequiredFilter; + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + @BeforeEach + void setUp() + { + authRequiredFilter = new AuthRequiredFilter(); + } + + @Test + void doHandle_whenUserPrincipalIsNull_sendsUnauthorizedError() + throws Exception + { + when( request.getUserPrincipal() ).thenReturn( null ); + + authRequiredFilter.doHandle( request, response, filterChain ); + + verify( response ).addHeader( HttpHeaders.WWW_AUTHENTICATE, "Basic" ); + verify( response ).addHeader( HttpHeaders.WWW_AUTHENTICATE, "Bearer" ); + verify( response ).sendError( HttpServletResponse.SC_UNAUTHORIZED ); + verify( filterChain, never() ).doFilter( request, response ); + } + + @Test + void doHandle_whenUserPrincipalIsNotNull_proceedsWithFilterChain() + throws Exception + { + when( request.getUserPrincipal() ).thenReturn( () -> "user" ); + + authRequiredFilter.doHandle( request, response, filterChain ); + + verify( response, never() ).addHeader( anyString(), anyString() ); + verify( response, never() ).sendError( anyInt() ); + verify( filterChain ).doFilter( request, response ); + } +} diff --git a/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/BasicAuthFilterTest.java b/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/BasicAuthFilterTest.java index 6f1c565950c..e9977f32505 100644 --- a/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/BasicAuthFilterTest.java +++ b/modules/web/web-impl/src/test/java/com/enonic/xp/web/impl/auth/BasicAuthFilterTest.java @@ -6,12 +6,14 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import com.google.common.net.HttpHeaders; +import com.enonic.xp.context.ContextAccessor; import com.enonic.xp.security.IdProvider; import com.enonic.xp.security.IdProviderKey; import com.enonic.xp.security.IdProviders; @@ -19,7 +21,12 @@ import com.enonic.xp.security.SecurityService; import com.enonic.xp.security.User; import com.enonic.xp.security.auth.AuthenticationInfo; +import com.enonic.xp.session.SessionMock; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class BasicAuthFilterTest @@ -37,6 +44,7 @@ public class BasicAuthFilterTest @BeforeEach public void setup() { + ContextAccessor.current().getLocalScope().setSession( new SessionMock() ); this.request = Mockito.mock( HttpServletRequest.class ); this.response = Mockito.mock( HttpServletResponse.class ); this.chain = Mockito.mock( FilterChain.class ); @@ -50,15 +58,26 @@ public void setup() when( this.securityService.getIdProviders() ).thenReturn( idProviders ); } - private void rightAuthentication() + @AfterEach + public void tearDown() + { + ContextAccessor.current().getLocalScope().setSession( null ); + } + + private AuthenticationInfo goodAuthenticationInfo() { final User user = User.create().login( "user" ).key( PrincipalKey.ofUser( IdProviderKey.from( "store" ), "user" ) ).build(); - when( this.securityService.authenticate( Mockito.any() ) ).thenReturn( AuthenticationInfo.create().user( user ).build() ); + return AuthenticationInfo.create().user( user ).build(); + } + + private void rightAuthentication() + { + when( this.securityService.authenticate( any() ) ).thenReturn( goodAuthenticationInfo() ); } private void wrongAuthentication() { - when( this.securityService.authenticate( Mockito.any() ) ).thenReturn( AuthenticationInfo.unAuthenticated() ); + when( this.securityService.authenticate( any() ) ).thenReturn( AuthenticationInfo.unAuthenticated() ); } private void doFilter() @@ -70,7 +89,7 @@ private void doFilter() private void verifyChain() throws Exception { - Mockito.verify( this.chain, Mockito.times( 1 ) ).doFilter( this.request, this.response ); + verify( this.chain, Mockito.times( 1 ) ).doFilter( this.request, this.response ); } private void setAuthHeader( final String value ) @@ -78,6 +97,11 @@ private void setAuthHeader( final String value ) when( request.getHeader( HttpHeaders.AUTHORIZATION ) ).thenReturn( value ); } + private static void verifyAuthenticated( boolean yes ) + { + assertEquals( yes, ContextAccessor.current().getAuthInfo().isAuthenticated() ); + } + private String base64( final String value ) { return Base64.getEncoder().encodeToString( value.getBytes() ); @@ -89,6 +113,7 @@ public void noHeader() { doFilter(); verifyChain(); + verifyAuthenticated( false ); } @Test @@ -98,6 +123,7 @@ public void header_wrongFormat() setAuthHeader( "some-value" ); doFilter(); verifyChain(); + verifyAuthenticated( false ); } @Test @@ -107,6 +133,7 @@ public void header_noCredentials() setAuthHeader( "BASIC" ); doFilter(); verifyChain(); + verifyAuthenticated( false ); } @Test @@ -116,6 +143,7 @@ public void header_noPassword() setAuthHeader( "BASIC " + base64( "user" ) ); doFilter(); verifyChain(); + verifyAuthenticated( false ); } @Test @@ -126,6 +154,7 @@ public void header_defaultIdProvider_noAccess() setAuthHeader( "BASIC " + base64( "user:wrong" ) ); doFilter(); verifyChain(); + verifyAuthenticated( false ); } @Test @@ -136,5 +165,29 @@ public void header_defaultIdProvider_authenticated() setAuthHeader( "BASIC " + base64( "user:password" ) ); doFilter(); verifyChain(); + verifyAuthenticated( true ); + } + + @Test + public void header_defaultIdProvider_authenticated_colon_allowed() + throws Exception + { + rightAuthentication(); + setAuthHeader( "BASIC " + base64( "user:pass:word" ) ); + doFilter(); + verifyChain(); + verifyAuthenticated( true ); + } + + @Test + public void preAuthenticated() + throws Exception + { + setAuthHeader( "BASIC " + base64( "user:password" ) ); + ContextAccessor.current().getLocalScope().setAttribute( goodAuthenticationInfo() ); + doFilter(); + verifyChain(); + verifyAuthenticated( true ); + verify( this.securityService, never() ).authenticate( any() ); } }