Skip to content

Commit

Permalink
fix: runtime parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Oct 28, 2024
1 parent 575686d commit 0d4ea5c
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 106 deletions.
62 changes: 33 additions & 29 deletions integration_tests/conn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,34 @@ test "query text protocol" {
}
}
}
// { // Iterating over rows, collecting elements into []const ?[]const u8
// const query_res = try c.queryRows("SELECT 3, 4, null, 6, 7");
// const rows: ResultSet(TextResultRow) = try query_res.expect(.rows);
// const rows_iter: ResultRowIter(TextResultRow) = rows.iter();
// while (try rows_iter.next()) |row| {
// const elems: TextElems = try row.textElems(allocator);
// defer elems.deinit(allocator);

// try std.testing.expectEqualDeep(
// @as([]const ?[]const u8, &.{ "3", "4", null, "6", "7" }),
// elems.elems,
// );
// }
// }
// { // Iterating over rows, collecting elements into []const []const ?[]const u8
// const query_res = try c.queryRows("SELECT 8,9 UNION ALL SELECT 10,11");
// const rows: ResultSet(TextResultRow) = try query_res.expect(.rows);
// const table = try rows.tableTexts(allocator);
// defer table.deinit(allocator);

// try std.testing.expectEqualDeep(
// @as([]const []const ?[]const u8, &.{
// &.{ "8", "9" },
// &.{ "10", "11" },
// }),
// table.table,
// );
// }
{ // Iterating over rows, collecting elements into []const ?[]const u8
const query_res = try c.queryRows("SELECT 3, 4, null, 6, 7");
const rows: ResultSet(TextResultRow) = try query_res.expect(.rows);
const rows_iter: ResultRowIter(TextResultRow) = rows.iter();
while (try rows_iter.next()) |row| {
const elems: TextElems = try row.textElems(allocator);
defer elems.deinit(allocator);

try std.testing.expectEqualDeep(
@as([]const ?[]const u8, &.{ "3", "4", null, "6", "7" }),
elems.elems,
);
}
}
{ // Iterating over rows, collecting elements into []const []const ?[]const u8
const query_res = try c.queryRows("SELECT 8,9 UNION ALL SELECT 10,11");
const rows: ResultSet(TextResultRow) = try query_res.expect(.rows);
const table = try rows.tableTexts(allocator);
defer table.deinit(allocator);

try std.testing.expectEqualDeep(
@as([]const []const ?[]const u8, &.{
&.{ "8", "9" },
&.{ "10", "11" },
}),
table.table,
);
}
}

test "prepare check" {
Expand Down Expand Up @@ -696,7 +696,7 @@ test "select concat with params" {
const prep_res = try c.prepare(allocator, "SELECT CONCAT(?, ?) AS col1");
defer prep_res.deinit(allocator);
const prep_stmt = try prep_res.expect(.stmt);
const res = try c.executeRows(&prep_stmt, .{ "hello", "world" });
const res = try c.executeRows(&prep_stmt, .{ runtimeValue("hello"), runtimeValue("world") });
const rows: ResultSet(BinaryResultRow) = try res.expect(.rows);
const rows_iter = rows.iter();

Expand All @@ -707,3 +707,7 @@ test "select concat with params" {
try std.testing.expectEqualDeep(expected, structs.struct_list.items);
}
}

fn runtimeValue(a: anytype) @TypeOf(a) {
return a;
}
18 changes: 9 additions & 9 deletions src/conversion.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub fn scanBinResultRow(dest: anytype, packet: *const Packet, col_defs: []const
const null_bitmap_len = (col_defs.len + 7 + 2) / 8;
const null_bitmap = reader.readRefRuntime(null_bitmap_len);

const child_type = @typeInfo(@TypeOf(dest)).Pointer.child;
const struct_fields = @typeInfo(child_type).Struct.fields;
const child_type = @typeInfo(@TypeOf(dest)).pointer.child;
const struct_fields = @typeInfo(child_type).@"struct".fields;

if (struct_fields.len != col_defs.len) {
std.log.err("received {d} columns from mysql, but given {d} fields for struct", .{ struct_fields.len, col_defs.len });
Expand All @@ -30,11 +30,11 @@ pub fn scanBinResultRow(dest: anytype, packet: *const Packet, col_defs: []const
const isNull = binResIsNull(null_bitmap, i);

switch (field_info) {
.Optional => {
.optional => {
if (isNull) {
@field(dest, field.name) = null;
} else {
@field(dest, field.name) = try binElemToValue(field_info.Optional.child, field.name, &col_def, &reader, allocator);
@field(dest, field.name) = try binElemToValue(field_info.optional.child, field.name, &col_def, &reader, allocator);
}
},
else => {
Expand Down Expand Up @@ -146,9 +146,9 @@ inline fn binElemToValue(
}

switch (field_info) {
.Pointer => |pointer| {
.pointer => |pointer| {
switch (@typeInfo(pointer.child)) {
.Int => |int| {
.int => |int| {
if (int.bits == 8) {
switch (col_type) {
.MYSQL_TYPE_STRING,
Expand Down Expand Up @@ -181,7 +181,7 @@ inline fn binElemToValue(
else => {},
}
},
.Enum => |e| {
.@"enum" => |e| {
switch (col_type) {
.MYSQL_TYPE_STRING,
.MYSQL_TYPE_VARCHAR,
Expand Down Expand Up @@ -211,7 +211,7 @@ inline fn binElemToValue(
else => {},
}
},
.Int => |int| {
.int => |int| {
switch (int.signedness) {
.unsigned => {
switch (col_type) {
Expand Down Expand Up @@ -249,7 +249,7 @@ inline fn binElemToValue(
},
}
},
.Float => |float| {
.float => |float| {
if (float.bits >= 64) {
switch (col_type) {
.MYSQL_TYPE_DOUBLE => return @as(f64, @bitCast(reader.readInt(u64))),
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/packet.zig
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub const PayloadReader = struct {
}

pub fn readInt(p: *PayloadReader, Int: type) Int {
const bytes = p.readRefComptime(@divExact(@typeInfo(Int).Int.bits, 8));
const bytes = p.readRefComptime(@divExact(@typeInfo(Int).int.bits, 8));
return std.mem.readInt(Int, bytes, .little);
}

Expand Down
2 changes: 1 addition & 1 deletion src/protocol/packet_writer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub const PacketWriter = struct {
}

pub fn writeInt(p: *PacketWriter, comptime Int: type, int: Int) !void {
const bytes = try p.advanceComptime(@divExact(@typeInfo(Int).Int.bits, 8));
const bytes = try p.advanceComptime(@divExact(@typeInfo(Int).int.bits, 8));
std.mem.writeInt(Int, bytes, int, .little);
}

Expand Down
100 changes: 38 additions & 62 deletions src/protocol/prepared_statements.zig
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,16 @@ pub const ExecuteRequest = struct {
try writer.writeInt(u8, e.flags);
try writer.writeInt(u32, e.iteration_count);

const col_defs = e.prep_stmt.params;
if (params.len != col_defs.len) {
std.log.err("expected column count: {d}, but got {d}", .{ col_defs.len, params.len });
return error.ParamsCountNotMatch;
}

// const has_attributes_to_write = (e.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) and e.attributes.len > 0;

const param_count = e.prep_stmt.prep_ok.num_params;
if (param_count > 0
// const param_count = params.len;
if (params.len > 0
//or has_attributes_to_write
) {
// if (has_attributes_to_write) {
Expand All @@ -98,33 +104,21 @@ pub const ExecuteRequest = struct {

try writeNullBitmap(params, writer);

const col_defs = e.prep_stmt.params;
if (params.len != col_defs.len) {
std.log.err("expected column count: {d}, but got {d}", .{ col_defs.len, params.len });
return error.ParamsCountNotMatch;
}

// If a statement is re-executed without changing the params types,
// the types do not need to be sent to the server again.
// send type to server (0 / 1)
try writer.writeLengthEncodedInteger(e.new_params_bind_flag);
//if (e.new_params_bind_flag > 0) {
comptime var enum_field_types: [params.len]constants.EnumFieldType = undefined;
inline for (params, &enum_field_types) |param, *enum_field_type| {
enum_field_type.* = comptime enumFieldTypeFromParam(param);
enum_field_type.* = comptime enumFieldTypeFromParam(@TypeOf(param));
}

inline for (enum_field_types, params) |enum_field_type, param| {
inline for (params, enum_field_types) |param, enum_field_type| {
try writer.writeInt(u8, @intFromEnum(enum_field_type));
const sign_flag = comptime switch (@typeInfo(@TypeOf(param))) {
.ComptimeInt => switch (enum_field_type) {
.MYSQL_TYPE_TINY => if (param > maxInt(i8)) 0x80 else 0,
.MYSQL_TYPE_SHORT => if (param > maxInt(i16)) 0x80 else 0,
.MYSQL_TYPE_LONG => if (param > maxInt(i32)) 0x80 else 0,
.MYSQL_TYPE_LONGLONG => if (param > maxInt(i64)) 0x80 else 0,
else => 0,
},
.Int => |int| if (int.signedness == .unsigned) 0x80 else 0,
const sign_flag = switch (@typeInfo(@TypeOf(param))) {
.comptime_int => if (param > maxInt(i64)) 0x80 else 0,
.int => |int| if (int.signedness == .unsigned) 0x80 else 0,
else => 0,
};
try writer.writeInt(u8, sign_flag);
Expand All @@ -146,10 +140,11 @@ pub const ExecuteRequest = struct {
// TODO: Write params and attr as binary values
// Write params as binary values
inline for (params, enum_field_types) |param, enum_field_type| {
// if (enum_field_type == constants.EnumFieldType.MYSQL_TYPE_NULL) {
// continue;
// }
try writeParamAsFieldType(writer, enum_field_type, param);
if (isNull(param)) {
try writeParamAsFieldType(writer, constants.EnumFieldType.MYSQL_TYPE_NULL, param);
} else {
try writeParamAsFieldType(writer, enum_field_type, param);
}
}

// if (has_attributes_to_write) {
Expand All @@ -161,22 +156,15 @@ pub const ExecuteRequest = struct {
}
};

fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType {
const Param = @TypeOf(param);
fn enumFieldTypeFromParam(Param: type) constants.EnumFieldType {
const param_type_info = @typeInfo(Param);
return switch (Param) {
DateTime => constants.EnumFieldType.MYSQL_TYPE_DATETIME,
Duration => constants.EnumFieldType.MYSQL_TYPE_TIME,
else => switch (param_type_info) {
.Null => return constants.EnumFieldType.MYSQL_TYPE_NULL,
.Optional => {
if (param) |p| {
return enumFieldTypeFromParam(p);
} else {
return constants.EnumFieldType.MYSQL_TYPE_NULL;
}
},
.Int => |int| {
.null => return constants.EnumFieldType.MYSQL_TYPE_NULL,
.optional => |o| return enumFieldTypeFromParam(o.child),
.int => |int| {
if (int.bits <= 8) {
return constants.EnumFieldType.MYSQL_TYPE_TINY;
} else if (int.bits <= 16) {
Expand All @@ -187,45 +175,33 @@ fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType {
return constants.EnumFieldType.MYSQL_TYPE_LONGLONG;
}
},
.ComptimeInt => {
if (std.math.minInt(i8) <= param and param <= std.math.maxInt(u8)) {
return constants.EnumFieldType.MYSQL_TYPE_TINY;
} else if (std.math.minInt(i16) <= param and param <= std.math.maxInt(u16)) {
return constants.EnumFieldType.MYSQL_TYPE_SHORT;
} else if (std.math.minInt(i32) <= param and param <= std.math.maxInt(u32)) {
return constants.EnumFieldType.MYSQL_TYPE_LONG;
} else if (std.math.minInt(i64) <= param and param <= std.math.maxInt(u64)) {
return constants.EnumFieldType.MYSQL_TYPE_LONGLONG;
} else {
@compileLog("hello");
}
},
.Float => |float| {
.comptime_int => return constants.EnumFieldType.MYSQL_TYPE_LONGLONG,
.float => |float| {
if (float.bits <= 32) {
return constants.EnumFieldType.MYSQL_TYPE_FLOAT;
} else if (float.bits <= 64) {
return constants.EnumFieldType.MYSQL_TYPE_DOUBLE;
}
},
.ComptimeFloat => return constants.EnumFieldType.MYSQL_TYPE_DOUBLE, // Safer to assume double
.Array => |array| {
.comptime_float => return constants.EnumFieldType.MYSQL_TYPE_DOUBLE, // Safer to assume double
.array => |array| {
switch (@typeInfo(array.child)) {
.Int => |int| {
.int => |int| {
if (int.bits == 8) {
return constants.EnumFieldType.MYSQL_TYPE_STRING;
}
},
else => {},
}
},
.Enum => return constants.EnumFieldType.MYSQL_TYPE_STRING,
.Pointer => |pointer| {
.@"enum" => return constants.EnumFieldType.MYSQL_TYPE_STRING,
.pointer => |pointer| {
switch (pointer.size) {
.One => return enumFieldTypeFromParam(param.*),
.One => return enumFieldTypeFromParam(pointer.child),
else => {},
}
switch (@typeInfo(pointer.child)) {
.Int => |int| {
.int => |int| {
if (int.bits == 8) {
switch (pointer.size) {
.Slice, .C, .Many => return constants.EnumFieldType.MYSQL_TYPE_STRING,
Expand All @@ -237,7 +213,7 @@ fn enumFieldTypeFromParam(param: anytype) constants.EnumFieldType {
}
},
else => {
@compileLog(param);
@compileLog(Param);
@compileError("unsupported type");
},
},
Expand All @@ -252,7 +228,7 @@ fn writeParamAsFieldType(
param: anytype,
) !void {
return switch (@typeInfo(@TypeOf(param))) {
.Optional => if (param) |p| {
.optional => if (param) |p| {
return try writeParamAsFieldType(writer, enum_field_type, p);
} else {
return;
Expand All @@ -279,13 +255,13 @@ fn writeParamAsFieldType(

fn stringCast(param: anytype) []const u8 {
switch (@typeInfo(@TypeOf(param))) {
.Pointer => |pointer| {
.pointer => |pointer| {
switch (pointer.size) {
.C, .Many => return std.mem.span(param),
else => {},
}
},
.Enum => return @tagName(param),
.@"enum" => return @tagName(param),
else => {},
}

Expand Down Expand Up @@ -471,9 +447,9 @@ pub fn nullBitsParamsAttrs(params: anytype, start: usize, attrs: []const BinaryP
}

inline fn isNull(param: anytype) bool {
return switch (@typeInfo(@TypeOf(param))) {
inline .Optional => if (param) |p| isNull(p) else true,
inline .Null => true,
return comptime switch (@typeInfo(@TypeOf(param))) {
inline .optional => if (param) |p| isNull(p) else true,
inline .null => true,
inline else => false,
};
}
Expand Down
8 changes: 4 additions & 4 deletions src/result.zig
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,21 @@ pub const BinaryResultRow = struct {
}

fn structFreeDynamic(s: anytype, allocator: std.mem.Allocator) void {
const s_ti = @typeInfo(@TypeOf(s)).Struct;
const s_ti = @typeInfo(@TypeOf(s)).@"struct";
inline for (s_ti.fields) |field| {
structFreeStr(field.type, @field(s, field.name), allocator);
}
}

fn structFreeStr(comptime StructField: type, value: StructField, allocator: std.mem.Allocator) void {
switch (@typeInfo(StructField)) {
.Pointer => |p| switch (@typeInfo(p.child)) {
.Int => |int| if (int.bits == 8) {
.pointer => |p| switch (@typeInfo(p.child)) {
.int => |int| if (int.bits == 8) {
allocator.free(value);
},
else => {},
},
.Optional => |o| if (value) |some| structFreeStr(o.child, some, allocator),
.optional => |o| if (value) |some| structFreeStr(o.child, some, allocator),
else => {},
}
}
Expand Down

0 comments on commit 0d4ea5c

Please sign in to comment.