Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-7644 Make re-auth status codes configurable #4420

Merged
merged 7 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ internal fun Application.authTestServer() {
val token = call.request.headers["Authorization"]
if (token.isNullOrEmpty() || token.contains("invalid")) {
call.response.header(HttpHeaders.WWWAuthenticate, "Bearer realm=\"TestServer\"")
call.respond(HttpStatusCode.Unauthorized)
val status = call.request.queryParameters["status"]?.toIntOrNull() ?: 401
wkornewald marked this conversation as resolved.
Show resolved Hide resolved
call.respond(HttpStatusCode.fromValue(status))
return@get
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
public final class io/ktor/client/plugins/auth/AuthConfig {
public fun <init> ()V
public final fun getProviders ()Ljava/util/List;
public final fun isUnauthorizedResponse ()Lkotlin/jvm/functions/Function2;
public final fun reAuthorizeOnResponse (Lkotlin/jvm/functions/Function2;)V
}

public final class io/ktor/client/plugins/auth/AuthKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ final class io.ktor.client.plugins.auth/AuthConfig { // io.ktor.client.plugins.a

final val providers // io.ktor.client.plugins.auth/AuthConfig.providers|{}providers[0]
final fun <get-providers>(): kotlin.collections/MutableList<io.ktor.client.plugins.auth/AuthProvider> // io.ktor.client.plugins.auth/AuthConfig.providers.<get-providers>|<get-providers>(){}[0]

final var isUnauthorizedResponse // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse|{}isUnauthorizedResponse[0]
final fun <get-isUnauthorizedResponse>(): kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean> // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse.<get-isUnauthorizedResponse>|<get-isUnauthorizedResponse>(){}[0]

final fun reAuthorizeOnResponse(kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean>) // io.ktor.client.plugins.auth/AuthConfig.reAuthorizeOnResponse|reAuthorizeOnResponse(kotlin.coroutines.SuspendFunction1<io.ktor.client.statement.HttpResponse,kotlin.Boolean>){}[0]
}

final val io.ktor.client.plugins.auth/Auth // io.ktor.client.plugins.auth/Auth|{}Auth[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
* Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins.auth

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.api.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.auth.*
import io.ktor.util.*
Expand All @@ -23,9 +23,36 @@ private class AtomicCounter {
val atomic = atomic(0)
}

/**
* Configuration used by [Auth] plugin.
*/
@KtorDsl
public class AuthConfig {
/**
* [AuthProvider] list to use.
*/
public val providers: MutableList<AuthProvider> = mutableListOf()

/**
* The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* By default checks against HTTP status 401.
*
* You can set this value via [reAuthorizeOnResponse].
*/
@InternalAPI
public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized }
private set
Comment on lines +43 to +45
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an internal API? My intention was to allow access to the current value, so you can extend an existing rule.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to preserve the right to introduce breaking changes to this field. Keeping it public for those who’re ready to further breaking changes.


/**
* Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* Use this to change the value of [isUnauthorizedResponse].
*/
public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) {
@OptIn(InternalAPI::class)
isUnauthorizedResponse = block
}
}

/**
Expand All @@ -39,8 +66,9 @@ public val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-request")
*
* You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html).
*
* [providers] - list of auth providers to use.
* @see [AuthConfig] for configuration options.
*/
@OptIn(InternalAPI::class)
public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthConfig) {
val providers = pluginConfig.providers.toList()

Expand All @@ -50,7 +78,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
val tokenVersionsAttributeKey =
AttributeKey<MutableMap<AuthProvider, Int>>("ProviderVersionAttributeKey")

@OptIn(InternalAPI::class)
fun findProvider(
call: HttpClientCall,
candidateProviders: Set<AuthProvider>
Expand All @@ -64,10 +91,10 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
}

authHeaders.isEmpty() -> {
LOGGER.trace(
"401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
LOGGER.trace {
"Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
"Can not add or refresh token"
)
}
null
}

Expand All @@ -88,9 +115,9 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
val requestTokenVersion = requestTokenVersions[provider]

if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) {
LOGGER.trace("Refreshing token for ${call.request.url}")
LOGGER.trace { "Refreshing token for ${call.request.url}" }
if (!provider.refreshToken(call.response)) {
LOGGER.trace("Refreshing token failed for ${call.request.url}")
LOGGER.trace { "Refreshing token failed for ${call.request.url}" }
return false
} else {
requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet()
Expand All @@ -99,7 +126,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
return true
}

@OptIn(InternalAPI::class)
suspend fun Send.Sender.executeWithNewToken(
call: HttpClientCall,
provider: AuthProvider,
Expand All @@ -111,13 +137,13 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
provider.addRequestHeaders(request, authHeader)
request.attributes.put(AuthCircuitBreaker, Unit)

LOGGER.trace("Sending new request to ${call.request.url}")
LOGGER.trace { "Sending new request to ${call.request.url}" }
return proceed(request)
}

onRequest { request, _ ->
providers.filter { it.sendWithoutRequest(request) }.forEach { provider ->
LOGGER.trace("Adding auth headers for ${request.url} from provider $provider")
LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" }
val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() }
val requestTokenVersions = request.attributes
.computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() }
Expand All @@ -128,22 +154,22 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon

on(Send) { originalRequest ->
val origin = proceed(originalRequest)
if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin
if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin
if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin

var call = origin

val candidateProviders = HashSet(providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
LOGGER.trace("Received 401 for ${call.request.url}")
while (pluginConfig.isUnauthorizedResponse(call.response)) {
LOGGER.trace { "Unauthorized response for ${call.request.url}" }

val (provider, authHeader) = findProvider(call, candidateProviders) ?: run {
LOGGER.trace("Can not find auth provider for ${call.request.url}")
LOGGER.trace { "Can not find auth provider for ${call.request.url}" }
return@on call
}

LOGGER.trace("Using provider $provider for ${call.request.url}")
LOGGER.trace { "Using provider $provider for ${call.request.url}" }

candidateProviders.remove(provider)
if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,27 @@ class AuthTest : ClientLoader() {
}
}

@Test
fun testForbiddenBearerAuthWithInvalidAccessAndValidRefreshTokens() = clientTests {
config {
install(Auth) {
reAuthorizeOnResponse { it.status == HttpStatusCode.Forbidden }
bearer {
refreshTokens { BearerTokens("valid", "refresh") }
loadTokens { BearerTokens("invalid", "refresh") }
}
}

expectSuccess = false
}

test { client ->
client.prepareGet("$TEST_SERVER/auth/bearer/test-refresh?status=403").execute {
assertEquals(HttpStatusCode.OK, it.status)
}
}
}

// The return of refreshTokenFun is null, cause it should not be called at all, if loadTokensFun returns valid tokens
@Test
fun testUnauthorizedBearerAuthWithValidAccessTokenAndInvalidRefreshToken() = clientTests {
Expand Down