Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mqtt): handle properties for MQTT 5 #5916

Merged
merged 1 commit into from
Dec 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions apisix/stream/plugins/mqtt-proxy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,39 @@ function _M.check_schema(conf)
end


local function parse_mqtt(data)
local res = {}
res.packet_type_flags_byte = str_byte(data, 1, 1)
if res.packet_type_flags_byte < 16 or res.packet_type_flags_byte > 32 then
return nil, "Received unexpected MQTT packet type+flags: "
.. res.packet_type_flags_byte
end

local parsed_pos = 1
res.remaining_len = 0
local function decode_variable_byte_int(data, offset)
local multiplier = 1
for i = 2, 5 do
parsed_pos = i
local len = 0
local pos
for i = offset, offset + 3 do
pos = i
local byte = str_byte(data, i, i)
res.remaining_len = res.remaining_len + bit.band(byte, 127) * multiplier
len = len + bit.band(byte, 127) * multiplier
multiplier = multiplier * 128
if bit.band(byte, 128) == 0 then
break
end
end

return len, pos
end


local function parse_msg_hdr(data)
local packet_type_flags_byte = str_byte(data, 1, 1)
if packet_type_flags_byte < 16 or packet_type_flags_byte > 32 then
return nil, nil,
"Received unexpected MQTT packet type+flags: " .. packet_type_flags_byte
end

local len, pos = decode_variable_byte_int(data, 2)
return len, pos
end


local function parse_mqtt(data, parsed_pos)
local res = {}

local protocol_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256
+ str_byte(data, parsed_pos + 2, parsed_pos + 2)
parsed_pos = parsed_pos + 2
Expand All @@ -96,10 +108,15 @@ local function parse_mqtt(data)

res.protocol_ver = str_byte(data, parsed_pos + 1, parsed_pos + 1)
parsed_pos = parsed_pos + 1
if res.protocol_ver == 4 then
parsed_pos = parsed_pos + 3
elseif res.protocol_ver == 5 then
parsed_pos = parsed_pos + 9

-- skip control flags & keepalive
parsed_pos = parsed_pos + 3

if res.protocol_ver == 5 then
-- skip properties
local property_len
property_len, parsed_pos = decode_variable_byte_int(data, parsed_pos + 1)
parsed_pos = parsed_pos + property_len
end

local client_id_len = str_byte(data, parsed_pos + 1, parsed_pos + 1) * 256
Expand Down Expand Up @@ -129,31 +146,29 @@ function _M.preread(conf, ctx)
local sock = ngx.req.socket()
-- the header format of MQTT CONNECT can be found in
-- https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901033
local data, err = sock:peek(14)
local data, err = sock:peek(5)
if not data then
core.log.error("failed to read first 16 bytes: ", err)
core.log.error("failed to read the msg header: ", err)
return 503
end

local res, err = parse_mqtt(data)
if not res then
core.log.error("failed to parse the first 16 bytes: ", err)
local remain_len, pos, err = parse_msg_hdr(data)
if not remain_len then
core.log.error("failed to parse the msg header: ", err)
return 503
end

if res.expect_len > #data then
data, err = sock:peek(res.expect_len)
if not data then
core.log.error("failed to read ", res.expect_len, " bytes: ", err)
return 503
end
local data, err = sock:peek(pos + remain_len)
if not data then
core.log.error("failed to read the Connect Command: ", err)
return 503
end

res = parse_mqtt(data)
if res.expect_len > #data then
core.log.error("failed to parse mqtt request, expect len: ",
res.expect_len, " but got ", #data)
return 503
end
local res = parse_mqtt(data, pos)
if res.expect_len > #data then
core.log.error("failed to parse mqtt request, expect len: ",
res.expect_len, " but got ", #data)
return 503
end

if res.protocol and res.protocol ~= conf.protocol_name then
Expand Down
73 changes: 72 additions & 1 deletion t/stream-plugin/mqtt-proxy.t
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,82 @@ mqtt client id: foo
=== TEST 13: hit route with empty client id
--- stream_enable
--- stream_request eval
"\x10\x0f\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00"
"\x10\x0c\x00\x04\x4d\x51\x54\x54\x04\x02\x00\x3c\x00\x00"
--- stream_response
hello world
--- grep_error_log eval
qr/mqtt client id: \w+/
--- grep_error_log_out
--- no_error_log
[error]



=== TEST 14: MQTT 5
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/stream_routes/1',
ngx.HTTP_PUT,
[[{
"remote_addr": "127.0.0.1",
"server_port": 1985,
"plugins": {
"mqtt-proxy": {
"protocol_name": "MQTT",
"protocol_level": 5
}
},
"upstream": {
"type": "roundrobin",
"nodes": [{
"host": "127.0.0.1",
"port": 1995,
"weight": 1
}]
}
}]]
)

if code >= 300 then
ngx.status = code
end
ngx.say(body)
}
}
--- request
GET /t
--- response_body
passed
--- no_error_log
[error]



=== TEST 15: hit route with empty property
--- stream_enable
--- stream_request eval
"\x10\x0d\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x00\x00\x00"
--- stream_response
hello world
--- grep_error_log eval
qr/mqtt client id: \w+/
--- grep_error_log_out
--- no_error_log
[error]



=== TEST 16: hit route with property
--- stream_enable
--- stream_request eval
"\x10\x1b\x00\x04\x4d\x51\x54\x54\x05\x02\x00\x3c\x05\x11\x00\x00\x0e\x10\x00\x09\x63\x6c\x69\x6e\x74\x2d\x31\x31\x31"
--- stream_response
hello world
--- grep_error_log eval
qr/mqtt client id: \S+/
--- grep_error_log_out
mqtt client id: clint-111
--- no_error_log
[error]