Skip to content

Commit

Permalink
feat: execute input params
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Nov 25, 2023
1 parent c760886 commit af11144
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 34 deletions.
58 changes: 42 additions & 16 deletions integration_tests/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -252,31 +252,57 @@ test "binary data types" {
try queryExpectOk(&c, "CREATE DATABASE test");
defer queryExpectOk(&c, "DROP DATABASE test") catch {};

try queryExpectOk(
&c,
try queryExpectOk(&c,
\\
\\CREATE TABLE test.int_types_example (
\\ tinyint_col TINYINT,
\\ smallint_col SMALLINT,
\\ mediumint_col MEDIUMINT,
\\ int_col INT,
\\ bigint_col BIGINT,
\\ tinyint_unsigned_col TINYINT UNSIGNED,
\\ smallint_unsigned_col SMALLINT UNSIGNED,
\\ mediumint_unsigned_col MEDIUMINT UNSIGNED,
\\ int_unsigned_col INT UNSIGNED,
\\ bigint_unsigned_col BIGINT UNSIGNED
\\ int_col INT
// \\ bigint_col BIGINT,
// \\ tinyint_unsigned_col TINYINT UNSIGNED,
// \\ smallint_unsigned_col SMALLINT UNSIGNED,
// \\ mediumint_unsigned_col MEDIUMINT UNSIGNED,
// \\ int_unsigned_col INT UNSIGNED,
// \\ bigint_unsigned_col BIGINT UNSIGNED
\\)
,
);
defer queryExpectOk(&c, "DROP TABLE test.int_types_example") catch {};

// const prep_res = try c.prepare(allocator, "INSERT INTO int_types_example VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
// defer prep_res.deinit(allocator);
// const prep_stmt = try prep_res.expect(.ok);
// const exe_res = try c.execute(allocator, prep_stmt, .{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
// defer exe_res.deinit(allocator);
// _ = try exe_res.expect(.ok);
const prep_res = try c.prepare(
allocator,
"INSERT INTO test.int_types_example VALUES (?, ?, ?, ?)",
);
defer prep_res.deinit(allocator);
const prep_stmt = try prep_res.expect(.ok);

const params = .{
.{ -128, -32768, -8388608, -2147483648 },
.{ 0, 0, 0, 0 },
.{ 127, 32767, 8388607, 2147483647 },
};
inline for (params) |param| {
const exe_res = try c.execute(allocator, &prep_stmt, param);
defer exe_res.deinit(allocator);
_ = try exe_res.expect(.ok);
}

{
const res = try c.query(allocator, "SELECT * FROM test.int_types_example");
defer res.deinit(allocator);
const rows_iter = (try res.expect(.rows)).iter();

const table = try rows_iter.collect(allocator);
defer table.deinit(allocator);

const expected: []const []const ?[]const u8 = &.{
&.{ "-128", "-32768", "-8388608", "-2147483648" },
&.{ "0", "0", "0", "0" },
&.{ "127", "32767", "8388607", "2147483647" },
};
// std.debug.print("\n{?s}\n", .{table.rows[2][2]});
try std.testing.expectEqualDeep(expected, table.rows);
}
}

//
Expand Down
2 changes: 2 additions & 0 deletions src/conn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,15 @@ pub const Conn = struct {
fn sendPacketUsingSmallPacketWriter(conn: *Conn, packet: anytype) !void {
std.debug.assert(conn.state == .connected);
var small_packet_writer = stream_buffered.SmallPacketWriter.init(&conn.writer, conn.generateSequenceId());
errdefer conn.writer.reset();
try packet.write(&small_packet_writer);
try small_packet_writer.flush();
}

fn sendPacketUsingSmallPacketWriterWithParams(conn: *Conn, packet: anytype, params: anytype) !void {
std.debug.assert(conn.state == .connected);
var small_packet_writer = stream_buffered.SmallPacketWriter.init(&conn.writer, conn.generateSequenceId());
errdefer conn.writer.reset();
try packet.writeWithParams(&small_packet_writer, params);
try small_packet_writer.flush();
}
Expand Down
110 changes: 109 additions & 1 deletion src/helper.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,116 @@ const TextResultRow = result.TextResultRow;
const Options = result.BinaryResultRow.Options;
const protocol = @import("./protocol.zig");
const PacketReader = protocol.packet_reader.PacketReader;
const packet_writer = protocol.packet_writer;
const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41;

fn comptimeIntToUInt(
comptime Unsigned: type,
comptime Signed: type,
comptime int: comptime_int,
) Unsigned {
return blk: {
if (comptime (int < 0)) {
break :blk @bitCast(@as(Signed, int));
} else {
break :blk int;
}
};
}

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row_value
// https://mariadb.com/kb/en/com_stmt_execute/#binary-parameter-encoding
pub fn encodeBinaryParam(param: anytype, col_def: *const ColumnDefinition41, writer: anytype) !void {
const param_type_info = @typeInfo(@TypeOf(param));
const col_type: EnumFieldType = @enumFromInt(col_def.column_type);

switch (param_type_info) {
.Int => |int| {
switch (col_type) {
.MYSQL_TYPE_LONGLONG => {
if (int.bits == 64) {
return try packet_writer.writeUInt64(writer, @bitCast(param));
}
},
.MYSQL_TYPE_LONG,
.MYSQL_TYPE_INT24,
=> {
if (int.bits == 32) {
return try packet_writer.writeUInt32(writer, @bitCast(param));
}
},
.MYSQL_TYPE_SHORT,
.MYSQL_TYPE_YEAR,
=> {
if (int.bits == 16) {
return try packet_writer.writeUInt16(writer, @bitCast(param));
}
},
.MYSQL_TYPE_TINY => {
if (int.bits == 8) {
return try packet_writer.writeUInt8(writer, @bitCast(param));
}
},
else => {},
}
},
.ComptimeInt => {
switch (col_type) {
.MYSQL_TYPE_LONGLONG => {
const value: u64 = comptimeIntToUInt(u64, i64, param);
return try packet_writer.writeUInt64(writer, value);
},
.MYSQL_TYPE_LONG,
.MYSQL_TYPE_INT24,
=> {
const value: u32 = comptimeIntToUInt(u32, i32, param);
return try packet_writer.writeUInt32(writer, value);
},
// .MYSQL_TYPE_SHORT,
// .MYSQL_TYPE_YEAR,
// => {
// const value: u16 = comptimeIntToUInt(u16, i16, param);
// return try packet_writer.writeUInt16(writer, value);
// },
else => {},
}
},

.Pointer => |pointer| {
switch (@typeInfo(pointer.child)) {
.Int => |int| {
if (int.bits == 8) {
switch (col_type) {
.MYSQL_TYPE_STRING,
.MYSQL_TYPE_VARCHAR,
.MYSQL_TYPE_VAR_STRING,
.MYSQL_TYPE_ENUM,
.MYSQL_TYPE_SET,
.MYSQL_TYPE_LONG_BLOB,
.MYSQL_TYPE_MEDIUM_BLOB,
.MYSQL_TYPE_BLOB,
.MYSQL_TYPE_TINY_BLOB,
.MYSQL_TYPE_GEOMETRY,
.MYSQL_TYPE_BIT,
.MYSQL_TYPE_DECIMAL,
.MYSQL_TYPE_NEWDECIMAL,
=> return try packet_writer.writeLengthEncodedString(writer, param),
else => {},
}
}
},
else => {},
}
},
else => {},
}

// comptime FieldType: type, field_name: []const u8, col_name: []const u8, col_type: EnumFieldType)
// TODO: insert field name if struct is passed in
logConversionError(@TypeOf(param), "", col_def.name, col_type);
return error.InvalidConversion;
}

pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void {
var packet_reader = PacketReader.initFromPayload(raw);
for (dest) |*d| {
Expand Down Expand Up @@ -68,7 +176,7 @@ pub fn scanBinResRowtoStruct(dest: anytype, raw: []const u8, col_defs: []ColumnD

inline fn logConversionError(comptime FieldType: type, field_name: []const u8, col_name: []const u8, col_type: EnumFieldType) void {
std.log.err(
"cannot convert from column(name: {s}, type: {any}) to field(name: {s}, type: {any})\n",
"MySQL Column: (name: {s}, type: {any}), Zig Value: (name: {s}, type: {any})\n",
.{ col_name, col_type, field_name, FieldType },
);
}
Expand Down
37 changes: 20 additions & 17 deletions src/protocol/prepared_statements.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ const constants = @import("../constants.zig");
const Packet = @import("./packet.zig").Packet;
const PacketReader = @import("./packet_reader.zig").PacketReader;
const PreparedStatement = @import("./../result.zig").PreparedStatement;
const ColumnDefinition41 = @import("./column_definition.zig").ColumnDefinition41;
const helper = @import("./../helper.zig");

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
pub const PrepareRequest = struct {
Expand Down Expand Up @@ -95,10 +97,11 @@ pub const ExecuteRequest = struct {
// send type to server (0 / 1)
try packet_writer.writeLengthEncodedInteger(writer, e.new_params_bind_flag);
if (e.new_params_bind_flag > 0) {
for (e.prep_stmt.params) |p| {
try packet_writer.writeUInt16(writer, p.flags);
for (e.prep_stmt.params) |col_def| {
try packet_writer.writeUInt8(writer, col_def.column_type);
try packet_writer.writeUInt8(writer, 0);
if (e.capabilities & constants.CLIENT_QUERY_ATTRIBUTES > 0) {
try packet_writer.writeLengthEncodedString(writer, p.name);
try packet_writer.writeLengthEncodedString(writer, col_def.name);
}
}
if (has_attributes_to_write) {
Expand All @@ -110,24 +113,24 @@ pub const ExecuteRequest = struct {
}

// TODO: Write params and attr as binary values
// // Write params as binary values
// for (params) |b| {
// try writeBinaryParam(b, writer);
// }
// if (has_attributes_to_write) {
// for (e.attributes) |b| {
// try writeBinaryParam(b, writer);
// }
// }
// Write params as binary values
inline for (params, e.prep_stmt.params) |param, *col_def| {
try helper.encodeBinaryParam(param, col_def, writer);
}
if (has_attributes_to_write) {
for (e.attributes) |b| {
try writeAttr(b, writer);
}
}
}
}
};

// fn writeBinaryParam(param: BinaryParam, writer: anytype) !void {
// _ = writer;
// _ = param;
// @panic("TODO");
// }
fn writeAttr(param: BinaryParam, writer: anytype) !void {
_ = writer;
_ = param;
@panic("TODO: support mysql attributes");
}

fn writeNullBitmap(params: anytype, attributes: []const BinaryParam, writer: anytype) !void {
const byte_count = (params.len + attributes.len + 7) / 8;
Expand Down
5 changes: 5 additions & 0 deletions src/stream_buffered.zig
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ pub const Writer = struct {
len: usize = 0,
stream: std.net.Stream,

// invalidates all previous writes
pub fn reset(w: *Writer) void {
w.len = 0;
}

// write all behavior
pub fn write(w: *Writer, buffer: []const u8) !void {
var already_written: usize = 0;
Expand Down

0 comments on commit af11144

Please sign in to comment.