diff --git a/apisix/stream/plugins/mqtt-proxy.lua b/apisix/stream/plugins/mqtt-proxy.lua index fae0eb08f5f4..7c890505aa9b 100644 --- a/apisix/stream/plugins/mqtt-proxy.lua +++ b/apisix/stream/plugins/mqtt-proxy.lua @@ -67,27 +67,39 @@ function _M.check_schema(conf) end -local function parse_mqtt(data) - local res = {} - res.packet_type_flags_byte = str_byte(data, 1, 1) - if res.packet_type_flags_byte < 16 or res.packet_type_flags_byte > 32 then - return nil, "Received unexpected MQTT packet type+flags: " - .. res.packet_type_flags_byte - end - - local parsed_pos = 1 - res.remaining_len = 0 +local function decode_variable_byte_int(data, offset) local multiplier = 1 - for i = 2, 5 do - parsed_pos = i + local len = 0 + local pos + for i = offset, offset + 3 do + pos = i local byte = str_byte(data, i, i) - res.remaining_len = res.remaining_len + bit.band(byte, 127) * multiplier + len = len + bit.band(byte, 127) * multiplier multiplier = multiplier * 128 if bit.band(byte, 128) == 0 then break end end + return len, pos +end + + +local function parse_msg_hdr(data) + local packet_type_flags_byte = str_byte(data, 1, 1) + if packet_type_flags_byte < 16 or packet_type_flags_byte > 32 then + return nil, nil, + "Received unexpected MQTT packet type+flags: " .. packet_type_flags_byte + end + + local len, pos = decode_variable_byte_int(data, 2) + return len, pos +end + + +local function parse_mqtt(data, parsed_pos) + local res = {} + local protocol_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256 + str_byte(data, parsed_pos + 2, parsed_pos + 2) parsed_pos = parsed_pos + 2 @@ -96,10 +108,15 @@ local function parse_mqtt(data) res.protocol_ver = str_byte(data, parsed_pos + 1, parsed_pos + 1) parsed_pos = parsed_pos + 1 - if res.protocol_ver == 4 then - parsed_pos = parsed_pos + 3 - elseif res.protocol_ver == 5 then - parsed_pos = parsed_pos + 9 + + -- skip control flags & keepalive + parsed_pos = parsed_pos + 3 + + if res.protocol_ver == 5 then + -- skip properties + local property_len + property_len, parsed_pos = decode_variable_byte_int(data, parsed_pos + 1) + parsed_pos = parsed_pos + property_len end local client_id_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256 @@ -129,31 +146,29 @@ function _M.preread(conf, ctx) local sock = ngx.req.socket() -- the header format of MQTT CONNECT can be found in -- https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901033 - local data, err = sock:peek(14) + local data, err = sock:peek(5) if not data then - core.log.error("failed to read first 16 bytes: ", err) + core.log.error("failed to read the msg header: ", err) return 503 end - local res, err = parse_mqtt(data) - if not res then - core.log.error("failed to parse the first 16 bytes: ", err) + local remain_len, pos, err = parse_msg_hdr(data) + if not remain_len then + core.log.error("failed to parse the msg header: ", err) return 503 end - if res.expect_len > #data then - data, err = sock:peek(res.expect_len) - if not data then - core.log.error("failed to read ", res.expect_len, " bytes: ", err) - return 503 - end + local data, err = sock:peek(pos + remain_len) + if not data then + core.log.error("failed to read the Connect Command: ", err) + return 503 + end - res = parse_mqtt(data) - if res.expect_len > #data then - core.log.error("failed to parse mqtt request, expect len: ", - res.expect_len, " but got ", #data) - return 503 - end + local res = parse_mqtt(data, pos) + if res.expect_len > #data then + core.log.error("failed to parse mqtt request, expect len: ", + res.expect_len, " but got ", #data) + return 503 end if res.protocol and res.protocol ~= conf.protocol_name then diff --git a/t/stream-plugin/mqtt-proxy.t b/t/stream-plugin/mqtt-proxy.t index ae46fa8cdcc9..3aa5cdfda0f0 100644 --- a/t/stream-plugin/mqtt-proxy.t +++ b/t/stream-plugin/mqtt-proxy.t @@ -328,7 +328,7 @@ mqtt client id: foo === TEST 13: hit route with empty client id --- stream_enable --- stream_request eval -"\x10\x0f\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00" +"\x10\x0c\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00" --- stream_response hello world --- grep_error_log eval @@ -336,3 +336,74 @@ qr/mqtt client id: \w+/ --- grep_error_log_out --- no_error_log [error] + + + +=== TEST 14: MQTT 5 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/stream_routes/1', + ngx.HTTP_PUT, + [[{ + "remote_addr": "127.0.0.1", + "server_port": 1985, + "plugins": { + "mqtt-proxy": { + "protocol_name": "MQTT", + "protocol_level": 5 + } + }, + "upstream": { + "type": "roundrobin", + "nodes": [{ + "host": "127.0.0.1", + "port": 1995, + "weight": 1 + }] + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- request +GET /t +--- response_body +passed +--- no_error_log +[error] + + + +=== TEST 15: hit route with empty property +--- stream_enable +--- stream_request eval +"\x10\x0d\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x00\x00\x00" +--- stream_response +hello world +--- grep_error_log eval +qr/mqtt client id: \w+/ +--- grep_error_log_out +--- no_error_log +[error] + + + +=== TEST 16: hit route with property +--- stream_enable +--- stream_request eval +"\x10\x1b\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x05\x11\x00\x00\x0e\x10\x00\x09\x63\x6c\x69\x6e\x74\x2d\x31\x31\x31" +--- stream_response +hello world +--- grep_error_log eval +qr/mqtt client id: \S+/ +--- grep_error_log_out +mqtt client id: clint-111 +--- no_error_log +[error]