diff --git a/kong/constants.lua b/kong/constants.lua index 364f29302c1..57d66b777e1 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -32,11 +32,14 @@ return { PROXY_TIME = "X-Kong-Proxy-Time", API_TIME = "X-Kong-Api-Time", CONSUMER_ID = "X-Consumer-ID", + CONSUMER_CUSTOM_ID = "X-Consumer-Custom-ID", + CONSUMER_USERNAME = "X-Consumer-Username", RATELIMIT_LIMIT = "X-RateLimit-Limit", RATELIMIT_REMAINING = "X-RateLimit-Remaining" }, CACHE = { APIS = "apis", + CONSUMERS = "consumers", PLUGINS_CONFIGURATIONS = "plugins_configurations", BASICAUTH_CREDENTIAL = "basicauth_credentials", KEYAUTH_CREDENTIAL = "keyauth_credentials", diff --git a/kong/plugins/basicauth/access.lua b/kong/plugins/basicauth/access.lua index b6c4604d6af..cb7c8d21a4c 100644 --- a/kong/plugins/basicauth/access.lua +++ b/kong/plugins/basicauth/access.lua @@ -85,7 +85,18 @@ function _M.execute(conf) return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") end - ngx.req.set_header(constants.HEADERS.CONSUMER_ID, credential.consumer_id) + -- Retrieve consumer + local consumer = cache.get_and_set(cache.consumer_key(credential.consumer_id), function() + local result, err = dao.consumers:find_one(credential.consumer_id) + if err then + return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + end + return result + end) + + ngx.req.set_header(constants.HEADERS.CONSUMER_ID, consumer.id) + ngx.req.set_header(constants.HEADERS.CONSUMER_CUSTOM_ID, consumer.custom_id) + ngx.req.set_header(constants.HEADERS.CONSUMER_USERNAME, consumer.username) ngx.ctx.authenticated_entity = credential end diff --git a/kong/plugins/keyauth/access.lua b/kong/plugins/keyauth/access.lua index b8197c08678..ae054642dcb 100644 --- a/kong/plugins/keyauth/access.lua +++ b/kong/plugins/keyauth/access.lua @@ -139,7 +139,18 @@ function _M.execute(conf) return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") end - ngx.req.set_header(constants.HEADERS.CONSUMER_ID, credential.consumer_id) + -- Retrieve consumer + local consumer = cache.get_and_set(cache.consumer_key(credential.consumer_id), function() + local result, err = dao.consumers:find_one(credential.consumer_id) + if err then + return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + end + return result + end) + + ngx.req.set_header(constants.HEADERS.CONSUMER_ID, consumer.id) + ngx.req.set_header(constants.HEADERS.CONSUMER_CUSTOM_ID, consumer.custom_id) + ngx.req.set_header(constants.HEADERS.CONSUMER_USERNAME, consumer.username) ngx.ctx.authenticated_entity = credential end diff --git a/kong/tools/database_cache.lua b/kong/tools/database_cache.lua index 9777d5cec94..721e2af2b6f 100644 --- a/kong/tools/database_cache.lua +++ b/kong/tools/database_cache.lua @@ -44,6 +44,10 @@ function _M.api_key(host) return constants.CACHE.APIS.."/"..host end +function _M.consumer_key(id) + return constants.CACHE.CONSUMERS.."/"..id +end + function _M.plugin_configuration_key(name, api_id, consumer_id) return constants.CACHE.PLUGINS_CONFIGURATIONS.."/"..name.."/"..api_id..(consumer_id and "/"..consumer_id or "") end diff --git a/spec/plugins/authentication_spec.lua b/spec/plugins/authentication_spec.lua index 607070606b3..d96aaffa708 100644 --- a/spec/plugins/authentication_spec.lua +++ b/spec/plugins/authentication_spec.lua @@ -16,7 +16,7 @@ describe("Authentication Plugin", function() { name = "tests auth 3", public_dns = "test3.com", target_url = "http://mockbin.com" } }, consumer = { - { username = "auth_tests_consuser" } + { username = "auth_tests_consumer" } }, plugin_configuration = { { name = "keyauth", value = { key_names = { "apikey" }}, __api = 1 }, @@ -89,6 +89,15 @@ describe("Authentication Plugin", function() assert.are.equal("Invalid authentication credentials", body.message) end) + it("should set right headers", function() + local response, status = http_client.post(STUB_POST_URL, {apikey = "apikey123"}, {host = "test1.com"}) + assert.are.equal(200, status) + local parsed_response = cjson.decode(response) + assert.truthy(parsed_response.headers["x-consumer-id"]) + assert.truthy(parsed_response.headers["x-consumer-username"]) + assert.are.equal("auth_tests_consumer", parsed_response.headers["x-consumer-username"]) + end) + describe("Hide credentials", function() it("should pass with POST and hide credentials", function() @@ -206,5 +215,14 @@ describe("Authentication Plugin", function() assert.are.equal("Basic dXNlcm5hbWU6cGFzc3dvcmQ=", parsed_response.headers.authorization) end) + it("should set right headers", function() + local response, status = http_client.post(STUB_POST_URL, {}, {host = "test2.com", authorization = "Basic dXNlcm5hbWU6cGFzc3dvcmQ="}) + assert.are.equal(200, status) + local parsed_response = cjson.decode(response) + assert.truthy(parsed_response.headers["x-consumer-id"]) + assert.truthy(parsed_response.headers["x-consumer-username"]) + assert.are.equal("auth_tests_consumer", parsed_response.headers["x-consumer-username"]) + end) + end) end)