Skip to content

Commit

Permalink
refactor data types
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Dec 28, 2023
1 parent ab238d5 commit 7df9d96
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 107 deletions.
23 changes: 15 additions & 8 deletions integration_tests/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ test "query text protocol" {
var dest = [_]?[]const u8{undefined};
while (try rows.next(allocator)) |row| {
defer row.deinit(allocator);
try row.scan(&dest);
const data = try row.expect(.data);
try data.scan(&dest);
try std.testing.expectEqualSlices(u8, "1", dest[0].?);
}
}
Expand All @@ -70,7 +71,8 @@ test "query text protocol" {
var dest = [_]?[]const u8{ undefined, undefined };
while (try rows.next(allocator)) |row| {
defer row.deinit(allocator);
try row.scan(&dest);
const data = try row.expect(.data);
try data.scan(&dest);
try std.testing.expectEqualSlices(u8, "3", dest[0].?);
try std.testing.expectEqualSlices(u8, "4", dest[1].?);
}
Expand All @@ -82,7 +84,8 @@ test "query text protocol" {
var dest = [_]?[]const u8{ undefined, undefined, undefined };
while (try rows.next(allocator)) |row| {
defer row.deinit(allocator);
try row.scan(&dest);
const data = try row.expect(.data);
try data.scan(&dest);
try std.testing.expectEqualSlices(u8, "5", dest[0].?);
try std.testing.expectEqual(@as(?[]const u8, null), dest[1]);
try std.testing.expectEqualSlices(u8, "7", dest[2].?);
Expand All @@ -97,14 +100,16 @@ test "query text protocol" {
var dest = [_]?[]const u8{ undefined, undefined };
const row = try rows.readRow(allocator);
defer row.deinit(std.testing.allocator);
try row.scan(&dest);
const data = try row.expect(.data);
try data.scan(&dest);
try std.testing.expectEqualSlices(u8, "8", dest[0].?);
try std.testing.expectEqualSlices(u8, "9", dest[1].?);
}
{
const row = try rows.readRow(allocator);
defer row.deinit(std.testing.allocator);
const dest = try row.scanAlloc(allocator);
const data = try row.expect(.data);
const dest = try data.scanAlloc(allocator);
defer allocator.free(dest);
try std.testing.expectEqualSlices(u8, "10", dest[0].?);
try std.testing.expectEqualSlices(u8, "11", dest[1].?);
Expand All @@ -115,7 +120,7 @@ test "query text protocol" {
switch (row.value) {
.eof => {},
.err => |err| return err.asError(),
.raw => @panic("unexpected raw"),
.data => @panic("unexpected data"),
}
}
}
Expand Down Expand Up @@ -245,11 +250,13 @@ test "prepare execute with result" {
defer row.deinit(allocator);
{
var dest: MyType = undefined;
try row.scan(&dest);
const data = try row.expect(.data);
try data.scan(&dest);
try std.testing.expectEqualDeep(expected, dest);
}
{
const dest = try row.scanAlloc(MyType, allocator);
const data = try row.expect(.data);
const dest = try data.scanAlloc(MyType, allocator);
defer allocator.destroy(dest);
try std.testing.expectEqualDeep(&expected, dest);
}
Expand Down
8 changes: 4 additions & 4 deletions src/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ const Conn = conn.Conn;
const result = @import("./result.zig");
const QueryResult = result.QueryResult;
const PrepareResult = result.PrepareResult;
const TextResultRow = result.TextResultRow;
const BinaryResultRow = result.BinaryResultRow;
const PreparedStatement = result.PreparedStatement;
const TextResultData = result.TextResultData;
const BinaryResultData = result.BinaryResultData;

pub const Client = struct {
config: Config,
Expand All @@ -30,7 +30,7 @@ pub const Client = struct {
try client.conn.ping(allocator, &client.config);
}

pub fn query(client: *Client, allocator: std.mem.Allocator, query_string: []const u8) !QueryResult(TextResultRow) {
pub fn query(client: *Client, allocator: std.mem.Allocator, query_string: []const u8) !QueryResult(TextResultData) {
try client.connectIfNotConnected(allocator);
return client.conn.query(allocator, query_string);
}
Expand All @@ -40,7 +40,7 @@ pub const Client = struct {
return client.conn.prepare(allocator, query_string);
}

pub fn execute(client: *Client, allocator: std.mem.Allocator, prep_stmt: *const PreparedStatement, params: anytype) !QueryResult(BinaryResultRow) {
pub fn execute(client: *Client, allocator: std.mem.Allocator, prep_stmt: *const PreparedStatement, params: anytype) !QueryResult(BinaryResultData) {
try client.connectIfNotConnected(allocator);
return client.conn.execute(allocator, prep_stmt, params);
}
Expand Down
12 changes: 6 additions & 6 deletions src/conn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ const result = @import("./result.zig");
const QueryResult = result.QueryResult;
const PrepareResult = result.PrepareResult;
const PreparedStatement = result.PreparedStatement;
const TextResultRow = result.TextResultRow;
const BinaryResultRow = result.BinaryResultRow;
const TextResultData = result.TextResultData;
const BinaryResultData = result.BinaryResultData;
const ResultSet = result.ResultSet;
const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41;

Expand All @@ -49,12 +49,12 @@ pub const Conn = struct {

// TODO: add options
/// caller must consume the result by switching on the result's value
pub fn query(conn: *Conn, allocator: std.mem.Allocator, query_string: []const u8) !QueryResult(TextResultRow) {
pub fn query(conn: *Conn, allocator: std.mem.Allocator, query_string: []const u8) !QueryResult(TextResultData) {
std.debug.assert(conn.state == .connected);
conn.sequence_id = 0;
const query_request: QueryRequest = .{ .query = query_string };
try conn.sendPacketUsingSmallPacketWriter(query_request);
return QueryResult(TextResultRow).init(conn, allocator);
return QueryResult(TextResultData).init(conn, allocator);
}

// TODO: add options
Expand All @@ -74,15 +74,15 @@ pub const Conn = struct {
};
}

pub fn execute(conn: *Conn, allocator: std.mem.Allocator, prep_stmt: *const PreparedStatement, params: anytype) !QueryResult(BinaryResultRow) {
pub fn execute(conn: *Conn, allocator: std.mem.Allocator, prep_stmt: *const PreparedStatement, params: anytype) !QueryResult(BinaryResultData) {
std.debug.assert(conn.state == .connected);
conn.sequence_id = 0;
const execute_request: ExecuteRequest = .{
.capabilities = conn.client_capabilities,
.prep_stmt = prep_stmt,
};
try conn.sendPacketUsingSmallPacketWriterWithParams(execute_request, params);
return QueryResult(BinaryResultRow).init(conn, allocator);
return QueryResult(BinaryResultData).init(conn, allocator);
}

pub fn close(conn: *Conn) void {
Expand Down
48 changes: 25 additions & 23 deletions src/helper.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ const constants = @import("./constants.zig");
const EnumFieldType = constants.EnumFieldType;
const result = @import("./result.zig");
const ResultSet = result.ResultSet;
const TextResultRow = result.TextResultRow;
const BinaryResultRow = result.BinaryResultRow;
const Options = result.BinaryResultRow.Options;
const protocol = @import("./protocol.zig");
const PacketReader = protocol.packet_reader.PacketReader;
const packet_writer = protocol.packet_writer;
const ColumnDefinition41 = protocol.column_definition.ColumnDefinition41;
const DateTime = @import("./temporal.zig").DateTime;
const Duration = @import("./temporal.zig").Duration;
const ResultRow = result.ResultRow;
const TextResultData = result.TextResultData;
const BinaryResultData = result.BinaryResultData;

fn comptimeIntToUInt(
comptime Unsigned: type,
Expand Down Expand Up @@ -226,7 +226,7 @@ pub fn encodeBinaryParam(param: anytype, col_def: *const ColumnDefinition41, wri
// comptime FieldType: type, field_name: []const u8, col_name: []const u8, col_type: EnumFieldType)
// TODO: insert field name if struct is passed in
logConversionError(@TypeOf(param), "", col_def.name, col_type);
return error.InvalidConversion;
return error.EncodeBinaryParam;
}

// To save space the packet can be compressed:
Expand Down Expand Up @@ -287,7 +287,7 @@ fn encodeDuration(d: Duration, writer: anytype) !void {
}
}

pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void {
pub fn scanTextResultRow(dest: []?[]const u8, raw: []const u8) !void {
var packet_reader = PacketReader.initFromPayload(raw);
for (dest) |*d| {
d.* = blk: {
Expand All @@ -305,7 +305,7 @@ pub fn scanTextResultRow(raw: []const u8, dest: []?[]const u8) !void {
}

// dest is a pointer to a struct
pub fn scanBinResultRow(dest: anytype, raw: []const u8, col_defs: []const ColumnDefinition41) void {
pub fn scanBinResultRow(dest: anytype, raw: []const u8, col_defs: []const ColumnDefinition41) !void {
var reader = PacketReader.initFromPayload(raw);
const first = reader.readByte();
std.debug.assert(first == constants.BINARY_PROTOCOL_RESULTSET_ROW_HEADER);
Expand All @@ -328,15 +328,15 @@ pub fn scanBinResultRow(dest: anytype, raw: []const u8, col_defs: []const Column
if (isNull) {
@field(dest, field.name) = null;
} else {
@field(dest, field.name) = binElemToValue(field_info.Optional.child, field.name, &col_def, &reader);
@field(dest, field.name) = try binElemToValue(field_info.Optional.child, field.name, &col_def, &reader);
}
},
else => {
if (isNull) {
std.log.err("column: {s} value is null, but field: {s} is not nullable\n", .{ col_def.name, field.name });
unreachable;
}
@field(dest, field.name) = binElemToValue(field.type, field.name, &col_def, &reader);
@field(dest, field.name) = try binElemToValue(field.type, field.name, &col_def, &reader);
},
}
}
Expand All @@ -345,12 +345,12 @@ pub fn scanBinResultRow(dest: anytype, raw: []const u8, col_defs: []const Column

inline fn logConversionError(comptime FieldType: type, field_name: []const u8, col_name: []const u8, col_type: EnumFieldType) void {
std.log.err(
"MySQL Column: (name: {s}, type: {any}), Zig Value: (name: {s}, type: {any})\n",
"Conversion Error: MySQL Column: (name: {s}, type: {any}), Zig Value: (name: {s}, type: {any})\n",
.{ col_name, col_type, field_name, FieldType },
);
}

inline fn binElemToValue(comptime FieldType: type, field_name: []const u8, col_def: *const ColumnDefinition41, reader: *PacketReader) FieldType {
inline fn binElemToValue(comptime FieldType: type, field_name: []const u8, col_def: *const ColumnDefinition41, reader: *PacketReader) !FieldType {
const field_info = @typeInfo(FieldType);
const col_type: EnumFieldType = @enumFromInt(col_def.column_type);

Expand Down Expand Up @@ -400,7 +400,7 @@ inline fn binElemToValue(comptime FieldType: type, field_name: []const u8, col_d
}

logConversionError(FieldType, field_name, col_def.name, col_type);
unreachable;
return error.BinElemToValue;
}

inline fn binResIsNull(null_bitmap: []const u8, col_idx: usize) bool {
Expand Down Expand Up @@ -441,12 +441,12 @@ test "binResIsNull" {
}
}

pub fn ResultSetIter(comptime ResultRowType: type) type {
pub fn ResultSetIter(comptime T: type) type {
return struct {
result_set: *const ResultSet(ResultRowType),
result_set: *const ResultSet(T),

pub fn next(i: *const ResultSetIter(ResultRowType), allocator: std.mem.Allocator) !?ResultRowType {
const row = try i.result_set.readRow(allocator);
pub fn next(iter: *const ResultSetIter(T), allocator: std.mem.Allocator) !?ResultRow(T) {
const row = try iter.result_set.readRow(allocator);
return switch (row.value) {
.eof => {
// need to deinit as caller would not know to do so
Expand All @@ -458,8 +458,8 @@ pub fn ResultSetIter(comptime ResultRowType: type) type {
};
}

pub fn collectTexts(iter: *const ResultSetIter(TextResultRow), allocator: std.mem.Allocator) !TableTexts {
var row_acc = std.ArrayList(TextResultRow).init(allocator);
pub fn collectTexts(iter: *const ResultSetIter(TextResultData), allocator: std.mem.Allocator) !TableTexts {
var row_acc = std.ArrayList(ResultRow(TextResultData)).init(allocator);
while (try iter.next(allocator)) |row| {
const new_row_ptr = try row_acc.addOne();
new_row_ptr.* = row;
Expand All @@ -470,7 +470,8 @@ pub fn ResultSetIter(comptime ResultRowType: type) type {
var elems = try allocator.alloc(?[]const u8, row_acc.items.len * num_cols);
for (row_acc.items, 0..) |row, i| {
const dest_row = elems[i * num_cols .. (i + 1) * num_cols];
try row.scan(dest_row);
const data = try row.expect(.data);
try data.scan(dest_row);
rows[i] = dest_row;
}

Expand All @@ -481,16 +482,17 @@ pub fn ResultSetIter(comptime ResultRowType: type) type {
};
}

pub fn collectStructs(iter: *const ResultSetIter(BinaryResultRow), comptime Struct: type, allocator: std.mem.Allocator) !TableStructs(Struct) {
var row_acc = std.ArrayList(BinaryResultRow).init(allocator);
pub fn collectStructs(iter: *const ResultSetIter(BinaryResultData), comptime Struct: type, allocator: std.mem.Allocator) !TableStructs(Struct) {
var row_acc = std.ArrayList(ResultRow(BinaryResultData)).init(allocator);
while (try iter.next(allocator)) |row| {
const new_row_ptr = try row_acc.addOne();
new_row_ptr.* = row;
}

const structs = try allocator.alloc(Struct, row_acc.items.len);
for (row_acc.items, structs) |row, *s| {
try row.scan(s);
const data = try row.expect(.data);
try data.scan(s);
}

return .{
Expand All @@ -502,7 +504,7 @@ pub fn ResultSetIter(comptime ResultRowType: type) type {
}

pub const TableTexts = struct {
result_rows: []const TextResultRow,
result_rows: []const ResultRow(TextResultData),
elems: []const ?[]const u8,
rows: []const []const ?[]const u8,

Expand Down Expand Up @@ -531,7 +533,7 @@ pub const TableTexts = struct {

pub fn TableStructs(comptime Struct: type) type {
return struct {
result_rows: []const BinaryResultRow,
result_rows: []const ResultRow(BinaryResultData),
rows: []const Struct,

pub fn deinit(t: *const TableStructs(Struct), allocator: std.mem.Allocator) void {
Expand Down
Loading

0 comments on commit 7df9d96

Please sign in to comment.