From e7c6afa602b6cb742b780e2706f8d9ba34970ae6 Mon Sep 17 00:00:00 2001 From: thefosk Date: Wed, 6 Apr 2016 17:30:54 -0700 Subject: [PATCH] LDAP_AUTH (plugin) LDAP Authentication Plugin Plugin authenticates user against LDAP server --- kong-0.8.0rc2-0.rockspec | 10 +- kong/constants.lua | 2 +- kong/plugins/ldap-auth/access.lua | 114 ++++++++ kong/plugins/ldap-auth/asn1.lua | 360 +++++++++++++++++++++++++ kong/plugins/ldap-auth/handler.lua | 17 ++ kong/plugins/ldap-auth/ldap.lua | 144 ++++++++++ kong/plugins/ldap-auth/schema.lua | 14 + kong/tools/database_cache.lua | 7 +- spec/plugins/ldap-auth/access_spec.lua | 127 +++++++++ 9 files changed, 792 insertions(+), 3 deletions(-) create mode 100644 kong/plugins/ldap-auth/access.lua create mode 100644 kong/plugins/ldap-auth/asn1.lua create mode 100644 kong/plugins/ldap-auth/handler.lua create mode 100644 kong/plugins/ldap-auth/ldap.lua create mode 100644 kong/plugins/ldap-auth/schema.lua create mode 100644 spec/plugins/ldap-auth/access_spec.lua diff --git a/kong-0.8.0rc2-0.rockspec b/kong-0.8.0rc2-0.rockspec index a7e1b08fe06..2bb7322e9ec 100644 --- a/kong-0.8.0rc2-0.rockspec +++ b/kong-0.8.0rc2-0.rockspec @@ -13,6 +13,7 @@ description = { dependencies = { "luasec ~> 0.5-2", + "penlight ~> 1.3.2", "lua_uuid ~> 0.2.0-2", "lua_system_constants ~> 0.1.1-0", "luatz ~> 0.3-1", @@ -33,7 +34,8 @@ dependencies = { "lrexlib-pcre ~> 2.7.2-1", "lua-llthreads2 ~> 0.1.3-1", "luacrypto >= 0.3.2-1", - "luasyslog >= 1.0.0-2" + "luasyslog >= 1.0.0-2", + "lua_pack ~> 1.0.4-0" } build = { type = "builtin", @@ -247,6 +249,12 @@ build = { ["kong.plugins.hmac-auth.api"] = "kong/plugins/hmac-auth/api.lua", ["kong.plugins.hmac-auth.daos"] = "kong/plugins/hmac-auth/daos.lua", + ["kong.plugins.ldap-auth.handler"] = "kong/plugins/ldap-auth/handler.lua", + ["kong.plugins.ldap-auth.access"] = "kong/plugins/ldap-auth/access.lua", + ["kong.plugins.ldap-auth.schema"] = "kong/plugins/ldap-auth/schema.lua", + ["kong.plugins.ldap-auth.ldap"] = "kong/plugins/ldap-auth/ldap.lua", + ["kong.plugins.ldap-auth.asn1"] = "kong/plugins/ldap-auth/asn1.lua", + ["kong.plugins.syslog.handler"] = "kong/plugins/syslog/handler.lua", ["kong.plugins.syslog.schema"] = "kong/plugins/syslog/schema.lua", diff --git a/kong/constants.lua b/kong/constants.lua index 0b1af70b5b0..166494364bc 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -18,7 +18,7 @@ return { "http-log", "key-auth", "hmac-auth", "basic-auth", "ip-restriction", "mashape-analytics", "request-transformer", "response-transformer", "request-size-limiting", "rate-limiting", "response-ratelimiting", "syslog", - "loggly", "datadog", "runscope" + "loggly", "datadog", "runscope", "ldap-auth" }, -- Non standard headers, specific to Kong HEADERS = { diff --git a/kong/plugins/ldap-auth/access.lua b/kong/plugins/ldap-auth/access.lua new file mode 100644 index 00000000000..4a75c8aa197 --- /dev/null +++ b/kong/plugins/ldap-auth/access.lua @@ -0,0 +1,114 @@ +local responses = require "kong.tools.responses" +local constants = require "kong.constants" +local cache = require "kong.tools.database_cache" +local base64 = require "base64" +local ldap = require "kong.plugins.ldap-auth.ldap" + +local match = string.match +local ngx_log = ngx.log +local request = ngx.req +local ngx_error = ngx.ERR +local ngx_debug = ngx.DEBUG +local ngx_socket_tcp = ngx.socket.tcp +local tostring = tostring + +local AUTHORIZATION = "authorization" +local PROXY_AUTHORIZATION = "proxy-authorization" + +local _M = {} + +local function retrieve_credentials(authorization_header_value, conf) + local username, password + if authorization_header_value then + local cred = match(authorization_header_value, "%s*[ldap|LDAP]%s+(.*)") + + if cred ~= nil then + local decoded_cred = base64.decode(cred) + username, password = match(decoded_cred, "(.+):(.+)") + end + end + return username, password +end + +local function ldap_authenticate(given_username, given_password, conf) + local is_authenticated + local error, suppressed_err, ok + local who = conf.attribute.."="..given_username..","..conf.base_dn + + local sock = ngx_socket_tcp() + sock:settimeout(conf.timeout) + ok, error = sock:connect(conf.ldap_host, conf.ldap_port) + if not ok then + ngx_log(ngx_error, "[ldap-auth] failed to connect to "..conf.ldap_host..":"..tostring(conf.ldap_port)..": ", error) + return responses.send_HTTP_INTERNAL_SERVER_ERROR(error) + end + + if conf.start_tls then + local success, error = ldap.start_tls(sock) + if not success then + return false, error + end + local _, error = sock:sslhandshake(true, conf.ldap_host, conf.verify_ldap_host) + if error ~= nil then + return false, "failed to do SSL handshake with "..conf.ldap_host..":"..tostring(conf.ldap_port)..": ".. error + end + end + + is_authenticated, error = ldap.bind_request(sock, who, given_password) + + ok, suppressed_err = sock:setkeepalive(conf.keepalive) + if not ok then + ngx_log(ngx_error, "[ldap-auth] failed to keepalive to "..conf.ldap_host..":"..tostring(conf.ldap_port)..": ", suppressed_err) + end + return is_authenticated, error +end + +local function authenticate(conf, given_credentials) + local given_username, given_password = retrieve_credentials(given_credentials) + if given_username == nil then + return false + end + + local credential = cache.get_or_set(cache.ldap_credential_key(given_username), function() + ngx_log(ngx_debug, "[ldap-auth] authenticating user against LDAP server: "..conf.ldap_host..":"..conf.ldap_port) + + local ok, err = ldap_authenticate(given_username, given_password, conf) + if err ~= nil then ngx_log(ngx_error, err) end + if not ok then + return nil + end + return {username = given_username, password = given_password} + end, conf.cache_ttl) + + return credential and credential.password == given_password, credential +end + +function _M.execute(conf) + local authorization_value = request.get_headers()[AUTHORIZATION] + local proxy_authorization_value = request.get_headers()[PROXY_AUTHORIZATION] + + -- If both headers are missing, return 401 + if not (authorization_value or proxy_authorization_value) then + ngx.header["WWW-Authenticate"] = 'LDAP realm="kong"' + return responses.send_HTTP_UNAUTHORIZED() + end + + local is_authorized, credential = authenticate(conf, proxy_authorization_value) + if not is_authorized then + is_authorized, credential = authenticate(conf, authorization_value) + end + + if not is_authorized then + return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") + end + + if conf.hide_credentials then + request.clear_header(AUTHORIZATION) + request.clear_header(PROXY_AUTHORIZATION) + end + + request.set_header(constants.HEADERS.CREDENTIAL_USERNAME, credential.username) + ngx.ctx.authenticated_credential = credential +end + +return _M diff --git a/kong/plugins/ldap-auth/asn1.lua b/kong/plugins/ldap-auth/asn1.lua new file mode 100644 index 00000000000..a9b5a6ed863 --- /dev/null +++ b/kong/plugins/ldap-auth/asn1.lua @@ -0,0 +1,360 @@ +require "lua_pack" + +local bpack = string.pack +local bunpack = string.unpack +local math = math +local bit = bit +local setmetatable = setmetatable +local table = table +local string_reverse = string.reverse +local string_char = string.char + +local _M = {} + +_M.BERCLASS = { + Universal = 0, + Application = 64, + ContextSpecific = 128, + Private = 192 +} + +_M.ASN1Decoder = { + + new = function(self,o) + o = o or {} + setmetatable(o, self) + self.__index = self + o:registerBaseDecoders() + return o + end, + + decode = function(self, encStr, pos) + local etype, elen + local newpos = pos + newpos, etype = bunpack(encStr, "X1", newpos) + newpos, elen = self.decodeLength(encStr, newpos) + if self.decoder[etype] then + return self.decoder[etype](self, encStr, elen, newpos) + else + return newpos, nil + end + end, + + setStopOnError = function(self, val) + self.stoponerror = val + end, + + registerBaseDecoders = function(self) + self.decoder = {} + + self.decoder["0A"] = function(self, encStr, elen, pos) + return self.decodeInt(encStr, elen, pos) + end + + self.decoder["8A"] = function(self, encStr, elen, pos) + return bunpack(encStr, "A" .. elen, pos) + end + + self.decoder["31"] = function(self, encStr, elen, pos) + return pos, nil + end + + -- Boolean + self.decoder["01"] = function(self, encStr, elen, pos) + local val = bunpack(encStr, "X", pos) + if val ~= "FF" then + return pos, true + else + return pos, false + end + end + + -- Integer + self.decoder["02"] = function(self, encStr, elen, pos) + return self.decodeInt(encStr, elen, pos) + end + + -- Octet String + self.decoder["04"] = function(self, encStr, elen, pos) + return bunpack(encStr, "A" .. elen, pos) + end + + -- Null + self.decoder["05"] = function(self, encStr, elen, pos) + return pos, false + end + + -- Object Identifier + self.decoder["06"] = function(self, encStr, elen, pos) + return self:decodeOID(encStr, elen, pos) + end + + -- Context specific tags + self.decoder["30"] = function(self, encStr, elen, pos) + return self:decodeSeq(encStr, elen, pos) + end + end, + + registerTagDecoders = function(self, tagDecoders) + self:registerBaseDecoders() + for k, v in pairs(tagDecoders) do + self.decoder[k] = v + end + end, + + decodeLength = function(encStr, pos) + local elen + pos, elen = bunpack(encStr, 'C', pos) + if (elen > 128) then + elen = elen - 128 + local elenCalc = 0 + local elenNext + for i = 1, elen do + elenCalc = elenCalc * 256 + pos, elenNext = bunpack(encStr, 'C', pos) + elenCalc = elenCalc + elenNext + end + elen = elenCalc + end + return pos, elen + end, + + decodeSeq = function(self, encStr, len, pos) + local seq = {} + local sPos = 1 + local sStr + pos, sStr = bunpack(encStr, "A" .. len, pos) + while (sPos < len) do + local newSeq + sPos, newSeq = self:decode(sStr, sPos) + if (not(newSeq) and self.stoponerror) then break end + table.insert(seq, newSeq) + end + return pos, seq + end, + + decode_oid_component = function(encStr, pos) + local octet + local n = 0 + + repeat + pos, octet = bunpack(encStr, "b", pos) + n = n * 128 + bit.band(0x7F, octet) + until octet < 128 + + return pos, n + end, + + decodeOID = function(self, encStr, len, pos) + local last + local oid = {} + local octet + + last = pos + len - 1 + if pos <= last then + oid._snmp = '06' + pos, octet = bunpack(encStr, "C", pos) + oid[2] = math.fmod(octet, 40) + octet = octet - oid[2] + oid[1] = octet/40 + end + + while pos <= last do + local c + pos, c = self.decode_oid_component(encStr, pos) + oid[#oid + 1] = c + end + + return pos, oid + end, + + decodeInt = function(encStr, len, pos) + local hexStr + pos, hexStr = bunpack(encStr, "X" .. len, pos) + local value = tonumber(hexStr, 16) + if (value >= math.pow(256, len)/2) then + value = value - math.pow(256, len) + end + return pos, value + end +} + +_M.ASN1Encoder = { + + new = function(self) + local o = {} + setmetatable(o, self) + self.__index = self + o:registerBaseEncoders() + return o + end, + + encodeSeq = function(self, seqData) + return bpack('XAA' , '30', self.encodeLength(#seqData), seqData) + end, + + encode = function(self, val) + local vtype = type(val) + + if self.encoder[vtype] then + return self.encoder[vtype](self,val) + end + end, + + registerTagEncoders = function(self, tagEncoders) + self:registerBaseEncoders() + for k, v in pairs(tagEncoders) do + self.encoder[k] = v + end + end, + + registerBaseEncoders = function(self) + self.encoder = {} + + self.encoder['table'] = function(self, val) + if (val._ldap == '0A') then + local ival = self.encodeInt(val[1]) + local len = self.encodeLength(#ival) + return bpack('XAA', '0A', len, ival) + end + if (val._ldaptype) then + local len + if val[1] == nil or #val[1] == 0 then + return bpack('XC', val._ldaptype, 0) + else + len = self.encodeLength(#val[1]) + return bpack('XAA', val._ldaptype, len, val[1]) + end + end + + local encVal = "" + for _, v in ipairs(val) do + encVal = encVal .. self.encode(v) -- todo: buffer? + end + local tableType = "\x30" + if (val["_snmp"]) then + tableType = bpack("X", val["_snmp"]) + end + return bpack('AAA', tableType, self.encodeLength(#encVal), encVal) + end + + -- Boolean encoder + self.encoder['boolean'] = function(self, val) + if val then + return bpack('X','01 01 FF') + else + return bpack('X', '01 01 00') + end + end + + -- Integer encoder + self.encoder['number'] = function(self, val) + local ival = self.encodeInt(val) + local len = self.encodeLength(#ival) + return bpack('XAA', '02', len, ival) + end + + -- Octet String encoder + self.encoder['string'] = function(self, val) + local len = self.encodeLength(#val) + return bpack('XAA', '04', len, val) + end + + -- Null encoder + self.encoder['nil'] = function(self, val) + return bpack('X', '05 00') + end + + end, + + encode_oid_component = function(n) + local parts = {} + parts[1] = string_char(bit.mod(n, 128)) + while n >= 128 do + n = bit.rshift(n, 7) + parts[#parts + 1] = string_char(bit.mod(n, 128) + 0x80) + end + return string_reverse(table.concat(parts)) + end, + + encodeInt = function(val) + local lsb = 0 + if val > 0 then + local valStr = "" + while (val > 0) do + lsb = math.fmod(val, 256) + valStr = valStr .. bpack('C', lsb) + val = math.floor(val/256) + end + if lsb > 127 then + valStr = valStr .. "\0" + end + + return string_reverse(valStr) + elseif val < 0 then + local i = 1 + local tcval = val + 256 + while tcval <= 127 do + tcval = tcval + (math.pow(256, i) * 255) + i = i+1 + end + local valStr = "" + while (tcval > 0) do + lsb = math.fmod(tcval, 256) + valStr = valStr .. bpack("C", lsb) + tcval = math.floor(tcval/256) + end + return string_reverse(valStr) + else -- val == 0 + return bpack("x") + end + end, + + encodeLength = function(len) + if len < 128 then + return string_char(len) + else + local parts = {} + + while len > 0 do + parts[#parts + 1] = string_char(bit.mod(len, 256)) + len = bit.rshift(len, 8) + end + + return string_char(#parts + 0x80) .. string_reverse(table.concat(parts)) + end + end +} + +function _M.BERtoInt(class, constructed, number) + local asn1_type = class + number + + if constructed == true then + asn1_type = asn1_type + 32 + end + + return asn1_type +end + +function _M.intToBER(i) + local ber = {} + if bit.band(i, _M.BERCLASS.Application) == _M.BERCLASS.Application then + ber.class = _M.BERCLASS.Application + elseif bit.band(i, _M.BERCLASS.ContextSpecific) == _M.BERCLASS.ContextSpecific then + ber.class = _M.BERCLASS.ContextSpecific + elseif bit.band(i, _M.BERCLASS.Private) == _M.BERCLASS.Private then + ber.class = _M.BERCLASS.Private + else + ber.class = _M.BERCLASS.Universal + end + if bit.band(i, 32) == 32 then + ber.constructed = true + ber.number = i - ber.class - 32 + else + ber.primitive = true + ber.number = i - ber.class + end + return ber +end + +return _M diff --git a/kong/plugins/ldap-auth/handler.lua b/kong/plugins/ldap-auth/handler.lua new file mode 100644 index 00000000000..62129d982f5 --- /dev/null +++ b/kong/plugins/ldap-auth/handler.lua @@ -0,0 +1,17 @@ +local access = require "kong.plugins.ldap-auth.access" +local BasePlugin = require "kong.plugins.base_plugin" + +local LdapAuthHandler = BasePlugin:extend() + +function LdapAuthHandler:new() + LdapAuthHandler.super.new(self, "ldap-auth") +end + +function LdapAuthHandler:access(conf) + LdapAuthHandler.super.access(self) + access.execute(conf) +end + +LdapAuthHandler.PRIORITY = 1000 + +return LdapAuthHandler diff --git a/kong/plugins/ldap-auth/ldap.lua b/kong/plugins/ldap-auth/ldap.lua new file mode 100644 index 00000000000..2fe8cf46362 --- /dev/null +++ b/kong/plugins/ldap-auth/ldap.lua @@ -0,0 +1,144 @@ +local asn1 = require "kong.plugins.ldap-auth.asn1" +local bunpack = string.unpack + +local string_format = string.format + +local _M = {} + +local ldapMessageId = 1 + +local ERROR_MSG = { + [1] = "Initialization of LDAP library failed.", + [4] = "Size limit exceeded.", + [13] = "Confidentiality required", + [32] = "No such object", + [34] = "Invalid DN", + [49] = "The supplied credential is invalid." +} + +local APPNO = { + BindRequest = 0, + BindResponse = 1, + UnbindRequest = 2, + ExtendedRequest = 23, + ExtendedResponse = 24 +} + +local function encodeLDAPOp(encoder, appno, isConstructed, data) + local asn1_type = asn1.BERtoInt(asn1.BERCLASS.Application, isConstructed, appno) + return encoder:encode({ _ldaptype = string_format("%X", asn1_type), data }) +end + +local function claculate_payload_length(encStr, pos, socket) + local elen + pos, elen = bunpack(encStr, 'C', pos) + if (elen > 128) then + elen = elen - 128 + local elenCalc = 0 + local elenNext + for i = 1, elen do + elenCalc = elenCalc * 256 + encStr = encStr..socket:receive(1) + pos, elenNext = bunpack(encStr, 'C', pos) + elenCalc = elenCalc + elenNext + end + elen = elenCalc + end + return pos, elen +end + +function _M.bind_request(socket, username, password) + local encoder = asn1.ASN1Encoder:new() + local decoder = asn1.ASN1Decoder:new() + + local ldapAuth = encoder:encode({ _ldaptype = 80, password }) + local bindReq = encoder:encode(3) .. encoder:encode(username) .. ldapAuth + local ldapMsg = encoder:encode(ldapMessageId) .. encodeLDAPOp(encoder, APPNO.BindRequest, true, bindReq) + local packet + local pos, packet_len, tmp, _ + local response = {} + + packet = encoder:encodeSeq(ldapMsg) + ldapMessageId = ldapMessageId +1 + socket:send(packet) + packet = socket:receive(2) + _, packet_len = claculate_payload_length(packet, 2, socket) + + packet = socket:receive(packet_len) + pos, response.messageID = decoder:decode(packet, 1) + pos, tmp = bunpack(packet, "C", pos) + pos = decoder.decodeLength(packet, pos) + response.protocolOp = asn1.intToBER(tmp) + + if response.protocolOp.number ~= APPNO.BindResponse then + return false, string_format("Received incorrect Op in packet: %d, expected %d", response.protocolOp.number, APPNO.BindResponse) + end + + pos, response.resultCode = decoder:decode(packet, pos) + + if (response.resultCode ~= 0) then + local error_msg + pos, response.matchedDN = decoder:decode(packet, pos) + _, response.errorMessage = decoder:decode(packet, pos) + error_msg = ERROR_MSG[response.resultCode] + return false, string_format("\n Error: %s\n Details: %s", + error_msg or "Unknown error occurred (code: " .. response.resultCode .. + ")", response.errorMessage or "") + else + return true + end +end + + +function _M.unbind_request(socket) + local ldapMsg, packet + local encoder = asn1.ASN1Encoder:new() + + ldapMessageId = ldapMessageId +1 + ldapMsg = encoder:encode(ldapMessageId) .. encodeLDAPOp(encoder, APPNO.UnbindRequest, false, nil) + packet = encoder:encodeSeq(ldapMsg) + socket:send(packet) + return true, "" +end + +function _M.start_tls(socket) + + local ldapMsg, pos, packet, packet_len, tmp, _ + local response = {} + local encoder = asn1.ASN1Encoder:new() + local decoder = asn1.ASN1Decoder:new() + + local method_name = encoder:encode({_ldaptype = 80, "1.3.6.1.4.1.1466.20037"}) + ldapMessageId = ldapMessageId +1 + ldapMsg = encoder:encode(ldapMessageId) .. encodeLDAPOp(encoder, APPNO.ExtendedRequest, true, method_name) + packet = encoder:encodeSeq(ldapMsg) + socket:send(packet) + packet = socket:receive(2) + _, packet_len = claculate_payload_length(packet, 2, socket) + + packet = socket:receive(packet_len) + pos, response.messageID = decoder:decode(packet, 1) + pos, tmp = bunpack(packet, "C", pos) + pos = decoder.decodeLength(packet, pos) + response.protocolOp = asn1.intToBER(tmp) + + if response.protocolOp.number ~= APPNO.ExtendedResponse then + return false, string_format("Received incorrect Op in packet: %d, expected %d", response.protocolOp.number, APPNO.ExtendedResponse) + end + + pos, response.resultCode = decoder:decode(packet, pos) + + if (response.resultCode ~= 0) then + local error_msg + pos, response.matchedDN = decoder:decode(packet, pos) + _, response.errorMessage = decoder:decode(packet, pos) + error_msg = ERROR_MSG[response.resultCode] + return false, string_format("\n Error: %s\n Details: %s", + error_msg or "Unknown error occurred (code: " .. response.resultCode .. + ")", response.errorMessage or "") + else + return true + end +end + +return _M; diff --git a/kong/plugins/ldap-auth/schema.lua b/kong/plugins/ldap-auth/schema.lua new file mode 100644 index 00000000000..940477d95f6 --- /dev/null +++ b/kong/plugins/ldap-auth/schema.lua @@ -0,0 +1,14 @@ +return { +fields = { + ldap_host = {required = true, type = "string"}, + ldap_port = {required = true, type = "number"}, + start_tls = {required = true, type = "boolean", default = false}, + verify_ldap_host = {required = true, type = "boolean", default = false}, + base_dn = {required = true, type = "string"}, + attribute = {required = true, type = "string"}, + cache_ttl = {required = true, type = "number", default = 60}, + hide_credentials = {type = "boolean", default = false}, + timeout = {type = "number", default = 10000}, + keepalive = {type = "number", default = 60000}, + } +} diff --git a/kong/tools/database_cache.lua b/kong/tools/database_cache.lua index 01e94e55821..7e77d238e3d 100644 --- a/kong/tools/database_cache.lua +++ b/kong/tools/database_cache.lua @@ -16,7 +16,8 @@ local CACHE_KEYS = { REQUESTS = "requests", AUTOJOIN_RETRIES = "autojoin_retries", TIMERS = "timers", - ALL_APIS_BY_DIC = "ALL_APIS_BY_DIC" + ALL_APIS_BY_DIC = "ALL_APIS_BY_DIC", + LDAP_CREDENTIAL = "ldap_credentials" } local _M = {} @@ -102,6 +103,10 @@ function _M.jwtauth_credential_key(secret) return CACHE_KEYS.JWTAUTH_CREDENTIAL..":"..secret end +function _M.ldap_credential_key(username) + return CACHE_KEYS.LDAP_CREDENTIAL.."/"..username +end + function _M.acls_key(consumer_id) return CACHE_KEYS.ACLS..":"..consumer_id end diff --git a/spec/plugins/ldap-auth/access_spec.lua b/spec/plugins/ldap-auth/access_spec.lua new file mode 100644 index 00000000000..c5bf60b209d --- /dev/null +++ b/spec/plugins/ldap-auth/access_spec.lua @@ -0,0 +1,127 @@ +local spec_helper = require "spec.spec_helpers" +local http_client = require "kong.tools.http_client" +local cjson = require "cjson" +local base64 = require "base64" +local cache = require "kong.tools.database_cache" + +local PROXY_URL = spec_helper.PROXY_URL +local API_URL = spec_helper.API_URL + +describe("LDAP-AUTH Plugin", function() + setup(function() + spec_helper.prepare_db() + spec_helper.insert_fixtures { + api = { + {name = "test-ldap", request_host = "ldap.com", upstream_url = "http://mockbin.com"}, + {name = "test-ldap2", request_host = "ldap2.com", upstream_url = "http://mockbin.com"} + }, + plugin = { + {name = "ldap-auth", config = {ldap_host = "ldap.forumsys.com", ldap_port = "389", start_tls = false, base_dn = "dc=example,dc=com", attribute = "uid"}, __api = 1}, + {name = "ldap-auth", config = {ldap_host = "ldap.forumsys.com", ldap_port = "389", start_tls = false, base_dn = "dc=example,dc=com", attribute = "uid", hide_credentials = true}, __api = 2}, + } + } + + spec_helper.start_kong() + end) + + teardown(function() + spec_helper.stop_kong() + end) + + describe("ldap-auth", function() + it("should return invalid credentials and www-authenticate header when the credential is missing", function() + local response, status, headers = http_client.get(PROXY_URL.."/get", {}, {host = "ldap.com"}) + assert.equal(401, status) + local body = cjson.decode(response) + assert.equal(headers["www-authenticate"], 'LDAP realm="kong"') + assert.equal("Unauthorized", body.message) + end) + + it("should return invalid credentials when credential value is in wrong format in authorization header", function() + local response, status = http_client.get(PROXY_URL.."/get", {}, {host = "ldap.com", authorization = "abcd"}) + local body = cjson.decode(response) + assert.equal(403, status) + assert.equal("Invalid authentication credentials", body.message) + end) + + it("should return invalid credentials when credential value is in wrong format in proxy-authorization header", function() + local response, status = http_client.get(PROXY_URL.."/get", {}, {host = "ldap.com", ["proxy-authorization"] = "abcd"}) + local body = cjson.decode(response) + assert.equal(403, status) + assert.equal("Invalid authentication credentials", body.message) + end) + + it("should return invalid credentials when credential value is missing in authorization header", function() + local _, status = http_client.get(PROXY_URL.."/get", {}, {host = "ldap.com", authorization = "ldap "}) + assert.equal(403, status) + end) + + it("should pass if credential is valid in post request", function() + local _, status = http_client.post(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:password")}) + assert.equal(200, status) + end) + + it("should pass if credential is valid and starts with space in post request", function() + local _, status = http_client.post(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = " ldap "..base64.encode("einstein:password")}) + assert.equal(200, status) + end) + + it("should pass if signature type indicator is in caps and credential is valid in post request", function() + local _, status = http_client.post(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "LDAP "..base64.encode("einstein:password")}) + assert.equal(200, status) + end) + + it("should pass if credential is valid in get request", function() + local response, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:password")}) + assert.equal(200, status) + local parsed_response = cjson.decode(response) + assert.truthy(parsed_response.headers["x-credential-username"]) + assert.equal("einstein", parsed_response.headers["x-credential-username"]) + end) + + it("should not pass if credential does not has password encoded in get request", function() + local _, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:")}) + assert.equal(403, status) + end) + + it("should not pass if credential has multiple encoded username or password separated by ':' in get request", function() + local _, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:password:another_password")}) + assert.equal(403, status) + end) + + it("should not pass if credential is invalid in get request", function() + local _, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:wrong_password")}) + assert.equal(403, status) + end) + + it("should not hide credential sent along with authorization header to upstream server", function() + local response, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:password")}) + assert.equal(200, status) + local parsed_response = cjson.decode(response) + assert.equal("ldap "..base64.encode("einstein:password"), parsed_response.headers["authorization"]) + end) + + it("should hide credential sent along with authorization header to upstream server", function() + local response, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap2.com", authorization = "ldap "..base64.encode("einstein:password")}) + assert.equal(200, status) + local parsed_response = cjson.decode(response) + assert.falsy(parsed_response.headers["authorization"]) + end) + + it("should cache LDAP Auth Credential", function() + local _, status = http_client.get(PROXY_URL.."/request", {}, {host = "ldap.com", authorization = "ldap "..base64.encode("einstein:password")}) + assert.equals(200, status) + + -- Check that cache is populated + local cache_key = cache.ldap_credential_key("einstein") + local exists = true + while(exists) do + local _, status = http_client.get(API_URL.."/cache/"..cache_key) + if status ~= 200 then + exists = false + end + end + assert.equals(200, status) + end) + end) +end)