diff --git a/.luacheckrc b/.luacheckrc index ad56977b05c..b5652c1fca3 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -20,5 +20,5 @@ files["kong/vendor/resty_http.lua"] = { } files["spec/"] = { - globals = {"describe", "it", "before_each", "setup", "after_each", "teardown", "stub", "mock", "spy", "finally"} + globals = {"describe", "it", "before_each", "setup", "after_each", "teardown", "stub", "mock", "spy", "finally", "pending"} } diff --git a/kong/api/crud_helpers.lua b/kong/api/crud_helpers.lua index a00904de19d..60c859432cd 100644 --- a/kong/api/crud_helpers.lua +++ b/kong/api/crud_helpers.lua @@ -11,8 +11,6 @@ function _M.find_api_by_name_or_id(self, dao_factory, helpers) } self.params.name_or_id = nil - -- TODO: make the base_dao more flexible so we can query find_one with key/values - -- https://github.com/Mashape/kong/issues/103 local data, err = dao_factory.apis:find_by_keys(fetch_keys) if err then return helpers.yield_error(err) @@ -74,13 +72,17 @@ function _M.paginated_set(self, dao_collection) end function _M.put(params, dao_collection) - local new_entity, err - if params.id then - new_entity, err = dao_collection:update(params) - if not err and new_entity then + local res, new_entity, err + + res, err = dao_collection:find_by_primary_key(params) + if err then + return app_helpers.yield_error(err) + end + + if res then + new_entity, err = dao_collection:update(params, true) + if not err then return responses.send_HTTP_OK(new_entity) - elseif not new_entity then - return responses.send_HTTP_NOT_FOUND() end else new_entity, err = dao_collection:insert(params) @@ -104,17 +106,21 @@ function _M.post(params, dao_collection, success) end end -function _M.patch(params, dao_collection) - local new_entity, err = dao_collection:update(params) +function _M.patch(new_entity, old_entity, dao_collection) + for k, v in pairs(new_entity) do + old_entity[k] = v + end + + local updated_entity, err = dao_collection:update(old_entity) if err then return app_helpers.yield_error(err) else - return responses.send_HTTP_OK(new_entity) + return responses.send_HTTP_OK(updated_entity) end end -function _M.delete(entity_id, dao_collection) - local ok, err = dao_collection:delete(entity_id) +function _M.delete(where_t, dao_collection) + local ok, err = dao_collection:delete(where_t) if not ok then if err then return app_helpers.yield_error(err) diff --git a/kong/api/routes/apis.lua b/kong/api/routes/apis.lua index af93f0b7e29..d3a93a6eefb 100644 --- a/kong/api/routes/apis.lua +++ b/kong/api/routes/apis.lua @@ -24,12 +24,11 @@ return { end, PATCH = function(self, dao_factory) - self.params.id = self.api.id - crud.patch(self.params, dao_factory.apis) + crud.patch(self.params, self.api, dao_factory.apis) end, DELETE = function(self, dao_factory) - crud.delete(self.api.id, dao_factory.apis) + crud.delete(self.api, dao_factory.apis) end }, @@ -79,12 +78,11 @@ return { end, PATCH = function(self, dao_factory, helpers) - self.params.id = self.plugin.id - crud.patch(self.params, dao_factory.plugins_configurations) + crud.patch(self.params, self.plugin, dao_factory.plugins_configurations) end, DELETE = function(self, dao_factory) - crud.delete(self.plugin.id, dao_factory.plugins_configurations) + crud.delete(self.plugin, dao_factory.plugins_configurations) end } } diff --git a/kong/api/routes/consumers.lua b/kong/api/routes/consumers.lua index 610e0ffd4d3..784fabbd350 100644 --- a/kong/api/routes/consumers.lua +++ b/kong/api/routes/consumers.lua @@ -25,12 +25,11 @@ return { end, PATCH = function(self, dao_factory, helpers) - self.params.id = self.consumer.id - crud.patch(self.params, dao_factory.consumers) + crud.patch(self.params, self.consumer, dao_factory.consumers) end, DELETE = function(self, dao_factory, helpers) - crud.delete(self.consumer.id, dao_factory.consumers) + crud.delete(self.consumer, dao_factory.consumers) end } } diff --git a/kong/api/routes/plugins_configurations.lua b/kong/api/routes/plugins_configurations.lua index 8401d6105f2..65f57964f45 100644 --- a/kong/api/routes/plugins_configurations.lua +++ b/kong/api/routes/plugins_configurations.lua @@ -25,7 +25,7 @@ return { ["/plugins_configurations/:id"] = { before = function(self, dao_factory, helpers) local err - self.plugin_conf, err = dao_factory.plugins_configurations:find_one(self.params.id) + self.plugin_conf, err = dao_factory.plugins_configurations:find_by_primary_key({ id = self.params.id }) if err then return helpers.yield_error(err) elseif not self.plugin_conf then @@ -38,12 +38,11 @@ return { end, PATCH = function(self, dao_factory) - self.params.id = self.plugin_conf.id - crud.patch(self.params, dao_factory.plugins_configurations) + crud.patch(self.params, self.plugin_conf, dao_factory.plugins_configurations) end, DELETE = function(self, dao_factory) - crud.delete(self.plugin_conf.id, dao_factory.plugins_configurations) + crud.delete(self.plugin_conf, dao_factory.plugins_configurations) end } } diff --git a/kong/constants.lua b/kong/constants.lua index edda67b766f..b5e359b08e5 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -23,10 +23,6 @@ return { UNIQUE = "unique", FOREIGN = "foreign" }, - DATABASE_TYPES = { - ID = "id", - TIMESTAMP = "timestamp" - }, -- Non standard headers, specific to Kong HEADERS = { HOST_OVERRIDE = "X-Host-Override", diff --git a/kong/dao/cassandra/apis.lua b/kong/dao/cassandra/apis.lua index 4848a35bcbf..e867907e9b6 100644 --- a/kong/dao/cassandra/apis.lua +++ b/kong/dao/cassandra/apis.lua @@ -1,55 +1,19 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local apis_schema = require "kong.dao.schemas.apis" +local query_builder = require "kong.dao.cassandra.query_builder" local Apis = BaseDao:extend() function Apis:new(properties) - self._entity = "API" + self._table = "apis" self._schema = apis_schema - self._queries = { - insert = { - args_keys = { "id", "name", "public_dns", "path", "strip_path", "target_url", "created_at" }, - query = [[ INSERT INTO apis(id, name, public_dns, path, strip_path, target_url, created_at) - VALUES(?, ?, ?, ?, ?, ?, ?); ]] - }, - update = { - args_keys = { "name", "public_dns", "path", "strip_path", "target_url", "id" }, - query = [[ UPDATE apis SET name = ?, public_dns = ?, path = ?, strip_path = ?, target_url = ? WHERE id = ?; ]] - }, - select = { - query = [[ SELECT * FROM apis %s; ]] - }, - select_one = { - args_keys = { "id" }, - query = [[ SELECT * FROM apis WHERE id = ?; ]] - }, - delete = { - args_keys = { "id" }, - query = [[ DELETE FROM apis WHERE id = ?; ]] - }, - __unique = { - name = { - args_keys = { "name" }, - query = [[ SELECT id FROM apis WHERE name = ?; ]] - }, - path = { - args_keys = { "path" }, - query = [[ SELECT id FROM apis WHERE path = ?; ]] - }, - public_dns = { - args_keys = { "public_dns" }, - query = [[ SELECT id FROM apis WHERE public_dns = ?; ]] - } - }, - drop = "TRUNCATE apis;" - } - Apis.super.new(self, properties) end function Apis:find_all() local apis = {} - for _, rows, page, err in Apis.super._execute_kong_query(self, self._queries.select.query, nil, {auto_paging=true}) do + local select_q = query_builder.select(self._table) + for _, rows, page, err in Apis.super.execute(self, select_q, nil, nil, {auto_paging=true}) do if err then return nil, err end @@ -63,28 +27,23 @@ function Apis:find_all() end -- @override -function Apis:delete(api_id) - local ok, err = Apis.super.delete(self, api_id) +function Apis:delete(where_t) + local ok, err = Apis.super.delete(self, where_t) if not ok then return false, err end -- delete all related plugins configurations local plugins_dao = self._factory.plugins_configurations - local query, args_keys, errors = plugins_dao:_build_where_query(plugins_dao._queries.select.query, { - api_id = api_id - }) - if errors then - return nil, errors - end + local select_q, columns = query_builder.select(plugins_dao._table, {api_id = where_t.id}, plugins_dao._column_family_details) - for _, rows, page, err in plugins_dao:_execute_kong_query({query=query, args_keys=args_keys}, {api_id=api_id}, {auto_paging=true}) do + for _, rows, page, err in plugins_dao:execute(select_q, columns, {api_id = where_t.id}, {auto_paging = true}) do if err then return nil, err end for _, row in ipairs(rows) do - local ok_del_plugin, err = plugins_dao:delete(row.id) + local ok_del_plugin, err = plugins_dao:delete({id = row.id}) if not ok_del_plugin then return nil, err end diff --git a/kong/dao/cassandra/base_dao.lua b/kong/dao/cassandra/base_dao.lua index f1a90bb0434..84188ed140e 100644 --- a/kong/dao/cassandra/base_dao.lua +++ b/kong/dao/cassandra/base_dao.lua @@ -1,12 +1,8 @@ -- Kong's Cassandra base DAO entity. Provides basic functionnalities on top of -- lua-resty-cassandra (https://github.com/jbochi/lua-resty-cassandra) --- --- Entities (APIs, Consumers) having a schema and defined kong_queries can extend --- this object to benefit from methods such as `insert`, `update`, schema validations --- (including UNIQUE and FOREIGN check), marshalling of some properties, etc... -local validations = require("kong.dao.schemas_validation") -local validate = validations.validate +local query_builder = require "kong.dao.cassandra.query_builder" +local validations = require "kong.dao.schemas_validation" local constants = require "kong.constants" local cassandra = require "cassandra" local timestamp = require "kong.tools.timestamp" @@ -25,6 +21,23 @@ local BaseDao = Object:extend() uuid.seed() function BaseDao:new(properties) + if self._schema then + self._primary_key = self._schema.primary_key + self._clustering_key = self._schema.clustering_key + local indexes = {} + for field_k, field_v in pairs(self._schema.fields) do + if field_v.queryable then + indexes[field_k] = true + end + end + + self._column_family_details = { + primary_key = self._primary_key, + clustering_key = self._clustering_key, + indexes = indexes + } + end + self._properties = properties self._statements_cache = {} end @@ -41,104 +54,8 @@ function BaseDao:_unmarshall(t) return t end --- Run a statement checking if a row exists (false if it does). --- @param `kong_query` kong_query to execute --- @param `t` args to bind to the statement --- @param `is_updating` If true, will ignore UNIQUE if same entity --- @return `unique` true if doesn't exist (UNIQUE), false otherwise --- @return `error` Error if any during execution -function BaseDao:_check_unique(kong_query, t, is_updating) - local results, err = self:_execute_kong_query(kong_query, t) - if err then - return false, "Error during UNIQUE check: "..err.message - elseif results and #results > 0 then - if not is_updating then - return false - else - -- If we are updating, we ignore UNIQUE values if coming from the same entity - local unique = true - for k,v in ipairs(results) do - if v.id ~= t.id then - unique = false - break - end - end - - return unique - end - else - return true - end -end - --- Run a statement checking if a row exists (true if it does). --- @param `kong_query` kong_query to execute --- @param `t` args to bind to the statement --- @return `exists` true if the row exists (FOREIGN), false otherwise --- @return `error` Error if any during the query execution --- @return `results` Results of the statement if `exists` is true (useful for :update() -function BaseDao:_check_foreign(kong_query, t) - local results, err = self:_execute_kong_query(kong_query, t) - if err then - return false, err - elseif not results or #results == 0 then - return false - else - return true, nil, results - end -end - --- Run the FOREIGN exists check on all statements in __foreign. --- @param `t` args to bind to the __foreign statements --- @return `exists` if all results EXIST, false otherwise --- @return `error` Error if any during the query execution --- @return `errors` A table with the list of not existing foreign entities -function BaseDao:_check_all_foreign(t) - if not self._queries.__foreign then return true end - - local errors - for k, kong_query in pairs(self._queries.__foreign) do - if t[k] and t[k] ~= constants.DATABASE_NULL_ID then - local exists, err = self:_check_foreign(kong_query, t) - if err then - return false, err - elseif not exists then - errors = utils.add_error(errors, k, k.." "..t[k].." does not exist") - end - end - end - - return errors == nil, nil, errors -end - --- Run the UNIQUE on all statements in __unique. --- @param `t` args to bind to the __unique statements --- @param `is_updating` If true, will ignore UNIQUE if same entity --- @return `unique` true if all results are UNIQUE, false otherwise --- @return `error` Error if any during the query execution --- @return `errors` A table with the list of already existing entities -function BaseDao:_check_all_unique(t, is_updating) - if not self._queries.__unique then return true end - - local errors - for k, statement in pairs(self._queries.__unique) do - if t[k] or k == "self" then - local unique, err = self:_check_unique(statement, t, is_updating) - if err then - return false, err - elseif not unique and k == "self" then - return false, nil, self._entity.." already exists" - elseif not unique then - errors = utils.add_error(errors, k, k.." already exists with value '"..t[k].."'") - end - end - end - - return errors == nil, nil, errors -end - --- Open a Cassandra session on the configured keyspace. --- @param `keyspace` (Optional) Override the keyspace for this session if specified. +-- Open a session on the configured keyspace. +-- @param `keyspace` (Optional) Override the keyspace for this session if specified. -- @return `session` Opened session -- @return `error` Error if any function BaseDao:_open_session(keyspace) @@ -168,7 +85,8 @@ function BaseDao:_open_session(keyspace) return session end --- Close the given opened session. Will try to put the session in the socket pool if supported. +-- Close the given opened session. +-- Will try to put the session in the socket pool if supported. -- @param `session` Cassandra session to close -- @return `error` Error if any function BaseDao:_close_session(session) @@ -186,8 +104,8 @@ end -- Build the array of arguments to pass to lua-resty-cassandra :execute method. -- Note: -- Since this method only accepts an ordered list, we build this list from --- the `args_keys` property of all prepared statement, taking into account special --- cassandra values (uuid, timestamps, NULL) +-- the entity `t` and an (ordered) array of parameters for a query, taking +-- into account special cassandra values (uuid, timestamps, NULL). -- @param `schema` A schema with type properties to encode specific values -- @param `t` Values to bind to a statement -- @param `parameters` An ordered list of parameters @@ -196,78 +114,47 @@ end local function encode_cassandra_args(schema, t, args_keys) local args_to_bind = {} local errors + for _, column in ipairs(args_keys) do - local schema_field = schema[column] - local value = t[column] + local schema_field = schema.fields[column] + local arg = t[column] - if schema_field.type == constants.DATABASE_TYPES.ID and value then - if validations.is_valid_uuid(value) then - value = cassandra.uuid(value) + if schema_field.type == "id" and arg then + if validations.is_valid_uuid(arg) then + arg = cassandra.uuid(arg) else - errors = utils.add_error(errors, column, value.." is an invalid uuid") + errors = utils.add_error(errors, column, arg.." is an invalid uuid") end - elseif schema_field.type == constants.DATABASE_TYPES.TIMESTAMP and value then - value = cassandra.timestamp(value) - elseif value == nil then - value = cassandra.null + elseif schema_field.type == "timestamp" and arg then + arg = cassandra.timestamp(arg) + elseif arg == nil then + arg = cassandra.null end - table.insert(args_to_bind, value) + table.insert(args_to_bind, arg) end return args_to_bind, errors end -function BaseDao:_build_where_query(query, t) - local args_keys = {} - local where_str = "" - local errors - - -- if t is an args_keys, compute a WHERE statement - if t and utils.table_size(t) > 0 then - local where = {} - for k, v in pairs(t) do - if self._schema[k] and self._schema[k].queryable or k == "id" then - table.insert(where, string.format("%s = ?", k)) - table.insert(args_keys, k) - else - errors = utils.add_error(errors, k, k.." is not queryable.") - end - end - - if errors then - return nil, nil, DaoError(errors, error_types.SCHEMA) - end - - where_str = "WHERE "..table.concat(where, " AND ").." ALLOW FILTERING" - end - - return string.format(query, where_str), args_keys -end - -- Get a statement from the cache or prepare it (and thus insert it in the cache). -- The cache key will be the plain string query representation. --- @param `kong_query` A kong query from the _queries property. +-- @param `query` The query to prepare -- @return `statement` The prepared cassandra statement -- @return `cache_key` The cache key used to store it into the cache -- @return `error` Error if any during the query preparation -function BaseDao:_get_or_prepare(kong_query) - local query - if type(kong_query) == "string" then - query = kong_query - elseif kong_query.query then - query = kong_query.query - else +function BaseDao:get_or_prepare_stmt(query) + if type(query) ~= "string" then -- Cannot be prepared (probably a BatchStatement) - return kong_query + return query end local statement, err -- Retrieve the prepared statement from cache or prepare and cache - if self._statements_cache[kong_query.query] then - statement = self._statements_cache[kong_query.query].statement + if self._statements_cache[query] then + statement = self._statements_cache[query] else - statement, err = self:prepare_kong_statement(kong_query) + statement, err = self:prepare_stmt(query) if err then return nil, query, err end @@ -333,77 +220,113 @@ function BaseDao:_execute(statement, args, options, keyspace) end end --- Execute a kong_query (_queries property of DAO entities). +-- Execute a query. -- Will prepare the query before execution and cache the prepared statement. -- Will create an arguments array for lua-resty-cassandra's :execute() --- @param `kong_query` The kong_query to execute +-- @param `query` The query to execute -- @param `args_to_bind` Key/value table of arguments to bind -- @param `options` Options to pass to lua-resty-cassandra :execute() -- @return :_execute() -function BaseDao:_execute_kong_query(operation, args_to_bind, options) +function BaseDao:execute(query, columns, args_to_bind, options) -- Prepare query and cache the prepared statement for later call - local statement, cache_key, err = self:_get_or_prepare(operation) + local statement, cache_key, err = self:get_or_prepare_stmt(query) if err then return nil, err end -- Build args array if operation has some local args - if operation.args_keys and args_to_bind then + if columns and args_to_bind then local errors - args, errors = encode_cassandra_args(self._schema, args_to_bind, operation.args_keys) + args, errors = encode_cassandra_args(self._schema, args_to_bind, columns) if errors then return nil, DaoError(errors, error_types.INVALID_TYPE) end end -- Execute statement - local results, err - results, err = self:_execute(statement, args, options) + local results, err = self:_execute(statement, args, options) if err and err.cassandra_err_code == cassandra_constants.error_codes.UNPREPARED then if ngx then ngx.log(ngx.NOTICE, "Cassandra did not recognize prepared statement \""..cache_key.."\". Re-preparing it and re-trying the query. (Error: "..err..")") end -- If the statement was declared unprepared, clear it from the cache, and try again. self._statements_cache[cache_key] = nil - return self:_execute_kong_query(operation, args_to_bind, options) + return self:execute(query, columns, args_to_bind, options) end return results, err end ----------------------- --- PUBLIC INTERFACE -- ----------------------- +-- Check all fields marked with a `unique` in the schema do not already exist. +function BaseDao:check_unique_fields(t, is_update) + local errors --- Prepare a statement used by kong and insert it into the statement cache. --- Note: --- Since lua-resty-cassandra doesn't support binding by name yet, we need --- to keep a record of properties to bind for each statement. Thus, a "kong query" --- is an object made of a prepared statement and an array of columns to bind. --- See :_execute_kong_query() for the usage of this args_keys array doing the binding. --- @param `kong_query` The kong_query to prepare and insert into the cache. --- @return `statement` The prepared statement, ready to be used by lua-resty-cassandra. --- @return `error` Error if any during the preparation of the statement -function BaseDao:prepare_kong_statement(kong_query) - -- _queries can contain strings or tables with string + keys of arguments to bind - local query - if type(kong_query) == "string" then - query = kong_query - elseif kong_query.query then - query = kong_query.query + for k, field in pairs(self._schema.fields) do + if field.unique and t[k] ~= nil then + local res, err = self:find_by_keys {[k] = t[k]} + if err then + return false, nil, "Error during UNIQUE check: "..err.message + elseif res and #res > 0 then + local is_self = true + if is_update then + -- If update, check if the retrieved entity is not the entity itself + res = res[1] + for _, key in ipairs(self._primary_key) do + if t[key] ~= res[key] then + is_self = false + break + end + end + else + is_self = false + end + + if not is_self then + errors = utils.add_error(errors, k, k.." already exists with value '"..t[k].."'") + end + end + end + end + + return errors == nil, errors +end + +-- Check all fields marked as `foreign` in the schema exist on other column families. +function BaseDao:check_foreign_fields(t) + local errors, foreign_type, foreign_field, res, err + + for k, field in pairs(self._schema.fields) do + if field.foreign ~= nil and type(field.foreign) == "string" then + foreign_type, foreign_field = unpack(stringy.split(field.foreign, ":")) + if foreign_type and foreign_field and self._factory[foreign_type] and t[k] ~= nil and t[k] ~= constants.DATABASE_NULL_ID then + res, err = self._factory[foreign_type]:find_by_keys {[foreign_field] = t[k]} + if err then + return false, nil, "Error during FOREIGN check: "..err.message + elseif not res or #res == 0 then + errors = utils.add_error(errors, k, k.." "..t[k].." does not exist") + end + end + end end - -- handle SELECT queries with %s for dynamic select by keys - local query_to_prepare = string.format(query, "") - query_to_prepare = stringy.strip(query_to_prepare) + return errors == nil, errors +end + +-- Prepare a query and insert it into the statement cache. +-- @param `query` The query to prepare +-- @return `statement` The prepared statement, ready to be used by lua-resty-cassandra. +-- @return `error` Error if any during the preparation of the statement +function BaseDao:prepare_stmt(query) + assert(type(query) == "string", "Query to prepare must be a string") + query = stringy.strip(query) local session, err = self:_open_session() if err then return nil, err end - local prepared_stmt, prepare_err = session:prepare(query_to_prepare) + local prepared_stmt, prepare_err = session:prepare(query) local err = self:_close_session(session) if err then @@ -411,59 +334,60 @@ function BaseDao:prepare_kong_statement(kong_query) end if prepare_err then - return nil, DaoError("Failed to prepare statement: \""..query_to_prepare.."\". "..prepare_err, error_types.DATABASE) + return nil, DaoError("Failed to prepare statement: \""..query.."\". "..prepare_err, error_types.DATABASE) else -- cache key is the non-striped/non-formatted query from _queries - self._statements_cache[query] = { - query = query, - args_keys = kong_query.args_keys, - statement = prepared_stmt - } - + self._statements_cache[query] = prepared_stmt return prepared_stmt end end - --- Execute the INSERT kong_query of a DAO entity. --- Validates the entity's schema + UNIQUE values + FOREIGN KEYS. --- Generates id and created_at fields. +-- Insert a row in the DAO's table. +-- Perform schema validation, UNIQUE checks, FOREIGN checks. -- @param `t` A table representing the entity to insert -- @return `result` Inserted entity or nil -- @return `error` Error if any during the execution function BaseDao:insert(t) - local ok, err, errors - if not t then - return nil, DaoError("Cannot insert a nil element", error_types.SCHEMA) + assert(t ~= nil, "Cannot insert a nil element") + assert(type(t) == "table", "Entity to insert must be a table") + + local ok, db_err, errors + + -- Populate the entity with any default/overriden values and validate it + errors = validations.validate(t, self, { + dao_insert = function(field) + if field.type == "id" then + return uuid() + elseif field.type == "timestamp" then + return timestamp.get_utc() + end + end + }) + if errors then + return nil, errors end - -- Override created_at and id by default value - t.created_at = timestamp.get_utc() - t.id = uuid() - - -- Validate schema - ok, errors = validate(t, self._schema) + ok, errors = validations.on_insert(t, self._schema, self._factory) if not ok then - return nil, DaoError(errors, error_types.SCHEMA) + return nil, errors end - -- Check UNIQUE values - ok, err, errors = self:_check_all_unique(t) - if err then - return nil, DaoError(err, error_types.DATABASE) + ok, errors, db_err = self:check_unique_fields(t) + if db_err then + return nil, DaoError(db_err, error_types.DATABASE) elseif not ok then return nil, DaoError(errors, error_types.UNIQUE) end - -- Check foreign entities EXIST - ok, err, errors = self:_check_all_foreign(t) - if err then - return nil, DaoError(err, error_types.DATABASE) + ok, errors, db_err = self:check_foreign_fields(t) + if db_err then + return nil, DaoError(db_err, error_types.DATABASE) elseif not ok then return nil, DaoError(errors, error_types.FOREIGN) end - local _, stmt_err = self:_execute_kong_query(self._queries.insert, self:_marshall(t)) + local insert_q, columns = query_builder.insert(self._table, t) + local _, stmt_err = self:execute(insert_q, columns, self:_marshall(t)) if stmt_err then return nil, stmt_err else @@ -471,58 +395,78 @@ function BaseDao:insert(t) end end --- Execute the UPDATE kong_query of a DAO entity. --- Validate entity's schema + UNIQUE values + FOREIGN KEYS. +local function extract_primary_key(t, primary_key, clustering_key) + local t_no_primary_key = utils.deep_copy(t) + local t_primary_key = {} + for _, key in ipairs(primary_key) do + t_primary_key[key] = t[key] + t_no_primary_key[key] = nil + end + if clustering_key then + for _, key in ipairs(clustering_key) do + t_primary_key[key] = t[key] + t_no_primary_key[key] = nil + end + end + return t_primary_key, t_no_primary_key +end + +-- Update a row: find the row with the given PRIMARY KEY and update the other values +-- If `full`, sets to NULL values that are not included in the schema. +-- Performs schema validation, UNIQUE and FOREIGN checks. -- @param `t` A table representing the entity to insert +-- @param `full` If `true`, set to NULL any column not in the `t` parameter -- @return `result` Updated entity or nil -- @return `error` Error if any during the execution -function BaseDao:update(t) - local ok, err, errors - if not t then - return nil, DaoError("Cannot update a nil element", error_types.SCHEMA) - end +function BaseDao:update(t, full) + assert(t ~= nil, "Cannot update a nil element") + assert(type(t) == "table", "Entity to update must be a table") + + local ok, db_err, errors - -- Check if exists to prevent upsert and manually set UNSET values (pfffff...) - local results - ok, err, results = self:_check_foreign(self._queries.select_one, t) + -- Check if exists to prevent upsert + local res, err = self:find_by_primary_key(t) if err then - return nil, err - elseif not ok then - return nil - else - -- Set UNSET values to prevent cassandra from setting to NULL - -- @see Test case - -- @see https://issues.apache.org/jira/browse/DATABASE-7304 - for k, v in pairs(results[1]) do - if t[k] == nil then - t[k] = v - end - end + return false, err + elseif not res then + return false end -- Validate schema - ok, errors = validate(t, self._schema, true) - if not ok then - return nil, DaoError(errors, error_types.SCHEMA) + errors = validations.validate(t, self, {partial_update = not full, full_update = full}) + if errors then + return nil, errors end - -- Check UNIQUE with update - ok, err, errors = self:_check_all_unique(t, true) - if err then - return nil, DaoError(err, error_types.DATABASE) + ok, errors, db_err = self:check_unique_fields(t, true) + if db_err then + return nil, DaoError(db_err, error_types.DATABASE) elseif not ok then return nil, DaoError(errors, error_types.UNIQUE) end - -- Check FOREIGN entities - ok, err, errors = self:_check_all_foreign(t) - if err then - return nil, DaoError(err, error_types.DATABASE) + ok, errors, db_err = self:check_foreign_fields(t) + if db_err then + return nil, DaoError(db_err, error_types.DATABASE) elseif not ok then return nil, DaoError(errors, error_types.FOREIGN) end - local _, stmt_err = self:_execute_kong_query(self._queries.update, self:_marshall(t)) + -- Extract primary key from the entity + local t_primary_key, t_no_primary_key = extract_primary_key(t, self._primary_key, self._clustering_key) + + -- If full, add `null` values to the SET part of the query for nil columns + if full then + for k, v in pairs(self._schema.fields) do + if not t[k] and not v.immutable then + t_no_primary_key[k] = cassandra.null + end + end + end + + local update_q, columns = query_builder.update(self._table, t_no_primary_key, t_primary_key) + + local _, stmt_err = self:execute(update_q, columns, self:_marshall(t)) if stmt_err then return nil, stmt_err else @@ -530,11 +474,22 @@ function BaseDao:update(t) end end --- Execute the SELECT_ONE kong_query of a DAO entity. --- @param `id` uuid of the entity to select --- @return `result` The first row of the _execute_kong_query() return value -function BaseDao:find_one(id) - local data, err = self:_execute_kong_query(self._queries.select_one, { id = id }) +-- Retrieve a row at given PRIMARY KEY. +-- @param `where_t` A table containing the PRIMARY KEY (columns/values) of the row to retrieve. +-- @return `row` The first row of the result. +-- @return `error` +function BaseDao:find_by_primary_key(where_t) + assert(self._primary_key ~= nil and type(self._primary_key) == "table" , "Entity does not have a primary_key") + assert(where_t ~= nil and type(where_t) == "table", "where_t must be a table") + + local t_primary_key = extract_primary_key(where_t, self._primary_key) + + if next(t_primary_key) == nil then + return nil + end + + local select_q, where_columns = query_builder.select(self._table, t_primary_key, self._column_family_details, nil, true) + local data, err = self:execute(select_q, where_columns, t_primary_key) -- Return the 1st and only element of the result set if data and utils.table_size(data) > 0 then @@ -546,51 +501,57 @@ function BaseDao:find_one(id) return data, err end --- Execute the SELECT kong_query of a DAO entity with a special WHERE clause. --- @warning Generated statement will use `ALLOW FILTERING` in their queries. --- @param `t` (Optional) Keys by which to find an entity. +-- Retrieve a set of rows from the given columns/value table. +-- @param `where_t` (Optional) columns/values table by which to find an entity. -- @param `page_size` Size of the page to retrieve (number of rows). -- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. --- @return _execute_kong_query() -function BaseDao:find_by_keys(t, page_size, paging_state) - local select_where_query, args_keys, errors = self:_build_where_query(self._queries.select.query, t) - if errors then - return nil, errors - end - - return self:_execute_kong_query({ query = select_where_query, args_keys = args_keys }, t, { +-- @return `res` +-- @return `err` +-- @return `filtering` A boolean indicating if ALLOW FILTERING was needed by the query +function BaseDao:find_by_keys(where_t, page_size, paging_state) + local select_q, where_columns, filtering = query_builder.select(self._table, where_t, self._column_family_details) + local res, err = self:execute(select_q, where_columns, where_t, { page_size = page_size, paging_state = paging_state }) + + return res, err, filtering end --- Execute the SELECT kong_query of a DAO entity. --- @param `page_size` Size of the page to retrieve (number of rows). --- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. --- @return find_by_keys() +-- Retrieve a page of the table attached to the DAO. +-- @param `page_size` Size of the page to retrieve (number of rows). +-- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. +-- @return `find_by_keys()` function BaseDao:find(page_size, paging_state) return self:find_by_keys(nil, page_size, paging_state) end --- Execute the SELECT kong_query of a DAO entity. --- @param `id` uuid of the entity to delete +-- Delete the row at a given PRIMARY KEY. +-- @param `where_t` A table containing the PRIMARY KEY (columns/values) of the row to delete -- @return `success` True if deleted, false if otherwise or not found -- @return `error` Error if any during the query execution -function BaseDao:delete(id) - local exists, err = self:_check_foreign(self._queries.select_one, { id = id }) +function BaseDao:delete(where_t) + assert(self._primary_key ~= nil and type(self._primary_key) == "table" , "Entity does not have a primary_key") + assert(where_t ~= nil and type(where_t) == "table", "where_t must be a table") + + -- Test if exists first + local res, err = self:find_by_primary_key(where_t) if err then return false, err - elseif not exists then + elseif not res then return false end - return self:_execute_kong_query(self._queries.delete, { id = id }) + local t_primary_key = extract_primary_key(where_t, self._primary_key, self._clustering_key) + local delete_q, where_columns = query_builder.delete(self._table, t_primary_key) + return self:execute(delete_q, where_columns, where_t) end +-- Truncate the table of this DAO +-- @return `:execute()` function BaseDao:drop() - if self._queries.drop then - return self:_execute_kong_query(self._queries.drop) - end + local truncate_q = query_builder.truncate(self._table) + return self:execute(truncate_q) end return BaseDao diff --git a/kong/dao/cassandra/consumers.lua b/kong/dao/cassandra/consumers.lua index dcf189dbd45..e5488c9cf73 100644 --- a/kong/dao/cassandra/consumers.lua +++ b/kong/dao/cassandra/consumers.lua @@ -1,70 +1,34 @@ local BaseDao = require "kong.dao.cassandra.base_dao" +local query_builder = require "kong.dao.cassandra.query_builder" local consumers_schema = require "kong.dao.schemas.consumers" local Consumers = BaseDao:extend() function Consumers:new(properties) - self._entity = "Consumer" + self._table = "consumers" self._schema = consumers_schema - self._queries = { - insert = { - args_keys = { "id", "custom_id", "username", "created_at" }, - query = [[ INSERT INTO consumers(id, custom_id, username, created_at) VALUES(?, ?, ?, ?); ]] - }, - update = { - args_keys = { "custom_id", "username", "created_at", "id" }, - query = [[ UPDATE consumers SET custom_id = ?, username = ?, created_at = ? WHERE id = ?; ]] - }, - select = { - query = [[ SELECT * FROM consumers %s; ]] - }, - select_one = { - args_keys = { "id" }, - query = [[ SELECT * FROM consumers WHERE id = ?; ]] - }, - delete = { - args_keys = { "id" }, - query = [[ DELETE FROM consumers WHERE id = ?; ]] - }, - __unique = { - custom_id ={ - args_keys = { "custom_id" }, - query = [[ SELECT id FROM consumers WHERE custom_id = ?; ]] - }, - username ={ - args_keys = { "username" }, - query = [[ SELECT id FROM consumers WHERE username = ?; ]] - } - }, - drop = "TRUNCATE consumers;" - } Consumers.super.new(self, properties) end -- @override -function Consumers:delete(consumer_id) - local ok, err = Consumers.super.delete(self, consumer_id) +function Consumers:delete(where_t) + local ok, err = Consumers.super.delete(self, where_t) if not ok then return false, err end - -- delete all related plugins configurations local plugins_dao = self._factory.plugins_configurations - local query, args_keys, errors = plugins_dao:_build_where_query(plugins_dao._queries.select.query, { - consumer_id = consumer_id - }) - if errors then - return nil, errors - end + local select_q, columns = query_builder.select(plugins_dao._table, {consumer_id = where_t.id}, plugins_dao._column_family_details) - for _, rows, page, err in plugins_dao:_execute_kong_query({query=query, args_keys=args_keys}, {consumer_id=consumer_id}, {auto_paging=true}) do + -- delete all related plugins configurations + for _, rows, page, err in plugins_dao:execute(select_q, columns, {consumer_id = where_t.id}, {auto_paging = true}) do if err then return nil, err end for _, row in ipairs(rows) do - local ok_del_plugin, err = plugins_dao:delete(row.id) + local ok_del_plugin, err = plugins_dao:delete({id = row.id}) if not ok_del_plugin then return nil, err end diff --git a/kong/dao/cassandra/factory.lua b/kong/dao/cassandra/factory.lua index 7553aeea108..9a60c3aebe1 100644 --- a/kong/dao/cassandra/factory.lua +++ b/kong/dao/cassandra/factory.lua @@ -67,7 +67,7 @@ function CassandraFactory:drop() end end --- Prepare all statements of collections `._queries` property and put them +-- Prepare all statements of collections `queries` property and put them -- in a statements cache -- -- Note: @@ -78,24 +78,31 @@ end -- @return error if any function CassandraFactory:prepare() local function prepare_collection(collection, collection_queries) - if not collection_queries then collection_queries = collection._queries end + local err for stmt_name, collection_query in pairs(collection_queries) do - if type(collection_query) == "table" and collection_query.query == nil then - -- Nested queries, let's recurse to prepare them too - prepare_collection(collection, collection_query) - else - local _, err = collection:prepare_kong_statement(collection_query) - if err then - error(err) - end + err = select(2, collection:prepare_stmt(collection_query)) + if err then + error(err) end end end + -- Check cassandra is accessible + local session = cassandra.new() + session:set_timeout(self._properties.timeout) + local ok, co_err = session:connect(self._properties.hosts, self._properties.port) + session:close() + + if not ok then + return DaoError(co_err, constants.DATABASE_ERROR_TYPES.DATABASE) + end + for _, collection in pairs(self.daos) do - local status, err = pcall(function() prepare_collection(collection) end) - if not status then - return err + if collection.queries then + local status, err = pcall(function() prepare_collection(collection, collection.queries) end) + if not status then + return err + end end end end diff --git a/kong/dao/cassandra/migrations.lua b/kong/dao/cassandra/migrations.lua index 92c8d875aa4..00971dc88ff 100644 --- a/kong/dao/cassandra/migrations.lua +++ b/kong/dao/cassandra/migrations.lua @@ -4,7 +4,8 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local Migrations = BaseDao:extend() function Migrations:new(properties) - self._queries = { + self._table = "schema_migrations" + self.queries = { add_migration = [[ UPDATE schema_migrations SET migrations = migrations + ? WHERE id = 'migrations'; ]], @@ -27,7 +28,7 @@ end -- @return query result -- @return error if any function Migrations:add_migration(migration_name) - return Migrations.super._execute(self, self._queries.add_migration, + return Migrations.super._execute(self, self.queries.add_migration, { cassandra.list({ migration_name }) }) end @@ -37,7 +38,7 @@ end function Migrations:get_migrations() local rows, err - rows, err = Migrations.super._execute(self, self._queries.get_keyspace, + rows, err = Migrations.super._execute(self, self.queries.get_keyspace, { self._properties.keyspace }, nil, "system") if err then return nil, err @@ -46,7 +47,7 @@ function Migrations:get_migrations() return nil end - rows, err = Migrations.super._execute(self, self._queries.get_migrations) + rows, err = Migrations.super._execute(self, self.queries.get_migrations) if err then return nil, err elseif rows and #rows > 0 then @@ -58,8 +59,12 @@ end -- @return query result -- @return error if any function Migrations:delete_migration(migration_name) - return Migrations.super._execute(self, self._queries.delete_migration, + return Migrations.super._execute(self, self.queries.delete_migration, { cassandra.list({ migration_name }) }) end +function Migrations:drop() + -- never drop this +end + return { migrations = Migrations } diff --git a/kong/dao/cassandra/plugins_configurations.lua b/kong/dao/cassandra/plugins_configurations.lua index 8d89b304eae..b2273af63ad 100644 --- a/kong/dao/cassandra/plugins_configurations.lua +++ b/kong/dao/cassandra/plugins_configurations.lua @@ -1,4 +1,5 @@ local plugins_configurations_schema = require "kong.dao.schemas.plugins_configurations" +local query_builder = require "kong.dao.cassandra.query_builder" local constants = require "kong.constants" local BaseDao = require "kong.dao.cassandra.base_dao" local cjson = require "cjson" @@ -6,47 +7,8 @@ local cjson = require "cjson" local PluginsConfigurations = BaseDao:extend() function PluginsConfigurations:new(properties) - self._entity = "Plugin configuration" + self._table = "plugins_configurations" self._schema = plugins_configurations_schema - self._queries = { - insert = { - args_keys = { "id", "api_id", "consumer_id", "name", "value", "enabled", "created_at" }, - query = [[ INSERT INTO plugins_configurations(id, api_id, consumer_id, name, value, enabled, created_at) - VALUES(?, ?, ?, ?, ?, ?, ?); ]] - }, - update = { - args_keys = { "api_id", "consumer_id", "value", "enabled", "created_at", "id", "name" }, - query = [[ UPDATE plugins_configurations SET api_id = ?, consumer_id = ?, value = ?, enabled = ?, created_at = ? WHERE id = ? AND name = ?; ]] - }, - select = { - query = [[ SELECT * FROM plugins_configurations %s; ]] - }, - select_one = { - args_keys = { "id" }, - query = [[ SELECT * FROM plugins_configurations WHERE id = ?; ]] - }, - delete = { - args_keys = { "id" }, - query = [[ DELETE FROM plugins_configurations WHERE id = ?; ]] - }, - __unique = { - self = { - args_keys = { "api_id", "consumer_id", "name" }, - query = [[ SELECT * FROM plugins_configurations WHERE api_id = ? AND consumer_id = ? AND name = ? ALLOW FILTERING; ]] - } - }, - __foreign = { - api_id = { - args_keys = { "api_id" }, - query = [[ SELECT id FROM apis WHERE id = ?; ]] - }, - consumer_id = { - args_keys = { "consumer_id" }, - query = [[ SELECT id FROM consumers WHERE id = ?; ]] - } - }, - drop = "TRUNCATE plugins_configurations;" - } PluginsConfigurations.super.new(self, properties) end @@ -74,6 +36,14 @@ function PluginsConfigurations:_unmarshall(t) return t end +-- @override +function PluginsConfigurations:update(t) + if not t.consumer_id then + t.consumer_id = constants.DATABASE_NULL_ID + end + return PluginsConfigurations.super.update(self, t) +end + function PluginsConfigurations:find_distinct() -- Open session local session, err = PluginsConfigurations.super._open_session(self) @@ -81,14 +51,16 @@ function PluginsConfigurations:find_distinct() return nil, err end + local select_q = query_builder.select(self._table) + -- Execute query local distinct_names = {} - for _, rows, page, err in session:execute(string.format(self._queries.select.query, ""), nil, {auto_paging=true}) do + for _, rows, page, err in PluginsConfigurations.super.execute(self, select_q, nil, nil, {auto_paging=true}) do if err then return nil, err end for _, v in ipairs(rows) do - -- Rows also contains other properites, so making sure it's a plugin + -- Rows also contains other properties, so making sure it's a plugin if v.name then distinct_names[v.name] = true end diff --git a/kong/dao/cassandra/query_builder.lua b/kong/dao/cassandra/query_builder.lua new file mode 100644 index 00000000000..ffbb9af2579 --- /dev/null +++ b/kong/dao/cassandra/query_builder.lua @@ -0,0 +1,195 @@ +local _M = {} + +local function trim(s) + return (s:gsub("^%s*(.-)%s*$", "%1")) +end + +local function select_fragment(column_family, select_columns) + if select_columns then + assert(type(select_columns) == "table", "select_columns must be a table") + select_columns = table.concat(select_columns, ", ") + else + select_columns = "*" + end + + return string.format("SELECT %s FROM %s", select_columns, column_family) +end + +local function insert_fragment(column_family, insert_values) + local values_placeholders, columns = {}, {} + for column, value in pairs(insert_values) do + table.insert(values_placeholders, "?") + table.insert(columns, column) + end + + local columns_names_str = table.concat(columns, ", ") + values_placeholders = table.concat(values_placeholders, ", ") + + return string.format("INSERT INTO %s(%s) VALUES(%s)", column_family, columns_names_str, values_placeholders), columns +end + +local function update_fragment(column_family, update_values) + local placeholders, update_columns = {}, {} + for column in pairs(update_values) do + table.insert(update_columns, column) + table.insert(placeholders, string.format("%s = ?", column)) + end + + placeholders = table.concat(placeholders, ", ") + + return string.format("UPDATE %s SET %s", column_family, placeholders), update_columns +end + +local function delete_fragment(column_family) + return string.format("DELETE FROM %s", column_family) +end + +local function where_fragment(where_t, column_family_details, no_filtering_check) + if not where_t then where_t = {} end + if not column_family_details then column_family_details = {} end + + assert(type(where_t) == "table", "where_t must be a table") + if next(where_t) == nil then + if not no_filtering_check then + return "" + else + error("where_t must contain keys") + end + end + + for _, prop in ipairs({"primary_key", "clustering_key", "indexes"}) do + if column_family_details[prop] then + assert(type(column_family_details[prop]) == "table", prop.." must be a table") + else + column_family_details[prop] = {} + end + end + + local where_parts, columns = {}, {} + local needs_filtering = false + local filtering = "" + local n_indexed_cols = 0 + + for column in pairs(where_t) do + table.insert(where_parts, string.format("%s = ?", column)) + table.insert(columns, column) + + if not no_filtering_check and not needs_filtering then -- as soon as we need it, it's not revertible + -- check if this field belongs to the primary_key + local primary_contains = false + for _, key in ipairs(column_family_details.primary_key) do + if column == key then + primary_contains = true + break + end + end + for _, key in ipairs(column_family_details.clustering_key) do + if column == key then + primary_contains = true + break + end + end + -- check the number of indexed fields being queried. If more than 1, we need filtering + if column_family_details.indexes[column] then + n_indexed_cols = n_indexed_cols + 1 + end + -- if the column is not part of the primary key, nor indexed, or if we have more + -- than one indexed column being queried, we need filtering. + if (not primary_contains and not column_family_details.indexes[column]) or n_indexed_cols > 1 then + needs_filtering = true + end + end + end + + if needs_filtering then + filtering = " ALLOW FILTERING" + else + needs_filtering = false + end + + where_parts = table.concat(where_parts, " AND ") + + return string.format("WHERE %s%s", where_parts, filtering), columns, needs_filtering +end + +-- Generate a SELECT query with an optional WHERE instruction. +-- If building a WHERE instruction, we need some additional informations about the column family. +-- @param `column_family` Name of the column family +-- @param `column_family_details` Additional infos about the column family (partition key, clustering key, indexes) +-- @param `select_columns` A list of columns to retrieve +-- @return `query` The SELECT query +-- @return `columns` An list of columns to bind for the query, in the order of the placeholder markers (?) +-- @return `needs_filtering` A boolean indicating if ALLOW FILTERING was added to this query or not +function _M.select(column_family, where_t, column_family_details, select_columns) + assert(type(column_family) == "string", "column_family must be a string") + + local select_str = select_fragment(column_family, select_columns) + local where_str, columns, needed_filtering = where_fragment(where_t, column_family_details) + + return trim(string.format("%s %s", select_str, where_str)), columns, needed_filtering +end + +-- Generate an INSERT query. +-- @param `column_family` Name of the column family +-- @param `insert_values` A columns/values table of values to insert +-- @return `query` The INSERT query +-- @return `needs_filtering` A boolean indicating if ALLOW FILTERING was added to this query or not +function _M.insert(column_family, insert_values) + assert(type(column_family) == "string", "column_family must be a string") + assert(type(insert_values) == "table", "insert_values must be a table") + assert(next(insert_values) ~= nil, "insert_values cannot be empty") + + return insert_fragment(column_family, insert_values) +end + +-- Generate an UPDATE query with update values (SET part) and a mandatory WHERE instruction. +-- @param `column_family` Name of the column family +-- @param `update_values` A columns/values table of values to update +-- @param `where_t` A columns/values table to select the row to update +-- @return `query` The UPDATE query +-- @return `columns` An list of columns to bind for the query, in the order of the placeholder markers (?) +function _M.update(column_family, update_values, where_t) + assert(type(column_family) == "string", "column_family must be a string") + assert(type(update_values) == "table", "update_values must be a table") + assert(next(update_values) ~= nil, "update_values cannot be empty") + + local update_str, update_columns = update_fragment(column_family, update_values) + local where_str, where_columns = where_fragment(where_t, nil, true) + + -- concat columns from SET and WHERE parts of the query + local columns = {} + if update_columns then + columns = update_columns + end + if where_columns then + for _, v in ipairs(where_columns) do + table.insert(columns, v) + end + end + + return trim(string.format("%s %s", update_str, where_str)), columns +end + +-- Generate a DELETE QUERY with a mandatory WHERE instruction. +-- @param `column_family` Name of the column family +-- @param `where_t` A columns/values table to select the row to DELETE +-- @return `columns` An list of columns to bind for the query, in the order of the placeholder markers (?) +function _M.delete(column_family, where_t) + assert(type(column_family) == "string", "column_family must be a string") + + local delete_str = delete_fragment(column_family) + local where_str, where_columns = where_fragment(where_t, nil, true) + + return trim(string.format("%s %s", delete_str, where_str)), where_columns +end + +-- Generate a TRUNCATE query +-- @param `column_family` Name of the column family +-- @return `query` +function _M.truncate(column_family) + assert(type(column_family) == "string", "column_family must be a string") + + return "TRUNCATE "..column_family +end + +return _M diff --git a/kong/dao/error.lua b/kong/dao/error.lua index 3d5ece57ff7..5ceebf10572 100644 --- a/kong/dao/error.lua +++ b/kong/dao/error.lua @@ -51,7 +51,7 @@ local mt = { -- Cassandra server error if err_type == constants.DATABASE_ERROR_TYPES.DATABASE then - t.message = "Cassandra error: "..t.message + t.message = "Cassandra error: "..t.message -- TODO remove once cassandra driver has nicer error messages t.cassandra_err_code = err.code end diff --git a/kong/dao/schemas/apis.lua b/kong/dao/schemas/apis.lua index d7280f7a87d..5e2d0b22ee0 100644 --- a/kong/dao/schemas/apis.lua +++ b/kong/dao/schemas/apis.lua @@ -55,13 +55,16 @@ local function check_path(path, api_t) end return { - id = { type = "id" }, - name = { type = "string", unique = true, queryable = true, default = function(api_t) return api_t.public_dns end }, - public_dns = { type = "string", unique = true, queryable = true, - func = check_public_dns_and_path, - regex = "([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*)" }, - path = { type = "string", queryable = true, unique = true, func = check_path }, - strip_path = { type = "boolean" }, - target_url = { type = "string", required = true, func = validate_target_url }, - created_at = { type = "timestamp" } + name = "API", + primary_key = {"id"}, + fields = { + id = { type = "id", dao_insert_value = true }, + created_at = { type = "timestamp", dao_insert_value = true }, + name = { type = "string", unique = true, queryable = true, default = function(api_t) return api_t.public_dns end }, + public_dns = { type = "string", unique = true, queryable = true, func = check_public_dns_and_path, + regex = "([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*)" }, + path = { type = "string", unique = true, func = check_path }, + strip_path = { type = "boolean" }, + target_url = { type = "string", required = true, func = validate_target_url } + } } diff --git a/kong/dao/schemas/consumers.lua b/kong/dao/schemas/consumers.lua index 69fe5183a8f..958c490bb7d 100644 --- a/kong/dao/schemas/consumers.lua +++ b/kong/dao/schemas/consumers.lua @@ -1,5 +1,4 @@ local stringy = require "stringy" -local constants = require "kong.constants" local function check_custom_id_and_username(value, consumer_t) local username = type(consumer_t.username) == "string" and stringy.strip(consumer_t.username) or "" @@ -13,8 +12,12 @@ local function check_custom_id_and_username(value, consumer_t) end return { - id = { type = constants.DATABASE_TYPES.ID }, - custom_id = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username }, - username = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username }, - created_at = { type = constants.DATABASE_TYPES.TIMESTAMP } + name = "Consumer", + primary_key = {"id"}, + fields = { + id = { type = "id", dao_insert_value = true }, + created_at = { type = "timestamp", dao_insert_value = true }, + custom_id = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username }, + username = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username } + } } diff --git a/kong/dao/schemas/plugins_configurations.lua b/kong/dao/schemas/plugins_configurations.lua index 2ac86d2f981..89764d2f059 100644 --- a/kong/dao/schemas/plugins_configurations.lua +++ b/kong/dao/schemas/plugins_configurations.lua @@ -1,4 +1,5 @@ local utils = require "kong.tools.utils" +local DaoError = require "kong.dao.error" local constants = require "kong.constants" local function load_value_schema(plugin_t) @@ -13,11 +14,33 @@ local function load_value_schema(plugin_t) end return { - id = { type = constants.DATABASE_TYPES.ID }, - api_id = { type = constants.DATABASE_TYPES.ID, required = true, foreign = true, queryable = true }, - consumer_id = { type = constants.DATABASE_TYPES.ID, foreign = true, queryable = true, default = constants.DATABASE_NULL_ID }, - name = { type = "string", required = true, queryable = true, immutable = true }, - value = { type = "table", schema = load_value_schema }, - enabled = { type = "boolean", default = true }, - created_at = { type = constants.DATABASE_TYPES.TIMESTAMP } + name = "Plugin configuration", + primary_key = {"id"}, + clustering_key = {"name"}, + fields = { + id = { type = "id", dao_insert_value = true }, + created_at = { type = "timestamp", dao_insert_value = true }, + api_id = { type = "id", required = true, foreign = "apis:id", queryable = true }, + consumer_id = { type = "id", foreign = "consumers:id", queryable = true, default = constants.DATABASE_NULL_ID }, + name = { type = "string", required = true, immutable = true, queryable = true }, + value = { type = "table", schema = load_value_schema }, + enabled = { type = "boolean", default = true } + }, + on_insert = function(plugin_t, dao) + local res, err = dao.plugins_configurations:find_by_keys({ + name = plugin_t.name, + api_id = plugin_t.api_id, + consumer_id = plugin_t.consumer_id + }) + + if err then + return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) + end + + if res and #res > 0 then + return false, DaoError("Plugin configuration already exists", constants.DATABASE_ERROR_TYPES.UNIQUE) + else + return true + end + end } diff --git a/kong/dao/schemas_validation.lua b/kong/dao/schemas_validation.lua index 43bb1ed2ad6..39a6c141242 100644 --- a/kong/dao/schemas_validation.lua +++ b/kong/dao/schemas_validation.lua @@ -1,6 +1,8 @@ local utils = require "kong.tools.utils" local stringy = require "stringy" +local DaoError = require "kong.dao.error" local constants = require "kong.constants" +local error_types = constants.DATABASE_ERROR_TYPES local POSSIBLE_TYPES = { id = true, @@ -12,15 +14,15 @@ local POSSIBLE_TYPES = { timestamp = true } -local types_validation = { - [constants.DATABASE_TYPES.ID] = function(v) return type(v) == "string" end, - [constants.DATABASE_TYPES.TIMESTAMP] = function(v) return type(v) == "number" end, +local custom_types_validation = { + ["id"] = function(v) return type(v) == "string" end, + ["timestamp"] = function(v) return type(v) == "number" end, ["array"] = function(v) return utils.is_array(v) end } local function validate_type(field_type, value) - if types_validation[field_type] then - return types_validation[field_type](value) + if custom_types_validation[field_type] then + return custom_types_validation[field_type](value) end return type(value) == field_type end @@ -30,35 +32,46 @@ local _M = {} -- Validate a table against a given schema -- @param `t` Entity to validate, as a table. -- @param `schema` Schema against which to validate the entity. --- @param `is_update` For an entity update, check immutable fields. Set to true. +-- @param `options` +-- `dao_insert` A function called foe each field with a `dao_insert_value` property. +-- `is_update` For an entity update, check immutable fields. Set to true. -- @return `valid` Success of validation. True or false. -- @return `errors` A list of encountered errors during the validation. -function _M.validate(t, schema, is_update) +function _M.validate_fields(t, schema, options) + if not options then options = {} end local errors - -- Check the given table against a given schema - for column, v in pairs(schema) do - - -- [DEFAULT] Set default value for the field if given - if t[column] == nil and v.default ~= nil then - if type(v.default) == "function" then - t[column] = v.default(t) - else - t[column] = v.default + if not options.partial_update and not options.full_update then + for column, v in pairs(schema.fields) do + -- [DEFAULT] Set default value for the field if given + if t[column] == nil and v.default ~= nil then + if type(v.default) == "function" then + t[column] = v.default(t) + else + t[column] = v.default + end + end + -- [INSERT_VALUE] + if v.dao_insert_value and type(options.dao_insert) == "function" then + t[column] = options.dao_insert(v) end + end + end + + -- Check the given table against a given schema + for column, v in pairs(schema.fields) do -- [IMMUTABLE] check immutability of a field if updating - elseif is_update and t[column] ~= nil and v.immutable and not v.required then + if (options.partial_update or options.full_update) and t[column] ~= nil and v.immutable and not v.required then errors = utils.add_error(errors, column, column.." cannot be updated") end -- [TYPE] Check if type is valid. Boolean and Numbers as strings are accepted and converted - if v.type ~= nil and t[column] ~= nil then + if t[column] ~= nil and v.type ~= nil then local is_valid_type - -- ALIASES: number, timestamp, boolean and array can be passed as strings and will be converted if type(t[column]) == "string" then t[column] = stringy.strip(t[column]) - if v.type == "number" or v .type == constants.DATABASE_TYPES.TIMESTAMP then + if v.type == "number" or v .type == "timestamp" then t[column] = tonumber(t[column]) is_valid_type = t[column] ~= nil elseif v.type == "boolean" then @@ -84,7 +97,7 @@ function _M.validate(t, schema, is_update) end -- [ENUM] Check if the value is allowed in the enum. - if v.enum and t[column] ~= nil then + if t[column] ~= nil and v.enum then local found = false for _, allowed in ipairs(v.enum) do if allowed == t[column] then @@ -105,23 +118,21 @@ function _M.validate(t, schema, is_update) end end - -- [SCHEMA] Validate a sub-schema from a table or retrived by a function + -- [SCHEMA] Validate a sub-schema from a table or retrieved by a function if v.schema then local sub_schema, err if type(v.schema) == "function" then sub_schema, err = v.schema(t) + if err then -- could not retrieve sub schema + errors = utils.add_error(errors, column, err) + end else sub_schema = v.schema end - if err then - -- could not retrieve sub schema - errors = utils.add_error(errors, column, err) - end - if sub_schema then -- Check for sub-schema defaults and required properties in advance - for sub_field_k, sub_field in pairs(sub_schema) do + for sub_field_k, sub_field in pairs(sub_schema.fields) do if t[column] == nil then if sub_field.default then -- Sub-value has a default, be polite and pre-assign the sub-value t[column] = {} @@ -133,7 +144,7 @@ function _M.validate(t, schema, is_update) if t[column] and type(t[column]) == "table" then -- Actually validating the sub-schema - local s_ok, s_errors = _M.validate(t[column], sub_schema, is_update) + local s_ok, s_errors = _M.validate_fields(t[column], sub_schema, options) if not s_ok then for s_k, s_v in pairs(s_errors) do errors = utils.add_error(errors, column.."."..s_k, s_v) @@ -143,28 +154,31 @@ function _M.validate(t, schema, is_update) end end - -- [REQUIRED] Check that required fields are set. Now that default and most other checks - -- have been run. - if v.required and (t[column] == nil or t[column] == "") then - errors = utils.add_error(errors, column, column.." is required") - end + if not options.partial_update or t[column] ~= nil then + -- [REQUIRED] Check that required fields are set. + -- Now that default and most other checks have been run. + if v.required and (t[column] == nil or t[column] == "") then + errors = utils.add_error(errors, column, column.." is required") + end - -- [FUNC] Check field against a custom function only if there is no error on that field already - if v.func and type(v.func) == "function" and (errors == nil or errors[column] == nil) then - local ok, err, new_fields = v.func(t[column], t) - if not ok and err then - errors = utils.add_error(errors, column, err) - elseif new_fields then - for k, v in pairs(new_fields) do - t[k] = v + if type(v.func) == "function" and (errors == nil or errors[column] == nil) then + -- [FUNC] Check field against a custom function + -- only if there is no error on that field already. + local ok, err, new_fields = v.func(t[column], t) + if not ok and err then + errors = utils.add_error(errors, column, err) + elseif new_fields then + for k, v in pairs(new_fields) do + t[k] = v + end end end end end -- Check for unexpected fields in the entity - for k, v in pairs(t) do - if schema[k] == nil then + for k in pairs(t) do + if schema.fields[k] == nil then errors = utils.add_error(errors, k, k.." is an unknown field") end end @@ -172,6 +186,28 @@ function _M.validate(t, schema, is_update) return errors == nil, errors end +function _M.on_insert(t, schema, dao) + if schema.on_insert and type(schema.on_insert) == "function" then + local valid, err = schema.on_insert(t, dao) + if not valid or err then + return false, err + else + return true + end + else + return true + end +end + +function _M.validate(t, dao, options) + local ok, errors + + ok, errors = _M.validate_fields(t, dao._schema, options) + if not ok then + return DaoError(errors, error_types.SCHEMA) + end +end + local digit = "[0-9a-f]" local uuid_pattern = "^"..table.concat({ digit:rep(8), digit:rep(4), digit:rep(4), digit:rep(4), digit:rep(12) }, '%-').."$" function _M.is_valid_uuid(uuid) diff --git a/kong/plugins/basicauth/access.lua b/kong/plugins/basicauth/access.lua index af6e895aac8..8c0b89299a9 100644 --- a/kong/plugins/basicauth/access.lua +++ b/kong/plugins/basicauth/access.lua @@ -93,7 +93,7 @@ function _M.execute(conf) -- Retrieve consumer local consumer = cache.get_or_set(cache.consumer_key(credential.consumer_id), function() - local result, err = dao.consumers:find_one(credential.consumer_id) + local result, err = dao.consumers:find_by_primary_key({ id = credential.consumer_id }) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) end diff --git a/kong/plugins/basicauth/api.lua b/kong/plugins/basicauth/api.lua index d72a37cc7cd..9e06ac92215 100644 --- a/kong/plugins/basicauth/api.lua +++ b/kong/plugins/basicauth/api.lua @@ -41,11 +41,11 @@ return { end, PATCH = function(self, dao_factory) - crud.patch(self.params, dao_factory.basicauth_credentials) + crud.patch(self.params, self.credential, dao_factory.basicauth_credentials) end, DELETE = function(self, dao_factory) - crud.delete(self.credential.id, dao_factory.basicauth_credentials) + crud.delete(self.credential, dao_factory.basicauth_credentials) end } } diff --git a/kong/plugins/basicauth/daos.lua b/kong/plugins/basicauth/daos.lua index 1ff53ca242d..0a18a2682df 100644 --- a/kong/plugins/basicauth/daos.lua +++ b/kong/plugins/basicauth/daos.lua @@ -1,54 +1,21 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local SCHEMA = { - id = { type = "id" }, - consumer_id = { type = "id", required = true, foreign = true, queryable = true }, - username = { type = "string", required = true, unique = true, queryable = true }, - password = { type = "string" }, - created_at = { type = "timestamp" } + primary_key = {"id"}, + fields = { + id = { type = "id", dao_insert_value = true }, + created_at = { type = "timestamp", dao_insert_value = true }, + consumer_id = { type = "id", required = true, foreign = "consumers:id" }, + username = { type = "string", required = true, unique = true, queryable = true }, + password = { type = "string" } + } } local BasicAuthCredentials = BaseDao:extend() function BasicAuthCredentials:new(properties) + self._table = "basicauth_credentials" self._schema = SCHEMA - self._queries = { - insert = { - args_keys = { "id", "consumer_id", "username", "password", "created_at" }, - query = [[ - INSERT INTO basicauth_credentials(id, consumer_id, username, password, created_at) - VALUES(?, ?, ?, ?, ?); - ]] - }, - update = { - args_keys = { "username", "password", "created_at", "id" }, - query = [[ UPDATE basicauth_credentials SET username = ?, password = ?, created_at = ? WHERE id = ?; ]] - }, - select = { - query = [[ SELECT * FROM basicauth_credentials %s; ]] - }, - select_one = { - args_keys = { "id" }, - query = [[ SELECT * FROM basicauth_credentials WHERE id = ?; ]] - }, - delete = { - args_keys = { "id" }, - query = [[ DELETE FROM basicauth_credentials WHERE id = ?; ]] - }, - __foreign = { - consumer_id = { - args_keys = { "consumer_id" }, - query = [[ SELECT id FROM consumers WHERE id = ?; ]] - } - }, - __unique = { - username = { - args_keys = { "username" }, - query = [[ SELECT id FROM basicauth_credentials WHERE username = ?; ]] - } - }, - drop = "TRUNCATE basicauth_credentials;" - } BasicAuthCredentials.super.new(self, properties) end diff --git a/kong/plugins/basicauth/schema.lua b/kong/plugins/basicauth/schema.lua index 3916fb53108..68ec192d395 100644 --- a/kong/plugins/basicauth/schema.lua +++ b/kong/plugins/basicauth/schema.lua @@ -1,3 +1,5 @@ return { - hide_credentials = { type = "boolean", default = false } + fields = { + hide_credentials = { type = "boolean", default = false } + } } diff --git a/kong/plugins/cors/schema.lua b/kong/plugins/cors/schema.lua index 20ccba7bc85..d6cfa8d4b9f 100644 --- a/kong/plugins/cors/schema.lua +++ b/kong/plugins/cors/schema.lua @@ -1,9 +1,11 @@ return { - origin = { type = "string" }, - headers = { type = "string" }, - exposed_headers = { type = "string" }, - methods = { type = "string" }, - max_age = { type = "number" }, - credentials = { type = "boolean", default = false }, - preflight_continue = { type = "boolean", default = false } + fields = { + origin = { type = "string" }, + headers = { type = "string" }, + exposed_headers = { type = "string" }, + methods = { type = "string" }, + max_age = { type = "number" }, + credentials = { type = "boolean", default = false }, + preflight_continue = { type = "boolean", default = false } + } } diff --git a/kong/plugins/filelog/schema.lua b/kong/plugins/filelog/schema.lua index de2798b472e..3d11ae60f88 100644 --- a/kong/plugins/filelog/schema.lua +++ b/kong/plugins/filelog/schema.lua @@ -9,10 +9,12 @@ local function validate_file(value) if not exists then os.remove(value) -- Remove the created file if it didn't exist before end - + return true end return { - path = { required = true, type = "string", func = validate_file } -} \ No newline at end of file + fields = { + path = { required = true, type = "string", func = validate_file } + } +} diff --git a/kong/plugins/httplog/schema.lua b/kong/plugins/httplog/schema.lua index aaba8d57335..c748de838aa 100644 --- a/kong/plugins/httplog/schema.lua +++ b/kong/plugins/httplog/schema.lua @@ -1,6 +1,8 @@ return { - http_endpoint = { required = true, type = "string" }, - method = { default = "POST", enum = { "POST", "PUT", "PATCH" } }, - timeout = { default = 10000, type = "number" }, - keepalive = { default = 60000, type = "number" } + fields = { + http_endpoint = { required = true, type = "string" }, + method = { default = "POST", enum = { "POST", "PUT", "PATCH" } }, + timeout = { default = 10000, type = "number" }, + keepalive = { default = 60000, type = "number" } + } } diff --git a/kong/plugins/keyauth/access.lua b/kong/plugins/keyauth/access.lua index 3a4f5a0a972..e4abd5c68b4 100644 --- a/kong/plugins/keyauth/access.lua +++ b/kong/plugins/keyauth/access.lua @@ -146,7 +146,7 @@ function _M.execute(conf) -- Retrieve consumer local consumer = cache.get_or_set(cache.consumer_key(credential.consumer_id), function() - local result, err = dao.consumers:find_one(credential.consumer_id) + local result, err = dao.consumers:find_by_primary_key({ id = credential.consumer_id }) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) end diff --git a/kong/plugins/keyauth/api.lua b/kong/plugins/keyauth/api.lua index 84aa16503c2..385a9a16f38 100644 --- a/kong/plugins/keyauth/api.lua +++ b/kong/plugins/keyauth/api.lua @@ -30,22 +30,22 @@ return { return helpers.yield_error(err) end - self.plugin = data[1] - if not self.plugin then + self.credential = data[1] + if not self.credential then return helpers.responses.send_HTTP_NOT_FOUND() end end, GET = function(self, dao_factory, helpers) - return helpers.responses.send_HTTP_OK(self.plugin) + return helpers.responses.send_HTTP_OK(self.credential) end, PATCH = function(self, dao_factory) - crud.patch(self.params, dao_factory.keyauth_credentials) + crud.patch(self.params, self.credential, dao_factory.keyauth_credentials) end, DELETE = function(self, dao_factory) - crud.delete(self.plugin.id, dao_factory.keyauth_credentials) + crud.delete(self.credential, dao_factory.keyauth_credentials) end } } diff --git a/kong/plugins/keyauth/daos.lua b/kong/plugins/keyauth/daos.lua index c6e04e08348..2b5f12dd740 100644 --- a/kong/plugins/keyauth/daos.lua +++ b/kong/plugins/keyauth/daos.lua @@ -1,53 +1,20 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local SCHEMA = { - id = { type = "id" }, - consumer_id = { type = "id", required = true, foreign = true, queryable = true }, - key = { type = "string", required = true, unique = true, queryable = true }, - created_at = { type = "timestamp" } + primary_key = {"id"}, + fields = { + id = { type = "id", dao_insert_value = true }, + created_at = { type = "timestamp", dao_insert_value = true }, + consumer_id = { type = "id", required = true, foreign = "consumers:id" }, + key = { type = "string", required = true, unique = true, queryable = true } + } } local KeyAuth = BaseDao:extend() function KeyAuth:new(properties) + self._table = "keyauth_credentials" self._schema = SCHEMA - self._queries = { - insert = { - args_keys = { "id", "consumer_id", "key", "created_at" }, - query = [[ - INSERT INTO keyauth_credentials(id, consumer_id, key, created_at) - VALUES(?, ?, ?, ?); - ]] - }, - update = { - args_keys = { "key", "created_at", "id" }, - query = [[ UPDATE keyauth_credentials SET key = ?, created_at = ? WHERE id = ?; ]] - }, - select = { - query = [[ SELECT * FROM keyauth_credentials %s; ]] - }, - select_one = { - args_keys = { "id" }, - query = [[ SELECT * FROM keyauth_credentials WHERE id = ?; ]] - }, - delete = { - args_keys = { "id" }, - query = [[ DELETE FROM keyauth_credentials WHERE id = ?; ]] - }, - __foreign = { - consumer_id = { - args_keys = { "consumer_id" }, - query = [[ SELECT id FROM consumers WHERE id = ?; ]] - } - }, - __unique = { - key = { - args_keys = { "key" }, - query = [[ SELECT id FROM keyauth_credentials WHERE key = ?; ]] - } - }, - drop = "TRUNCATE keyauth_credentials;" - } KeyAuth.super.new(self, properties) end diff --git a/kong/plugins/keyauth/schema.lua b/kong/plugins/keyauth/schema.lua index 8b57e2f0bb3..b877a6c8c2d 100644 --- a/kong/plugins/keyauth/schema.lua +++ b/kong/plugins/keyauth/schema.lua @@ -5,6 +5,8 @@ local function default_key_names(t) end return { - key_names = { required = true, type = "array", default = default_key_names }, - hide_credentials = { type = "boolean", default = false } + fields = { + key_names = { required = true, type = "array", default = default_key_names }, + hide_credentials = { type = "boolean", default = false } + } } diff --git a/kong/plugins/ratelimiting/daos.lua b/kong/plugins/ratelimiting/daos.lua index 2e909a7cdec..0bec02a918c 100644 --- a/kong/plugins/ratelimiting/daos.lua +++ b/kong/plugins/ratelimiting/daos.lua @@ -5,7 +5,8 @@ local timestamp = require "kong.tools.timestamp" local RateLimitingMetrics = BaseDao:extend() function RateLimitingMetrics:new(properties) - self._queries = { + self._table = "ratelimiting_metrics" + self.queries = { increment_counter = [[ UPDATE ratelimiting_metrics SET value = value + 1 WHERE api_id = ? AND identifier = ? AND period_date = ? AND @@ -17,8 +18,7 @@ function RateLimitingMetrics:new(properties) delete = [[ DELETE FROM ratelimiting_metrics WHERE api_id = ? AND identifier = ? AND period_date = ? AND - period = ?; ]], - drop = "TRUNCATE ratelimiting_metrics;" + period = ?; ]] } RateLimitingMetrics.super.new(self, properties) @@ -29,7 +29,7 @@ function RateLimitingMetrics:increment(api_id, identifier, current_timestamp) local batch = cassandra.BatchStatement(cassandra.batch_types.COUNTER) for period, period_date in pairs(periods) do - batch:add(self._queries.increment_counter, { + batch:add(self.queries.increment_counter, { cassandra.uuid(api_id), identifier, cassandra.timestamp(period_date), @@ -43,7 +43,7 @@ end function RateLimitingMetrics:find_one(api_id, identifier, current_timestamp, period) local periods = timestamp.get_timestamps(current_timestamp) - local metric, err = RateLimitingMetrics.super._execute(self, self._queries.select_one, { + local metric, err = RateLimitingMetrics.super._execute(self, self.queries.select_one, { cassandra.uuid(api_id), identifier, cassandra.timestamp(periods[period]), @@ -60,11 +60,15 @@ function RateLimitingMetrics:find_one(api_id, identifier, current_timestamp, per return metric end +-- Unsuported +function RateLimitingMetrics:find_by_primary_key() + error("ratelimiting_metrics:find_by_primary_key() not yet implemented", 2) +end + function RateLimitingMetrics:delete(api_id, identifier, periods) - error("ratelimiting_metrics:delete() not yet implemented") + error("ratelimiting_metrics:delete() not yet implemented", 2) end --- Unsuported function RateLimitingMetrics:insert() error("ratelimiting_metrics:insert() not supported", 2) end diff --git a/kong/plugins/ratelimiting/schema.lua b/kong/plugins/ratelimiting/schema.lua index 337d60619f1..037458d5ad3 100644 --- a/kong/plugins/ratelimiting/schema.lua +++ b/kong/plugins/ratelimiting/schema.lua @@ -1,6 +1,8 @@ local constants = require "kong.constants" return { - limit = { required = true, type = "number" }, - period = { required = true, type = "string", enum = constants.RATELIMIT.PERIODS } + fields = { + limit = { required = true, type = "number" }, + period = { required = true, type = "string", enum = constants.RATELIMIT.PERIODS } + } } diff --git a/kong/plugins/request_transformer/schema.lua b/kong/plugins/request_transformer/schema.lua index 25c8fc34205..00a02e10939 100644 --- a/kong/plugins/request_transformer/schema.lua +++ b/kong/plugins/request_transformer/schema.lua @@ -1,16 +1,22 @@ return { - add = { type = "table", + fields = { + add = { type = "table", schema = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" } + fields = { + form = { type = "array" }, + headers = { type = "array" }, + querystring = { type = "array" } + } } - }, - remove = { type = "table", - schema = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" } + }, + remove = { type = "table", + schema = { + fields = { + form = { type = "array" }, + headers = { type = "array" }, + querystring = { type = "array" } + } + } } } } diff --git a/kong/plugins/requestsizelimiting/schema.lua b/kong/plugins/requestsizelimiting/schema.lua index e3370087de8..affd13e97f3 100644 --- a/kong/plugins/requestsizelimiting/schema.lua +++ b/kong/plugins/requestsizelimiting/schema.lua @@ -1,3 +1,5 @@ return { - allowed_payload_size = { default = 128, type = "number" } + fields = { + allowed_payload_size = { default = 128, type = "number" } + } } diff --git a/kong/plugins/response_transformer/schema.lua b/kong/plugins/response_transformer/schema.lua index 9c110d14819..db024d57f48 100644 --- a/kong/plugins/response_transformer/schema.lua +++ b/kong/plugins/response_transformer/schema.lua @@ -1,12 +1,18 @@ return { - add = { type = "table", schema = { - json = { type = "array" }, - headers = { type = "array" } - } - }, - remove = { type = "table", schema = { - json = { type = "array" }, - headers = { type = "array" } + fields = { + add = { type = "table", schema = { + fields = { + json = { type = "array" }, + headers = { type = "array" } + } + } + }, + remove = { type = "table", schema = { + fields = { + json = { type = "array" }, + headers = { type = "array" } + } + } } } } diff --git a/kong/plugins/ssl/schema.lua b/kong/plugins/ssl/schema.lua index 225fa557059..f3a329a23a8 100644 --- a/kong/plugins/ssl/schema.lua +++ b/kong/plugins/ssl/schema.lua @@ -18,11 +18,13 @@ local function validate_key(v) end return { - cert = { required = true, type = "string", func = validate_cert }, - key = { required = true, type = "string", func = validate_key }, - only_https = { required = false, type = "boolean", default = false }, + fields = { + cert = { required = true, type = "string", func = validate_cert }, + key = { required = true, type = "string", func = validate_key }, + only_https = { required = false, type = "boolean", default = false }, - -- Internal use - _cert_der_cache = { type = "string", immutable = true }, - _key_der_cache = { type = "string", immutable = true } + -- Internal use + _cert_der_cache = { type = "string", immutable = true }, + _key_der_cache = { type = "string", immutable = true } + } } diff --git a/kong/plugins/tcplog/schema.lua b/kong/plugins/tcplog/schema.lua index 76847269b23..8186fe17b79 100644 --- a/kong/plugins/tcplog/schema.lua +++ b/kong/plugins/tcplog/schema.lua @@ -1,6 +1,8 @@ return { - host = { required = true, type = "string" }, - port = { required = true, type = "number" }, - timeout = { default = 10000, type = "number" }, - keepalive = { default = 60000, type = "number" } + fields = { + host = { required = true, type = "string" }, + port = { required = true, type = "number" }, + timeout = { default = 10000, type = "number" }, + keepalive = { default = 60000, type = "number" } + } } diff --git a/kong/plugins/udplog/schema.lua b/kong/plugins/udplog/schema.lua index 8d5bdae100f..c24a7800b13 100644 --- a/kong/plugins/udplog/schema.lua +++ b/kong/plugins/udplog/schema.lua @@ -1,5 +1,7 @@ return { - host = { required = true, type = "string" }, - port = { required = true, type = "number" }, - timeout = { default = 10000, type = "number" } + fields = { + host = { required = true, type = "string" }, + port = { required = true, type = "number" }, + timeout = { default = 10000, type = "number" } + } } diff --git a/kong/tools/faker.lua b/kong/tools/faker.lua index 6ff30653c1b..23daf5f3618 100644 --- a/kong/tools/faker.lua +++ b/kong/tools/faker.lua @@ -74,7 +74,7 @@ function Faker:insert_from_table(entities_to_insert) end -- Insert in DB - local dao_type = type=="plugin_configuration" and "plugins_configurations" or type.."s" + local dao_type = type == "plugin_configuration" and "plugins_configurations" or type.."s" local res, err = self.dao_factory[dao_type]:insert(entity) if err then local printable_mt = require "kong.tools.printable" diff --git a/kong/tools/migrations.lua b/kong/tools/migrations.lua index 9c1ac5c6406..341c61ac3e2 100644 --- a/kong/tools/migrations.lua +++ b/kong/tools/migrations.lua @@ -18,7 +18,7 @@ end -- Createa migration interface for each database available function Migrations:create(configuration, name, callback) - for k, _ in pairs(configuration.databases_available) do + for k in pairs(configuration.databases_available) do local date_str = os.date("%Y-%m-%d-%H%M%S") local file_path = IO.path:join(self.migrations_path, k) local file_name = date_str.."_"..name @@ -87,7 +87,11 @@ function Migrations:migrate(callback) -- Execute all new migrations, in order for _, file_path in ipairs(diff_migrations) do -- Load our migration script - local migration = loadfile(file_path)() + local migration_file = loadfile(file_path) + if not migration_file then + error("Migration failed: cannot load file at "..file_path) + end + local migration = migration_file() -- Generate UP query from string + options local up_query = migration.up(self.options) diff --git a/spec/integration/admin_api/admin_api_spec.lua b/spec/integration/admin_api/admin_api_spec.lua index e15452809ee..164e66b8282 100644 --- a/spec/integration/admin_api/admin_api_spec.lua +++ b/spec/integration/admin_api/admin_api_spec.lua @@ -405,7 +405,8 @@ describe("Admin API", function() end) it("should update the entity if a full body is given", function() - local data = http_client.get(base_url.."/"..CREATED_IDS[endpoint.collection]) + local data, status = http_client.get(base_url.."/"..CREATED_IDS[endpoint.collection]) + assert.are.equal(200, status) local body = json.decode(data) -- Create new body diff --git a/spec/integration/admin_api/apis_routes_spec.lua b/spec/integration/admin_api/apis_routes_spec.lua index 2570387d312..2f49592a272 100644 --- a/spec/integration/admin_api/apis_routes_spec.lua +++ b/spec/integration/admin_api/apis_routes_spec.lua @@ -232,7 +232,7 @@ describe("Admin API", function() assert.equal(201, status) local body = json.decode(response) - local _, err = dao_plugins:delete(body.id) + local _, err = dao_plugins:delete({id = body.id, name = body.name}) assert.falsy(err) response, status = http_client.post(BASE_URL, { @@ -242,7 +242,7 @@ describe("Admin API", function() assert.equal(201, status) body = json.decode(response) - _, err = dao_plugins:delete(body.id) + _, err = dao_plugins:delete({id = body.id, name = body.name}) assert.falsy(err) end) @@ -263,7 +263,7 @@ describe("Admin API", function() assert.equal(201, status) local body = json.decode(response) - local _, err = dao_plugins:delete(body.id) + local _, err = dao_plugins:delete({id = body.id, name = body.name}) assert.falsy(err) response, status = http_client.put(BASE_URL, { @@ -274,7 +274,7 @@ describe("Admin API", function() body = json.decode(response) response, status = http_client.put(BASE_URL, { - id=body.id, + id = body.id, name = "keyauth", value = {key_names={"updated_apikey"}} }, {["content-type"]="application/json"}) @@ -358,7 +358,7 @@ describe("Admin API", function() assert.equal(404, status) end) - it("[SUCCESS] should delete an API", function() + it("[SUCCESS] should delete a plugin configuration", function() local response, status = http_client.delete(BASE_URL..plugin.id) assert.equal(204, status) assert.falsy(response) diff --git a/spec/integration/admin_api/consumers_routes_spec.lua b/spec/integration/admin_api/consumers_routes_spec.lua index 0ecb7b82ebf..e6c8efcb8d6 100644 --- a/spec/integration/admin_api/consumers_routes_spec.lua +++ b/spec/integration/admin_api/consumers_routes_spec.lua @@ -21,8 +21,8 @@ describe("Admin API", function() it("[SUCCESS] should create a Consumer", function() send_content_types(BASE_URL, "POST", { - username="consumer POST tests" - }, 201, nil, {drop_db=true}) + username = "consumer POST tests" + }, 201, nil, {drop_db = true}) end) it("[FAILURE] should return proper errors", function() @@ -31,7 +31,7 @@ describe("Admin API", function() '{"custom_id":"At least a \'custom_id\' or a \'username\' must be specified","username":"At least a \'custom_id\' or a \'username\' must be specified"}') send_content_types(BASE_URL, "POST", { - username="consumer POST tests" + username = "consumer POST tests" }, 409, '{"username":"username already exists with value \'consumer POST tests\'"}') end) @@ -41,12 +41,12 @@ describe("Admin API", function() it("[SUCCESS] should create and update", function() local consumer = send_content_types(BASE_URL, "PUT", { - username="consumer PUT tests" + username = "consumer PUT tests" }, 201, nil, {drop_db=true}) consumer = send_content_types(BASE_URL, "PUT", { - id=consumer.id, - username="consumer PUT tests updated", + id = consumer.id, + username = "consumer PUT tests updated", }, 200) assert.equal("consumer PUT tests updated", consumer.username) end) @@ -57,7 +57,7 @@ describe("Admin API", function() '{"custom_id":"At least a \'custom_id\' or a \'username\' must be specified","username":"At least a \'custom_id\' or a \'username\' must be specified"}') send_content_types(BASE_URL, "PUT", { - username="consumer PUT tests updated", + username = "consumer PUT tests updated", }, 409, '{"username":"username already exists with value \'consumer PUT tests updated\'"}') end) @@ -112,7 +112,7 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - consumer = {{ username="get_consumer_tests" }} + consumer = {{username = "get_consumer_tests"}} } consumer = fixtures.consumer[1] end) @@ -157,9 +157,9 @@ describe("Admin API", function() local _, status = http_client.patch(BASE_URL.."hello", {username="patch-updated"}) assert.equal(404, status) - local response, status = http_client.patch(BASE_URL..consumer.id, {username=""}) + local response, status = http_client.patch(BASE_URL..consumer.id, {username=" "}) assert.equal(400, status) - assert.equal('{"custom_id":"At least a \'custom_id\' or a \'username\' must be specified","username":"username is not a string"}\n', response) + assert.equal('{"username":"At least a \'custom_id\' or a \'username\' must be specified"}\n', response) end) end) diff --git a/spec/integration/cli/restart_spec.lua b/spec/integration/cli/restart_spec.lua index 4937dd390a3..ae580ce7a2f 100644 --- a/spec/integration/cli/restart_spec.lua +++ b/spec/integration/cli/restart_spec.lua @@ -19,16 +19,16 @@ describe("CLI", function() it("should restart kong when it's running", function() local _, code = spec_helper.stop_kong() assert.are.same(0, code) - local _, code = spec_helper.start_kong() + _, code = spec_helper.start_kong() assert.are.same(0, code) - local _, code = spec_helper.restart_kong() + _, code = spec_helper.restart_kong() assert.are.same(0, code) end) it("should restart kong when it's crashed", function() local kong_pid = IO.read_file(spec_helper.get_env().configuration.pid_file) os.execute("pkill -9 nginx") - while os.execute("kill -0 "..kong_pid) == 0 do + while os.execute("kill -0 "..kong_pid.." ") == 0 do -- Wait till it's really over end diff --git a/spec/plugins/basicauth/api_spec.lua b/spec/plugins/basicauth/api_spec.lua index b3cd2e520a5..d00fa5cdce1 100644 --- a/spec/plugins/basicauth/api_spec.lua +++ b/spec/plugins/basicauth/api_spec.lua @@ -43,7 +43,7 @@ describe("Basic Auth Credentials API", function() describe("PUT", function() setup(function() - spec_helper.get_env().dao_factory.basicauth_credentials:delete(credential.id) + spec_helper.get_env().dao_factory.basicauth_credentials:delete({id = credential.id}) end) it("[SUCCESS] should create and update", function() diff --git a/spec/plugins/keyauth/api_spec.lua b/spec/plugins/keyauth/api_spec.lua index 64a3097a2de..443866c84e6 100644 --- a/spec/plugins/keyauth/api_spec.lua +++ b/spec/plugins/keyauth/api_spec.lua @@ -43,7 +43,7 @@ describe("Basic Auth Credentials API", function() describe("PUT", function() setup(function() - spec_helper.get_env().dao_factory.keyauth_credentials:delete(credential.id) + spec_helper.get_env().dao_factory.keyauth_credentials:delete({id = credential.id}) end) it("[SUCCESS] should create and update", function() diff --git a/spec/plugins/keyauth/daos_spec.lua b/spec/plugins/keyauth/daos_spec.lua new file mode 100644 index 00000000000..f0d0d70f952 --- /dev/null +++ b/spec/plugins/keyauth/daos_spec.lua @@ -0,0 +1,60 @@ +local spec_helper = require "spec.spec_helpers" +local uuid = require "uuid" + +local env = spec_helper.get_env() +local dao_factory = env.dao_factory +local faker = env.faker + +describe("DAO keyauth Credentials", function() + + setup(function() + spec_helper.prepare_db() + end) + + it("should not insert in DB if consumer does not exist", function() + -- Without a consumer_id, it's a schema error + local app_t = {name = "keyauth", value = {key_names = {"apikey"}}} + local app, err = dao_factory.keyauth_credentials:insert(app_t) + assert.falsy(app) + assert.truthy(err) + assert.True(err.schema) + assert.are.same("consumer_id is required", err.message.consumer_id) + + -- With an invalid consumer_id, it's a FOREIGN error + local app_t = {key = "apikey123", consumer_id = uuid()} + local app, err = dao_factory.keyauth_credentials:insert(app_t) + assert.falsy(app) + assert.truthy(err) + assert.True(err.foreign) + assert.equal("consumer_id "..app_t.consumer_id.." does not exist", err.message.consumer_id) + end) + + it("should insert in DB and add generated values", function() + local consumer_t = faker:fake_entity("consumer") + local consumer, err = dao_factory.consumers:insert(consumer_t) + assert.falsy(err) + + local cred_t = {key = "apikey123", consumer_id = consumer.id} + local app, err = dao_factory.keyauth_credentials:insert(cred_t) + assert.falsy(err) + assert.truthy(app.id) + assert.truthy(app.created_at) + end) + + it("should find a Credential by public_key", function() + local app, err = dao_factory.keyauth_credentials:find_by_keys { + key = "user122" + } + assert.falsy(err) + assert.truthy(app) + end) + + it("should handle empty strings", function() + local apps, err = dao_factory.keyauth_credentials:find_by_keys { + key = "" + } + assert.falsy(err) + assert.same({}, apps) + end) + +end) diff --git a/spec/plugins/ratelimiting_spec.lua b/spec/plugins/ratelimiting/access_spec.lua similarity index 100% rename from spec/plugins/ratelimiting_spec.lua rename to spec/plugins/ratelimiting/access_spec.lua diff --git a/spec/plugins/ratelimiting/daos_spec.lua b/spec/plugins/ratelimiting/daos_spec.lua new file mode 100644 index 00000000000..2f71c82289a --- /dev/null +++ b/spec/plugins/ratelimiting/daos_spec.lua @@ -0,0 +1,106 @@ +local spec_helper = require "spec.spec_helpers" +local timestamp = require "kong.tools.timestamp" +local uuid = require "uuid" + +local env = spec_helper.get_env() +local dao_factory = env.dao_factory +local ratelimiting_metrics = dao_factory.ratelimiting_metrics + +describe("Rate Limiting Metrics", function() + local api_id = uuid() + local identifier = uuid() + + after_each(function() + spec_helper.drop_db() + end) + + it("should return nil when ratelimiting metrics are not existing", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + -- Very first select should return nil + for period, period_date in pairs(periods) do + local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) + assert.falsy(err) + assert.are.same(nil, metric) + end + end) + + it("should increment ratelimiting metrics with the given period", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + + -- First increment + local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) + assert.falsy(err) + assert.True(ok) + + -- First select + for period, period_date in pairs(periods) do + local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) + assert.falsy(err) + assert.are.same({ + api_id = api_id, + identifier = identifier, + period = period, + period_date = period_date, + value = 1 -- The important part + }, metric) + end + + -- Second increment + local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) + assert.falsy(err) + assert.True(ok) + + -- Second select + for period, period_date in pairs(periods) do + local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) + assert.falsy(err) + assert.are.same({ + api_id = api_id, + identifier = identifier, + period = period, + period_date = period_date, + value = 2 -- The important part + }, metric) + end + + -- 1 second delay + current_timestamp = 1424217601 + periods = timestamp.get_timestamps(current_timestamp) + + -- Third increment + local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) + assert.falsy(err) + assert.True(ok) + + -- Third select with 1 second delay + for period, period_date in pairs(periods) do + + local expected_value = 3 + + if period == "second" then + expected_value = 1 + end + + local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) + assert.falsy(err) + assert.are.same({ + api_id = api_id, + identifier = identifier, + period = period, + period_date = period_date, + value = expected_value -- The important part + }, metric) + end + end) + + it("should throw errors for non supported methods of the base_dao", function() + assert.has_error(ratelimiting_metrics.find, "ratelimiting_metrics:find() not supported") + assert.has_error(ratelimiting_metrics.insert, "ratelimiting_metrics:insert() not supported") + assert.has_error(ratelimiting_metrics.update, "ratelimiting_metrics:update() not supported") + assert.has_error(ratelimiting_metrics.delete, "ratelimiting_metrics:delete() not yet implemented") + assert.has_error(ratelimiting_metrics.find_by_keys, "ratelimiting_metrics:find_by_keys() not supported") + end) + +end) -- describe rate limiting metrics diff --git a/spec/unit/dao/cassandra/base_dao_spec.lua b/spec/unit/dao/cassandra/base_dao_spec.lua new file mode 100644 index 00000000000..0052a069d1a --- /dev/null +++ b/spec/unit/dao/cassandra/base_dao_spec.lua @@ -0,0 +1,720 @@ +local spec_helper = require "spec.spec_helpers" +local cassandra = require "cassandra" +local constants = require "kong.constants" +local DaoError = require "kong.dao.error" +local utils = require "kong.tools.utils" +local cjson = require "cjson" +local uuid = require "uuid" + +-- Raw session for double-check purposes +local session +-- Load everything we need from the spec_helper +local env = spec_helper.get_env() -- test environment +local faker = env.faker +local dao_factory = env.dao_factory +local configuration = env.configuration +configuration.cassandra = configuration.databases_available[configuration.database].properties + +-- An utility function to apply tests on core collections. +local function describe_core_collections(tests_cb) + for type, dao in pairs({ api = dao_factory.apis, + consumer = dao_factory.consumers }) do + local collection = type == "plugin_configuration" and "plugins_configurations" or type.."s" + describe(collection, function() + tests_cb(type, collection) + end) + end +end + +-- An utility function to test if an object is a DaoError. +-- Naming is due to luassert extensibility's restrictions +local function daoError(state, arguments) + local stub_err = DaoError("", "") + return getmetatable(stub_err) == getmetatable(arguments[1]) +end + +local say = require("say") +say:set("assertion.daoError.positive", "Expected %s\nto be a DaoError") +say:set("assertion.daoError.negative", "Expected %s\nto not be a DaoError") +assert:register("assertion", "daoError", daoError, "assertion.daoError.positive", "assertion.daoError.negative") + +-- Let's go +describe("Cassandra", function() + + setup(function() + spec_helper.prepare_db() + + -- Create a parallel session to verify the dao's behaviour + session = cassandra.new() + session:set_timeout(configuration.cassandra.timeout) + + local _, err = session:connect(configuration.cassandra.hosts, configuration.cassandra.port) + assert.falsy(err) + + local _, err = session:set_keyspace("kong_tests") + assert.falsy(err) + end) + + teardown(function() + if session then + local _, err = session:close() + assert.falsy(err) + end + end) + + describe("Base DAO", function() + describe(":insert()", function() + + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory.apis:insert() + end, "Cannot insert a nil element") + + assert.has_error(function() + dao_factory.apis:insert("") + end, "Entity to insert must be a table") + end) + + it("should insert in DB and let the schema validation add generated values", function() + -- API + local api_t = faker:fake_entity("api") + local api, err = dao_factory.apis:insert(api_t) + assert.falsy(err) + assert.truthy(api.id) + assert.truthy(api.created_at) + local apis, err = session:execute("SELECT * FROM apis") + assert.falsy(err) + assert.True(#apis > 0) + assert.equal(api.id, apis[1].id) + + -- API + api, err = dao_factory.apis:insert { + public_dns = "test.com", + target_url = "http://mockbin.com" + } + assert.falsy(err) + assert.truthy(api.name) + assert.equal("test.com", api.name) + + -- Consumer + local consumer_t = faker:fake_entity("consumer") + local consumer, err = dao_factory.consumers:insert(consumer_t) + assert.falsy(err) + assert.truthy(consumer.id) + assert.truthy(consumer.created_at) + local consumers, err = session:execute("SELECT * FROM consumers") + assert.falsy(err) + assert.True(#consumers > 0) + assert.equal(consumer.id, consumers[1].id) + + -- Plugin configuration + local plugin_t = { name = "keyauth", api_id = api.id, consumer_id = consumer.id } + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(err) + assert.truthy(plugin) + assert.truthy(plugin.consumer_id) + local plugins, err = session:execute("SELECT * FROM plugins_configurations") + assert.falsy(err) + assert.True(#plugins > 0) + assert.equal(plugin.id, plugins[1].id) + end) + + it("should let the schema validation return errors and not insert", function() + -- Without an api_id, it's a schema error + local plugin_t = faker:fake_entity("plugin_configuration") + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(plugin) + assert.truthy(err) + assert.is_daoError(err) + assert.True(err.schema) + assert.are.same("api_id is required", err.message.api_id) + end) + + it("should ensure fields with `unique` are unique", function() + local api_t = faker:fake_entity("api") + + -- Success + local _, err = dao_factory.apis:insert(api_t) + assert.falsy(err) + + -- Failure + local api, err = dao_factory.apis:insert(api_t) + assert.truthy(err) + assert.is_daoError(err) + assert.True(err.unique) + assert.are.same("name already exists with value '"..api_t.name.."'", err.message.name) + assert.falsy(api) + end) + + it("should ensure fields with `foreign` are existing", function() + -- Plugin configuration + local plugin_t = faker:fake_entity("plugin_configuration") + plugin_t.api_id = uuid() + plugin_t.consumer_id = uuid() + + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(plugin) + assert.truthy(err) + assert.is_daoError(err) + assert.True(err.foreign) + assert.are.same("api_id "..plugin_t.api_id.." does not exist", err.message.api_id) + assert.are.same("consumer_id "..plugin_t.consumer_id.." does not exist", err.message.consumer_id) + end) + + it("should do insert checks for entities with `on_insert`", function() + local api, err = dao_factory.apis:insert(faker:fake_entity("api")) + assert.falsy(err) + assert.truthy(api.id) + + local consumer, err = dao_factory.consumers:insert(faker:fake_entity("consumer")) + assert.falsy(err) + assert.truthy(consumer.id) + + local plugin_t = faker:fake_entity("plugin_configuration") + plugin_t.api_id = api.id + plugin_t.consumer_id = consumer.id + + -- Success: plugin doesn't exist yet + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(err) + assert.truthy(plugin) + + -- Failure: the same plugin is already inserted + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(plugin) + assert.truthy(err) + assert.is_daoError(err) + assert.True(err.unique) + assert.are.same("Plugin configuration already exists", err.message) + end) + + end) -- describe :insert() + + describe(":update()", function() + + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory.apis:update() + end, "Cannot update a nil element") + + assert.has_error(function() + dao_factory.apis:update("") + end, "Entity to update must be a table") + end) + + it("should return nil and no error if no entity was found to update in DB", function() + local api_t = faker:fake_entity("api") + api_t.id = uuid() + + -- No entity to update + local entity, err = dao_factory.apis:update(api_t) + assert.falsy(entity) + assert.falsy(err) + end) + + it("should consider no entity to be found if an empty table is given to it", function() + local api, err = dao_factory.apis:update({}) + assert.falsy(err) + assert.falsy(api) + end) + + it("should update specified, non-primary fields in DB", function() + -- API + local apis, err = session:execute("SELECT * FROM apis") + assert.falsy(err) + assert.True(#apis > 0) + + local api_t = apis[1] + api_t.name = api_t.name.." updated" + + local api, err = dao_factory.apis:update(api_t) + assert.falsy(err) + assert.truthy(api) + + apis, err = session:execute("SELECT * FROM apis WHERE name = ?", {api_t.name}) + assert.falsy(err) + assert.equal(1, #apis) + assert.equal(api_t.id, apis[1].id) + assert.equal(api_t.name, apis[1].name) + assert.equal(api_t.public_dns, apis[1].public_dns) + assert.equal(api_t.target_url, apis[1].target_url) + + -- Consumer + local consumers, err = session:execute("SELECT * FROM consumers") + assert.falsy(err) + assert.True(#consumers > 0) + + local consumer_t = consumers[1] + consumer_t.custom_id = consumer_t.custom_id.."updated" + + local consumer, err = dao_factory.consumers:update(consumer_t) + assert.falsy(err) + assert.truthy(consumer) + + consumers, err = session:execute("SELECT * FROM consumers WHERE custom_id = ?", {consumer_t.custom_id}) + assert.falsy(err) + assert.equal(1, #consumers) + assert.equal(consumer_t.name, consumers[1].name) + + -- Plugin Configuration + local plugins, err = session:execute("SELECT * FROM plugins_configurations") + assert.falsy(err) + assert.True(#plugins > 0) + + local plugin_t = plugins[1] + plugin_t.value = cjson.decode(plugin_t.value) + plugin_t.enabled = false + local plugin, err = dao_factory.plugins_configurations:update(plugin_t) + assert.falsy(err) + assert.truthy(plugin) + + plugins, err = session:execute("SELECT * FROM plugins_configurations WHERE id = ?", {cassandra.uuid(plugin_t.id)}) + assert.falsy(err) + assert.equal(1, #plugins) + end) + + it("should ensure fields with `unique` are unique", function() + local apis, err = session:execute("SELECT * FROM apis") + assert.falsy(err) + assert.True(#apis > 0) + + local api_t = apis[1] + -- Should not work because we're reusing a public_dns + api_t.public_dns = apis[2].public_dns + + local api, err = dao_factory.apis:update(api_t) + assert.truthy(err) + assert.falsy(api) + assert.is_daoError(err) + assert.True(err.unique) + assert.equal("public_dns already exists with value '"..api_t.public_dns.."'", err.message.public_dns) + end) + + describe("full", function() + + it("should set to NULL if a field is not specified", function() + local api_t = faker:fake_entity("api") + api_t.path = "/path" + + local api, err = dao_factory.apis:insert(api_t) + assert.falsy(err) + assert.truthy(api_t.path) + + -- Update + api.path = nil + api, err = dao_factory.apis:update(api, true) + assert.falsy(err) + assert.truthy(api) + assert.falsy(api.path) + + -- Check update + api, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) + assert.falsy(err) + assert.falsy(api.path) + end) + + it("should still check the validity of the schema", function() + local api_t = faker:fake_entity("api") + + local api, err = dao_factory.apis:insert(api_t) + assert.falsy(err) + assert.truthy(api_t) + + -- Update + api.public_dns = nil + + local nil_api, err = dao_factory.apis:update(api, true) + assert.truthy(err) + assert.falsy(nil_api) + + -- Check update failed + api, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) + assert.falsy(err) + assert.truthy(api[1].name) + assert.truthy(api[1].public_dns) + end) + + end) + end) -- describe :update() + + describe(":find_by_keys()", function() + describe_core_collections(function(type, collection) + + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory[collection]:find_by_keys("") + end, "where_t must be a table") + end) + + it("should handle empty search fields", function() + local results, err = dao_factory[collection]:find_by_keys({}) + assert.falsy(err) + assert.truthy(results) + assert.True(#results > 0) + end) + + it("should handle nil search fields", function() + local results, err = dao_factory[collection]:find_by_keys(nil) + assert.falsy(err) + assert.truthy(results) + assert.True(#results > 0) + end) + end) + + it("should query an entity from the given fields and return if filtering was needed", function() + -- Filtering needed + local apis, err = session:execute("SELECT * FROM apis") + assert.falsy(err) + assert.True(#apis > 0) + + local api_t = apis[1] + local apis, err, needs_filtering = dao_factory.apis:find_by_keys(api_t) + assert.falsy(err) + assert.same(api_t, apis[1]) + assert.True(needs_filtering) + + -- No Filtering needed + apis, err, needs_filtering = dao_factory.apis:find_by_keys {public_dns = api_t.public_dns} + assert.falsy(err) + assert.same(api_t, apis[1]) + assert.False(needs_filtering) + end) + + end) -- describe :find_by_keys() + + describe(":find()", function() + + setup(function() + spec_helper.drop_db() + spec_helper.seed_db(10) + end) + + describe_core_collections(function(type, collection) + + it("should find entities", function() + local entities, err = session:execute("SELECT * FROM "..collection) + assert.falsy(err) + assert.truthy(entities) + assert.True(#entities > 0) + + local results, err = dao_factory[collection]:find() + assert.falsy(err) + assert.truthy(results) + assert.same(#entities, #results) + end) + + it("should allow pagination", function() + -- 1st page + local rows_1, err = dao_factory[collection]:find(2) + assert.falsy(err) + assert.truthy(rows_1) + assert.same(2, #rows_1) + assert.truthy(rows_1.next_page) + + -- 2nd page + local rows_2, err = dao_factory[collection]:find(2, rows_1.next_page) + assert.falsy(err) + assert.truthy(rows_2) + assert.same(2, #rows_2) + end) + + end) + end) -- describe :find() + + describe(":find_by_primary_key()", function() + describe_core_collections(function(type, collection) + + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory[collection]:find_by_primary_key("") + end, "where_t must be a table") + end) + + it("should return nil (not found) if where_t is empty", function() + local res, err = dao_factory[collection]:find_by_primary_key({}) + assert.falsy(err) + assert.falsy(res) + end) + + end) + + it("should find one entity by its primary key", function() + local apis, err = session:execute("SELECT * FROM apis") + assert.falsy(err) + assert.True(#apis > 0) + + local api, err = dao_factory.apis:find_by_primary_key { id = apis[1].id } + assert.falsy(err) + assert.truthy(apis) + assert.same(apis[1], api) + end) + + it("should handle an invalid uuid value", function() + local apis, err = dao_factory.apis:find_by_primary_key { id = "abcd" } + assert.falsy(apis) + assert.True(err.invalid_type) + assert.equal("abcd is an invalid uuid", err.message.id) + end) + + describe("plugin_configurations", function() + + setup(function() + local fixtures = spec_helper.seed_db(1) + faker:insert_from_table { + plugin_configuration = { + { name = "keyauth", value = {key_names = {"apikey"}}, api_id = fixtures.api[1].id } + } + } + end) + + it("should unmarshall the `value` field", function() + local plugins, err = session:execute("SELECT * FROM plugins_configurations") + assert.falsy(err) + assert.truthy(plugins) + assert.True(#plugins> 0) + + local plugin_t = plugins[1] + + local plugin, err = dao_factory.plugins_configurations:find_by_primary_key { + id = plugin_t.id, + name = plugin_t.name + } + assert.falsy(err) + assert.truthy(plugin) + assert.equal("table", type(plugin.value)) + end) + + end) + end) -- describe :find_by_primary_key() + + describe(":delete()", function() + + describe_core_collections(function(type, collection) + + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory[collection]:delete("") + end, "where_t must be a table") + end) + + it("should return false if entity to delete wasn't found", function() + local ok, err = dao_factory[collection]:delete({id = uuid()}) + assert.falsy(err) + assert.False(ok) + end) + + it("should delete an entity based on its primary key", function() + local entities, err = session:execute("SELECT * FROM "..collection) + assert.falsy(err) + assert.truthy(entities) + assert.True(#entities > 0) + + local ok, err = dao_factory[collection]:delete(entities[1]) + assert.falsy(err) + assert.True(ok) + + local entities, err = session:execute("SELECT * FROM "..collection.." WHERE id = ?", {cassandra.uuid(entities[1].id)}) + assert.falsy(err) + assert.truthy(entities) + assert.are.same(0, #entities) + end) + + end) + + describe("APIs", function() + local api, untouched_api + + setup(function() + spec_helper.drop_db() + local fixtures = spec_helper.insert_fixtures { + api = { + { name = "cascade delete", + public_dns = "mockbin.com", + target_url = "http://mockbin.com" }, + { name = "untouched cascade delete", + public_dns = "untouched.com", + target_url = "http://mockbin.com" } + }, + plugin_configuration = { + {name = "keyauth", __api = 1}, + {name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 1}, + {name = "filelog", value = {path = "/tmp/spec.log" }, __api = 1}, + + {name = "keyauth", __api = 2} + } + } + api = fixtures.api[1] + untouched_api = fixtures.api[2] + end) + + teardown(function() + spec_helper.drop_db() + end) + + it("should delete all related plugins_configurations when deleting an API", function() + local ok, err = dao_factory.apis:delete(api) + assert.falsy(err) + assert.True(ok) + + -- Make sure we have 0 matches + local results, err = dao_factory.plugins_configurations:find_by_keys { + api_id = api.id + } + assert.falsy(err) + assert.equal(0, #results) + + -- Make sure the untouched API still has its plugins + results, err = dao_factory.plugins_configurations:find_by_keys { + api_id = untouched_api.id + } + assert.falsy(err) + assert.equal(1, #results) + end) + + end) + + describe("Consumers", function() + local consumer, untouched_consumer + + setup(function() + spec_helper.drop_db() + local fixtures = spec_helper.insert_fixtures { + api = { + { name = "cascade delete", + public_dns = "mockbin.com", + target_url = "http://mockbin.com" } + }, + consumer = { + {username = "king kong"}, + {username = "untouched consumer"} + }, + plugin_configuration = { + {name = "keyauth", __api = 1, __consumer = 1}, + {name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 1, __consumer = 1}, + {name = "filelog", value = {path = "/tmp/spec.log" }, __api = 1, __consumer = 1}, + + {name = "keyauth", __api = 1, __consumer = 2} + } + } + consumer = fixtures.consumer[1] + untouched_consumer = fixtures.consumer[2] + end) + + teardown(function() + spec_helper.drop_db() + end) + + it("should delete all related plugins_configurations when deleting a Consumer", function() + local ok, err = dao_factory.consumers:delete(consumer) + assert.True(ok) + assert.falsy(err) + + local results, err = dao_factory.plugins_configurations:find_by_keys { + consumer_id = consumer.id + } + assert.falsy(err) + assert.are.same(0, #results) + + -- Make sure the untouched Consumer still has its plugin + results, err = dao_factory.plugins_configurations:find_by_keys { + consumer_id = untouched_consumer.id + } + assert.falsy(err) + assert.are.same(1, #results) + end) + + end) + end) -- describe :delete() + + -- + -- APIs additional behaviour + -- + + describe("APIs", function() + + setup(function() + spec_helper.seed_db(100) + end) + + describe(":find_all()", function() + local apis, err = dao_factory.apis:find_all() + assert.falsy(err) + assert.truthy(apis) + assert.equal(100, #apis) + end) + end) + + -- + -- Plugins configuration additional behaviour + -- + + describe("plugin_configurations", function() + describe(":find_distinct()", function() + it("should find distinct plugins configurations", function() + faker:insert_from_table { + api = { + { name = "tests distinct 1", public_dns = "foo.com", target_url = "http://mockbin.com" }, + { name = "tests distinct 2", public_dns = "bar.com", target_url = "http://mockbin.com" } + }, + plugin_configuration = { + { name = "keyauth", value = {key_names = {"apikey"}, hide_credentials = true}, __api = 1 }, + { name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 1 }, + { name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 2 }, + { name = "filelog", value = { path = "/tmp/spec.log" }, __api = 1 } + } + } + + local res, err = dao_factory.plugins_configurations:find_distinct() + + assert.falsy(err) + assert.truthy(res) + + assert.are.same(3, #res) + assert.truthy(utils.table_contains(res, "keyauth")) + assert.truthy(utils.table_contains(res, "ratelimiting")) + assert.truthy(utils.table_contains(res, "filelog")) + end) + end) + + describe(":insert()", function() + local api_id + local inserted_plugin + it("should insert a plugin and set the consumer_id to a 'null' uuid if none is specified", function() + -- Since we want to specifically select plugins configurations which have _no_ consumer_id sometimes, we cannot rely on using + -- NULL (and thus, not inserting the consumer_id column for the row). To fix this, we use a predefined, nullified uuid... + + -- Create an API + local api_t = faker:fake_entity("api") + local api, err = dao_factory.apis:insert(api_t) + assert.falsy(err) + + local plugin_t = faker:fake_entity("plugin_configuration") + plugin_t.api_id = api.id + + local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) + assert.falsy(err) + assert.truthy(plugin) + assert.falsy(plugin.consumer_id) + + -- for next test + api_id = api.id + inserted_plugin = plugin + inserted_plugin.consumer_id = nil + end) + + it("should select a plugin configuration by 'null' uuid consumer_id and remove the column", function() + -- Now we should be able to select this plugin + local rows, err = dao_factory.plugins_configurations:find_by_keys { + api_id = api_id, + consumer_id = constants.DATABASE_NULL_ID + } + assert.falsy(err) + assert.truthy(rows[1]) + assert.are.same(inserted_plugin, rows[1]) + assert.falsy(rows[1].consumer_id) + end) + end) + + end) -- describe plugins configurations + end) -- describe Base DAO +end) -- describe Cassandra diff --git a/spec/unit/dao/cassandra/factory_spec.lua b/spec/unit/dao/cassandra/factory_spec.lua new file mode 100644 index 00000000000..f2028342a28 --- /dev/null +++ b/spec/unit/dao/cassandra/factory_spec.lua @@ -0,0 +1,23 @@ +local CassandraFactory = require "kong.dao.cassandra.factory" +local spec_helper = require "spec.spec_helpers" + +local env = spec_helper.get_env() -- test environment +local configuration = env.configuration +configuration.cassandra = configuration.databases_available[configuration.database].properties + +describe(":prepare()", function() + + it("should return an error if cannot connect to Cassandra", function() + local new_factory = CassandraFactory({ hosts = "127.0.0.1", + port = 45678, + timeout = 1000, + keyspace = configuration.cassandra.keyspace + }) + + local err = new_factory:prepare() + assert.truthy(err) + assert.True(err.database) + assert.are.same("Cassandra error: connection refused", err.message) + end) + +end) diff --git a/spec/unit/dao/cassandra/query_builder_spec.lua b/spec/unit/dao/cassandra/query_builder_spec.lua new file mode 100644 index 00000000000..5d1a3896dbf --- /dev/null +++ b/spec/unit/dao/cassandra/query_builder_spec.lua @@ -0,0 +1,194 @@ +local builder = require "kong.dao.cassandra.query_builder" + +describe("Query Builder", function() + + local apis_details = { + primary_key = {"id"}, + clustering_key = {"cluster_key"}, + indexes = {public_dns = true, name = true} + } + + describe("SELECT", function() + + it("should build a SELECT query", function() + local q = builder.select("apis") + assert.equal("SELECT * FROM apis", q) + end) + + it("should restrict columns to SELECT", function() + local q = builder.select("apis", nil, nil, {"name", "id"}) + assert.equal("SELECT name, id FROM apis", q) + end) + + it("should return the columns of the arguments to bind", function() + local _, columns = builder.select("apis", {name="mockbin", public_dns="mockbin.com"}) + assert.same({"name", "public_dns"}, columns) + end) + + describe("WHERE", function() + it("should not allow filtering if all the queried fields are indexed", function() + local q, _, needs_filtering = builder.select("apis", {name="mockbin"}, apis_details) + assert.equal("SELECT * FROM apis WHERE name = ?", q) + assert.False(needs_filtering) + end) + + it("should not allow filtering if all the queried fields are primary keys", function() + local q, _, needs_filtering = builder.select("apis", {id="1"}, apis_details) + assert.equal("SELECT * FROM apis WHERE id = ?", q) + assert.False(needs_filtering) + end) + + it("should not allow filtering if all the queried fields are primary keys or indexed", function() + local q, _, needs_filtering = builder.select("apis", {id="1", name="mockbin"}, apis_details) + assert.equal("SELECT * FROM apis WHERE name = ? AND id = ?", q) + assert.False(needs_filtering) + end) + + it("should not allow filtering if all the queried fields are primary keys or indexed", function() + local q = builder.select("apis", {id="1", name="mockbin", cluster_key="foo"}, apis_details) + assert.equal("SELECT * FROM apis WHERE cluster_key = ? AND name = ? AND id = ?", q) + end) + + it("should enable filtering when more than one indexed field is being queried", function() + local q, _, needs_filtering = builder.select("apis", {name="mockbin", public_dns="mockbin.com"}, apis_details) + assert.equal("SELECT * FROM apis WHERE name = ? AND public_dns = ? ALLOW FILTERING", q) + assert.True(needs_filtering) + end) + end) + + it("should throw an error if no column_family", function() + assert.has_error(function() + builder.select() + end, "column_family must be a string") + end) + + it("should throw an error if select_columns is not a table", function() + assert.has_error(function() + builder.select("apis", {name="mockbin"}, nil, "") + end, "select_columns must be a table") + end) + + it("should throw an error if primary_key is not a table", function() + assert.has_error(function() + builder.select("apis", {name="mockbin"}, {primary_key = ""}) + end, "primary_key must be a table") + end) + + it("should throw an error if indexes is not a table", function() + assert.has_error(function() + builder.select("apis", {name="mockbin"}, {indexes = ""}) + end, "indexes must be a table") + end) + + it("should throw an error if where_key is not a table", function() + assert.has_error(function() + builder.select("apis", "") + end, "where_t must be a table") + end) + + end) + + describe("INSERT", function() + + it("should build an INSERT query", function() + local q = builder.insert("apis", {id="123", name="mockbin"}) + assert.equal("INSERT INTO apis(name, id) VALUES(?, ?)", q) + end) + + it("should return the columns of the arguments to bind", function() + local _, columns = builder.insert("apis", {id="123", name="mockbin"}) + assert.same({"name", "id"}, columns) + end) + + it("should throw an error if no column_family", function() + assert.has_error(function() + builder.insert(nil, {"id", "name"}) + end, "column_family must be a string") + end) + + it("should throw an error if no insert_values", function() + assert.has_error(function() + builder.insert("apis") + end, "insert_values must be a table") + end) + + end) + + describe("UPDATE", function() + + it("should build an UPDATE query", function() + local q = builder.update("apis", {name="mockbin"}, {id="1"}, apis_details) + assert.equal("UPDATE apis SET name = ? WHERE id = ?", q) + end) + + it("should return the columns of the arguments to bind", function() + local _, columns = builder.update("apis", {public_dns="1234", name="mockbin"}, {id="1"}, apis_details) + assert.same({"public_dns", "name", "id"}, columns) + end) + + it("should throw an error if no column_family", function() + assert.has_error(function() + builder.update() + end, "column_family must be a string") + end) + + it("should throw an error if no update_values", function() + assert.has_error(function() + builder.update("apis") + end, "update_values must be a table") + + assert.has_error(function() + builder.update("apis", {}) + end, "update_values cannot be empty") + end) + + it("should throw an error if no where_t", function() + assert.has_error(function() + builder.update("apis", {name="foo"}, {}) + end, "where_t must contain keys") + end) + + end) + + describe("DELETE", function() + + it("should build a DELETE query", function() + local q = builder.delete("apis", {id="1234"}) + assert.equal("DELETE FROM apis WHERE id = ?", q) + end) + + it("should return the columns of the arguments to bind", function() + local _, columns = builder.delete("apis", {id="1234"}) + assert.same({"id"}, columns) + end) + + it("should throw an error if no column_family", function() + assert.has_error(function() + builder.delete() + end, "column_family must be a string") + end) + + it("should throw an error if no where_t", function() + assert.has_error(function() + builder.delete("apis", {}) + end, "where_t must contain keys") + end) + + end) + + describe("TRUNCATE", function() + + it("should build a TRUNCATE query", function() + local q = builder.truncate("apis") + assert.equal("TRUNCATE apis", q) + end) + + it("should throw an error if no column_family", function() + assert.has_error(function() + builder.truncate() + end, "column_family must be a string") + end) + + end) +end) + diff --git a/spec/unit/dao/cassandra_spec.lua b/spec/unit/dao/cassandra_spec.lua deleted file mode 100644 index 3017e432d86..00000000000 --- a/spec/unit/dao/cassandra_spec.lua +++ /dev/null @@ -1,1060 +0,0 @@ -local CassandraFactory = require "kong.dao.cassandra.factory" -local spec_helper = require "spec.spec_helpers" -local timestamp = require "kong.tools.timestamp" -local cassandra = require "cassandra" -local constants = require "kong.constants" -local DaoError = require "kong.dao.error" -local utils = require "kong.tools.utils" -local cjson = require "cjson" -local uuid = require "uuid" - --- Raw session for double-check purposes -local session --- Load everything we need from the spec_helper -local env = spec_helper.get_env() -- test environment -local faker = env.faker -local dao_factory = env.dao_factory -local configuration = env.configuration -configuration.cassandra = configuration.databases_available[configuration.database].properties - --- An utility function to apply tests on core collections. -local function describe_core_collections(tests_cb) - for type, dao in pairs({ api = dao_factory.apis, - consumer = dao_factory.consumers }) do - local collection = type == "plugin_configuration" and "plugins_configurations" or type.."s" - describe(collection, function() - tests_cb(type, collection) - end) - end -end - --- An utility function to test if an object is a DaoError. --- Naming is due to luassert extensibility's restrictions -local function daoError(state, arguments) - local stub_err = DaoError("", "") - return getmetatable(stub_err) == getmetatable(arguments[1]) -end - -local say = require("say") -say:set("assertion.daoError.positive", "Expected %s\nto be a DaoError") -say:set("assertion.daoError.negative", "Expected %s\nto not be a DaoError") -assert:register("assertion", "daoError", daoError, "assertion.daoError.positive", "assertion.daoError.negative") - --- Let's go -describe("Cassandra DAO", function() - - setup(function() - spec_helper.prepare_db() - - -- Create a parallel session to verify the dao's behaviour - session = cassandra.new() - session:set_timeout(configuration.cassandra.timeout) - - local _, err = session:connect(configuration.cassandra.hosts, configuration.cassandra.port) - assert.falsy(err) - - local _, err = session:set_keyspace(configuration.cassandra.keyspace) - assert.falsy(err) - end) - - teardown(function() - if session then - local _, err = session:close() - assert.falsy(err) - end - end) - - describe("Collections schemas", function() - - describe_core_collections(function(type, collection) - - it("should have statements for all unique and foreign schema fields", function() - for column, schema_field in pairs(dao_factory[collection]._schema) do - if schema_field.unique then - assert.truthy(dao_factory[collection]._queries.__unique[column]) - end - if schema_field.foreign then - assert.truthy(dao_factory[collection]._queries.__foreign[column]) - end - end - end) - - end) - end) - - describe("Factory", function() - - describe(":prepare()", function() - - it("should prepare all queries in collection's _queries", function() - local new_factory = CassandraFactory({ hosts = "127.0.0.1", - port = 9042, - timeout = 1000, - keyspace = configuration.cassandra.keyspace - }) - - local err = new_factory:prepare() - assert.falsy(err) - - -- assert collections have prepared statements - for _, collection in ipairs({ "apis", "consumers" }) do - for k, v in pairs(new_factory[collection]._queries) do - local cache_key - if type(v) == "string" then - cache_key = v - elseif v.query then - cache_key = v.query - end - - if cache_key then - assert.truthy(new_factory[collection]._statements_cache[cache_key]) - end - end - end - end) - - it("should raise an error if cannot connect to Cassandra", function() - local new_factory = CassandraFactory({ hosts = "127.0.0.1", - port = 45678, - timeout = 1000, - keyspace = configuration.cassandra.keyspace - }) - - local err = new_factory:prepare() - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.database) - assert.are.same("Cassandra error: connection refused", err.message) - end) - - end) - end) -- describe Factory - - -- - -- Core DAO Collections (consumers, apis, plugins_configurations) - -- - - describe("DAO Collections", function() - - describe(":insert()", function() - - describe("APIs", function() - - it("should insert in DB and add generated values", function() - local api_t = faker:fake_entity("api") - local api, err = dao_factory.apis:insert(api_t) - assert.falsy(err) - assert.truthy(api.id) - assert.truthy(api.created_at) - end) - - it("should use the public_dns as the name if none is specified", function() - local api, err = dao_factory.apis:insert { - public_dns = "test.com", - target_url = "http://mockbin.com" - } - assert.falsy(err) - assert.truthy(api.name) - assert.are.same("test.com", api.name) - end) - - it("should not insert an invalid api", function() - -- Nil - local api, err = dao_factory.apis:insert() - assert.falsy(api) - assert.truthy(err) - assert.True(err.schema) - assert.are.same("Cannot insert a nil element", err.message) - - -- Invalid schema UNIQUE error (already existing API name) - local api_rows, err = session:execute("SELECT * FROM apis LIMIT 1;") - assert.falsy(err) - local api_t = faker:fake_entity("api") - api_t.name = api_rows[1].name - - local api, err = dao_factory.apis:insert(api_t) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.unique) - assert.are.same("name already exists with value '"..api_t.name.."'", err.message.name) - assert.falsy(api) - - -- Duplicated name - local apis, err = session:execute("SELECT * FROM apis") - assert.falsy(err) - assert.truthy(#apis > 0) - - local api_t = faker:fake_entity("api") - api_t.name = apis[1].name - local api, err = dao_factory.apis:insert(api_t) - assert.falsy(api) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.unique) - assert.are.same("name already exists with value '"..api_t.name.."'", err.message.name) - end) - - end) - - describe("Consumers", function() - - it("should insert an consumer in DB and add generated values", function() - local consumer_t = faker:fake_entity("consumer") - local consumer, err = dao_factory.consumers:insert(consumer_t) - assert.falsy(err) - assert.truthy(consumer.id) - assert.truthy(consumer.created_at) - end) - - end) - - describe("plugin_configurations", function() - - it("should not insert in DB if invalid", function() - -- Without an api_id, it's a schema error - local plugin_t = faker:fake_entity("plugin_configuration") - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(plugin) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.schema) - assert.are.same("api_id is required", err.message.api_id) - - -- With an invalid api_id, it's an FOREIGN error - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.api_id = uuid() - - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(plugin) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.foreign) - assert.are.same("api_id "..plugin_t.api_id.." does not exist", err.message.api_id) - - -- With invalid api_id and consumer_id, it's an EXISTS error - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.api_id = uuid() - plugin_t.consumer_id = uuid() - - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(plugin) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.foreign) - assert.are.same("api_id "..plugin_t.api_id.." does not exist", err.message.api_id) - assert.are.same("consumer_id "..plugin_t.consumer_id.." does not exist", err.message.consumer_id) - end) - - it("should insert a plugin configuration in DB and add generated values", function() - local api_t = faker:fake_entity("api") - local api, err = dao_factory.apis:insert(api_t) - assert.falsy(err) - - local consumers, err = session:execute("SELECT * FROM consumers") - assert.falsy(err) - assert.True(#consumers > 0) - - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.api_id = api.id - plugin_t.consumer_id = consumers[1].id - - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(err) - assert.truthy(plugin) - assert.truthy(plugin.consumer_id) - end) - - it("should not insert twice a plugin with same api_id, consumer_id and name", function() - -- Insert a new API for a fresh start - local api, err = dao_factory.apis:insert(faker:fake_entity("api")) - assert.falsy(err) - assert.truthy(api.id) - - local consumers, err = session:execute("SELECT * FROM consumers") - assert.falsy(err) - assert.True(#consumers > 0) - - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.api_id = api.id - plugin_t.consumer_id = consumers[#consumers].id - - -- This should work - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(err) - assert.truthy(plugin) - - -- This should fail - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(plugin) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.unique) - assert.are.same("Plugin configuration already exists", err.message) - end) - - it("should not insert a plugin if this plugin doesn't exist (not installed)", function() - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.name = "world domination plugin" - - -- This should fail - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(plugin) - assert.truthy(err) - assert.is_daoError(err) - assert.are.same("Plugin \"world domination plugin\" not found", err.message.value) - end) - - it("should validate a plugin value schema", function() - -- Success - -- Insert a new API for a fresh start - local api, err = dao_factory.apis:insert(faker:fake_entity("api")) - assert.falsy(err) - assert.truthy(api.id) - - local consumers, err = session:execute("SELECT * FROM consumers") - assert.falsy(err) - assert.True(#consumers > 0) - - local plugin_t = { - api_id = api.id, - consumer_id = consumers[#consumers].id, - name = "keyauth", - value = { - key_names = { "x-kong-key" } - } - } - - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(err) - assert.truthy(plugin) - - local ok, err = dao_factory.plugins_configurations:delete(plugin.id) - assert.True(ok) - assert.falsy(err) - - -- Failure - plugin_t.name = "ratelimiting" - plugin_t.value = { period = "hello" } - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.truthy(err) - assert.is_daoError(err) - assert.truthy(err.schema) - assert.are.same("\"hello\" is not allowed. Allowed values are: \"second\", \"minute\", \"hour\", \"day\", \"month\", \"year\"", err.message["value.period"]) - assert.falsy(plugin) - end) - - end) - end) -- describe :insert() - - describe(":update()", function() - - describe_core_collections(function(type, collection) - - it("should return nil if no entity was found to update in DB", function() - local t = faker:fake_entity(type) - t.id = uuid() - - -- Remove immutable fields - for k,v in pairs(dao_factory[collection]._schema) do - if v.immutable and not v.required then - t[k] = nil - end - end - - -- No entity to update - local entity, err = dao_factory[collection]:update(t) - assert.falsy(entity) - assert.falsy(err) - end) - - end) - - describe("APIs", function() - - -- Cassandra sets to NULL unset fields specified in an UPDATE query - -- https://issues.apache.org/jira/browse/CASSANDRA-7304 - it("should update in DB without setting to NULL unset fields", function() - local apis, err = session:execute("SELECT * FROM apis") - assert.falsy(err) - assert.True(#apis > 0) - - local api_t = apis[1] - api_t.name = api_t.name.." updated" - - -- This should not set those values to NULL in DB - api_t.created_at = nil - api_t.public_dns = nil - api_t.target_url = nil - - local api, err = dao_factory.apis:update(api_t) - assert.falsy(err) - assert.truthy(api) - - local apis, err = session:execute("SELECT * FROM apis WHERE name = '"..api_t.name.."'") - assert.falsy(err) - assert.are.same(1, #apis) - assert.truthy(apis[1].id) - assert.truthy(apis[1].created_at) - assert.truthy(apis[1].public_dns) - assert.truthy(apis[1].target_url) - assert.are.same(api_t.name, apis[1].name) - end) - - it("should prevent the update if the UNIQUE check fails", function() - local apis, err = session:execute("SELECT * FROM apis") - assert.falsy(err) - assert.True(#apis > 0) - - local api_t = apis[1] - api_t.name = api_t.name.." unique update attempt" - - -- Should not work because UNIQUE check fails - api_t.public_dns = apis[2].public_dns - - local api, err = dao_factory.apis:update(api_t) - assert.falsy(api) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.unique) - assert.are.same("public_dns already exists with value '"..api_t.public_dns.."'", err.message.public_dns) - end) - - end) - - describe("Consumers", function() - - it("should update in DB if entity can be found", function() - local consumers, err = session:execute("SELECT * FROM consumers") - assert.falsy(err) - assert.True(#consumers > 0) - - local consumer_t = consumers[1] - - -- Should be correctly updated in DB - consumer_t.custom_id = consumer_t.custom_id.."updated" - - local consumer, err = dao_factory.consumers:update(consumer_t) - assert.falsy(err) - assert.truthy(consumer) - - local consumers, err = session:execute("SELECT * FROM consumers WHERE custom_id = '"..consumer_t.custom_id.."'") - assert.falsy(err) - assert.True(#consumers == 1) - assert.are.same(consumer_t.name, consumers[1].name) - end) - - end) - - describe("plugin_configurations", function() - - setup(function() - local fixtures = spec_helper.seed_db(1) - faker:insert_from_table { - plugin_configuration = { - { name = "keyauth", value = {key_names = {"apikey"}}, api_id = fixtures.api[1].id } - } - } - end) - - it("should update in DB if entity can be found", function() - local plugins_configurations, err = session:execute("SELECT * FROM plugins_configurations") - assert.falsy(err) - assert.True(#plugins_configurations > 0) - - local plugin_conf_t = plugins_configurations[1] - plugin_conf_t.value = cjson.decode(plugin_conf_t.value) - plugin_conf_t.enabled = false - local plugin_conf, err = dao_factory.plugins_configurations:update(plugin_conf_t) - assert.falsy(err) - assert.truthy(plugin_conf) - - local plugins_configurations, err = session:execute("SELECT * FROM plugins_configurations WHERE id = ?", { cassandra.uuid(plugin_conf_t.id) }) - assert.falsy(err) - assert.are.same(1, #plugins_configurations) - end) - - end) - end) -- describe :update() - - describe(":delete()", function() - - describe_core_collections(function(type, collection) - - it("should return false if there was nothing to delete", function() - local ok, err = dao_factory[collection]:delete(uuid()) - assert.is_not_true(ok) - assert.falsy(err) - end) - - it("should delete an entity if it can be found", function() - local entities, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(entities) - assert.True(#entities > 0) - - local ok, err = dao_factory[collection]:delete(entities[1].id) - assert.falsy(err) - assert.True(ok) - - local entities, err = session:execute("SELECT * FROM "..collection.." WHERE id = "..entities[1].id ) - assert.falsy(err) - assert.truthy(entities) - assert.are.same(0, #entities) - end) - - end) - - describe("APIs", function() - local api, untouched_api - - setup(function() - spec_helper.drop_db() - - -- Insert an API - local _, err - api, err = dao_factory.apis:insert { - name = "cascade delete test", - public_dns = "cascade.com", - target_url = "http://mockbin.com" - } - assert.falsy(err) - - -- Insert some plugins_configurations - _, err = dao_factory.plugins_configurations:insert { - name = "keyauth", value = { key_names = {"apikey"} }, api_id = api.id - } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "ratelimiting", value = { period = "minute", limit = 6 }, api_id = api.id - } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "filelog", value = { path = "/tmp/spec.log" }, api_id = api.id - } - assert.falsy(err) - - -- Insert an unrelated API + plugin - untouched_api, err = dao_factory.apis:insert { - name = "untouched cascade test api", - public_dns = "untouched.com", - target_url = "http://mockbin.com" - } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "filelog", value = { path = "/tmp/spec.log" }, api_id = untouched_api.id - } - assert.falsy(err) - - -- Make sure we have 3 matches - local results, err = dao_factory.plugins_configurations:find_by_keys { - api_id = api.id - } - assert.falsy(err) - assert.are.same(3, #results) - end) - - teardown(function() - spec_helper.drop_db() - end) - - it("should delete all related plugins_configurations when deleting an API", function() - local ok, err = dao_factory.apis:delete(api.id) - assert.falsy(err) - assert.True(ok) - - -- Make sure we have 0 matches - local results, err = dao_factory.plugins_configurations:find_by_keys { - api_id = api.id - } - assert.falsy(err) - assert.are.same(0, #results) - - -- Make sure the untouched API still has its plugin - local results, err = dao_factory.plugins_configurations:find_by_keys { - api_id = untouched_api.id - } - assert.falsy(err) - assert.are.same(1, #results) - end) - - end) - - describe("Consumers", function() - local api, consumer, untouched_consumer - - setup(function() - spec_helper.drop_db() - - local _, err - - -- Insert a Consumer - consumer, err = dao_factory.consumers:insert { username = "king kong" } - assert.falsy(err) - - -- Insert an API - api, err = dao_factory.apis:insert { - name = "cascade delete test", - public_dns = "cascade.com", - target_url = "http://mockbin.com" - } - assert.falsy(err) - - -- Insert some plugins_configurations - _, err = dao_factory.plugins_configurations:insert { - name="keyauth", value = { key_names = {"apikey"} }, api_id = api.id, - consumer_id = consumer.id - } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "ratelimiting", value = { period = "minute", limit = 6 }, api_id = api.id, - consumer_id = consumer.id - } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "filelog", value = { path = "/tmp/spec.log" }, api_id = api.id, - consumer_id = consumer.id - } - assert.falsy(err) - - -- Inser an untouched consumer + plugin - untouched_consumer, err = dao_factory.consumers:insert { username = "untouched consumer" } - assert.falsy(err) - - _, err = dao_factory.plugins_configurations:insert { - name = "filelog", value = { path = "/tmp/spec.log" }, api_id = api.id, - consumer_id = untouched_consumer.id - } - assert.falsy(err) - - local results, err = dao_factory.plugins_configurations:find_by_keys { - consumer_id = consumer.id - } - assert.falsy(err) - assert.are.same(3, #results) - end) - - teardown(function() - spec_helper.drop_db() - end) - - it("should delete all related plugins_configurations when deleting an API", function() - local ok, err = dao_factory.consumers:delete(consumer.id) - assert.True(ok) - assert.falsy(err) - - local results, err = dao_factory.plugins_configurations:find_by_keys { - consumer_id = consumer.id - } - assert.falsy(err) - assert.are.same(0, #results) - - -- Make sure the untouched Consumer still has its plugin - local results, err = dao_factory.plugins_configurations:find_by_keys { - consumer_id = untouched_consumer.id - } - assert.falsy(err) - assert.are.same(1, #results) - end) - - end) - end) -- describe :delete() - - describe(":find()", function() - - setup(function() - spec_helper.drop_db() - spec_helper.seed_db(10) - end) - - describe_core_collections(function(type, collection) - - it("should find entities", function() - local entities, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(entities) - assert.True(#entities > 0) - - local results, err = dao_factory[collection]:find() - assert.falsy(err) - assert.truthy(results) - assert.are.same(#entities, #results) - end) - - it("should allow pagination", function() - -- 1st page - local rows_1, err = dao_factory[collection]:find(2) - assert.falsy(err) - assert.truthy(rows_1) - assert.are.same(2, #rows_1) - assert.truthy(rows_1.next_page) - - -- 2nd page - local rows_2, err = dao_factory[collection]:find(2, rows_1.next_page) - assert.falsy(err) - assert.truthy(rows_2) - assert.are.same(2, #rows_2) - end) - - end) - end) -- describe :find() - - describe(":find_one()", function() - - describe_core_collections(function(type, collection) - - it("should find one entity by id", function() - local entities, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(entities) - assert.True(#entities > 0) - - local result, err = dao_factory[collection]:find_one(entities[1].id) - assert.falsy(err) - assert.truthy(result) - end) - - it("should handle an invalid uuid value", function() - local result, err = dao_factory[collection]:find_one("abcd") - assert.falsy(result) - assert.True(err.invalid_type) - assert.are.same("abcd is an invalid uuid", err.message.id) - end) - - end) - - describe("plugin_configurations", function() - - setup(function() - local fixtures = spec_helper.seed_db(1) - faker:insert_from_table { - plugin_configuration = { - { name = "keyauth", value = {key_names = {"apikey"}}, api_id = fixtures.api[1].id } - } - } - end) - - it("should deserialize the table property", function() - local plugins_configurations, err = session:execute("SELECT * FROM plugins_configurations") - assert.falsy(err) - assert.truthy(plugins_configurations) - assert.True(#plugins_configurations > 0) - - local plugin_t = plugins_configurations[1] - - local result, err = dao_factory.plugins_configurations:find_one(plugin_t.id) - assert.falsy(err) - assert.truthy(result) - assert.are.same("table", type(result.value)) - end) - - end) - end) -- describe :find_one() - - describe(":find_by_keys()", function() - - describe_core_collections(function(type, collection) - - it("should refuse non queryable keys", function() - local results, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - - local t = results[1] - - local results, err = dao_factory[collection]:find_by_keys(t) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.schema) - assert.falsy(results) - - -- All those fields are indeed non queryable - for k, v in pairs(err.message) do - assert.is_not_true(dao_factory[collection]._schema[k].queryable) - end - end) - - it("should handle empty search fields", function() - local results, err = dao_factory[collection]:find_by_keys({}) - assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - end) - - it("should handle nil search fields", function() - local results, err = dao_factory[collection]:find_by_keys(nil) - assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - end) - - it("should query an entity by its queryable fields", function() - local results, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - - local t = results[1] - local q = {} - - -- Remove nonqueryable fields - for k, schema_field in pairs(dao_factory[collection]._schema) do - if schema_field.queryable then - q[k] = t[k] - elseif schema_field.type == "table" then - t[k] = cjson.decode(t[k]) - end - end - - local results, err = dao_factory[collection]:find_by_keys(q) - assert.falsy(err) - assert.truthy(results) - - -- in case of plugins configurations - if t.consumer_id == constants.DATABASE_NULL_ID then - t.consumer_id = nil - end - - assert.are.same(t, results[1]) - end) - - end) - end) -- describe :find_by_keys() - - -- - -- Plugins configuration additional behaviour - -- - - describe("plugin_configurations", function() - local api_id - local inserted_plugin - - it("should find distinct plugins configurations", function() - faker:insert_from_table { - api = { - { name = "tests distinct 1", public_dns = "foo.com", target_url = "http://mockbin.com" }, - { name = "tests distinct 2", public_dns = "bar.com", target_url = "http://mockbin.com" } - }, - plugin_configuration = { - { name = "keyauth", value = {key_names = {"apikey"}, hide_credentials = true}, __api = 1 }, - { name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 1 }, - { name = "ratelimiting", value = {period = "minute", limit = 6}, __api = 2 }, - { name = "filelog", value = { path = "/tmp/spec.log" }, __api = 1 } - } - } - - local res, err = dao_factory.plugins_configurations:find_distinct() - - assert.falsy(err) - assert.truthy(res) - - assert.are.same(3, #res) - assert.truthy(utils.table_contains(res, "keyauth")) - assert.truthy(utils.table_contains(res, "ratelimiting")) - assert.truthy(utils.table_contains(res, "filelog")) - end) - - it("should insert a plugin and set the consumer_id to a 'null' uuid if none is specified", function() - -- Since we want to specifically select plugins configurations which have _no_ consumer_id sometimes, we cannot rely on using - -- NULL (and thus, not inserting the consumer_id column for the row). To fix this, we use a predefined, nullified uuid... - - -- Create an API - local api_t = faker:fake_entity("api") - local api, err = dao_factory.apis:insert(api_t) - assert.falsy(err) - - local plugin_t = faker:fake_entity("plugin_configuration") - plugin_t.api_id = api.id - - local plugin, err = dao_factory.plugins_configurations:insert(plugin_t) - assert.falsy(err) - assert.truthy(plugin) - assert.falsy(plugin.consumer_id) - - -- for next test - api_id = api.id - inserted_plugin = plugin - inserted_plugin.consumer_id = nil - end) - - it("should select a plugin configuration by 'null' uuid consumer_id and remove the column", function() - -- Now we should be able to select this plugin - local rows, err = dao_factory.plugins_configurations:find_by_keys { - api_id = api_id, - consumer_id = constants.DATABASE_NULL_ID - } - assert.falsy(err) - assert.truthy(rows[1]) - assert.are.same(inserted_plugin, rows[1]) - assert.falsy(rows[1].consumer_id) - end) - - end) -- describe plugins configurations - end) -- describe DAO Collections - - -- - -- Keyauth plugin collection - -- - - describe("Keyauth", function() - - it("should not insert in DB if consumer does not exist", function() - -- Without an consumer_id, it's a schema error - local app_t = { name = "keyauth", value = {key_names = {"apikey"}} } - local app, err = dao_factory.keyauth_credentials:insert(app_t) - assert.falsy(app) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.schema) - assert.are.same("consumer_id is required", err.message.consumer_id) - - -- With an invalid consumer_id, it's a FOREIGN error - local app_t = { key = "apikey123", consumer_id = uuid() } - local app, err = dao_factory.keyauth_credentials:insert(app_t) - assert.falsy(app) - assert.truthy(err) - assert.is_daoError(err) - assert.True(err.foreign) - assert.are.same("consumer_id "..app_t.consumer_id.." does not exist", err.message.consumer_id) - end) - - it("should insert in DB and add generated values", function() - local consumers, err = session:execute("SELECT * FROM consumers") - assert.falsy(err) - assert.truthy(#consumers > 0) - - local app_t = { key = "apikey123", consumer_id = consumers[1].id } - local app, err = dao_factory.keyauth_credentials:insert(app_t) - assert.falsy(err) - assert.truthy(app.id) - assert.truthy(app.created_at) - end) - - it("should find an KeyAuth Credential by public_key", function() - local app, err = dao_factory.keyauth_credentials:find_by_keys { - key = "user122" - } - assert.falsy(err) - assert.truthy(app) - end) - - it("should handle empty strings", function() - local apps, err = dao_factory.keyauth_credentials:find_by_keys { - key = "" - } - assert.falsy(err) - assert.are.same({}, apps) - end) - - end) - - -- - -- Rate Limiting plugin collection - -- - - describe("Rate Limiting Metrics", function() - local ratelimiting_metrics = dao_factory.ratelimiting_metrics - local api_id = uuid() - local identifier = uuid() - - after_each(function() - spec_helper.drop_db() - end) - - it("should return nil when ratelimiting metrics are not existing", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - -- Very first select should return nil - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.are.same(nil, metric) - end - end) - - it("should increment ratelimiting metrics with the given period", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - - -- First increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) - assert.falsy(err) - assert.True(ok) - - -- First select - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.are.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = 1 -- The important part - }, metric) - end - - -- Second increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) - assert.falsy(err) - assert.True(ok) - - -- Second select - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.are.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = 2 -- The important part - }, metric) - end - - -- 1 second delay - current_timestamp = 1424217601 - periods = timestamp.get_timestamps(current_timestamp) - - -- Third increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp) - assert.falsy(err) - assert.True(ok) - - -- Third select with 1 second delay - for period, period_date in pairs(periods) do - - local expected_value = 3 - - if period == "second" then - expected_value = 1 - end - - local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.are.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = expected_value -- The important part - }, metric) - end - end) - - it("should throw errors for non supported methods of the base_dao", function() - assert.has_error(ratelimiting_metrics.find, "ratelimiting_metrics:find() not supported") - assert.has_error(ratelimiting_metrics.insert, "ratelimiting_metrics:insert() not supported") - assert.has_error(ratelimiting_metrics.update, "ratelimiting_metrics:update() not supported") - assert.has_error(ratelimiting_metrics.delete, "ratelimiting_metrics:delete() not yet implemented") - assert.has_error(ratelimiting_metrics.find_by_keys, "ratelimiting_metrics:find_by_keys() not supported") - end) - - end) -- describe rate limiting metrics - -end) diff --git a/spec/unit/dao/entities_schemas_spec.lua b/spec/unit/dao/entities_schemas_spec.lua index 2b8a472aae3..1a47040ab9b 100644 --- a/spec/unit/dao/entities_schemas_spec.lua +++ b/spec/unit/dao/entities_schemas_spec.lua @@ -1,14 +1,31 @@ -local validate = require("kong.dao.schemas_validation").validate local api_schema = require "kong.dao.schemas.apis" local consumer_schema = require "kong.dao.schemas.consumers" +local plugins_configurations_schema = require "kong.dao.schemas.plugins_configurations" +local validate_fields = require("kong.dao.schemas_validation").validate_fields require "kong.tools.ngx_stub" describe("Entities Schemas", function() + + for k, schema in pairs({api = api_schema, + consumer = consumer_schema, + plugins_configurations = plugins_configurations_schema}) do + it(k.." schema should have some required properties", function() + assert.truthy(schema.name) + assert.equal("string", type(schema.name)) + + assert.truthy(schema.primary_key) + assert.equal("table", type(schema.primary_key)) + + assert.truthy(schema.fields) + assert.equal("table", type(schema.fields)) + end) + end + describe("APIs", function() it("should return error with wrong target_url", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ public_dns = "mockbin.com", target_url = "asdasd" }, api_schema) @@ -17,7 +34,7 @@ describe("Entities Schemas", function() end) it("should return error with wrong target_url protocol", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ public_dns = "mockbin.com", target_url = "wot://mockbin.com/" }, api_schema) @@ -26,7 +43,7 @@ describe("Entities Schemas", function() end) it("should validate without a path", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ public_dns = "mockbin.com", target_url = "http://mockbin.com" }, api_schema) @@ -35,7 +52,7 @@ describe("Entities Schemas", function() end) it("should validate with upper case protocol", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ public_dns = "mockbin.com", target_url = "HTTP://mockbin.com/world" }, api_schema) @@ -44,14 +61,14 @@ describe("Entities Schemas", function() end) it("should complain if missing `public_dns` and `path`", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ name = "mockbin" }, api_schema) assert.False(valid) assert.equal("At least a 'public_dns' or a 'path' must be specified", errors.path) assert.equal("At least a 'public_dns' or a 'path' must be specified", errors.public_dns) - local valid, errors = validate({ + local valid, errors = validate_fields({ name = "mockbin", path = true }, api_schema) @@ -60,8 +77,17 @@ describe("Entities Schemas", function() assert.equal("At least a 'public_dns' or a 'path' must be specified", errors.public_dns) end) + it("should set the name from public_dns if not set", function() + local t = { public_dns = "mockbin.com", target_url = "http://mockbin.com" } + + local valid, errors = validate_fields(t, api_schema) + assert.falsy(errors) + assert.True(valid) + assert.equal("mockbin.com", t.name) + end) + it("should only accept alphanumeric `path`", function() - local valid, errors = validate({ + local valid, errors = validate_fields({ name = "mockbin", path = "/[a-zA-Z]{3}", target_url = "http://mockbin.com" @@ -69,14 +95,14 @@ describe("Entities Schemas", function() assert.equal("path must only contain alphanumeric and '. -, _, ~, /' characters", errors.path) assert.False(valid) - valid = validate({ + valid = validate_fields({ name = "mockbin", path = "/status/", target_url = "http://mockbin.com" }, api_schema) assert.True(valid) - valid = validate({ + valid = validate_fields({ name = "mockbin", path = "/abcd~user-2", target_url = "http://mockbin.com" @@ -86,42 +112,42 @@ describe("Entities Schemas", function() it("should prefix a `path` with a slash and remove trailing slash", function() local api_t = { name = "mockbin", path = "status", target_url = "http://mockbin.com" } - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/status" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "status/" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/status/" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/deep/nested/status/" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) api_t.path = "deep/nested/status" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Strip all leading slashes api_t.path = "//deep/nested/status" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Strip all trailing slashes api_t.path = "/deep/nested/status//" - validate(api_t, api_schema) + validate_fields(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Error if invalid path api_t.path = "/deep//nested/status" - local _, errors = validate(api_t, api_schema) + local _, errors = validate_fields(api_t, api_schema) assert.equal("path is invalid: /deep//nested/status", errors.path) end) @@ -130,21 +156,46 @@ describe("Entities Schemas", function() describe("Consumers", function() it("should require a `custom_id` or `username`", function() - local valid, errors = validate({}, consumer_schema) + local valid, errors = validate_fields({}, consumer_schema) assert.False(valid) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) - valid, errors = validate({ username = "" }, consumer_schema) + valid, errors = validate_fields({ username = "" }, consumer_schema) assert.False(valid) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) - valid, errors = validate({ username = true }, consumer_schema) + valid, errors = validate_fields({ username = true }, consumer_schema) assert.False(valid) assert.equal("username is not a string", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) end) end) + + describe("Plugins Configurations", function() + + it("should not validate if the plugin doesn't exist (not installed)", function() + local valid, errors = validate_fields({name = "world domination"}, plugins_configurations_schema) + assert.False(valid) + assert.equal("Plugin \"world domination\" not found", errors.value) + end) + + it("should validate a plugin configuration's `value` field", function() + -- Success + local plugin = {name = "keyauth", api_id = "stub", value = {key_names = {"x-kong-key"}}} + local valid = validate_fields(plugin, plugins_configurations_schema) + assert.True(valid) + + -- Failure + plugin = {name = "ratelimiting", api_id = "stub", value = {period = "hello"}} + + local valid, errors = validate_fields(plugin, plugins_configurations_schema) + assert.False(valid) + assert.equal("limit is required", errors["value.limit"]) + assert.equal("\"hello\" is not allowed. Allowed values are: \"second\", \"minute\", \"hour\", \"day\", \"month\", \"year\"", errors["value.period"]) + end) + + end) end) diff --git a/spec/unit/schemas_spec.lua b/spec/unit/schemas_spec.lua index b3e1364a87a..fe9b4e516f8 100644 --- a/spec/unit/schemas_spec.lua +++ b/spec/unit/schemas_spec.lua @@ -1,5 +1,5 @@ local schemas = require "kong.dao.schemas_validation" -local validate = schemas.validate +local validate_fields = schemas.validate_fields require "kong.tools.ngx_stub" @@ -7,46 +7,48 @@ describe("Schemas", function() -- Ok kids, today we're gonna test a custom validation schema, -- grab a pair of glasses, this stuff can literally explode. - describe("#validate()", function() + describe("#validate_fields()", function() local schema = { - string = { type = "string", required = true, immutable = true }, - table = { type = "table" }, - number = { type = "number" }, - url = { regex = "(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])" }, - date = { default = 123456, immutable = true }, - allowed = { enum = { "hello", "world" }}, - boolean_val = { type = "boolean" }, - default = { default = function(t) - assert.truthy(t) - return "default" - end }, - custom = { func = function(v, t) - if v then - if t.default == "test_custom_func" then - return true + fields = { + string = { type = "string", required = true, immutable = true}, + table = {type = "table"}, + number = {type = "number"}, + url = {regex = "(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])"}, + date = {default = 123456, immutable = true}, + allowed = {enum = {"hello", "world"}}, + boolean_val = {type = "boolean"}, + default = {default = function(t) + assert.truthy(t) + return "default" + end}, + custom = {func = function(v, t) + if v then + if t.default == "test_custom_func" then + return true + else + return false, "Nah" + end else - return false, "Nah" + return true end - else - return true - end - end } + end} + } } it("should confirm a valid entity is valid", function() - local values = { string = "mockbin entity", url = "mockbin.com" } + local values = {string = "mockbin entity", url = "mockbin.com"} - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) - assert.truthy(valid) + assert.True(valid) end) describe("[required]", function() it("should invalidate entity if required property is missing", function() local values = { url = "mockbin.com" } - local valid, err = validate(values, schema) - assert.falsy(valid) + local valid, err = validate_fields(values, schema) + assert.False(valid) assert.truthy(err) assert.are.same("string is required", err.string) end) @@ -57,37 +59,37 @@ describe("Schemas", function() -- Failure local values = { string = "foo", table = "bar" } - local valid, err = validate(values, schema) - assert.falsy(valid) + local valid, err = validate_fields(values, schema) + assert.False(valid) assert.truthy(err) assert.are.same("table is not a table", err.table) -- Success local values = { string = "foo", table = { foo = "bar" }} - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) - assert.truthy(valid) + assert.True(valid) -- Failure local values = { string = 1, table = { foo = "bar" }} - local valid, err = validate(values, schema) - assert.falsy(valid) + local valid, err = validate_fields(values, schema) + assert.False(valid) assert.truthy(err) assert.are.same("string is not a string", err.string) -- Success local values = { string = "foo", number = 10 } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) - assert.truthy(valid) + assert.True(valid) -- Success local values = { string = "foo", number = "10" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("number", type(values.number)) @@ -95,7 +97,7 @@ describe("Schemas", function() -- Success local values = { string = "foo", boolean_val = true } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("boolean", type(values.boolean_val)) @@ -103,7 +105,7 @@ describe("Schemas", function() -- Success local values = { string = "foo", boolean_val = "true" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) end) @@ -111,7 +113,7 @@ describe("Schemas", function() it("should return error when an invalid boolean value is passed", function() local values = { string = "test", boolean_val = "ciao" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("boolean_val is not a boolean", err.boolean_val) @@ -120,7 +122,7 @@ describe("Schemas", function() it("should not return an error when a true boolean value is passed", function() local values = { string = "test", boolean_val = true } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) end) @@ -128,35 +130,43 @@ describe("Schemas", function() it("should not return an error when a false boolean value is passed", function() local values = { string = "test", boolean_val = false } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) end) it("should consider `id` and `timestamp` as types", function() - local s = { id = { type = "id" } } + local s = { + fields = { + id = { type = "id" } + } + } local values = { id = "123" } - local valid, err = validate(values, s) + local valid, err = validate_fields(values, s) assert.falsy(err) assert.truthy(valid) end) it("should consider `array` as a type", function() - local s = { array = { type = "array" } } + local s = { + fields = { + array = { type = "array" } + } + } -- Success local values = { array = {"hello", "world"} } - local valid, err = validate(values, s) + local valid, err = validate_fields(values, s) assert.True(valid) assert.falsy(err) -- Failure local values = { array = {hello="world"} } - local valid, err = validate(values, s) + local valid, err = validate_fields(values, s) assert.False(valid) assert.truthy(err) assert.equal("array is not a array", err.array) @@ -166,7 +176,7 @@ describe("Schemas", function() it("should not return an error when a `number` is passed as a string", function() local values = { string = "test", number = "10" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.same("number", type(values.number)) @@ -175,19 +185,23 @@ describe("Schemas", function() it("should not return an error when a `boolean` is passed as a string", function() local values = { string = "test", boolean_val = "false" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.same("boolean", type(values.boolean_val)) end) it("should alias a string to `array`", function() - local s = { array = { type = "array" } } + local s = { + fields = { + array = { type = "array" } + } + } -- It should also strip the resulting strings local values = { array = "hello, world" } - local valid, err = validate(values, s) + local valid, err = validate_fields(values, s) assert.True(valid) assert.falsy(err) assert.same({"hello", "world"}, values.array) @@ -200,7 +214,7 @@ describe("Schemas", function() -- Variables local values = { string = "mockbin entity", url = "mockbin.com" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same(123456, values.date) @@ -208,7 +222,7 @@ describe("Schemas", function() -- Functions local values = { string = "mockbin entity", url = "mockbin.com" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("default", values.default) @@ -218,7 +232,7 @@ describe("Schemas", function() -- Variables local values = { string = "mockbin entity", url = "mockbin.com", date = 654321 } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same(654321, values.date) @@ -226,18 +240,19 @@ describe("Schemas", function() -- Functions local values = { string = "mockbin entity", url = "mockbin.com", default = "abcdef" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("abcdef", values.default) end) + end) describe("[regex]", function() it("should validate a field against a regex", function() local values = { string = "mockbin entity", url = "mockbin_!" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("url has an invalid value", err.url) @@ -249,14 +264,14 @@ describe("Schemas", function() -- Success local values = { string = "somestring", allowed = "hello" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure local values = { string = "somestring", allowed = "hello123" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("\"hello123\" is not allowed. Allowed values are: \"hello\", \"world\"", err.allowed) @@ -268,49 +283,62 @@ describe("Schemas", function() -- Success local values = { string = "somestring", custom = true, default = "test_custom_func" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure local values = { string = "somestring", custom = true, default = "not the default :O" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("Nah", err.custom) end) end) - describe("[immutable]", function() - it("should prevent immutable properties to be changed if validating a schema that will be updated", function() - -- Success - local values = { string = "somestring", date = 1234 } + describe("[dao_insert_value]", function() + local schema = { + fields = { + string = { type = "string"}, + id = { type = "id", dao_insert_value = true }, + timestamp = { type = "timestamp", dao_insert_value = true } + } + } - local valid, err = validate(values, schema) - assert.falsy(err) - assert.truthy(valid) + it("should call a given function when encountering a field with `dao_insert_value`", function() + local values = {string = "hello", id = "0000"} - -- Failure - local valid, err = validate(values, schema, true) - assert.falsy(valid) - assert.truthy(err) - assert.are.same("date cannot be updated", err.date) + local valid, err = validate_fields(values, schema, {dao_insert = function(field) + if field.type == "id" then + return "1234" + elseif field.type == "timestamp" then + return 0000 + end + end}) + assert.falsy(err) + assert.True(valid) + assert.equal("1234", values.id) + assert.equal(0000, values.timestamp) + assert.equal("hello", values.string) end) - it("should ignore required properties if they are immutable and we are updating", function() - local values = { string = "somestring" } + it("should not raise any error if the function is not given", function() + local values = { string = "hello", id = "0000" } - local valid, err = validate(values, schema, true) + local valid, err = validate_fields(values, schema, { dao_insert = true }) -- invalid type assert.falsy(err) - assert.truthy(valid) + assert.True(valid) + assert.equal("0000", values.id) + assert.equal("hello", values.string) + assert.falsy(values.timestamp) end) end) it("should return error when unexpected values are included in the schema", function() local values = { string = "mockbin entity", url = "mockbin.com", unexpected = "abcdef" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) end) @@ -318,7 +346,7 @@ describe("Schemas", function() it("should be able to return multiple errors at once", function() local values = { url = "mockbin.com", unexpected = "abcdef" } - local valid, err = validate(values, schema) + local valid, err = validate_fields(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("string is required", err.string) @@ -327,10 +355,14 @@ describe("Schemas", function() it("should not check a custom function if a `required` condition is false already", function() local f = function() error("should not be called") end -- cannot use a spy which changes the type to table - local schema = { property = { required = true, func = f } } + local schema = { + fields = { + property = { required = true, func = f } + } + } assert.has_no_errors(function() - local valid, err = validate({}, schema) + local valid, err = validate_fields({}, schema) assert.False(valid) assert.are.same("property is required", err.property) end) @@ -338,7 +370,7 @@ describe("Schemas", function() describe("Sub-schemas", function() -- To check wether schema_from_function was called, we will simply use booleans because - -- busted's spy methods create tables and metatable magic, but the validate() function + -- busted's spy methods create tables and metatable magic, but the validate_fields() function -- only callse v.schema if the type is a function. Which is not the case with a busted spy. local called, called_with local schema_from_function = function(t) @@ -349,17 +381,21 @@ describe("Schemas", function() return nil, "Error loading the sub-sub-schema" end - return { sub_sub_field_required = { required = true } } + return { fields = {sub_sub_field_required = { required = true }} } end local nested_schema = { - some_required = { required = true }, - sub_schema = { - type = "table", - schema = { - sub_field_required = { required = true }, - sub_field_default = { default = "abcd" }, - sub_field_number = { type = "number" }, - error_loading_sub_sub_schema = {} + fields = { + some_required = { required = true }, + sub_schema = { + type = "table", + schema = { + fields = { + sub_field_required = { required = true }, + sub_field_default = { default = "abcd" }, + sub_field_number = { type = "number" }, + error_loading_sub_sub_schema = {} + } + } } } } @@ -368,7 +404,7 @@ describe("Schemas", function() -- Success local values = { some_required = "somestring", sub_schema = { sub_field_required = "sub value" }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.falsy(err) assert.truthy(valid) assert.are.same("abcd", values.sub_schema.sub_field_default) @@ -376,14 +412,14 @@ describe("Schemas", function() -- Failure local values = { some_required = "somestring", sub_schema = { sub_field_default = "" }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("sub_field_required is required", err["sub_schema.sub_field_required"]) end) it("should validate a property with a sub-schema from a function", function() - nested_schema.sub_schema.schema.sub_sub_schema = { schema = schema_from_function } + nested_schema.fields.sub_schema.schema.fields.sub_sub_schema = { schema = schema_from_function } -- Success local values = { some_required = "somestring", sub_schema = { @@ -391,7 +427,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.falsy(err) assert.truthy(valid) @@ -401,7 +437,7 @@ describe("Schemas", function() sub_sub_schema = {} }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("sub_sub_field_required is required", err["sub_schema.sub_sub_schema.sub_sub_field_required"]) @@ -414,7 +450,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.falsy(err) assert.truthy(valid) assert.True(called) @@ -429,7 +465,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate(values, nested_schema) + local valid, err = validate_fields(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("Error loading the sub-sub-schema", err["sub_schema.sub_sub_schema"]) @@ -444,11 +480,13 @@ describe("Schemas", function() end local schema = { - value = { type = "table", schema = {some_property={default="hello"}}, func = validate_value, required = true } + fields = { + value = { type = "table", schema = { fields = {some_property={default="hello"}}}, func = validate_value, required = true } + } } local obj = {} - local valid, err = validate(obj, schema) + local valid, err = validate_fields(obj, schema) assert.falsy(err) assert.True(valid) assert.are.same("hello", obj.value.some_property) @@ -456,16 +494,74 @@ describe("Schemas", function() it("should mark a value required if sub-schema has a `required`", function() local schema = { - value = { type = "table", schema = {some_property={required=true}} } + fields = { + value = {type = "table", schema = {fields = {some_property={required=true}}}} + } } local obj = {} - local valid, err = validate(obj, schema) + local valid, err = validate_fields(obj, schema) assert.truthy(err) assert.False(valid) assert.are.same("value.some_property is required", err.value) end) end) + + describe("[partial_update]", function() + it("should ignore required properties and defaults if we are updating because the entity might be partial", function() + local values = {} + + local valid, err = validate_fields(values, schema, {partial_update = true}) + assert.falsy(err) + assert.True(valid) + assert.falsy(values.default) + assert.falsy(values.date) + end) + + it("should still validate set properties", function() + local values = { string = 123 } + + local valid, err = validate_fields(values, schema, {partial_update = true}) + assert.False(valid) + assert.equal("string is not a string", err.string) + end) + + it("should ignore immutable fields if they are required", function() + local values = { string = "somestring" } + + local valid, err = validate_fields(values, schema, {partial_update = true}) + assert.falsy(err) + assert.True(valid) + end) + + it("should prevent immutable fields to be changed", function() + -- Success + local values = {string = "somestring", date = 1234} + + local valid, err = validate_fields(values, schema) + assert.falsy(err) + assert.truthy(valid) + + -- Failure + local valid, err = validate_fields(values, schema, {partial_update = true}) + assert.False(valid) + assert.truthy(err) + assert.equal("date cannot be updated", err.date) + end) + end) + + describe("[full_update]", function() + it("should not ignore required properties and ignore defaults", function() + local values = {} + + local valid, err = validate_fields(values, schema, {full_update = true}) + assert.False(valid) + assert.truthy(err) + assert.equal("string is required", err.string) + assert.falsy(values.default) + end) + end) + end) end) diff --git a/spec/unit/stub_coverage_spec.lua b/spec/unit/stub_coverage_spec.lua index 7fa86a6bb65..4f42c70fa39 100644 --- a/spec/unit/stub_coverage_spec.lua +++ b/spec/unit/stub_coverage_spec.lua @@ -6,7 +6,7 @@ local IO = require "kong.tools.io" -- Stub DAO for lapis controllers _G.dao = {} -local lua_sources = IO.retrieve_files("./kong", { exclude_dir_patterns = {"cli", "vendor", "filelog", "reports"}, file_pattern = ".lua" }) +local lua_sources = IO.retrieve_files("./kong", { exclude_dir_patterns = {"cli", "vendor", "filelog", "reports"}, file_pattern = ".lua$" }) for _, source_link in ipairs(lua_sources) do dofile(source_link)