Skip to content

Commit

Permalink
Merge pull request #8 from speed2exe/protocol-41
Browse files Browse the repository at this point in the history
chore: deprecated servers older than protocol 41
  • Loading branch information
speed2exe authored Mar 14, 2024
2 parents 3e1a7f9 + 9b61c9f commit f1e5dc6
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/config.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub const Config = struct {

// cfgs from Golang driver
client_found_rows: bool = false, // Return number of matching rows instead of rows changed
tls: bool = false,
ssl: bool = false,
multi_statements: bool = false,

pub fn capability_flags(config: *const Config) u32 {
Expand All @@ -25,7 +25,7 @@ pub const Config = struct {
if (config.client_found_rows) {
flags |= constants.CLIENT_FOUND_ROWS;
}
if (config.tls) {
if (config.ssl) {
flags |= constants.CLIENT_SSL;
}
if (config.multi_statements) {
Expand Down
13 changes: 7 additions & 6 deletions src/conn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub const Conn = struct {
capabilities: u32,
sequence_id: u8,

// Buffer to store metadata of the result set
result_meta: ResultMeta,

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
Expand All @@ -60,8 +61,8 @@ pub const Conn = struct {
const packet = try conn.readPacket();
const handshake_v10 = switch (packet.payload[0]) {
constants.HANDSHAKE_V10 => HandshakeV10.init(&packet),
constants.ERR => return ErrorPacket.init(&packet, 0).asError(),
else => return packet.asError(conn.capabilities),
constants.ERR => return ErrorPacket.initFirst(&packet).asError(),
else => return packet.asError(),
};
conn.capabilities = handshake_v10.capability_flags() & config.capability_flags();

Expand Down Expand Up @@ -105,7 +106,7 @@ pub const Conn = struct {

switch (packet.payload[0]) {
constants.OK => _ = OkPacket.init(&packet, c.capabilities),
else => return packet.asError(c.capabilities),
else => return packet.asError(),
}
}

Expand Down Expand Up @@ -174,7 +175,7 @@ pub const Conn = struct {
const packet = try c.readPacket();
return switch (packet.payload[0]) {
constants.OK => {},
else => packet.asError(c.capabilities),
else => packet.asError(),
};
}

Expand All @@ -199,7 +200,7 @@ pub const Conn = struct {
const resp_packet = try c.readPacket();
return switch (resp_packet.payload[0]) {
constants.OK => {},
else => resp_packet.asError(c.capabilities),
else => resp_packet.asError(),
};
}

Expand Down Expand Up @@ -242,7 +243,7 @@ pub const Conn = struct {
else => return error.UnsupportedCachingSha2PasswordMoreData,
}
},
else => return packet.asError(c.capabilities),
else => return packet.asError(),
}
}
}
Expand Down
42 changes: 22 additions & 20 deletions src/protocol/generic_response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ const PacketReader = @import("./packet_reader.zig").PacketReader;
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
pub const ErrorPacket = struct {
error_code: u16,
sql_state_marker: ?u8,
sql_state: ?*const [5]u8,
sql_state_marker: u8,
sql_state: *const [5]u8,
error_message: []const u8,

pub fn init(packet: *const Packet, capabilities: u32) ErrorPacket {
pub fn initFirst(packet: *const Packet) ErrorPacket {
var reader = packet.reader();
const header = reader.readByte();
std.debug.assert(header == constants.ERR);

var error_packet: ErrorPacket = undefined;
error_packet.error_code = reader.readInt(u16);
if (capabilities & constants.CLIENT_PROTOCOL_41 > 0) {
error_packet.sql_state_marker = reader.readByte();
error_packet.sql_state = reader.readRefComptime(5);
} else {
error_packet.sql_state_marker = null;
error_packet.sql_state = null;
}
error_packet.error_message = reader.readRefRemaining();
return error_packet;
}

pub fn init(packet: *const Packet) ErrorPacket {
var reader = packet.reader();
const header = reader.readByte();
std.debug.assert(header == constants.ERR);

var error_packet: ErrorPacket = undefined;
error_packet.error_code = reader.readInt(u16);

// CLIENT_PROTOCOL_41
error_packet.sql_state_marker = reader.readByte();
error_packet.sql_state = reader.readRefComptime(5);

error_packet.error_message = reader.readRefRemaining();
return error_packet;
}
Expand Down Expand Up @@ -57,16 +66,9 @@ pub const OkPacket = struct {
ok_packet.affected_rows = reader.readLengthEncodedInteger();
ok_packet.last_insert_id = reader.readLengthEncodedInteger();

if (capabilities & constants.CLIENT_PROTOCOL_41 > 0) {
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = reader.readInt(u16);
} else if (capabilities & constants.CLIENT_TRANSACTIONS > 0) {
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = null;
} else {
ok_packet.status_flags = null;
ok_packet.warnings = null;
}
// CLIENT_PROTOCOL_41
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = reader.readInt(u16);

ok_packet.session_state_info = null;
if (capabilities & constants.CLIENT_SESSION_TRACK > 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/handshake_response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const PacketWriter = @import("./packet_writer.zig").PacketWriter;

pub const HandshakeResponse41 = struct {
client_flag: u32, // capabilities
max_packet_size: u32 = 0,
max_packet_size: u32 = 0, // TODO: support configurable max packet size
character_set: u8,
username: [:0]const u8,
auth_response: []const u8,
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/packet.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ pub const Packet = struct {
return .{ .sequence_id = sequence_id, .payload = payload };
}

pub fn asError(packet: *const Packet, capabilities: u32) error{ UnexpectedPacket, ErrorPacket } {
pub fn asError(packet: *const Packet) error{ UnexpectedPacket, ErrorPacket } {
if (packet.payload[0] == constants.ERR) {
return ErrorPacket.init(packet, capabilities).asError();
return ErrorPacket.init(packet).asError();
}
std.log.warn("unexpected packet: {any}", .{packet});
return error.UnexpectedPacket;
Expand Down
14 changes: 7 additions & 7 deletions src/result.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub const QueryResult = union(enum) {
pub fn init(packet: *const Packet, capabilities: u32) !QueryResult {
return switch (packet.payload[0]) {
constants.OK => .{ .ok = OkPacket.init(packet, capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(packet, capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(packet) },
constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"),
else => {
std.log.warn(
Expand Down Expand Up @@ -64,9 +64,9 @@ pub fn QueryResultRows(comptime T: type) type {
\\Unexpected OkPacket: {any}\n,
\\If your query is not expecting a result set, use QueryResult instead.
, .{OkPacket.init(&packet, c.capabilities)});
return packet.asError(c.capabilities);
return packet.asError();
},
constants.ERR => .{ .err = ErrorPacket.init(&packet, c.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&packet) },
constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"),
else => .{ .rows = try ResultSet(T).init(c, &packet) },
};
Expand Down Expand Up @@ -254,7 +254,7 @@ pub fn ResultRow(comptime T: type) type {
fn init(conn: *Conn, col_defs: []const ColumnDefinition41) !ResultRow(T) {
const packet = try conn.readPacket();
return switch (packet.payload[0]) {
constants.ERR => .{ .err = ErrorPacket.init(&packet, conn.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&packet) },
constants.EOF => .{ .ok = OkPacket.init(&packet, conn.capabilities) },
else => .{ .row = .{ .packet = packet, .col_defs = col_defs } },
};
Expand Down Expand Up @@ -299,7 +299,7 @@ fn collectAllRowsPacketUntilEof(conn: *Conn, allocator: std.mem.Allocator) !std.
while (true) {
const packet = try conn.readPacket();
return switch (packet.payload[0]) {
constants.ERR => ErrorPacket.init(&packet, conn.capabilities).asError(),
constants.ERR => ErrorPacket.init(&packet).asError(),
constants.EOF => {
_ = OkPacket.init(&packet, conn.capabilities);
return packet_list;
Expand All @@ -320,9 +320,9 @@ pub const PrepareResult = union(enum) {
pub fn init(c: *Conn, allocator: std.mem.Allocator) !PrepareResult {
const response_packet = try c.readPacket();
return switch (response_packet.payload[0]) {
constants.ERR => .{ .err = ErrorPacket.init(&response_packet, c.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&response_packet) },
constants.OK => .{ .stmt = try PreparedStatement.init(&response_packet, c, allocator) },
else => return response_packet.asError(c.capabilities),
else => return response_packet.asError(),
};
}

Expand Down

0 comments on commit f1e5dc6

Please sign in to comment.