Skip to content

Commit

Permalink
feat: binary protocol null
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Nov 4, 2023
1 parent e84f912 commit e8cc323
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 48 deletions.
49 changes: 28 additions & 21 deletions integration_tests/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -242,26 +242,33 @@ test "prepare execute - 2" {
}
}

// test "prepare execute with result" {
// var c = Client.init(test_config);
// defer c.deinit();
//
// {
// const query =
// \\SELECT 1,2,3,4,5
// ;
// const prep_res = try c.prepare(allocator, query);
// defer prep_res.deinit(allocator);
// const prep_ok = try expectOk(prep_res.value);
// _ = prep_ok;
// // const query_res = try c.execute(allocator, &prep_ok);
// // defer query_res.deinit(allocator);
// // const rows = (try expectRows(query_res.value)).iter();
// // while (try rows.next(allocator)) |row| {
// // std.debug.print("row: {any}\n", .{row.value.raw});
// // defer row.deinit(allocator);
// // }
// }
// }
test "prepare execute with result" {
var c = Client.init(test_config);
defer c.deinit();

{
const query =
\\SELECT null
;
const prep_res = try c.prepare(allocator, query);
defer prep_res.deinit(allocator);
const prep_stmt = try expectOk(prep_res.value);
const query_res = try c.execute(allocator, &prep_stmt);
defer query_res.deinit(allocator);
const rows = (try expectRows(query_res.value)).iter();

const MyType = struct {
a: ?u8,
};
const expected = MyType{ .a = null };

var dest: MyType = undefined;
while (try rows.next(allocator)) |row| {
defer row.deinit(allocator);
try row.scanStruct(&dest);
try std.testing.expectEqual(expected, dest);
}
}
}

// SELECT CONCAT(?, ?) AS col1
42 changes: 42 additions & 0 deletions src/constants.zig
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,48 @@ pub const COM_QUERY :u8 = 0x03;
pub const COM_STMT_PREPARE: u8 = 0x16;
pub const COM_STMT_EXECUTE: u8 = 0x17;

pub const BINARY_PROTOCOL_RESULTSET_ROW_HEADER: u8 = 0x00;

// https://dev.mysql.com/doc/dev/mysql-server/latest/field__types_8h_source.html
pub const EnumFieldType = enum(u8) {
MYSQL_TYPE_DECIMAL,
MYSQL_TYPE_TINY,
MYSQL_TYPE_SHORT,
MYSQL_TYPE_LONG,
MYSQL_TYPE_FLOAT,
MYSQL_TYPE_DOUBLE,
MYSQL_TYPE_NULL,
MYSQL_TYPE_TIMESTAMP,
MYSQL_TYPE_LONGLONG,
MYSQL_TYPE_INT24,
MYSQL_TYPE_DATE,
MYSQL_TYPE_TIME,
MYSQL_TYPE_DATETIME,
MYSQL_TYPE_YEAR,
MYSQL_TYPE_NEWDATE,
MYSQL_TYPE_VARCHAR,
MYSQL_TYPE_BIT,
MYSQL_TYPE_TIMESTAMP2,
MYSQL_TYPE_DATETIME2,
MYSQL_TYPE_TIME2,
MYSQL_TYPE_TYPED_ARRAY,

MYSQL_TYPE_INVALID = 243,
MYSQL_TYPE_BOOL = 244,
MYSQL_TYPE_JSON = 245,
MYSQL_TYPE_NEWDECIMAL = 246,
MYSQL_TYPE_ENUM = 247,
MYSQL_TYPE_SET = 248,
MYSQL_TYPE_TINY_BLOB = 249,
MYSQL_TYPE_MEDIUM_BLOB = 250,
MYSQL_TYPE_LONG_BLOB = 251,
MYSQL_TYPE_BLOB = 252,
MYSQL_TYPE_VAR_STRING = 253,
MYSQL_TYPE_STRING = 254,
MYSQL_TYPE_GEOMETRY = 255
};


// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_command_phase_utility.html
pub const COM_QUIT: u8 = 0x01;
pub const COM_INIT_DB: u8 = 0x02;
Expand Down
125 changes: 118 additions & 7 deletions src/helper.zig
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/// This file is to store convenience functions and methods for callers.
const std = @import("std");
const constants = @import("./constants.zig");
const EnumFieldType = constants.EnumFieldType;
const result = @import("./result.zig");
const ResultSet = result.ResultSet;
const TextResultRow = result.TextResultRow;
const Options = result.BinaryResultRow.Options;
const PacketReader = @import("./protocol/packet_reader.zig").PacketReader;
const protocol = @import("./protocol.zig");
const PacketReader = protocol.packet_reader.PacketReader;
const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41;

pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void {
var packet_reader = PacketReader.initFromPayload(raw);
Expand All @@ -24,11 +27,119 @@ pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void {
}
}

pub fn scanBinaryResultRow(comptime T: type, raw: []const u8, dest: *T, options: Options) !void {
_ = options;
_ = dest;
_ = raw;
@panic("not implemented");
// dest is a pointer to a struct
pub fn scanBinResRowtoStruct(dest: anytype, raw: []const u8, col_defs: []ColumnDefinition41) void {
var reader = PacketReader.initFromPayload(raw);
const first = reader.readByte();
std.debug.assert(first == constants.BINARY_PROTOCOL_RESULTSET_ROW_HEADER);

// null bitmap
const null_bitmap_len = (col_defs.len + 7 + 2) / 8;
const null_bitmap = reader.readFixedRuntime(null_bitmap_len);

const child_type = @typeInfo(@TypeOf(dest)).Pointer.child;
const struct_fields = @typeInfo(child_type).Struct.fields;

std.debug.assert(struct_fields.len == col_defs.len);

inline for (struct_fields, col_defs, 0..) |field, col_def, i| {
if (!binResIsNull(null_bitmap, i)) {
@field(dest, field.name) = binElemToValue(field.type, &col_def, &reader);
} else {
switch (@typeInfo(field.type)) {
.Optional => @field(dest, field.name) = null,
else => {
std.log.err("field {s} is not optional\n", .{field.name});
unreachable;
},
}
}
}
std.debug.assert(reader.finished());
}

inline fn binElemToValue(comptime T: type, col_def: *const ColumnDefinition41, reader: *PacketReader) T {
_ = reader;
const col_type: EnumFieldType = @enumFromInt(col_def.column_type);
return switch (col_type) {
else => {
std.log.err("unimplemented col_type: {any}\n", .{col_type});
unreachable;
},
// .MYSQL_TYPE_DECIMAL => {},
// .MYSQL_TYPE_TINY => {},
// .MYSQL_TYPE_SHORT => {},
// .MYSQL_TYPE_LONG => {},
// .MYSQL_TYPE_FLOAT => {},
// .MYSQL_TYPE_DOUBLE => {},
// .MYSQL_TYPE_NULL => {},
// .MYSQL_TYPE_TIMESTAMP => {},
// .MYSQL_TYPE_LONGLONG => {},
// .MYSQL_TYPE_INT24 => {},
// .MYSQL_TYPE_DATE => {},
// .MYSQL_TYPE_TIME => {},
// .MYSQL_TYPE_DATETIME => {},
// .MYSQL_TYPE_YEAR => {},
// .MYSQL_TYPE_NEWDATE => {},
// .MYSQL_TYPE_VARCHAR => {},
// .MYSQL_TYPE_BIT => {},
// .MYSQL_TYPE_TIMESTAMP2 => {},
// .MYSQL_TYPE_DATETIME2 => {},
// .MYSQL_TYPE_TIME2 => {},
// .MYSQL_TYPE_TYPED_ARRAY => {},

// .MYSQL_TYPE_INVALID => {},
// .MYSQL_TYPE_BOOL => {},
// .MYSQL_TYPE_JSON => {},
// .MYSQL_TYPE_NEWDECIMAL => {},
// .MYSQL_TYPE_ENUM => {},
// .MYSQL_TYPE_SET => {},
// .MYSQL_TYPE_TINY_BLOB => {},
// .MYSQL_TYPE_MEDIUM_BLOB => {},
// .MYSQL_TYPE_LONG_BLOB => {},
// .MYSQL_TYPE_BLOB => {},
// .MYSQL_TYPE_VAR_STRING => {},
// .MYSQL_TYPE_STRING => {},
// .MYSQL_TYPE_GEOMETRY => {},
};
}

inline fn binResIsNull(null_bitmap: []const u8, col_idx: usize) bool {
// TODO: optimize: divmod
const byte_idx = (col_idx + 2) / 8;
const bit_idx = (col_idx + 2) % 8;
const byte = null_bitmap[byte_idx];
return (byte & (1 << bit_idx)) > 0;
}

test "binResIsNull" {
var tests = .{
.{
.null_bitmap = &.{0b00000100},
.col_idx = 0,
.expected = true,
},
.{
.null_bitmap = &.{0b00000000},
.col_idx = 0,
.expected = false,
},
.{
.null_bitmap = &.{ 0b00000000, 0b00000001 },
.col_idx = 6,
.expected = true,
},
.{
.null_bitmap = &.{ 0b10000000, 0b00000000 },
.col_idx = 5,
.expected = true,
},
};

inline for (tests) |t| {
const actual = binResIsNull(t.null_bitmap, t.col_idx);
try std.testing.expectEqual(t.expected, actual);
}
}

pub fn ResultSetIter(comptime ResultRowType: type) type {
Expand All @@ -55,7 +166,7 @@ pub fn ResultSetIter(comptime ResultRowType: type) type {
new_row_ptr.* = row;
}

const num_cols = iter.text_result_set.column_definitions.len;
const num_cols = iter.text_result_set.col_defs.len;
var rows = try allocator.alloc([]?[]const u8, row_acc.items.len);
var elems = try allocator.alloc(?[]const u8, row_acc.items.len * num_cols);
for (row_acc.items, 0..) |row, i| {
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/packet_reader.zig
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub const PacketReader = struct {
return bytes;
}

fn readFixedRuntime(packet_reader: *PacketReader, n: usize) []const u8 {
pub fn readFixedRuntime(packet_reader: *PacketReader, n: usize) []const u8 {
const bytes = packet_reader.payload[packet_reader.pos..][0..n];
packet_reader.pos += n;
return bytes;
Expand Down
33 changes: 14 additions & 19 deletions src/result.zig
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ pub fn QueryResult(comptime ResultRowType: type) type {
pub fn ResultSet(comptime ResultRowType: type) type {
return struct {
conn: *Conn,
column_packets: []Packet,
column_definitions: []ColumnDefinition41,
col_packets: []Packet,
col_defs: []ColumnDefinition41,

pub fn init(allocator: std.mem.Allocator, conn: *Conn, column_count: u64) !ResultSet(ResultRowType) {
var t: ResultSet(ResultRowType) = .{ .conn = conn, .column_packets = &.{}, .column_definitions = &.{} };
var t: ResultSet(ResultRowType) = .{ .conn = conn, .col_packets = &.{}, .col_defs = &.{} };
errdefer t.deinit(allocator);

t.column_packets = try allocator.alloc(Packet, column_count);
@memset(t.column_packets, Packet.safe_deinit());
t.column_definitions = try allocator.alloc(ColumnDefinition41, column_count);
t.col_packets = try allocator.alloc(Packet, column_count);
@memset(t.col_packets, Packet.safe_deinit());
t.col_defs = try allocator.alloc(ColumnDefinition41, column_count);

for (t.column_packets, t.column_definitions) |*pac, *def| {
for (t.col_packets, t.col_defs) |*pac, *def| {
pac.* = try conn.readPacket(allocator);
def.* = ColumnDefinition41.initFromPacket(pac);
}
Expand All @@ -57,11 +57,11 @@ pub fn ResultSet(comptime ResultRowType: type) type {
}

fn deinit(t: *const ResultSet(ResultRowType), allocator: std.mem.Allocator) void {
for (t.column_packets) |packet| {
for (t.col_packets) |packet| {
packet.deinit(allocator);
}
allocator.free(t.column_packets);
allocator.free(t.column_definitions);
allocator.free(t.col_packets);
allocator.free(t.col_defs);
}

pub fn readRow(t: *const ResultSet(ResultRowType), allocator: std.mem.Allocator) !ResultRowType {
Expand Down Expand Up @@ -95,7 +95,7 @@ pub const TextResultRow = struct {
},

pub fn scan(t: *const TextResultRow, dest: []?[]const u8) !void {
std.debug.assert(dest.len == t.result_set.column_definitions.len);
std.debug.assert(dest.len == t.result_set.col_defs.len);
switch (t.value) {
.err => |err| return err.asError(),
.eof => |eof| return eof.asError(),
Expand All @@ -119,11 +119,11 @@ pub const BinaryResultRow = struct {

const Options = struct {};

pub fn scanStruct(comptime T: type, t: *const BinaryResultRow, dest: ?*const T, options: Options) !void {
switch (t.value) {
pub fn scanStruct(b: *const BinaryResultRow, dest: anytype) !void {
switch (b.value) {
.err => |err| return err.asError(),
.eof => |eof| return eof.asError(),
.raw => try helper.scanTextBinaryRow(T, t.value.raw, dest, options),
.raw => |raw| helper.scanBinResRowtoStruct(dest, raw, b.result_set.col_defs),
}
}

Expand All @@ -140,11 +140,6 @@ pub const PrepareResult = struct {
},

pub fn deinit(p: *const PrepareResult, allocator: std.mem.Allocator) void {
// for (p.value.ok.packets, 0..) |pp, i| {
// std.debug.print("prepare result deinit: packet, {any} ptr: {any}\n", .{ i, @intFromPtr(pp.payload.ptr) });
// pp.deinit(allocator);
// }

p.packet.deinit(allocator);
switch (p.value) {
.ok => |prep_stmt| prep_stmt.deinit(allocator),
Expand Down

0 comments on commit e8cc323

Please sign in to comment.