diff --git a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java index 860ed9fc551..24566458e11 100644 --- a/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParser.java @@ -146,6 +146,7 @@ public BeanDefinition parse(Element element, ParserContext pc) { BeanMetadataElement saml2LogoutRequestSuccessHandler = BeanDefinitionBuilder .rootBeanDefinition(Saml2RelyingPartyInitiatedLogoutSuccessHandler.class) .addConstructorArgValue(logoutRequestResolver) + .addPropertyValue("logoutRequestRepository", logoutRequestRepository) .getBeanDefinition(); this.logoutFilter = BeanDefinitionBuilder.rootBeanDefinition(LogoutFilter.class) .addConstructorArgValue(saml2LogoutRequestSuccessHandler) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java index 3957d416dae..e13bddf7073 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java @@ -484,6 +484,7 @@ public void saml2LogoutResponseWhenCustomLogoutResponseHandlerThenUses() throws verify(getBean(Saml2LogoutResponseValidator.class)).validate(any()); } + // gh-11363 @Test public void saml2LogoutWhenCustomLogoutRequestRepositoryThenUses() throws Exception { this.spring.register(Saml2LogoutComponentsConfig.class).autowire(); diff --git a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java index 152525d4a20..d51349440a5 100644 --- a/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java @@ -63,6 +63,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; @@ -380,6 +381,22 @@ public void saml2LogoutResponseWhenCustomLogoutResponseHandlerThenUses() throws verify(getBean(Saml2LogoutResponseValidator.class)).validate(any()); } + // gh-11363 + @Test + public void saml2LogoutWhenCustomLogoutRequestRepositoryThenUses() throws Exception { + this.spring.configLocations(this.xml("CustomComponents")).autowire(); + RelyingPartyRegistration registration = this.repository.findByRegistrationId("get"); + Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) + .samlRequest(this.rpLogoutRequest) + .id(this.rpLogoutRequestId) + .relayState(this.rpLogoutRequestRelayState) + .parameters((params) -> params.put("Signature", this.rpLogoutRequestSignature)) + .build(); + given(getBean(Saml2LogoutRequestResolver.class).resolve(any(), any())).willReturn(logoutRequest); + this.mvc.perform(post("/logout").with(authentication(this.saml2User)).with(csrf())); + verify(getBean(Saml2LogoutRequestRepository.class)).saveLogoutRequest(eq(logoutRequest), any(), any()); + } + private T getBean(Class clazz) { return this.spring.getContext().getBean(clazz); }