diff --git a/kong/plugins/cors/handler.lua b/kong/plugins/cors/handler.lua index 258078d8bd3..cc97e3bb29e 100644 --- a/kong/plugins/cors/handler.lua +++ b/kong/plugins/cors/handler.lua @@ -5,6 +5,7 @@ local url = require "socket.url" local kong = kong local re_find = ngx.re.find +local find = string.find local concat = table.concat local tostring = tostring local ipairs = ipairs @@ -32,21 +33,24 @@ local normalized_req_domains = lrucache.new(10e3) local function normalize_origin(domain) local parsed_obj = assert(url.parse(domain)) if not parsed_obj.host then - return domain + return { + domain = domain, + host = domain, + } end local port = parsed_obj.port - if not port and parsed_obj.scheme then - if parsed_obj.scheme == "http" then - port = 80 - - elseif parsed_obj.scheme == "https" then - port = 443 - end + if (parsed_obj.scheme == "http" and port == "80") + or (parsed_obj.scheme == "https" and port == "443") then + port = nil end - return (parsed_obj.scheme and parsed_obj.scheme .. "://" or "") .. - parsed_obj.host .. (port and ":" .. port or "") + return { + domain = (parsed_obj.scheme and parsed_obj.scheme .. "://" or "") .. + parsed_obj.host .. + (port and ":" .. port or ""), + host = parsed_obj.host, + } end @@ -83,15 +87,40 @@ local function configure_origin(conf) local req_origin = kong.request.get_header("origin") if req_origin then - local normalized_domains = config_cache[conf] - if not normalized_domains then - normalized_domains = {} - - for _, domain in ipairs(conf.origins) do - table.insert(normalized_domains, normalize_origin(domain)) + local cached_domains = config_cache[conf] + if not cached_domains then + cached_domains = {} + + for _, entry in ipairs(conf.origins) do + local domain + local maybe_regex, _, err = re_find(entry, "[^A-Za-z0-9.:/-]", "jo") + if err then + kong.log.err("could not inspect origin for type: ", err) + end + + if maybe_regex then + -- Kong 0.x did not anchor regexes: + -- Perform adjustments to support regexes + -- explicitly anchored by the user. + if entry:sub(-1) ~= "$" then + entry = entry .. "$" + end + + if entry:sub(1, 1) == "^" then + entry = entry:sub(2) + end + + domain = { regex = entry } + + else + domain = normalize_origin(entry) + end + + domain.by_host = not find(entry, ":", 1, true) + table.insert(cached_domains, domain) end - config_cache[conf] = normalized_domains + config_cache[conf] = cached_domains end local normalized_req_origin = normalized_req_domains:get(req_origin) @@ -100,20 +129,31 @@ local function configure_origin(conf) normalized_req_domains:set(req_origin, normalized_req_origin) end - for _, normalized_domain in ipairs(normalized_domains) do - local from, _, err = re_find(normalized_req_origin, - normalized_domain .. "$", "ajo") - if err then - kong.log.err("could not search for domain: ", err) + for _, cached_domain in ipairs(cached_domains) do + local found, _, err + + if cached_domain.regex then + local subject = cached_domain.by_host + and normalized_req_origin.host + or normalized_req_origin.domain + + found, _, err = re_find(subject, cached_domain.regex, "ajo") + if err then + kong.log.err("could not search for domain: ", err) + end + + else + found = (normalized_req_origin.domain == cached_domain.domain) end - if from then - set_header("Access-Control-Allow-Origin", req_origin) + if found then + set_header("Access-Control-Allow-Origin", normalized_req_origin.domain) set_header("Vary", "Origin") return false end end end + return false end diff --git a/spec/03-plugins/13-cors/01-access_spec.lua b/spec/03-plugins/13-cors/01-access_spec.lua index 16ce13fe164..cbe4d7c029c 100644 --- a/spec/03-plugins/13-cors/01-access_spec.lua +++ b/spec/03-plugins/13-cors/01-access_spec.lua @@ -1,10 +1,234 @@ local helpers = require "spec.helpers" local cjson = require "cjson" +local inspect = require "inspect" +local tablex = require "pl.tablex" + + +local function sortedpairs(t) + local ks = tablex.keys(t) + table.sort(ks) + local i = 0 + return function() + i = i + 1 + return ks[i], t[ks[i]] + end +end + for _, strategy in helpers.each_strategy() do describe("Plugin: cors (access) [#" .. strategy .. "]", function() local proxy_client + local regex_testcases = { + { + -- single entry, host only: ignore value, always return configured data + origins = { "foo.test" }, + tests = { + ["http://evil.test"] = "foo.test", + ["http://foo.test"] = "foo.test", + ["http://foo.test.evil.test"] = "foo.test", + ["http://something.foo.test"] = "foo.test", + ["http://evilfoo.test"] = "foo.test", + ["http://foo.test:80"] = "foo.test", + ["http://foo.test:8000"] = "foo.test", + ["https://foo.test:8000"] = "foo.test", + ["http://foo.test:90"] = "foo.test", + ["http://foobtest"] = "foo.test", + ["https://bar.test:1234"] = "foo.test", + }, + }, + { + -- single entry, full domain (not regex): ignore value, always return configured data + origins = { "https://bar.test:1234" }, + tests = { + ["http://evil.test"] = "https://bar.test:1234", + ["http://foo.test"] = "https://bar.test:1234", + ["http://foo.test.evil.test"] = "https://bar.test:1234", + ["http://something.foo.test"] = "https://bar.test:1234", + ["http://evilfoo.test"] = "https://bar.test:1234", + ["http://foo.test:80"] = "https://bar.test:1234", + ["http://foo.test:8000"] = "https://bar.test:1234", + ["https://foo.test:8000"] = "https://bar.test:1234", + ["http://foo.test:90"] = "https://bar.test:1234", + ["http://foobtest"] = "https://bar.test:1234", + ["https://bar.test:1234"] = "https://bar.test:1234", + }, + }, + { + -- single entry, simple regex without ":": anchored match on host only + origins = { "foo\\.test" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://evilfoo.test"] = false, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = true, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = true, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- single entry, subdomain regex without ":": anchored match on host only + origins = { "(.*[./])?foo\\.test" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = true, + ["http://evilfoo.test"] = false, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = true, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = true, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- single entry, any-scheme subdomain regex with port: anchored match with scheme and port + origins = { "(.*[./])?foo\\.test:8000" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = false, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://evilfoo.test"] = false, + ["http://foo.test:80"] = false, + ["http://foo.test:8000"] = true, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- single entry, https subdomain regex with port: anchored match with scheme and port + origins = { "https://(.*[.])?foo\\.test:8000" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = false, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://foo.test:80"] = false, + ["http://foo.test:8000"] = false, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- single entry, explicitly anchored https subdomain regex with port: anchored match with scheme and port + origins = { "^http://(.*[.])?foo\\.test(:(80|90))?$" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = true, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = false, + ["https://foo.test:8000"] = false, + ["http://foo.test:90"] = true, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- multiple entries, host only (not regex): match on full normalized domain (i.e. all fail) + origins = { "foo.test", "bar.test" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = false, + ["http://foo.test.evil.test"] = false, + ["http://foo.test:80"] = false, + ["http://foo.test:8000"] = false, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- multiple entries, full domain (not regex): match on full normalized domain + origins = { "http://foo.test", "https://bar.test:1234" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = false, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = true, + }, + }, + { + -- multiple entries, simple regex without ":": anchored match on host only + origins = { "bar.test", "foo\\.test" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = true, + ["http://foo.test:90"] = true, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- multiple entries, subdomain regex without ":": anchored match on host only + origins = { "bar.test", "(.*\\.)?foo\\.test" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = true, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = true, + ["http://foo.test:80"] = "http://foo.test", + ["http://foo.test:8000"] = true, + ["http://foo.test:90"] = true, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- multiple entries, any-scheme subdomain regex with ":": anchored match with scheme and port + origins = { "bar.test", "(.*[./])?foo\\.test:8000" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = false, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://foo.test:80"] = false, + ["http://foo.test:8000"] = true, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + { + -- multiple entries, https subdomain regex with ":": anchored match with scheme and port + origins = { "bar.test", "https://(.*\\.)?foo\\.test:8000" }, + tests = { + ["http://evil.test"] = false, + ["http://foo.test"] = false, + ["http://foo.test.evil.test"] = false, + ["http://something.foo.test"] = false, + ["http://foo.test:80"] = false, + ["http://foo.test:8000"] = false, + ["https://foo.test:8000"] = true, + ["http://foo.test:90"] = false, + ["http://foobtest"] = false, + ["https://bar.test:1234"] = false, + }, + }, + } + lazy_setup(function() local bp = helpers.get_db_utils(strategy, nil, { "error-generator-last" }) @@ -204,20 +428,73 @@ for _, strategy in helpers.each_strategy() do }, } + for i, testcase in ipairs(regex_testcases) do + local route = bp.routes:insert({ + hosts = { "cors-regex-" .. i .. ".test" }, + }) + + bp.plugins:insert { + name = "cors", + route = { id = route.id }, + config = { + origins = testcase.origins, + } + } + end + assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", })) - proxy_client = helpers.proxy_client() end) lazy_teardown(function() - if proxy_client then proxy_client:close() end helpers.stop_kong() end) + before_each(function() + proxy_client = helpers.proxy_client() + end) + + after_each(function() + if proxy_client then proxy_client:close() end + end) + describe("HTTP method: OPTIONS", function() + + for i, testcase in ipairs(regex_testcases) do + local host = "cors-regex-" .. i .. ".test" + for origin, accept in sortedpairs(testcase.tests) do + it("given " .. origin .. ", " .. + inspect(testcase.origins) .. " will " .. + (accept and "accept" or "reject"), function() + + local res = assert(proxy_client:send { + method = "OPTIONS", + headers = { + ["Host"] = host, + ["Origin"] = origin, + } + }) + + assert.res_status(200, res) + + if accept then + assert.equal("GET,HEAD,PUT,PATCH,POST,DELETE", res.headers["Access-Control-Allow-Methods"]) + assert.equal(accept == true and origin or accept, res.headers["Access-Control-Allow-Origin"]) + assert.is_nil(res.headers["Access-Control-Allow-Headers"]) + assert.is_nil(res.headers["Access-Control-Expose-Headers"]) + assert.is_nil(res.headers["Access-Control-Allow-Credentials"]) + assert.is_nil(res.headers["Access-Control-Max-Age"]) + + else + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + end + end) + end + end + it("gives appropriate defaults", function() local res = assert(proxy_client:send { method = "OPTIONS", @@ -543,7 +820,7 @@ for _, strategy in helpers.each_strategy() do } }) assert.res_status(200, res) - assert.equals("http://my-site.com:80", res.headers["Access-Control-Allow-Origin"]) + assert.equals("http://my-site.com", res.headers["Access-Control-Allow-Origin"]) local res = assert(proxy_client:send { method = "GET",