diff --git a/apisix/http/route.lua b/apisix/http/route.lua index 6292b577a071..d475646b56c6 100644 --- a/apisix/http/route.lua +++ b/apisix/http/route.lua @@ -92,7 +92,7 @@ function _M.create_radixtree_uri_router(routes, uri_routes, with_parameter) end end - event.push(event.CONST.BUILD_ROUTER, uri_routes) + event.push(event.CONST.BUILD_ROUTER, routes) core.log.info("route items: ", core.json.delay_encode(uri_routes, true)) if with_parameter then diff --git a/apisix/init.lua b/apisix/init.lua index 2030e0241d51..883dbd9abab0 100644 --- a/apisix/init.lua +++ b/apisix/init.lua @@ -348,6 +348,115 @@ local function common_phase(phase_name) end + +function _M.handle_upstream(api_ctx, route, enable_websocket) + local up_id = route.value.upstream_id + + -- used for the traffic-split plugin + if api_ctx.upstream_id then + up_id = api_ctx.upstream_id + end + + if up_id then + local upstream = apisix_upstream.get_by_id(up_id) + if not upstream then + if is_http then + return core.response.exit(502) + end + + return ngx_exit(1) + end + + api_ctx.matched_upstream = upstream + + else + if route.has_domain then + local err + route, err = parse_domain_in_route(route) + if err then + core.log.error("failed to get resolved route: ", err) + return core.response.exit(500) + end + + api_ctx.conf_version = route.modifiedIndex + api_ctx.matched_route = route + end + + local route_val = route.value + + api_ctx.matched_upstream = (route.dns_value and + route.dns_value.upstream) + or route_val.upstream + end + + if api_ctx.matched_upstream and api_ctx.matched_upstream.tls and + api_ctx.matched_upstream.tls.client_cert_id then + + local cert_id = api_ctx.matched_upstream.tls.client_cert_id + local upstream_ssl = router.router_ssl.get_by_id(cert_id) + if not upstream_ssl or upstream_ssl.type ~= "client" then + local err = upstream_ssl and + "ssl type should be 'client'" or + "ssl id [" .. cert_id .. "] not exits" + core.log.error("failed to get ssl cert: ", err) + + if is_http then + return core.response.exit(502) + end + + return ngx_exit(1) + end + + core.log.info("matched ssl: ", + core.json.delay_encode(upstream_ssl, true)) + api_ctx.upstream_ssl = upstream_ssl + end + + if enable_websocket then + api_ctx.var.upstream_upgrade = api_ctx.var.http_upgrade + api_ctx.var.upstream_connection = api_ctx.var.http_connection + core.log.info("enabled websocket for route: ", route.value.id) + end + + -- load balancer is not required by kafka upstream, so the upstream + -- node selection process is intercepted and left to kafka to + -- handle on its own + if api_ctx.matched_upstream and api_ctx.matched_upstream.scheme == "kafka" then + return pubsub_kafka.access(api_ctx) + end + + local code, err = set_upstream(route, api_ctx) + if code then + core.log.error("failed to set upstream: ", err) + core.response.exit(code) + end + + local server, err = load_balancer.pick_server(route, api_ctx) + if not server then + core.log.error("failed to pick server: ", err) + return core.response.exit(502) + end + + api_ctx.picked_server = server + + set_upstream_headers(api_ctx, server) + + -- run the before_proxy method in access phase first to avoid always reinit request + common_phase("before_proxy") + + local up_scheme = api_ctx.upstream_scheme + if up_scheme == "grpcs" or up_scheme == "grpc" then + stash_ngx_ctx() + return ngx.exec("@grpc_pass") + end + + if api_ctx.dubbo_proxy_enabled then + stash_ngx_ctx() + return ngx.exec("@dubbo_pass") + end +end + + function _M.http_access_phase() local ngx_ctx = ngx.ctx @@ -495,110 +604,7 @@ function _M.http_access_phase() plugin.run_plugin("access", plugins, api_ctx) end - local up_id = route.value.upstream_id - - -- used for the traffic-split plugin - if api_ctx.upstream_id then - up_id = api_ctx.upstream_id - end - - if up_id then - local upstream = apisix_upstream.get_by_id(up_id) - if not upstream then - if is_http then - return core.response.exit(502) - end - - return ngx_exit(1) - end - - api_ctx.matched_upstream = upstream - - else - if route.has_domain then - local err - route, err = parse_domain_in_route(route) - if err then - core.log.error("failed to get resolved route: ", err) - return core.response.exit(500) - end - - api_ctx.conf_version = route.modifiedIndex - api_ctx.matched_route = route - end - - local route_val = route.value - - api_ctx.matched_upstream = (route.dns_value and - route.dns_value.upstream) - or route_val.upstream - end - - if api_ctx.matched_upstream and api_ctx.matched_upstream.tls and - api_ctx.matched_upstream.tls.client_cert_id then - - local cert_id = api_ctx.matched_upstream.tls.client_cert_id - local upstream_ssl = router.router_ssl.get_by_id(cert_id) - if not upstream_ssl or upstream_ssl.type ~= "client" then - local err = upstream_ssl and - "ssl type should be 'client'" or - "ssl id [" .. cert_id .. "] not exits" - core.log.error("failed to get ssl cert: ", err) - - if is_http then - return core.response.exit(502) - end - - return ngx_exit(1) - end - - core.log.info("matched ssl: ", - core.json.delay_encode(upstream_ssl, true)) - api_ctx.upstream_ssl = upstream_ssl - end - - if enable_websocket then - api_ctx.var.upstream_upgrade = api_ctx.var.http_upgrade - api_ctx.var.upstream_connection = api_ctx.var.http_connection - core.log.info("enabled websocket for route: ", route.value.id) - end - - -- load balancer is not required by kafka upstream, so the upstream - -- node selection process is intercepted and left to kafka to - -- handle on its own - if api_ctx.matched_upstream and api_ctx.matched_upstream.scheme == "kafka" then - return pubsub_kafka.access(api_ctx) - end - - local code, err = set_upstream(route, api_ctx) - if code then - core.log.error("failed to set upstream: ", err) - core.response.exit(code) - end - - local server, err = load_balancer.pick_server(route, api_ctx) - if not server then - core.log.error("failed to pick server: ", err) - return core.response.exit(502) - end - - api_ctx.picked_server = server - - set_upstream_headers(api_ctx, server) - - -- run the before_proxy method in access phase first to avoid always reinit request - common_phase("before_proxy") - - local up_scheme = api_ctx.upstream_scheme - if up_scheme == "grpcs" or up_scheme == "grpc" then - stash_ngx_ctx() - return ngx.exec("@grpc_pass") - end - - if api_ctx.dubbo_proxy_enabled then - stash_ngx_ctx() - return ngx.exec("@dubbo_pass") - end + _M.handle_upstream(api_ctx, route, enable_websocket) end diff --git a/apisix/plugins/ai.lua b/apisix/plugins/ai.lua index cb69f59a7c68..eeb78ca80420 100644 --- a/apisix/plugins/ai.lua +++ b/apisix/plugins/ai.lua @@ -15,12 +15,19 @@ -- limitations under the License. -- local require = require +local apisix = require("apisix") local core = require("apisix.core") local router = require("apisix.router") local event = require("apisix.core.event") +local load_balancer = require("apisix.balancer") +local balancer = require("ngx.balancer") +local is_http = ngx.config.subsystem == "http" +local enable_keepalive = balancer.enable_keepalive and is_http local ipairs = ipairs local pcall = pcall local loadstring = loadstring +local type = type +local pairs = pairs local get_cache_key_func local get_cache_key_func_def_render @@ -50,14 +57,17 @@ local plugin_name = "ai" local _M = { version = 0.1, - priority = 25000, + priority = 22900, name = plugin_name, schema = schema, scope = "global", } local orig_router_match +local orig_handle_upstream = apisix.handle_upstream +local orig_balancer_run = load_balancer.run +local default_keepalive_pool = {} local function match_route(ctx) orig_router_match(ctx) @@ -100,33 +110,98 @@ local function gen_get_cache_key_func(route_flags) end -local function routes_analyze(routes) - -- TODO: need to add a option in config.yaml to enable this feature(default is true) - local route_flags = core.table.new(0, 2) - for _, route in ipairs(routes) do - if route.methods then - route_flags["methods"] = true - end +local function ai_upstream() + core.log.info("enable sample upstream") +end - if route.host or route.hosts then - route_flags["host"] = true - end - if route.vars then - route_flags["vars"] = true +local pool_opt +local function ai_balancer_run(route) + local server = route.value.upstream.nodes[1] + if enable_keepalive then + local ok, err = balancer.set_current_peer(server.host, server.port or 80, pool_opt) + if not ok then + core.log.error("failed to set server peer [", server.host, ":", + server.port, "] err: ", err) + return ok, err end + balancer.enable_keepalive(default_keepalive_pool.idle_timeout, + default_keepalive_pool.requests) + else + balancer.set_current_peer(server.host, server.port or 80) + end +end - if route.filter_fun then - route_flags["filter_fun"] = true - end +local function routes_analyze(routes) + local route_flags = core.table.new(0, 16) + local route_up_flags = core.table.new(0, 12) + for _, route in ipairs(routes) do + if type(route) == "table" then + for key, value in pairs(route.value) do + -- collect route flags + if key == "methods" then + route_flags["methods"] = true + elseif key == "host" or key == "hosts" then + route_flags["host"] = true + elseif key == "vars" then + route_flags["vars"] = true + elseif key == "filter_fun"then + route_flags["filter_fun"] = true + elseif key == "remote_addr" or key == "remote_addrs" then + route_flags["remote_addr"] = true + elseif key == "service" then + route_flags["service"] = true + elseif key == "enable_websocket" then + route_flags["enable_websocket"] = true + elseif key == "plugins" then + route_flags["plugins"] = true + elseif key == "upstream_id" then + route_flags["upstream_id"] = true + elseif key == "service_id" then + route_flags["service_id"] = true + elseif key == "plugin_config_id" then + route_flags["plugin_config_id"] = true + end - if route.remote_addr or route.remote_addrs then - route_flags["remote_addr"] = true + -- collect upstream flags + if key == "upstream" then + if value.nodes and #value.nodes == 1 then + for k, v in pairs(value) do + if k == "nodes" then + if (not core.utils.parse_ipv4(v[1].host) + and not core.utils.parse_ipv6(v[1].host)) then + route_up_flags["has_domain"] = true + end + elseif k == "pass_host" and v ~= "pass" then + route_up_flags["pass_host"] = true + elseif k == "scheme" and v ~= "http" then + route_up_flags["scheme"] = true + elseif k == "checks" then + route_up_flags["checks"] = true + elseif k == "retries" then + route_up_flags["retries"] = true + elseif k == "timeout" then + route_up_flags["timeout"] = true + elseif k == "tls" then + route_up_flags["tls"] = true + elseif k == "keepalive" then + route_up_flags["keepalive"] = true + elseif k == "service_name" then + route_up_flags["service_name"] = true + end + end + else + route_up_flags["more_nodes"] = true + end + end + end end end if route_flags["vars"] or route_flags["filter_fun"] - or route_flags["remote_addr"] then + or route_flags["remote_addr"] + or route_flags["service_id"] + or route_flags["plugin_config_id"] then router.router_http.match = orig_router_match else core.log.info("use ai plane to match route") @@ -138,11 +213,44 @@ local function routes_analyze(routes) router.router_http.match = orig_router_match end end + + if route_flags["service"] + or route_flags["service_id"] + or route_flags["upstream_id"] + or route_flags["enable_websocket"] + or route_flags["plugins"] + or route_up_flags["has_domain"] + or route_up_flags["pass_host"] + or route_up_flags["scheme"] + or route_up_flags["checks"] + or route_up_flags["retries"] + or route_up_flags["timeout"] + or route_up_flags["tls"] + or route_up_flags["keepalive"] + or route_up_flags["service_name"] + or route_up_flags["more_nodes"] then + apisix.handle_upstream = orig_handle_upstream + load_balancer.run = orig_balancer_run + else + -- replace the upstream module + apisix.handle_upstream = ai_upstream + load_balancer.run = ai_balancer_run + end end function _M.init() event.register(event.CONST.BUILD_ROUTER, routes_analyze) + local local_conf = core.config.local_conf() + local up_keepalive_conf = + core.table.try_read_attr(local_conf, "nginx_config", + "http", "upstream") + default_keepalive_pool.idle_timeout = + core.config_util.parse_time_unit(up_keepalive_conf.keepalive_timeout) + default_keepalive_pool.size = up_keepalive_conf.keepalive + default_keepalive_pool.requests = up_keepalive_conf.keepalive_requests + + pool_opt = { pool_size = default_keepalive_pool.size } end diff --git a/conf/config-default.yaml b/conf/config-default.yaml index bad0e41e4c1d..6e714b577346 100755 --- a/conf/config-default.yaml +++ b/conf/config-default.yaml @@ -389,8 +389,8 @@ graphql: #cmd: ["ls", "-l"] plugins: # plugin list (sorted by priority) - - ai # priority: 25000 - real-ip # priority: 23000 + - ai # priority: 22900 - client-control # priority: 22000 - proxy-control # priority: 21990 - request-id # priority: 12015 diff --git a/t/admin/plugins.t b/t/admin/plugins.t index 74827e437ebf..98e337e57dfc 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -61,8 +61,8 @@ __DATA__ } --- response_body -ai real-ip +ai client-control proxy-control request-id diff --git a/t/core/config.t b/t/core/config.t index 18191dae724f..29d1cc52dc07 100644 --- a/t/core/config.t +++ b/t/core/config.t @@ -38,7 +38,7 @@ __DATA__ GET /t --- response_body etcd host: http://127.0.0.1:2379 -first plugin: "ai" +first plugin: "real-ip" diff --git a/t/debug/dynamic-hook.t b/t/debug/dynamic-hook.t index 692942d1f9e4..87d4450d569c 100644 --- a/t/debug/dynamic-hook.t +++ b/t/debug/dynamic-hook.t @@ -377,6 +377,11 @@ qr/call\srequire\(\"apisix.plugin\"\).filter\(\)\sreturn.*GET\s\/mysleep\?second === TEST 6: hook function with ctx as param +# ai module would conflict with the debug module +--- extra_yaml_config +plugins: + #ai + - example-plugin --- debug_config basic: enable: true diff --git a/t/plugin/ai.t b/t/plugin/ai.t index 9415771ab629..f695788ec190 100644 --- a/t/plugin/ai.t +++ b/t/plugin/ai.t @@ -620,3 +620,258 @@ route cache key: /hello#GET done --- error_log route cache key: /hello#GET#127.0.0.1 + + + +=== TEST 9: enable sample upstream +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET"], + "upstream": { + "nodes": { + "127.0.0.1:1980": 1 + }, + "type": "roundrobin" + }, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.sleep(0.5) + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + ngx.say("done") + } + } +--- response_body +done +--- error_log +enable sample upstream + + + +=== TEST 10: route has plugins and run before_proxy, disable samply upstream +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "plugins": { + "serverless-pre-function": { + "phase": "before_proxy", + "functions" : ["return function(conf, ctx) ngx.log(ngx.WARN, \"run before_proxy phase balancer_ip : \", ctx.balancer_ip) end"] + } + }, + "upstream": { + "nodes": { + "127.0.0.1:1980": 1 + }, + "type": "roundrobin" + }, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + ngx.say("done") + } + } +--- response_body +done +--- error_log +run before_proxy phase balancer_ip : 127.0.0.1 +--- no_error_log +enable sample upstream + + + +=== TEST 11: upstream has more than one nodes, disable sample upstream +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET"], + "upstream": { + "nodes": { + "127.0.0.1:1980": 1, + "127.0.0.1:1981": 1 + }, + "type": "roundrobin" + }, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.sleep(0.5) + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + ngx.say("done") + } + } +--- response_body +done +--- no_error_log +enable sample upstream + + + +=== TEST 12: node has domain, disable sample upstream +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET"], + "upstream": { + "nodes": { + "admin.apisix.dev:1980": 1 + }, + "type": "roundrobin" + }, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + ngx.say("done") + } + } +--- response_body +done +--- no_error_log +enable sample upstream + + + +=== TEST 13: enable --> disable sample upstream +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET"], + "upstream": { + "nodes": { + "127.0.0.1:1980": 1 + }, + "type": "roundrobin" + }, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.sleep(0.5) + + local http = require "resty.http" + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "methods": ["GET"], + "upstream": { + "nodes": { + "127.0.0.1:1980": 1 + }, + "type": "roundrobin" + }, + "enable_websocket": true, + "uri": "/hello" + }]] + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + ngx.sleep(0.5) + + local uri = "http://127.0.0.1:" .. ngx.var.server_port .. "/hello" + local httpc = http.new() + local res, err = httpc:request_uri(uri) + assert(res.status == 200) + if not res then + ngx.log(ngx.ERR, err) + return + end + + ngx.say("done") + } + } +--- response_body +done +--- grep_error_log eval +qr/enable sample upstream/ +--- grep_error_log_out +enable sample upstream