From 0d4ea5ce689c650dfc5239101d3ed4340ed49a69 Mon Sep 17 00:00:00 2001 From: Zack Fu Zi Xiang Date: Mon, 28 Oct 2024 21:27:23 +0800 Subject: [PATCH] fix: runtime parameters --- integration_tests/conn.zig | 62 +++++++++-------- src/conversion.zig | 18 ++--- src/protocol/packet.zig | 2 +- src/protocol/packet_writer.zig | 2 +- src/protocol/prepared_statements.zig | 100 ++++++++++----------------- src/result.zig | 8 +-- 6 files changed, 86 insertions(+), 106 deletions(-) diff --git a/integration_tests/conn.zig b/integration_tests/conn.zig index 2725bea..e323b35 100644 --- a/integration_tests/conn.zig +++ b/integration_tests/conn.zig @@ -73,34 +73,34 @@ test "query text protocol" { } } } - // { // Iterating over rows, collecting elements into []const ?[]const u8 - // const query_res = try c.queryRows("SELECT 3, 4, null, 6, 7"); - // const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); - // const rows_iter: ResultRowIter(TextResultRow) = rows.iter(); - // while (try rows_iter.next()) |row| { - // const elems: TextElems = try row.textElems(allocator); - // defer elems.deinit(allocator); - - // try std.testing.expectEqualDeep( - // @as([]const ?[]const u8, &.{ "3", "4", null, "6", "7" }), - // elems.elems, - // ); - // } - // } - // { // Iterating over rows, collecting elements into []const []const ?[]const u8 - // const query_res = try c.queryRows("SELECT 8,9 UNION ALL SELECT 10,11"); - // const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); - // const table = try rows.tableTexts(allocator); - // defer table.deinit(allocator); - - // try std.testing.expectEqualDeep( - // @as([]const []const ?[]const u8, &.{ - // &.{ "8", "9" }, - // &.{ "10", "11" }, - // }), - // table.table, - // ); - // } + { // Iterating over rows, collecting elements into []const ?[]const u8 + const query_res = try c.queryRows("SELECT 3, 4, null, 6, 7"); + const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); + const rows_iter: ResultRowIter(TextResultRow) = rows.iter(); + while (try rows_iter.next()) |row| { + const elems: TextElems = try row.textElems(allocator); + defer elems.deinit(allocator); + + try std.testing.expectEqualDeep( + @as([]const ?[]const u8, &.{ "3", "4", null, "6", "7" }), + elems.elems, + ); + } + } + { // Iterating over rows, collecting elements into []const []const ?[]const u8 + const query_res = try c.queryRows("SELECT 8,9 UNION ALL SELECT 10,11"); + const rows: ResultSet(TextResultRow) = try query_res.expect(.rows); + const table = try rows.tableTexts(allocator); + defer table.deinit(allocator); + + try std.testing.expectEqualDeep( + @as([]const []const ?[]const u8, &.{ + &.{ "8", "9" }, + &.{ "10", "11" }, + }), + table.table, + ); + } } test "prepare check" { @@ -696,7 +696,7 @@ test "select concat with params" { const prep_res = try c.prepare(allocator, "SELECT CONCAT(?, ?) AS col1"); defer prep_res.deinit(allocator); const prep_stmt = try prep_res.expect(.stmt); - const res = try c.executeRows(&prep_stmt, .{ "hello", "world" }); + const res = try c.executeRows(&prep_stmt, .{ runtimeValue("hello"), runtimeValue("world") }); const rows: ResultSet(BinaryResultRow) = try res.expect(.rows); const rows_iter = rows.iter(); @@ -707,3 +707,7 @@ test "select concat with params" { try std.testing.expectEqualDeep(expected, structs.struct_list.items); } } + +fn runtimeValue(a: anytype) @TypeOf(a) { + return a; +} diff --git a/src/conversion.zig b/src/conversion.zig index df2b04f..7015be0 100644 --- a/src/conversion.zig +++ b/src/conversion.zig @@ -17,8 +17,8 @@ pub fn scanBinResultRow(dest: anytype, packet: *const Packet, col_defs: []const const null_bitmap_len = (col_defs.len + 7 + 2) / 8; const null_bitmap = reader.readRefRuntime(null_bitmap_len); - const child_type = @typeInfo(@TypeOf(dest)).Pointer.child; - const struct_fields = @typeInfo(child_type).Struct.fields; + const child_type = @typeInfo(@TypeOf(dest)).pointer.child; + const struct_fields = @typeInfo(child_type).@"struct".fields; if (struct_fields.len != col_defs.len) { std.log.err("received {d} columns from mysql, but given {d} fields for struct", .{ struct_fields.len, col_defs.len }); @@ -30,11 +30,11 @@ pub fn scanBinResultRow(dest: anytype, packet: *const Packet, col_defs: []const const isNull = binResIsNull(null_bitmap, i); switch (field_info) { - .Optional => { + .optional => { if (isNull) { @field(dest, field.name) = null; } else { - @field(dest, field.name) = try binElemToValue(field_info.Optional.child, field.name, &col_def, &reader, allocator); + @field(dest, field.name) = try binElemToValue(field_info.optional.child, field.name, &col_def, &reader, allocator); } }, else => { @@ -146,9 +146,9 @@ inline fn binElemToValue( } switch (field_info) { - .Pointer => |pointer| { + .pointer => |pointer| { switch (@typeInfo(pointer.child)) { - .Int => |int| { + .int => |int| { if (int.bits == 8) { switch (col_type) { .MYSQL_TYPE_STRING, @@ -181,7 +181,7 @@ inline fn binElemToValue( else => {}, } }, - .Enum => |e| { + .@"enum" => |e| { switch (col_type) { .MYSQL_TYPE_STRING, .MYSQL_TYPE_VARCHAR, @@ -211,7 +211,7 @@ inline fn binElemToValue( else => {}, } }, - .Int => |int| { + .int => |int| { switch (int.signedness) { .unsigned => { switch (col_type) { @@ -249,7 +249,7 @@ inline fn binElemToValue( }, } }, - .Float => |float| { + .float => |float| { if (float.bits >= 64) { switch (col_type) { .MYSQL_TYPE_DOUBLE => return @as(f64, @bitCast(reader.readInt(u64))), diff --git a/src/protocol/packet.zig b/src/protocol/packet.zig index dcb5f8e..de433bf 100644 --- a/src/protocol/packet.zig +++ b/src/protocol/packet.zig @@ -58,7 +58,7 @@ pub const PayloadReader = struct { } pub fn readInt(p: *PayloadReader, Int: type) Int { - const bytes = p.readRefComptime(@divExact(@typeInfo(Int).Int.bits, 8)); + const bytes = p.readRefComptime(@divExact(@typeInfo(Int).int.bits, 8)); return std.mem.readInt(Int, bytes, .little); } diff --git a/src/protocol/packet_writer.zig b/src/protocol/packet_writer.zig index b3f3f20..8bc6417 100644 --- a/src/protocol/packet_writer.zig +++ b/src/protocol/packet_writer.zig @@ -101,7 +101,7 @@ pub const PacketWriter = struct { } pub fn writeInt(p: *PacketWriter, comptime Int: type, int: Int) !void { - const bytes = try p.advanceComptime(@divExact(@typeInfo(Int).Int.bits, 8)); + const bytes = try p.advanceComptime(@divExact(@typeInfo(Int).int.bits, 8)); std.mem.writeInt(Int, bytes, int, .little); } diff --git a/src/protocol/prepared_statements.zig b/src/protocol/prepared_statements.zig index aafce11..aa95364 100644 --- a/src/protocol/prepared_statements.zig +++ b/src/protocol/prepared_statements.zig @@ -79,10 +79,16 @@ pub const ExecuteRequest = struct { try writer.writeInt(u8, e.flags); try writer.writeInt(u32, e.iteration_count); + const col_defs = e.prep_stmt.params; + if (params.len != col_defs.len) { + std.log.err("expected column count: {d}, but got {d}", .{ col_defs.len, params.len }); + return error.ParamsCountNotMatch; + } + // const has_attributes_to_write = (e.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) and e.attributes.len > 0; - const param_count = e.prep_stmt.prep_ok.num_params; - if (param_count > 0 + // const param_count = params.len; + if (params.len > 0 //or has_attributes_to_write ) { // if (has_attributes_to_write) { @@ -98,12 +104,6 @@ pub const ExecuteRequest = struct { try writeNullBitmap(params, writer); - const col_defs = e.prep_stmt.params; - if (params.len != col_defs.len) { - std.log.err("expected column count: {d}, but got {d}", .{ col_defs.len, params.len }); - return error.ParamsCountNotMatch; - } - // If a statement is re-executed without changing the params types, // the types do not need to be sent to the server again. // send type to server (0 / 1) @@ -111,20 +111,14 @@ pub const ExecuteRequest = struct { //if (e.new_params_bind_flag > 0) { comptime var enum_field_types: [params.len]constants.EnumFieldType = undefined; inline for (params, &enum_field_types) |param, *enum_field_type| { - enum_field_type.* = comptime enumFieldTypeFromParam(param); + enum_field_type.* = comptime enumFieldTypeFromParam(@TypeOf(param)); } - inline for (enum_field_types, params) |enum_field_type, param| { + inline for (params, enum_field_types) |param, enum_field_type| { try writer.writeInt(u8, @intFromEnum(enum_field_type)); - const sign_flag = comptime switch (@typeInfo(@TypeOf(param))) { - .ComptimeInt => switch (enum_field_type) { - .MYSQL_TYPE_TINY => if (param > maxInt(i8)) 0x80 else 0, - .MYSQL_TYPE_SHORT => if (param > maxInt(i16)) 0x80 else 0, - .MYSQL_TYPE_LONG => if (param > maxInt(i32)) 0x80 else 0, - .MYSQL_TYPE_LONGLONG => if (param > maxInt(i64)) 0x80 else 0, - else => 0, - }, - .Int => |int| if (int.signedness == .unsigned) 0x80 else 0, + const sign_flag = switch (@typeInfo(@TypeOf(param))) { + .comptime_int => if (param > maxInt(i64)) 0x80 else 0, + .int => |int| if (int.signedness == .unsigned) 0x80 else 0, else => 0, }; try writer.writeInt(u8, sign_flag); @@ -146,10 +140,11 @@ pub const ExecuteRequest = struct { // TODO: Write params and attr as binary values // Write params as binary values inline for (params, enum_field_types) |param, enum_field_type| { - // if (enum_field_type == constants.EnumFieldType.MYSQL_TYPE_NULL) { - // continue; - // } - try writeParamAsFieldType(writer, enum_field_type, param); + if (isNull(param)) { + try writeParamAsFieldType(writer, constants.EnumFieldType.MYSQL_TYPE_NULL, param); + } else { + try writeParamAsFieldType(writer, enum_field_type, param); + } } // if (has_attributes_to_write) { @@ -161,22 +156,15 @@ pub const ExecuteRequest = struct { } }; -fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType { - const Param = @TypeOf(param); +fn enumFieldTypeFromParam(Param: type) constants.EnumFieldType { const param_type_info = @typeInfo(Param); return switch (Param) { DateTime => constants.EnumFieldType.MYSQL_TYPE_DATETIME, Duration => constants.EnumFieldType.MYSQL_TYPE_TIME, else => switch (param_type_info) { - .Null => return constants.EnumFieldType.MYSQL_TYPE_NULL, - .Optional => { - if (param) |p| { - return enumFieldTypeFromParam(p); - } else { - return constants.EnumFieldType.MYSQL_TYPE_NULL; - } - }, - .Int => |int| { + .null => return constants.EnumFieldType.MYSQL_TYPE_NULL, + .optional => |o| return enumFieldTypeFromParam(o.child), + .int => |int| { if (int.bits <= 8) { return constants.EnumFieldType.MYSQL_TYPE_TINY; } else if (int.bits <= 16) { @@ -187,30 +175,18 @@ fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType { return constants.EnumFieldType.MYSQL_TYPE_LONGLONG; } }, - .ComptimeInt => { - if (std.math.minInt(i8) <= param and param <= std.math.maxInt(u8)) { - return constants.EnumFieldType.MYSQL_TYPE_TINY; - } else if (std.math.minInt(i16) <= param and param <= std.math.maxInt(u16)) { - return constants.EnumFieldType.MYSQL_TYPE_SHORT; - } else if (std.math.minInt(i32) <= param and param <= std.math.maxInt(u32)) { - return constants.EnumFieldType.MYSQL_TYPE_LONG; - } else if (std.math.minInt(i64) <= param and param <= std.math.maxInt(u64)) { - return constants.EnumFieldType.MYSQL_TYPE_LONGLONG; - } else { - @compileLog("hello"); - } - }, - .Float => |float| { + .comptime_int => return constants.EnumFieldType.MYSQL_TYPE_LONGLONG, + .float => |float| { if (float.bits <= 32) { return constants.EnumFieldType.MYSQL_TYPE_FLOAT; } else if (float.bits <= 64) { return constants.EnumFieldType.MYSQL_TYPE_DOUBLE; } }, - .ComptimeFloat => return constants.EnumFieldType.MYSQL_TYPE_DOUBLE, // Safer to assume double - .Array => |array| { + .comptime_float => return constants.EnumFieldType.MYSQL_TYPE_DOUBLE, // Safer to assume double + .array => |array| { switch (@typeInfo(array.child)) { - .Int => |int| { + .int => |int| { if (int.bits == 8) { return constants.EnumFieldType.MYSQL_TYPE_STRING; } @@ -218,14 +194,14 @@ fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType { else => {}, } }, - .Enum => return constants.EnumFieldType.MYSQL_TYPE_STRING, - .Pointer => |pointer| { + .@"enum" => return constants.EnumFieldType.MYSQL_TYPE_STRING, + .pointer => |pointer| { switch (pointer.size) { - .One => return enumFieldTypeFromParam(param.*), + .One => return enumFieldTypeFromParam(pointer.child), else => {}, } switch (@typeInfo(pointer.child)) { - .Int => |int| { + .int => |int| { if (int.bits == 8) { switch (pointer.size) { .Slice, .C, .Many => return constants.EnumFieldType.MYSQL_TYPE_STRING, @@ -237,7 +213,7 @@ fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType { } }, else => { - @compileLog(param); + @compileLog(Param); @compileError("unsupported type"); }, }, @@ -252,7 +228,7 @@ fn writeParamAsFieldType( param: anytype, ) !void { return switch (@typeInfo(@TypeOf(param))) { - .Optional => if (param) |p| { + .optional => if (param) |p| { return try writeParamAsFieldType(writer, enum_field_type, p); } else { return; @@ -279,13 +255,13 @@ fn writeParamAsFieldType( fn stringCast(param: anytype) []const u8 { switch (@typeInfo(@TypeOf(param))) { - .Pointer => |pointer| { + .pointer => |pointer| { switch (pointer.size) { .C, .Many => return std.mem.span(param), else => {}, } }, - .Enum => return @tagName(param), + .@"enum" => return @tagName(param), else => {}, } @@ -471,9 +447,9 @@ pub fn nullBitsParamsAttrs(params: anytype, start: usize, attrs: []const BinaryP } inline fn isNull(param: anytype) bool { - return switch (@typeInfo(@TypeOf(param))) { - inline .Optional => if (param) |p| isNull(p) else true, - inline .Null => true, + return comptime switch (@typeInfo(@TypeOf(param))) { + inline .optional => if (param) |p| isNull(p) else true, + inline .null => true, inline else => false, }; } diff --git a/src/result.zig b/src/result.zig index d47c7af..74b5e89 100644 --- a/src/result.zig +++ b/src/result.zig @@ -225,7 +225,7 @@ pub const BinaryResultRow = struct { } fn structFreeDynamic(s: anytype, allocator: std.mem.Allocator) void { - const s_ti = @typeInfo(@TypeOf(s)).Struct; + const s_ti = @typeInfo(@TypeOf(s)).@"struct"; inline for (s_ti.fields) |field| { structFreeStr(field.type, @field(s, field.name), allocator); } @@ -233,13 +233,13 @@ pub const BinaryResultRow = struct { fn structFreeStr(comptime StructField: type, value: StructField, allocator: std.mem.Allocator) void { switch (@typeInfo(StructField)) { - .Pointer => |p| switch (@typeInfo(p.child)) { - .Int => |int| if (int.bits == 8) { + .pointer => |p| switch (@typeInfo(p.child)) { + .int => |int| if (int.bits == 8) { allocator.free(value); }, else => {}, }, - .Optional => |o| if (value) |some| structFreeStr(o.child, some, allocator), + .optional => |o| if (value) |some| structFreeStr(o.child, some, allocator), else => {}, } }