From e8cc3236387ffb43886fe279a089b295b2d82757 Mon Sep 17 00:00:00 2001 From: Fu Zi Xiang Date: Sun, 5 Nov 2023 04:07:37 +0800 Subject: [PATCH] feat: binary protocol null --- integration_tests/client.zig | 49 +++++++------ src/constants.zig | 42 +++++++++++ src/helper.zig | 125 +++++++++++++++++++++++++++++++-- src/protocol/packet_reader.zig | 2 +- src/result.zig | 33 ++++----- 5 files changed, 203 insertions(+), 48 deletions(-) diff --git a/integration_tests/client.zig b/integration_tests/client.zig index f68e0d1..366c2b8 100644 --- a/integration_tests/client.zig +++ b/integration_tests/client.zig @@ -242,26 +242,33 @@ test "prepare execute - 2" { } } -// test "prepare execute with result" { -// var c = Client.init(test_config); -// defer c.deinit(); -// -// { -// const query = -// \\SELECT 1,2,3,4,5 -// ; -// const prep_res = try c.prepare(allocator, query); -// defer prep_res.deinit(allocator); -// const prep_ok = try expectOk(prep_res.value); -// _ = prep_ok; -// // const query_res = try c.execute(allocator, &prep_ok); -// // defer query_res.deinit(allocator); -// // const rows = (try expectRows(query_res.value)).iter(); -// // while (try rows.next(allocator)) |row| { -// // std.debug.print("row: {any}\n", .{row.value.raw}); -// // defer row.deinit(allocator); -// // } -// } -// } +test "prepare execute with result" { + var c = Client.init(test_config); + defer c.deinit(); + + { + const query = + \\SELECT null + ; + const prep_res = try c.prepare(allocator, query); + defer prep_res.deinit(allocator); + const prep_stmt = try expectOk(prep_res.value); + const query_res = try c.execute(allocator, &prep_stmt); + defer query_res.deinit(allocator); + const rows = (try expectRows(query_res.value)).iter(); + + const MyType = struct { + a: ?u8, + }; + const expected = MyType{ .a = null }; + + var dest: MyType = undefined; + while (try rows.next(allocator)) |row| { + defer row.deinit(allocator); + try row.scanStruct(&dest); + try std.testing.expectEqual(expected, dest); + } + } +} // SELECT CONCAT(?, ?) AS col1 diff --git a/src/constants.zig b/src/constants.zig index f45d1b8..8fc90d0 100644 --- a/src/constants.zig +++ b/src/constants.zig @@ -67,6 +67,48 @@ pub const COM_QUERY :u8 = 0x03; pub const COM_STMT_PREPARE: u8 = 0x16; pub const COM_STMT_EXECUTE: u8 = 0x17; +pub const BINARY_PROTOCOL_RESULTSET_ROW_HEADER: u8 = 0x00; + +// https://dev.mysql.com/doc/dev/mysql-server/latest/field__types_8h_source.html +pub const EnumFieldType = enum(u8) { + MYSQL_TYPE_DECIMAL, + MYSQL_TYPE_TINY, + MYSQL_TYPE_SHORT, + MYSQL_TYPE_LONG, + MYSQL_TYPE_FLOAT, + MYSQL_TYPE_DOUBLE, + MYSQL_TYPE_NULL, + MYSQL_TYPE_TIMESTAMP, + MYSQL_TYPE_LONGLONG, + MYSQL_TYPE_INT24, + MYSQL_TYPE_DATE, + MYSQL_TYPE_TIME, + MYSQL_TYPE_DATETIME, + MYSQL_TYPE_YEAR, + MYSQL_TYPE_NEWDATE, + MYSQL_TYPE_VARCHAR, + MYSQL_TYPE_BIT, + MYSQL_TYPE_TIMESTAMP2, + MYSQL_TYPE_DATETIME2, + MYSQL_TYPE_TIME2, + MYSQL_TYPE_TYPED_ARRAY, + + MYSQL_TYPE_INVALID = 243, + MYSQL_TYPE_BOOL = 244, + MYSQL_TYPE_JSON = 245, + MYSQL_TYPE_NEWDECIMAL = 246, + MYSQL_TYPE_ENUM = 247, + MYSQL_TYPE_SET = 248, + MYSQL_TYPE_TINY_BLOB = 249, + MYSQL_TYPE_MEDIUM_BLOB = 250, + MYSQL_TYPE_LONG_BLOB = 251, + MYSQL_TYPE_BLOB = 252, + MYSQL_TYPE_VAR_STRING = 253, + MYSQL_TYPE_STRING = 254, + MYSQL_TYPE_GEOMETRY = 255 +}; + + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase_utility.html pub const COM_QUIT: u8 = 0x01; pub const COM_INIT_DB: u8 = 0x02; diff --git a/src/helper.zig b/src/helper.zig index 6338fad..25cdad0 100644 --- a/src/helper.zig +++ b/src/helper.zig @@ -1,11 +1,14 @@ /// This file is to store convenience functions and methods for callers. const std = @import("std"); const constants = @import("./constants.zig"); +const EnumFieldType = constants.EnumFieldType; const result = @import("./result.zig"); const ResultSet = result.ResultSet; const TextResultRow = result.TextResultRow; const Options = result.BinaryResultRow.Options; -const PacketReader = @import("./protocol/packet_reader.zig").PacketReader; +const protocol = @import("./protocol.zig"); +const PacketReader = protocol.packet_reader.PacketReader; +const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41; pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void { var packet_reader = PacketReader.initFromPayload(raw); @@ -24,11 +27,119 @@ pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void { } } -pub fn scanBinaryResultRow(comptime T: type, raw: []const u8, dest: *T, options: Options) !void { - _ = options; - _ = dest; - _ = raw; - @panic("not implemented"); +// dest is a pointer to a struct +pub fn scanBinResRowtoStruct(dest: anytype, raw: []const u8, col_defs: []ColumnDefinition41) void { + var reader = PacketReader.initFromPayload(raw); + const first = reader.readByte(); + std.debug.assert(first == constants.BINARY_PROTOCOL_RESULTSET_ROW_HEADER); + + // null bitmap + const null_bitmap_len = (col_defs.len + 7 + 2) / 8; + const null_bitmap = reader.readFixedRuntime(null_bitmap_len); + + const child_type = @typeInfo(@TypeOf(dest)).Pointer.child; + const struct_fields = @typeInfo(child_type).Struct.fields; + + std.debug.assert(struct_fields.len == col_defs.len); + + inline for (struct_fields, col_defs, 0..) |field, col_def, i| { + if (!binResIsNull(null_bitmap, i)) { + @field(dest, field.name) = binElemToValue(field.type, &col_def, &reader); + } else { + switch (@typeInfo(field.type)) { + .Optional => @field(dest, field.name) = null, + else => { + std.log.err("field {s} is not optional\n", .{field.name}); + unreachable; + }, + } + } + } + std.debug.assert(reader.finished()); +} + +inline fn binElemToValue(comptime T: type, col_def: *const ColumnDefinition41, reader: *PacketReader) T { + _ = reader; + const col_type: EnumFieldType = @enumFromInt(col_def.column_type); + return switch (col_type) { + else => { + std.log.err("unimplemented col_type: {any}\n", .{col_type}); + unreachable; + }, + // .MYSQL_TYPE_DECIMAL => {}, + // .MYSQL_TYPE_TINY => {}, + // .MYSQL_TYPE_SHORT => {}, + // .MYSQL_TYPE_LONG => {}, + // .MYSQL_TYPE_FLOAT => {}, + // .MYSQL_TYPE_DOUBLE => {}, + // .MYSQL_TYPE_NULL => {}, + // .MYSQL_TYPE_TIMESTAMP => {}, + // .MYSQL_TYPE_LONGLONG => {}, + // .MYSQL_TYPE_INT24 => {}, + // .MYSQL_TYPE_DATE => {}, + // .MYSQL_TYPE_TIME => {}, + // .MYSQL_TYPE_DATETIME => {}, + // .MYSQL_TYPE_YEAR => {}, + // .MYSQL_TYPE_NEWDATE => {}, + // .MYSQL_TYPE_VARCHAR => {}, + // .MYSQL_TYPE_BIT => {}, + // .MYSQL_TYPE_TIMESTAMP2 => {}, + // .MYSQL_TYPE_DATETIME2 => {}, + // .MYSQL_TYPE_TIME2 => {}, + // .MYSQL_TYPE_TYPED_ARRAY => {}, + + // .MYSQL_TYPE_INVALID => {}, + // .MYSQL_TYPE_BOOL => {}, + // .MYSQL_TYPE_JSON => {}, + // .MYSQL_TYPE_NEWDECIMAL => {}, + // .MYSQL_TYPE_ENUM => {}, + // .MYSQL_TYPE_SET => {}, + // .MYSQL_TYPE_TINY_BLOB => {}, + // .MYSQL_TYPE_MEDIUM_BLOB => {}, + // .MYSQL_TYPE_LONG_BLOB => {}, + // .MYSQL_TYPE_BLOB => {}, + // .MYSQL_TYPE_VAR_STRING => {}, + // .MYSQL_TYPE_STRING => {}, + // .MYSQL_TYPE_GEOMETRY => {}, + }; +} + +inline fn binResIsNull(null_bitmap: []const u8, col_idx: usize) bool { + // TODO: optimize: divmod + const byte_idx = (col_idx + 2) / 8; + const bit_idx = (col_idx + 2) % 8; + const byte = null_bitmap[byte_idx]; + return (byte & (1 << bit_idx)) > 0; +} + +test "binResIsNull" { + var tests = .{ + .{ + .null_bitmap = &.{0b00000100}, + .col_idx = 0, + .expected = true, + }, + .{ + .null_bitmap = &.{0b00000000}, + .col_idx = 0, + .expected = false, + }, + .{ + .null_bitmap = &.{ 0b00000000, 0b00000001 }, + .col_idx = 6, + .expected = true, + }, + .{ + .null_bitmap = &.{ 0b10000000, 0b00000000 }, + .col_idx = 5, + .expected = true, + }, + }; + + inline for (tests) |t| { + const actual = binResIsNull(t.null_bitmap, t.col_idx); + try std.testing.expectEqual(t.expected, actual); + } } pub fn ResultSetIter(comptime ResultRowType: type) type { @@ -55,7 +166,7 @@ pub fn ResultSetIter(comptime ResultRowType: type) type { new_row_ptr.* = row; } - const num_cols = iter.text_result_set.column_definitions.len; + const num_cols = iter.text_result_set.col_defs.len; var rows = try allocator.alloc([]?[]const u8, row_acc.items.len); var elems = try allocator.alloc(?[]const u8, row_acc.items.len * num_cols); for (row_acc.items, 0..) |row, i| { diff --git a/src/protocol/packet_reader.zig b/src/protocol/packet_reader.zig index 60423d9..1d67c16 100644 --- a/src/protocol/packet_reader.zig +++ b/src/protocol/packet_reader.zig @@ -32,7 +32,7 @@ pub const PacketReader = struct { return bytes; } - fn readFixedRuntime(packet_reader: *PacketReader, n: usize) []const u8 { + pub fn readFixedRuntime(packet_reader: *PacketReader, n: usize) []const u8 { const bytes = packet_reader.payload[packet_reader.pos..][0..n]; packet_reader.pos += n; return bytes; diff --git a/src/result.zig b/src/result.zig index 0c696b4..f886d08 100644 --- a/src/result.zig +++ b/src/result.zig @@ -34,18 +34,18 @@ pub fn QueryResult(comptime ResultRowType: type) type { pub fn ResultSet(comptime ResultRowType: type) type { return struct { conn: *Conn, - column_packets: []Packet, - column_definitions: []ColumnDefinition41, + col_packets: []Packet, + col_defs: []ColumnDefinition41, pub fn init(allocator: std.mem.Allocator, conn: *Conn, column_count: u64) !ResultSet(ResultRowType) { - var t: ResultSet(ResultRowType) = .{ .conn = conn, .column_packets = &.{}, .column_definitions = &.{} }; + var t: ResultSet(ResultRowType) = .{ .conn = conn, .col_packets = &.{}, .col_defs = &.{} }; errdefer t.deinit(allocator); - t.column_packets = try allocator.alloc(Packet, column_count); - @memset(t.column_packets, Packet.safe_deinit()); - t.column_definitions = try allocator.alloc(ColumnDefinition41, column_count); + t.col_packets = try allocator.alloc(Packet, column_count); + @memset(t.col_packets, Packet.safe_deinit()); + t.col_defs = try allocator.alloc(ColumnDefinition41, column_count); - for (t.column_packets, t.column_definitions) |*pac, *def| { + for (t.col_packets, t.col_defs) |*pac, *def| { pac.* = try conn.readPacket(allocator); def.* = ColumnDefinition41.initFromPacket(pac); } @@ -57,11 +57,11 @@ pub fn ResultSet(comptime ResultRowType: type) type { } fn deinit(t: *const ResultSet(ResultRowType), allocator: std.mem.Allocator) void { - for (t.column_packets) |packet| { + for (t.col_packets) |packet| { packet.deinit(allocator); } - allocator.free(t.column_packets); - allocator.free(t.column_definitions); + allocator.free(t.col_packets); + allocator.free(t.col_defs); } pub fn readRow(t: *const ResultSet(ResultRowType), allocator: std.mem.Allocator) !ResultRowType { @@ -95,7 +95,7 @@ pub const TextResultRow = struct { }, pub fn scan(t: *const TextResultRow, dest: []?[]const u8) !void { - std.debug.assert(dest.len == t.result_set.column_definitions.len); + std.debug.assert(dest.len == t.result_set.col_defs.len); switch (t.value) { .err => |err| return err.asError(), .eof => |eof| return eof.asError(), @@ -119,11 +119,11 @@ pub const BinaryResultRow = struct { const Options = struct {}; - pub fn scanStruct(comptime T: type, t: *const BinaryResultRow, dest: ?*const T, options: Options) !void { - switch (t.value) { + pub fn scanStruct(b: *const BinaryResultRow, dest: anytype) !void { + switch (b.value) { .err => |err| return err.asError(), .eof => |eof| return eof.asError(), - .raw => try helper.scanTextBinaryRow(T, t.value.raw, dest, options), + .raw => |raw| helper.scanBinResRowtoStruct(dest, raw, b.result_set.col_defs), } } @@ -140,11 +140,6 @@ pub const PrepareResult = struct { }, pub fn deinit(p: *const PrepareResult, allocator: std.mem.Allocator) void { - // for (p.value.ok.packets, 0..) |pp, i| { - // std.debug.print("prepare result deinit: packet, {any} ptr: {any}\n", .{ i, @intFromPtr(pp.payload.ptr) }); - // pp.deinit(allocator); - // } - p.packet.deinit(allocator); switch (p.value) { .ok => |prep_stmt| prep_stmt.deinit(allocator),