diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/UrlHandlerFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/UrlHandlerFilter.java new file mode 100644 index 000000000000..5d1d2c62f0eb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/UrlHandlerFilter.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.server.PathContainer; +import org.springframework.http.server.RequestPath; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * {@link org.springframework.web.server.WebFilter} that modifies the URL, and + * then redirects or wraps the request to apply the change. + * + *

To create an instance, you can use the following: + * + *

+ * UrlHandlerFilter filter = UrlHandlerFilter
+ *    .trailingSlashHandler("/path1/**").redirect(HttpStatus.PERMANENT_REDIRECT)
+ *    .trailingSlashHandler("/path2/**").mutateRequest()
+ *    .build();
+ * 
+ * + *

This {@code WebFilter} should be ordered ahead of security filters. + * + * @author Rossen Stoyanchev + * @since 6.2 + */ +public final class UrlHandlerFilter implements WebFilter { + + private static final Log logger = LogFactory.getLog(UrlHandlerFilter.class); + + + private final MultiValueMap handlers; + + + private UrlHandlerFilter(MultiValueMap handlers) { + this.handlers = new LinkedMultiValueMap<>(handlers); + } + + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + RequestPath path = exchange.getRequest().getPath(); + for (Map.Entry> entry : this.handlers.entrySet()) { + if (!entry.getKey().canHandle(exchange)) { + continue; + } + for (PathPattern pattern : entry.getValue()) { + if (pattern.matches(path)) { + return entry.getKey().handle(exchange, chain); + } + } + } + return chain.filter(exchange); + } + + /** + * Create a builder by adding a handler for URL's with a trailing slash. + * @param pathPatterns path patterns to map the handler to, e.g. + * "/path/*", "/path/**", + * "/path/foo/". + * @return a spec to configure the trailing slash handler with + * @see Builder#trailingSlashHandler(String...) + */ + public static Builder.TrailingSlashSpec trailingSlashHandler(String... pathPatterns) { + return new DefaultBuilder().trailingSlashHandler(pathPatterns); + } + + + /** + * Builder for {@link UrlHandlerFilter}. + */ + public interface Builder { + + /** + * Add a handler for URL's with a trailing slash. + * @param pathPatterns path patterns to map the handler to, e.g. + * "/path/*", "/path/**", + * "/path/foo/". + * @return a spec to configure the handler with + */ + TrailingSlashSpec trailingSlashHandler(String... pathPatterns); + + /** + * Build the {@link UrlHandlerFilter} instance. + */ + UrlHandlerFilter build(); + + + /** + * A spec to configure a trailing slash handler. + */ + interface TrailingSlashSpec { + + /** + * Configure a request interceptor to be called just before the handler + * is invoked when a URL with a trailing slash is matched. + */ + TrailingSlashSpec intercept(Function> interceptor); + + /** + * Handle requests by sending a redirect to the same URL but the + * trailing slash trimmed. + * @param statusCode the redirect status to use + * @return the top level {@link Builder}, which allows adding more + * handlers and then building the Filter instance. + */ + Builder redirect(HttpStatusCode statusCode); + + /** + * Handle the request by wrapping it in order to trim the trailing + * slash, and delegating to the rest of the filter chain. + * @return the top level {@link Builder}, which allows adding more + * handlers and then building the Filter instance. + */ + Builder mutateRequest(); + } + } + + + /** + * Default {@link Builder} implementation. + */ + private static final class DefaultBuilder implements Builder { + + private final PathPatternParser patternParser = new PathPatternParser(); + + private final MultiValueMap handlers = new LinkedMultiValueMap<>(); + + @Override + public TrailingSlashSpec trailingSlashHandler(String... patterns) { + return new DefaultTrailingSlashSpec(patterns); + } + + private DefaultBuilder addHandler(List pathPatterns, Handler handler) { + pathPatterns.forEach(pattern -> this.handlers.add(handler, pattern)); + return this; + } + + @Override + public UrlHandlerFilter build() { + return new UrlHandlerFilter(this.handlers); + } + + + private final class DefaultTrailingSlashSpec implements TrailingSlashSpec { + + private final List pathPatterns; + + @Nullable + private List>> interceptors; + + private DefaultTrailingSlashSpec(String[] patterns) { + this.pathPatterns = Arrays.stream(patterns) + .map(pattern -> pattern.endsWith("**") || pattern.endsWith("/") ? pattern : pattern + "/") + .map(patternParser::parse) + .toList(); + } + + @Override + public TrailingSlashSpec intercept(Function> interceptor) { + this.interceptors = (this.interceptors != null ? this.interceptors : new ArrayList<>()); + this.interceptors.add(interceptor); + return this; + } + + @Override + public Builder redirect(HttpStatusCode statusCode) { + Handler handler = new RedirectTrailingSlashHandler(statusCode, this.interceptors); + return DefaultBuilder.this.addHandler(this.pathPatterns, handler); + } + + @Override + public Builder mutateRequest() { + Handler handler = new RequestWrappingTrailingSlashHandler(this.interceptors); + return DefaultBuilder.this.addHandler(this.pathPatterns, handler); + } + } + } + + + /** + * Internal handler to encapsulate different ways to handle a request. + */ + private interface Handler { + + /** + * Whether the handler handles the given request. + */ + boolean canHandle(ServerWebExchange exchange); + + /** + * Handle the request, possibly delegating to the rest of the filter chain. + */ + Mono handle(ServerWebExchange exchange, WebFilterChain chain); + } + + + /** + * Base class for trailing slash {@link Handler} implementations. + */ + private abstract static class AbstractTrailingSlashHandler implements Handler { + + private static final List>> defaultInterceptors = + List.of(request -> { + if (logger.isTraceEnabled()) { + logger.trace("Handling trailing slash URL: " + request.getMethod() + " " + request.getURI()); + } + return Mono.empty(); + }); + + private final List>> interceptors; + + protected AbstractTrailingSlashHandler(@Nullable List>> interceptors) { + this.interceptors = (interceptors != null ? new ArrayList<>(interceptors) : defaultInterceptors); + } + + @Override + public boolean canHandle(ServerWebExchange exchange) { + List elements = exchange.getRequest().getPath().elements(); + return (elements.size() > 1 && elements.get(elements.size() - 1).value().equals("/")); + } + + @Override + public Mono handle(ServerWebExchange exchange, WebFilterChain chain) { + List> monos = new ArrayList<>(this.interceptors.size()); + this.interceptors.forEach(interceptor -> monos.add(interceptor.apply(exchange.getRequest()))); + return Flux.concat(monos).then(Mono.defer(() -> handleInternal(exchange, chain))); + } + + protected abstract Mono handleInternal(ServerWebExchange exchange, WebFilterChain chain); + + protected String trimTrailingSlash(ServerHttpRequest request) { + String path = request.getURI().getRawPath(); + int index = (StringUtils.hasLength(path) ? path.lastIndexOf('/') : -1); + return (index != -1 ? path.substring(0, index) : path); + } + } + + + /** + * Path handler that sends a redirect. + */ + private static final class RedirectTrailingSlashHandler extends AbstractTrailingSlashHandler { + + private final HttpStatusCode statusCode; + + RedirectTrailingSlashHandler( + HttpStatusCode statusCode, @Nullable List>> interceptors) { + + super(interceptors); + this.statusCode = statusCode; + } + + @Override + public Mono handleInternal(ServerWebExchange exchange, WebFilterChain chain) { + ServerHttpResponse response = exchange.getResponse(); + response.setStatusCode(this.statusCode); + response.getHeaders().set(HttpHeaders.LOCATION, trimTrailingSlash(exchange.getRequest())); + return Mono.empty(); + } + } + + + /** + * Path handler that mutates the request and continues processing. + */ + private static final class RequestWrappingTrailingSlashHandler extends AbstractTrailingSlashHandler { + + RequestWrappingTrailingSlashHandler(@Nullable List>> interceptors) { + super(interceptors); + } + + @Override + public Mono handleInternal(ServerWebExchange exchange, WebFilterChain chain) { + ServerHttpRequest request = exchange.getRequest(); + ServerHttpRequest mutatedRequest = request.mutate().path(trimTrailingSlash(request)).build(); + return chain.filter(exchange.mutate().request(mutatedRequest).build()); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/UrlHandlerFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/UrlHandlerFilterTests.java new file mode 100644 index 000000000000..a92f2194af9d --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/UrlHandlerFilterTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.net.URI; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; +import org.springframework.web.server.handler.DefaultWebFilterChain; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; +import org.springframework.web.testfixture.server.MockServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link UrlHandlerFilter}. + * + * @author Rossen Stoyanchev + */ +public class UrlHandlerFilterTests { + + @Test + void requestMutation() { + UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler("/path/**").mutateRequest().build(); + + String path = "/path/123"; + MockServerHttpRequest original = MockServerHttpRequest.get(path + "/").build(); + ServerWebExchange exchange = MockServerWebExchange.from(original); + + ServerHttpRequest actual = invokeFilter(filter, exchange); + + assertThat(actual).isNotNull().isNotSameAs(original); + assertThat(actual.getPath().value()).isEqualTo(path); + } + + @Test + void redirect() { + HttpStatus status = HttpStatus.PERMANENT_REDIRECT; + UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler("/path/*").redirect(status).build(); + + String path = "/path/123"; + MockServerHttpRequest original = MockServerHttpRequest.get(path + "/").build(); + ServerWebExchange exchange = MockServerWebExchange.from(original); + + assertThatThrownBy(() -> invokeFilter(filter, exchange)) + .hasMessageContaining("No argument value was captured"); + + assertThat(exchange.getResponse().getStatusCode()).isEqualTo(status); + assertThat(exchange.getResponse().getHeaders().getLocation()).isEqualTo(URI.create(path)); + } + + @Test + void noUrlHandling() { + testNoUrlHandling("/path/**", "/path/123"); + testNoUrlHandling("/path/*", "/path/123"); + testNoUrlHandling("/**", "/"); // gh-33444 + } + + private static void testNoUrlHandling(String pattern, String path) { + + // No request mutation + UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler(pattern).mutateRequest().build(); + + MockServerHttpRequest request = MockServerHttpRequest.get(path).build(); + ServerWebExchange exchange = MockServerWebExchange.from(request); + + ServerHttpRequest actual = invokeFilter(filter, exchange); + + assertThat(actual).isNotNull().isSameAs(request); + assertThat(actual.getPath().value()).isEqualTo(path); + + // No redirect + HttpStatus status = HttpStatus.PERMANENT_REDIRECT; + filter = UrlHandlerFilter.trailingSlashHandler(pattern).redirect(status).build(); + + request = MockServerHttpRequest.get(path).build(); + exchange = MockServerWebExchange.from(request); + + actual = invokeFilter(filter, exchange); + + assertThat(actual).isNotNull().isSameAs(request); + assertThat(exchange.getResponse().getStatusCode()).isNull(); + assertThat(exchange.getResponse().getHeaders().getLocation()).isNull(); + } + + private static ServerHttpRequest invokeFilter(UrlHandlerFilter filter, ServerWebExchange exchange) { + WebHandler handler = mock(WebHandler.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(ServerWebExchange.class); + given(handler.handle(captor.capture())).willReturn(Mono.empty()); + + WebFilterChain chain = new DefaultWebFilterChain(handler, List.of(filter)); + filter.filter(exchange, chain).block(); + + return captor.getValue().getRequest(); + } + +}