Skip to content

Commit

Permalink
Support @ModelAttribute with suspending function in WebFlux
Browse files Browse the repository at this point in the history
Closes gh-30894
  • Loading branch information
sdeleuze committed Sep 13, 2023
1 parent f5f8eab commit 29a4dab
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.web.reactive.result.method.annotation;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand All @@ -25,6 +26,7 @@
import reactor.core.publisher.Mono;

import org.springframework.core.Conventions;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
Expand All @@ -45,6 +47,7 @@
* default model initialization through {@code @ModelAttribute} methods.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 5.0
*/
class ModelInitializer {
Expand Down Expand Up @@ -119,18 +122,22 @@ private Mono<Void> handleResult(HandlerResult handlerResult, BindingContext bind
Object value = handlerResult.getReturnValue();
if (value != null) {
ResolvableType type = handlerResult.getReturnType();
MethodParameter typeSource = handlerResult.getReturnTypeSource();
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(type.resolve(), value);
if (isAsyncVoidType(type, adapter)) {
if (isAsyncVoidType(type, typeSource, adapter)) {
return Mono.from(adapter.toPublisher(value));
}
String name = getAttributeName(handlerResult.getReturnTypeSource());
String name = getAttributeName(typeSource);
bindingContext.getModel().asMap().putIfAbsent(name, value);
}
return Mono.empty();
}

private boolean isAsyncVoidType(ResolvableType type, @Nullable ReactiveAdapter adapter) {
return (adapter != null && (adapter.isNoValue() || type.resolveGeneric() == Void.class));

private boolean isAsyncVoidType(ResolvableType type, MethodParameter typeSource, @Nullable ReactiveAdapter adapter) {
Method method = typeSource.getMethod();
return (adapter != null && (adapter.isNoValue() || type.resolveGeneric() == Void.class)) ||
(method != null && KotlinDetector.isSuspendingFunction(method) && typeSource.getParameterType() == void.class);
}

private String getAttributeName(MethodParameter param) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.reactive.result.method.annotation

import kotlinx.coroutines.delay
import org.assertj.core.api.Assertions
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.context.support.StaticApplicationContext
import org.springframework.core.ReactiveAdapterRegistry
import org.springframework.ui.Model
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.ModelAttribute
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer
import org.springframework.web.method.HandlerMethod
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest
import org.springframework.web.testfixture.method.ResolvableMethod
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import java.time.Duration

/**
* Kotlin test fixture for [ModelInitializer].
*
* @author Sebastien Deleuze
*/
class ModelInitializerKotlinTests {

private val timeout = Duration.ofMillis(5000)

private lateinit var modelInitializer: ModelInitializer

private val exchange: ServerWebExchange = MockServerWebExchange.from(MockServerHttpRequest.get("/path"))

@BeforeEach
fun setup() {
val adapterRegistry = ReactiveAdapterRegistry.getSharedInstance()
val resolverConfigurer = ArgumentResolverConfigurer()
resolverConfigurer.addCustomResolver(ModelMethodArgumentResolver(adapterRegistry))
val methodResolver = ControllerMethodResolver(resolverConfigurer, adapterRegistry, StaticApplicationContext(),
emptyList())
modelInitializer = ModelInitializer(methodResolver, adapterRegistry)
}

@Test
@Suppress("UNCHECKED_CAST")
fun modelAttributeMethods() {
val controller = TestController()
val method = ResolvableMethod.on(TestController::class.java).annotPresent(GetMapping::class.java)
.resolveMethod()
val handlerMethod = HandlerMethod(controller, method)
val context = InitBinderBindingContext(ConfigurableWebBindingInitializer(), emptyList())
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(timeout)
val model = context.model.asMap()
Assertions.assertThat(model).hasSize(2)
val monoValue = model["suspendingReturnValue"] as Mono<TestBean>
Assertions.assertThat(monoValue.block(timeout)!!.name).isEqualTo("Suspending return value")
val value = model["suspendingModelParameter"] as TestBean
Assertions.assertThat(value.name).isEqualTo("Suspending model parameter")
}


private data class TestBean(val name: String)

private class TestController {

@ModelAttribute("suspendingReturnValue")
suspend fun suspendingReturnValue(): TestBean {
delay(1)
return TestBean("Suspending return value")
}

@ModelAttribute
suspend fun suspendingModelParameter(model: Model) {
delay(1)
model.addAttribute("suspendingModelParameter", TestBean("Suspending model parameter"))
}

@GetMapping
fun handleGet() {
}

}

}

0 comments on commit 29a4dab

Please sign in to comment.