diff --git a/kong/plugins/cors/handler.lua b/kong/plugins/cors/handler.lua index 60fb84e2ac9..f4b74a2a736 100644 --- a/kong/plugins/cors/handler.lua +++ b/kong/plugins/cors/handler.lua @@ -1,7 +1,8 @@ local BasePlugin = require "kong.plugins.base_plugin" local responses = require "kong.tools.responses" +local lrucache = require "resty.lrucache" - +local url = require "socket.url" local req_get_method = ngx.req.get_method local re_find = ngx.re.find local concat = table.concat @@ -16,6 +17,31 @@ CorsHandler.PRIORITY = 2000 CorsHandler.VERSION = "0.1.0" +-- per-worker cache of parsed origins +local CACHE_SIZE = 10 ^ 4 +local parsed_domains + + +local function parse_origin_domain(domain) + local parsed_obj = url.parse(domain) + if parsed_obj and parsed_obj.host then + local port = parsed_obj.port + if not port and parsed_obj.scheme then + if parsed_obj.scheme == "http" then + port = 80 + elseif parsed_obj.scheme == "https" then + port = 443 + end + end + return (parsed_obj.scheme and parsed_obj.scheme .. "://" or "") .. + parsed_obj.host .. + (port and ":" .. port or "") + else + return domain + end +end + + local function configure_origin(ngx, conf) local n_origins = conf.origins ~= nil and #conf.origins or 0 @@ -50,8 +76,21 @@ local function configure_origin(ngx, conf) local req_origin = ngx.var.http_origin if req_origin then + + local parsed_req_origin = parse_origin_domain(req_origin) + for _, domain in ipairs(conf.origins) do - local from, _, err = re_find(req_origin, domain, "jo") + if not parsed_domains then + parsed_domains = lrucache.new(CACHE_SIZE) + end + + local parsed_domain = parsed_domains:get(domain) + if not parsed_domain then + parsed_domain = parse_origin_domain(domain) + parsed_domains:set(domain, parsed_domain) + end + + local from, _, err = re_find(parsed_req_origin, "^" .. parsed_domain .. "$", "jo") if err then ngx.log(ngx.ERR, "[cors] could not search for domain: ", err) end diff --git a/spec/03-plugins/14-cors/01-access_spec.lua b/spec/03-plugins/14-cors/01-access_spec.lua index 43c4712bbe3..4bcd663b43a 100644 --- a/spec/03-plugins/14-cors/01-access_spec.lua +++ b/spec/03-plugins/14-cors/01-access_spec.lua @@ -44,6 +44,14 @@ for _, strategy in helpers.each_strategy() do hosts = { "cors9.com" }, }) + local route10 = bp.routes:insert({ + hosts = { "cors10.com" }, + }) + + local route11 = bp.routes:insert({ + hosts = { "cors11.com" }, + }) + bp.plugins:insert { name = "cors", route_id = route1.id, @@ -132,6 +140,22 @@ for _, strategy in helpers.each_strategy() do } } + bp.plugins:insert { + name = "cors", + route_id = route10.id, + config = { + origins = { "http://my-site.com", "http://my-other-site.com" }, + } + } + + bp.plugins:insert { + name = "cors", + route_id = route11.id, + config = { + origins = { "http://my-site.com", "https://my-other-site.com:9000" }, + } + } + assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", @@ -245,6 +269,44 @@ for _, strategy in helpers.each_strategy() do assert.res_status(204, res) assert.equal("origin,accepts", res.headers["Access-Control-Allow-Headers"]) end) + + it("properly validates flat strings", function() + -- Legitimate origins + local res = assert(proxy_client:send { + method = "OPTIONS", + headers = { + ["Host"] = "cors10.com", + ["Origin"] = "http://my-site.com" + } + }) + + assert.res_status(204, res) + assert.equal("http://my-site.com", res.headers["Access-Control-Allow-Origin"]) + + -- Illegitimate origins + res = assert(proxy_client:send { + method = "OPTIONS", + headers = { + ["Host"] = "cors10.com", + ["Origin"] = "http://bad-guys.com" + } + }) + + assert.res_status(204, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + + -- Tricky illegitimate origins + res = assert(proxy_client:send { + method = "OPTIONS", + headers = { + ["Host"] = "cors10.com", + ["Origin"] = "http://my-site.com.bad-guys.com" + } + }) + + assert.res_status(204, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + end) end) describe("HTTP method: others", function() @@ -322,11 +384,11 @@ for _, strategy in helpers.each_strategy() do method = "GET", headers = { ["Host"] = "cors6.com", - ["Origin"] = "http://www.example.com" + ["Origin"] = "example.com" } }) assert.res_status(200, res) - assert.equal("http://www.example.com", res.headers["Access-Control-Allow-Origin"]) + assert.equal("example.com", res.headers["Access-Control-Allow-Origin"]) assert.equal("Origin", res.headers["Vary"]) local domains = { @@ -352,6 +414,91 @@ for _, strategy in helpers.each_strategy() do end end) + it("does not automatically parse the host", function() + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors6.com", + ["Origin"] = "http://example.com" + } + }) + assert.res_status(200, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + + -- With a different transport too + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors6.com", + ["Origin"] = "https://example.com" + } + }) + assert.res_status(200, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + end) + + it("validates scheme and port", function() + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "http://my-site.com" + } + }) + assert.res_status(200, res) + assert.equals("http://my-site.com", res.headers["Access-Control-Allow-Origin"]) + + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "http://my-site.com:80" + } + }) + assert.res_status(200, res) + assert.equals("http://my-site.com:80", res.headers["Access-Control-Allow-Origin"]) + + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "http://my-site.com:8000" + } + }) + assert.res_status(200, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + + res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "https://my-site.com" + } + }) + assert.res_status(200, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "https://my-other-site.com:9000" + } + }) + assert.res_status(200, res) + assert.equals("https://my-other-site.com:9000", res.headers["Access-Control-Allow-Origin"]) + + local res = assert(proxy_client:send { + method = "GET", + headers = { + ["Host"] = "cors11.com", + ["Origin"] = "https://my-other-site.com:9001" + } + }) + assert.res_status(200, res) + assert.is_nil(res.headers["Access-Control-Allow-Origin"]) + end) + it("does not sets CORS orgin if origin host is not in origin_domains list", function() local res = assert(proxy_client:send { method = "GET",