diff --git a/kong/plugins/cors/handler.lua b/kong/plugins/cors/handler.lua index f4b74a2a736..04c32864e16 100644 --- a/kong/plugins/cors/handler.lua +++ b/kong/plugins/cors/handler.lua @@ -1,8 +1,9 @@ local BasePlugin = require "kong.plugins.base_plugin" local responses = require "kong.tools.responses" local lrucache = require "resty.lrucache" +local url = require "socket.url" + -local url = require "socket.url" local req_get_method = ngx.req.get_method local re_find = ngx.re.find local concat = table.concat @@ -17,28 +18,33 @@ CorsHandler.PRIORITY = 2000 CorsHandler.VERSION = "0.1.0" --- per-worker cache of parsed origins -local CACHE_SIZE = 10 ^ 4 -local parsed_domains +-- per-plugin cache of normalized origins for runtime comparison +local mt_cache = { __mode = "k" } +local config_cache = setmetatable({}, mt_cache) -local function parse_origin_domain(domain) - local parsed_obj = url.parse(domain) - if parsed_obj and parsed_obj.host then - local port = parsed_obj.port - if not port and parsed_obj.scheme then - if parsed_obj.scheme == "http" then - port = 80 - elseif parsed_obj.scheme == "https" then - port = 443 - end - end - return (parsed_obj.scheme and parsed_obj.scheme .. "://" or "") .. - parsed_obj.host .. - (port and ":" .. port or "") - else +-- per-worker cache of parsed requests origins with 1000 slots +local normalized_req_domains = lrucache.new(10e3) + + +local function normalize_origin(domain) + local parsed_obj = assert(url.parse(domain)) + if not parsed_obj.host then return domain end + + local port = parsed_obj.port + if not port and parsed_obj.scheme then + if parsed_obj.scheme == "http" then + port = 80 + + elseif parsed_obj.scheme == "https" then + port = 443 + end + end + + return (parsed_obj.scheme and parsed_obj.scheme .. "://" or "") .. + parsed_obj.host .. (port and ":" .. port or "") end @@ -76,21 +82,26 @@ local function configure_origin(ngx, conf) local req_origin = ngx.var.http_origin if req_origin then + local normalized_domains = config_cache[conf] + if not normalized_domains then + normalized_domains = {} - local parsed_req_origin = parse_origin_domain(req_origin) - - for _, domain in ipairs(conf.origins) do - if not parsed_domains then - parsed_domains = lrucache.new(CACHE_SIZE) + for _, domain in ipairs(conf.origins) do + table.insert(normalized_domains, normalize_origin(domain)) end - local parsed_domain = parsed_domains:get(domain) - if not parsed_domain then - parsed_domain = parse_origin_domain(domain) - parsed_domains:set(domain, parsed_domain) - end + config_cache[conf] = normalized_domains + end + + local normalized_req_origin = normalized_req_domains:get(req_origin) + if not normalized_req_origin then + normalized_req_origin = normalize_origin(req_origin) + normalized_req_domains:set(req_origin, normalized_req_origin) + end - local from, _, err = re_find(parsed_req_origin, "^" .. parsed_domain .. "$", "jo") + for _, normalized_domain in ipairs(normalized_domains) do + local from, _, err = re_find(normalized_req_origin, + normalized_domain .. "$", "ajo") if err then ngx.log(ngx.ERR, "[cors] could not search for domain: ", err) end