Skip to content

Commit

Permalink
Inherit parent context in coRouter DSL
Browse files Browse the repository at this point in the history
This commit also allows context override, as it
is useful for the nested router use case.

Closes gh-31831
  • Loading branch information
sdeleuze committed Dec 13, 2023
1 parent 8d4deca commit a01c6d5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
* @see RouterFunctions.nest
*/
fun RequestPredicate.nest(r: (CoRouterFunctionDsl.() -> Unit)) {
builder.add(nest(this, CoRouterFunctionDsl(r).build()))
builder.add(nest(this, CoRouterFunctionDsl(r).also { it.contextProvider = contextProvider }.build()))
}


Expand Down Expand Up @@ -628,9 +628,6 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
* @since 6.1
*/
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
if (this.contextProvider != null) {
throw IllegalStateException("The Coroutine context provider should not be defined more than once")
}
this.contextProvider = provider
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,45 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}

@Test
fun nestedContextProvider() {
val mockRequest = get("https://example.com/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}

@Test
fun nestedContextProviderWithOverride() {
val mockRequest = get("https://example.com/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProviderOverride.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}

@Test
fun doubleNestedContextProvider() {
val mockRequest = get("https://example.com/nested/nested/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(nestedRouterWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}

@Test
fun contextProviderAndFilter() {
val mockRequest = get("https://example.com/")
Expand Down Expand Up @@ -323,6 +362,36 @@ class CoRouterFunctionDslTests {
}
}

private val nestedRouterWithContextProvider = coRouter {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
"/nested".nest {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
"/nested".nest {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
}
}

private val nestedRouterWithContextProviderOverride = coRouter {
context {
CoroutineName("parent-context")
}
"/nested".nest {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
}

private val routerWithoutContext = coRouter {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
Expand Down

0 comments on commit a01c6d5

Please sign in to comment.