Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-7644 Make re-auth status codes configurable #4420

Merged
merged 7 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ internal fun Application.authTestServer() {
val token = call.request.headers["Authorization"]
if (token.isNullOrEmpty() || token.contains("invalid")) {
call.response.header(HttpHeaders.WWWAuthenticate, "Bearer realm=\"TestServer\"")
call.respond(HttpStatusCode.Unauthorized)
val status = call.request.queryParameters["status"]?.toIntOrNull() ?: 401
wkornewald marked this conversation as resolved.
Show resolved Hide resolved
call.respond(HttpStatusCode.fromValue(status))
return@get
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
public final class io/ktor/client/plugins/auth/AuthConfig {
public fun <init> ()V
public final fun getProviders ()Ljava/util/List;
public final fun isUnauthorizedResponse ()Lkotlin/jvm/functions/Function2;
public final fun reAuthorizeOnResponse (Lkotlin/jvm/functions/Function2;)V
}

public final class io/ktor/client/plugins/auth/AuthKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ final class io.ktor.client.plugins.auth/AuthConfig { // io.ktor.client.plugins.a

final val providers // io.ktor.client.plugins.auth/AuthConfig.providers|{}providers[0]
final fun <get-providers>(): kotlin.collections/MutableList<io.ktor.client.plugins.auth/AuthProvider> // io.ktor.client.plugins.auth/AuthConfig.providers.<get-providers>|<get-providers>(){}[0]

final var isUnauthorizedResponse // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse|{}isUnauthorizedResponse[0]
final fun <get-isUnauthorizedResponse>(): kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean> // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse.<get-isUnauthorizedResponse>|<get-isUnauthorizedResponse>(){}[0]

final fun reAuthorizeOnResponse(kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean>) // io.ktor.client.plugins.auth/AuthConfig.reAuthorizeOnResponse|reAuthorizeOnResponse(kotlin.coroutines.SuspendFunction1<io.ktor.client.statement.HttpResponse,kotlin.Boolean>){}[0]
}

final val io.ktor.client.plugins.auth/Auth // io.ktor.client.plugins.auth/Auth|{}Auth[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ package io.ktor.client.plugins.auth

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.api.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.auth.*
import io.ktor.util.*
Expand All @@ -23,9 +23,34 @@ private class AtomicCounter {
val atomic = atomic(0)
}

/**
* Configuration used by [Auth] plugin.
*/
@KtorDsl
public class AuthConfig {
/**
* [AuthProvider] list to use.
*/
public val providers: MutableList<AuthProvider> = mutableListOf()

/**
* The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* By default checks against HTTP status 401.
*
* You can set this value via [reAuthorizeOnResponse].
*/
public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized }
private set

/**
* Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth.
*
* Use this to change the value of [isUnauthorizedResponse].
*/
public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) {
isUnauthorizedResponse = block
}
}

/**
Expand All @@ -39,7 +64,7 @@ public val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-request")
*
* You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html).
*
* [providers] - list of auth providers to use.
* @see [AuthConfig] for configuration options.
*/
public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthConfig) {
val providers = pluginConfig.providers.toList()
Expand All @@ -64,10 +89,10 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
}

authHeaders.isEmpty() -> {
LOGGER.trace(
"401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
LOGGER.trace {
"Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
"Can not add or refresh token"
)
}
null
}

Expand All @@ -88,9 +113,9 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
val requestTokenVersion = requestTokenVersions[provider]

if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) {
LOGGER.trace("Refreshing token for ${call.request.url}")
LOGGER.trace { "Refreshing token for ${call.request.url}" }
if (!provider.refreshToken(call.response)) {
LOGGER.trace("Refreshing token failed for ${call.request.url}")
LOGGER.trace { "Refreshing token failed for ${call.request.url}" }
return false
} else {
requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet()
Expand All @@ -111,13 +136,13 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
provider.addRequestHeaders(request, authHeader)
request.attributes.put(AuthCircuitBreaker, Unit)

LOGGER.trace("Sending new request to ${call.request.url}")
LOGGER.trace { "Sending new request to ${call.request.url}" }
return proceed(request)
}

onRequest { request, _ ->
providers.filter { it.sendWithoutRequest(request) }.forEach { provider ->
LOGGER.trace("Adding auth headers for ${request.url} from provider $provider")
LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" }
val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() }
val requestTokenVersions = request.attributes
.computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() }
Expand All @@ -128,22 +153,22 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon

on(Send) { originalRequest ->
val origin = proceed(originalRequest)
if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin
if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin
if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin

var call = origin

val candidateProviders = HashSet(providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
LOGGER.trace("Received 401 for ${call.request.url}")
while (pluginConfig.isUnauthorizedResponse(call.response)) {
LOGGER.trace { "Unauthorized response for ${call.request.url}" }

val (provider, authHeader) = findProvider(call, candidateProviders) ?: run {
LOGGER.trace("Can not find auth provider for ${call.request.url}")
LOGGER.trace { "Can not find auth provider for ${call.request.url}" }
return@on call
}

LOGGER.trace("Using provider $provider for ${call.request.url}")
LOGGER.trace { "Using provider $provider for ${call.request.url}" }

candidateProviders.remove(provider)
if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,27 @@ class AuthTest : ClientLoader() {
}
}

@Test
fun testForbiddenBearerAuthWithInvalidAccessAndValidRefreshTokens() = clientTests {
config {
install(Auth) {
reAuthorizeOnResponse { it.status == HttpStatusCode.Forbidden }
bearer {
refreshTokens { BearerTokens("valid", "refresh") }
loadTokens { BearerTokens("invalid", "refresh") }
}
}

expectSuccess = false
}

test { client ->
client.prepareGet("$TEST_SERVER/auth/bearer/test-refresh?status=403").execute {
assertEquals(HttpStatusCode.OK, it.status)
}
}
}

// The return of refreshTokenFun is null, cause it should not be called at all, if loadTokensFun returns valid tokens
@Test
fun testUnauthorizedBearerAuthWithValidAccessTokenAndInvalidRefreshToken() = clientTests {
Expand Down