Skip to content

Commit

Permalink
feat(proxy) implement wildcard subdomains for public_dns
Browse files Browse the repository at this point in the history
An API's `public_dns` can now contain one '*' at the start or at the end
of the string, and on a dot border (same rules as
http://nginx.org/en/docs/http/server_names.html#wildcard_names).

If so, it will get special treatment during the lookup and the resolver
will try to match any given `Host` header against all wildcard
public_dns.

We keep things efficient by guaranteeing O(1) lookup for non wildcard
public_dns. Only wildcard public_dns and paths will be O(n).

Implements #297.
  • Loading branch information
thibaultcha committed Jul 5, 2015
1 parent 5bfa7ca commit ef34bc1
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 71 deletions.
22 changes: 20 additions & 2 deletions kong/dao/schemas/apis.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,30 @@ local function check_public_dns_and_path(value, api_t)
return false, "At least a 'public_dns' or a 'path' must be specified"
end

return true
-- Validate wildcard public_dns
if public_dns then
local _, count = public_dns:gsub("%*", "")
if count > 1 then
return false, "Only one wildcard is allowed: "..public_dns
elseif count > 0 then
local pos = public_dns:find("%*")
local valid
if pos == 1 then
valid = public_dns:match("^%*%.") ~= nil
elseif pos == string.len(public_dns) then
valid = public_dns:match(".%.%*$") ~= nil
end

if not valid then
return false, "Invalid wildcard placement: "..public_dns
end
end
end
end

local function check_path(path, api_t)
local valid, err = check_public_dns_and_path(path, api_t)
if not valid then
if valid == false then
return false, err
end

Expand Down
5 changes: 1 addition & 4 deletions kong/dao/schemas_validation.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
local utils = require "kong.tools.utils"
local stringy = require "stringy"
local DaoError = require "kong.dao.error"
local constants = require "kong.constants"
local error_types = constants.DATABASE_ERROR_TYPES

local POSSIBLE_TYPES = {
id = true,
Expand Down Expand Up @@ -172,7 +169,7 @@ function _M.validate_entity(t, schema, options)
-- [FUNC] Check field against a custom function
-- only if there is no error on that field already.
local ok, err, new_fields = v.func(t[column], t, column)
if not ok and err then
if ok == false and err then
errors = utils.add_error(errors, column, err)
elseif new_fields then
for k, v in pairs(new_fields) do
Expand Down
86 changes: 61 additions & 25 deletions kong/resolver/access.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,49 @@ local responses = require "kong.tools.responses"

local _M = {}

-- Take a public_dns and make it a pattern for wildcard matching.
-- Only do so if the public_dns actually has a wildcard.
local function create_wildcard_pattern(public_dns)
if string.find(public_dns, "*", 1, true) then
local pattern = string.gsub(public_dns, "%.", "%%.")
pattern = string.gsub(pattern, "*", ".+")
pattern = string.format("^%s$", pattern)
return pattern
end
end

-- Load all APIs in memory.
-- Sort the data for faster lookup: dictionary per public_dns, host,
-- and an array of wildcard public_dns.
local function load_apis_in_memory()
local apis, err = dao.apis:find_all()
if err then
return nil, err
end

-- build dictionnaries of public_dns:api and path:apis for efficient O(1) lookup.
-- we only do O(n) lookup for wildcard public_dns that are in an array.
local dns_dic, dns_wildcard, path_dic = {}, {}, {}
for _, api in ipairs(apis) do
if api.public_dns then
local pattern = create_wildcard_pattern(api.public_dns)
if pattern then
-- If the public_dns is a wildcard, we have a pattern and we can
-- store it in an array for later lookup.
table.insert(dns_wildcard, {pattern = pattern, api = api})
else
-- Keep non-wildcard public_dns in a dictionary for faster lookup.
dns_dic[api.public_dns] = api
end
end
if api.path then
path_dic[api.path] = api
end
end

return {by_dns = dns_dic, wildcard_dns = dns_wildcard, by_path = path_dic}
end

local function get_backend_url(api)
local result = api.target_url

Expand Down Expand Up @@ -37,7 +80,8 @@ end
-- matching the API's `public_dns`, either from the `request_uri` matching the API's `path`.
--
-- To perform this, we need to query _ALL_ APIs in memory. It is the only way to compare the `request_uri`
-- as a regex to the values set in DB. We keep APIs in the database cache for a longer time than usual.
-- as a regex to the values set in DB, as well as matching wildcard dns.
-- We keep APIs in the database cache for a longer time than usual.
-- @see https://github.com/Mashape/kong/issues/15 for an improvement on this.
--
-- @param `request_uri` The URI for this request.
Expand All @@ -49,31 +93,14 @@ end
local function find_api(request_uri)
local retrieved_api

-- retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", function()
local apis, err = dao.apis:find_all()
if err then
return nil, err
end

-- build dictionnaries of public_dns:api and path:apis for efficient lookup.
local dns_dic, path_dic = {}, {}
for _, api in ipairs(apis) do
if api.public_dns then
dns_dic[api.public_dns] = api
end
if api.path then
path_dic[api.path] = api
end
end
return {dns = dns_dic, path = path_dic}
end, 60) -- 60 seconds cache
-- Retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", load_apis_in_memory, 60) -- 60 seconds cache, longer than usual

if err then
return err
end

-- find by Host header
-- Find by Host header
local all_hosts = {}
for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do
local hosts = ngx.req.get_headers()[header_name]
Expand All @@ -85,9 +112,18 @@ local function find_api(request_uri)
for _, host in ipairs(hosts) do
host = unpack(stringy.split(host, ":"))
table.insert(all_hosts, host)
if apis_dics.dns[host] then
retrieved_api = apis_dics.dns[host]
break
if apis_dics.by_dns[host] then
retrieved_api = apis_dics.by_dns[host]
--break
else
-- If the API was not found in the dictionary, maybe it is a wildcard public_dns.
-- In that case, we need to loop over all of them.
for _, wildcard_dns in ipairs(apis_dics.wildcard_dns) do
if string.match(host, wildcard_dns.pattern) then
retrieved_api = wildcard_dns.api
break
end
end
end
end
end
Expand All @@ -99,7 +135,7 @@ local function find_api(request_uri)
end

-- Otherwise, we look for it by path. We have to loop over all APIs and compare the requested URI.
for path, api in pairs(apis_dics.path) do
for path, api in pairs(apis_dics.by_path) do
local m, err = ngx.re.match(request_uri, "^"..path)
if err then
ngx.log(ngx.ERR, "[resolver] error matching requested path: "..err)
Expand Down
87 changes: 50 additions & 37 deletions spec/integration/proxy/resolver_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ describe("Resolver", function()
spec_helper.prepare_db()
spec_helper.insert_fixtures {
api = {
{ name = "tests host resolver 1", public_dns = "mockbin.com", target_url = "http://mockbin.com" },
{ name = "tests host resolver 2", public_dns = "mockbin-auth.com", target_url = "http://mockbin.com" },
{ name = "tests path resolver", target_url = "http://mockbin.com", path = "/status/" },
{ name = "tests stripped path resolver", target_url = "http://mockbin.com", path = "/mockbin/", strip_path = true },
{ name = "tests deep path resolver", target_url = "http://mockbin.com", path = "/deep/path/", strip_path = true }
{name = "tests host resolver 1", public_dns = "mockbin.com", target_url = "http://mockbin.com"},
{name = "tests host resolver 2", public_dns = "mockbin-auth.com", target_url = "http://mockbin.com"},
{name = "tests path resolver", target_url = "http://mockbin.com", path = "/status/"},
{name = "tests stripped path resolver", target_url = "http://mockbin.com", path = "/mockbin/", strip_path = true},
{name = "tests deep path resolver", target_url = "http://mockbin.com", path = "/deep/path/", strip_path = true},
{name = "tests wildcard subdomain", target_url = "http://mockbin.com/status/200", public_dns = "*.wildcard.com"},
{name = "tests wildcard subdomain 2", target_url = "http://mockbin.com/status/201", public_dns = "wildcard.*"}
},
plugin_configuration = {
{ name = "keyauth", value = {key_names = {"apikey"} }, __api = 2 }
Expand All @@ -47,8 +49,8 @@ describe("Resolver", function()

it("should return Not Found when the API is not in Kong", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL, nil, { host = "foo.com" })
assert.are.equal(404, status)
assert.are.equal('{"public_dns":["foo.com"],"message":"API not found with these values","path":"\\/request"}\n', response)
assert.equal(404, status)
assert.equal('{"public_dns":["foo.com"],"message":"API not found with these values","path":"\\/request"}\n', response)
end)

end)
Expand All @@ -57,10 +59,10 @@ describe("Resolver", function()

it("should work when calling SSL port", function()
local response, status = http_client.get(STUB_GET_SSL_URL, nil, { host = "mockbin.com" })
assert.are.equal(200, status)
assert.equal(200, status)
assert.truthy(response)
local parsed_response = cjson.decode(response)
assert.are.same("GET", parsed_response.method)
assert.same("GET", parsed_response.method)
end)

it("should work when manually triggering the handshake on default route", function()
Expand All @@ -86,13 +88,13 @@ describe("Resolver", function()

local cert = parse_cert(conn:getpeercertificate())

assert.are.same(6, utils.table_size(cert))
assert.are.same("Kong", cert.organizationName)
assert.are.same("IT", cert.organizationalUnitName)
assert.are.same("US", cert.countryName)
assert.are.same("California", cert.stateOrProvinceName)
assert.are.same("San Francisco", cert.localityName)
assert.are.same("localhost", cert.commonName)
assert.same(6, utils.table_size(cert))
assert.same("Kong", cert.organizationName)
assert.same("IT", cert.organizationalUnitName)
assert.same("US", cert.countryName)
assert.same("California", cert.stateOrProvinceName)
assert.same("San Francisco", cert.localityName)
assert.same("localhost", cert.commonName)

conn:close()
end)
Expand All @@ -103,70 +105,81 @@ describe("Resolver", function()
describe("By Host", function()

it("should proxy when the API is in Kong", function()
local _, status = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com"})
assert.are.equal(200, status)
local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com"})
assert.equal(200, status)
end)

it("should proxy when the Host header is not trimmed", function()
local _, status = http_client.get(STUB_GET_URL, nil, { host = " mockbin.com "})
assert.are.equal(200, status)
local _, status = http_client.get(STUB_GET_URL, nil, {host = " mockbin.com "})
assert.equal(200, status)
end)

it("should proxy when the request has no Host header but the X-Host-Override header", function()
local _, status = http_client.get(STUB_GET_URL, nil, { ["X-Host-Override"] = "mockbin.com"})
assert.are.equal(200, status)
local _, status = http_client.get(STUB_GET_URL, nil, {["X-Host-Override"] = "mockbin.com"})
assert.equal(200, status)
end)

it("should proxy when the Host header contains a port", function()
local _, status = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com:80"})
assert.are.equal(200, status)
local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com:80"})
assert.equal(200, status)
end)

describe("with wildcard subdomain", function()

it("should proxy when the public_dns is a wildcard subdomain", function()
local _, status = http_client.get(STUB_GET_URL, nil, {host = "subdomain.wildcard.com"})
assert.equal(200, status)

_, status = http_client.get(STUB_GET_URL, nil, {host = "wildcard.org"})
assert.equal(201, status)
end)

end)
end)

describe("By Path", function()

it("should proxy when no Host is present but the request_uri matches the API's path", function()
local _, status = http_client.get(spec_helper.PROXY_URL.."/status/200")
assert.are.equal(200, status)
assert.equal(200, status)

local _, status = http_client.get(spec_helper.PROXY_URL.."/status/301")
assert.are.equal(301, status)
assert.equal(301, status)
end)

it("should not proxy when the path does not match the start of the request_uri", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/somepath/status/200")
local body = cjson.decode(response)
assert.are.equal("API not found with these values", body.message)
assert.are.equal("/somepath/status/200", body.path)
assert.are.equal(404, status)
assert.equal("API not found with these values", body.message)
assert.equal("/somepath/status/200", body.path)
assert.equal(404, status)
end)

it("should proxy and strip the path if `strip_path` is true", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request")
assert.are.equal(200, status)
assert.equal(200, status)
local body = cjson.decode(response)
assert.are.equal("http://mockbin.com/request", body.url)
assert.equal("http://mockbin.com/request", body.url)
end)

it("should proxy when the path has a deep level", function()
local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/path/status/200")
assert.are.equal(200, status)
assert.equal(200, status)
end)

end)

it("should return the correct Server and Via headers when the request was proxied", function()
local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com"})
assert.are.equal(200, status)
assert.are.equal("cloudflare-nginx", headers.server)
assert.are.equal(constants.NAME.."/"..constants.VERSION, headers.via)
assert.equal(200, status)
assert.equal("cloudflare-nginx", headers.server)
assert.equal(constants.NAME.."/"..constants.VERSION, headers.via)
end)

it("should return the correct Server and no Via header when the request was NOT proxied", function()
local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin-auth.com"})
assert.are.equal(401, status)
assert.are.equal(constants.NAME.."/"..constants.VERSION, headers.server)
assert.equal(401, status)
assert.equal(constants.NAME.."/"..constants.VERSION, headers.server)
assert.falsy(headers.via)
end)

Expand Down
2 changes: 1 addition & 1 deletion spec/unit/dao/cassandra/base_dao_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ describe("Cassandra", function()
assert.are.same("consumer_id "..plugin_t.consumer_id.." does not exist", err.message.consumer_id)
end)

it("should do insert checks for entities with `on_insert`", function()
it("should do insert checks for entities with `self_check`", function()
local api, err = dao_factory.apis:insert(faker:fake_entity("api"))
assert.falsy(err)
assert.truthy(api.id)
Expand Down
Loading

0 comments on commit ef34bc1

Please sign in to comment.