diff --git a/integration_tests/client.zig b/integration_tests/client.zig index 2104962..5fa3e56 100644 --- a/integration_tests/client.zig +++ b/integration_tests/client.zig @@ -374,5 +374,62 @@ test "binary data types - float" { } } +test "binary data types - string" { + var c = Client.init(test_config); + defer c.deinit(); + + try queryExpectOk(&c, "CREATE DATABASE test"); + defer queryExpectOk(&c, "DROP DATABASE test") catch {}; + + try queryExpectOk(&c, + \\ + \\CREATE TABLE test.string_types_example ( + \\ varchar_col VARCHAR(255), + \\ not_null_varchar_col VARCHAR(255) NOT NULL, + \\ enum_col ENUM('a', 'b', 'c'), + \\ not_null_enum_col ENUM('a', 'b', 'c') NOT NULL + \\) + ); + defer queryExpectOk(&c, "DROP TABLE test.string_types_example") catch {}; + + const prep_res = try c.prepare(allocator, "INSERT INTO test.string_types_example VALUES (?, ?, ?, ?)"); + defer prep_res.deinit(allocator); + const prep_stmt = try prep_res.expect(.ok); + + const params = .{ + .{ "hello", "world", "a", "b" }, + .{ null, "foo", null, "c" }, + .{ null, "", null, "a" }, + .{ + @as(?*const [3]u8, "baz"), + @as([*:0]const u8, "bar"), + @as(?[]const u8, null), + @as([:0]const u8, "c"), + }, + }; + inline for (params) |param| { + const exe_res = try c.execute(allocator, &prep_stmt, param); + defer exe_res.deinit(allocator); + _ = try exe_res.expect(.ok); + } + + { + const res = try c.query(allocator, "SELECT * FROM test.string_types_example"); + defer res.deinit(allocator); + const rows_iter = (try res.expect(.rows)).iter(); + + const table = try rows_iter.collect(allocator); + defer table.deinit(allocator); + + const expected: []const []const ?[]const u8 = &.{ + &.{ "hello", "world", "a", "b" }, + &.{ null, "foo", null, "c" }, + &.{ null, "", null, "a" }, + &.{ "baz", "bar", null, "c" }, + }; + try std.testing.expectEqualDeep(expected, table.rows); + } +} + // //// SELECT CONCAT(?, ?) AS col1 diff --git a/src/helper.zig b/src/helper.zig index 2dda3d7..b3050d5 100644 --- a/src/helper.zig +++ b/src/helper.zig @@ -32,40 +32,12 @@ pub fn encodeBinaryParam(param: anytype, col_def: *const ColumnDefinition41, wri const col_type: EnumFieldType = @enumFromInt(col_def.column_type); switch (param_type_info) { - .Null => { - switch (col_type) { - .MYSQL_TYPE_LONGLONG => return try writer.writer.advance(8), - .MYSQL_TYPE_LONG, - .MYSQL_TYPE_INT24, - => return try writer.writer.advance(4), - .MYSQL_TYPE_SHORT, - .MYSQL_TYPE_YEAR, - => return try writer.writer.advance(2), - .MYSQL_TYPE_TINY => return try writer.writer.advance(1), - .MYSQL_TYPE_STRING, - .MYSQL_TYPE_VARCHAR, - .MYSQL_TYPE_VAR_STRING, - .MYSQL_TYPE_ENUM, - .MYSQL_TYPE_SET, - .MYSQL_TYPE_LONG_BLOB, - .MYSQL_TYPE_MEDIUM_BLOB, - .MYSQL_TYPE_BLOB, - .MYSQL_TYPE_TINY_BLOB, - .MYSQL_TYPE_GEOMETRY, - .MYSQL_TYPE_BIT, - .MYSQL_TYPE_DECIMAL, - .MYSQL_TYPE_NEWDECIMAL, - => return try packet_writer.writeLengthEncodedString(writer, ""), - .MYSQL_TYPE_FLOAT => return try writer.writer.advance(4), - .MYSQL_TYPE_DOUBLE => return try writer.writer.advance(8), - else => {}, - } - }, + .Null => return, .Optional => { if (param) |p| { return encodeBinaryParam(p, col_def, writer); } else { - return encodeBinaryParam(null, col_def, writer); + return; } }, .Int => |int| { @@ -162,8 +134,37 @@ pub fn encodeBinaryParam(param: anytype, col_def: *const ColumnDefinition41, wri else => {}, } }, - + .Array => |array| { + switch (@typeInfo(array.child)) { + .Int => |int| { + if (int.bits == 8) { + switch (col_type) { + .MYSQL_TYPE_STRING, + .MYSQL_TYPE_VARCHAR, + .MYSQL_TYPE_VAR_STRING, + .MYSQL_TYPE_ENUM, + .MYSQL_TYPE_SET, + .MYSQL_TYPE_LONG_BLOB, + .MYSQL_TYPE_MEDIUM_BLOB, + .MYSQL_TYPE_BLOB, + .MYSQL_TYPE_TINY_BLOB, + .MYSQL_TYPE_GEOMETRY, + .MYSQL_TYPE_BIT, + .MYSQL_TYPE_DECIMAL, + .MYSQL_TYPE_NEWDECIMAL, + => return try packet_writer.writeLengthEncodedString(writer, ¶m), + else => {}, + } + } + }, + else => {}, + } + }, .Pointer => |pointer| { + switch (pointer.size) { + .One => return encodeBinaryParam(param.*, col_def, writer), + else => {}, + } switch (@typeInfo(pointer.child)) { .Int => |int| { if (int.bits == 8) { @@ -181,7 +182,11 @@ pub fn encodeBinaryParam(param: anytype, col_def: *const ColumnDefinition41, wri .MYSQL_TYPE_BIT, .MYSQL_TYPE_DECIMAL, .MYSQL_TYPE_NEWDECIMAL, - => return try packet_writer.writeLengthEncodedString(writer, param), + => switch (pointer.size) { + .C, .Many => return try packet_writer.writeLengthEncodedString(writer, std.mem.span(param)), + .Slice => return try packet_writer.writeLengthEncodedString(writer, param), + else => {}, + }, else => {}, } }