Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: deprecated servers older than protocol 41 #8

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/config.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub const Config = struct {

// cfgs from Golang driver
client_found_rows: bool = false, // Return number of matching rows instead of rows changed
tls: bool = false,
ssl: bool = false,
multi_statements: bool = false,

pub fn capability_flags(config: *const Config) u32 {
Expand All @@ -25,7 +25,7 @@ pub const Config = struct {
if (config.client_found_rows) {
flags |= constants.CLIENT_FOUND_ROWS;
}
if (config.tls) {
if (config.ssl) {
flags |= constants.CLIENT_SSL;
}
if (config.multi_statements) {
Expand Down
13 changes: 7 additions & 6 deletions src/conn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub const Conn = struct {
capabilities: u32,
sequence_id: u8,

// Buffer to store metadata of the result set
result_meta: ResultMeta,

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
Expand All @@ -60,8 +61,8 @@ pub const Conn = struct {
const packet = try conn.readPacket();
const handshake_v10 = switch (packet.payload[0]) {
constants.HANDSHAKE_V10 => HandshakeV10.init(&packet),
constants.ERR => return ErrorPacket.init(&packet, 0).asError(),
else => return packet.asError(conn.capabilities),
constants.ERR => return ErrorPacket.initFirst(&packet).asError(),
else => return packet.asError(),
};
conn.capabilities = handshake_v10.capability_flags() & config.capability_flags();

Expand Down Expand Up @@ -105,7 +106,7 @@ pub const Conn = struct {

switch (packet.payload[0]) {
constants.OK => _ = OkPacket.init(&packet, c.capabilities),
else => return packet.asError(c.capabilities),
else => return packet.asError(),
}
}

Expand Down Expand Up @@ -174,7 +175,7 @@ pub const Conn = struct {
const packet = try c.readPacket();
return switch (packet.payload[0]) {
constants.OK => {},
else => packet.asError(c.capabilities),
else => packet.asError(),
};
}

Expand All @@ -199,7 +200,7 @@ pub const Conn = struct {
const resp_packet = try c.readPacket();
return switch (resp_packet.payload[0]) {
constants.OK => {},
else => resp_packet.asError(c.capabilities),
else => resp_packet.asError(),
};
}

Expand Down Expand Up @@ -242,7 +243,7 @@ pub const Conn = struct {
else => return error.UnsupportedCachingSha2PasswordMoreData,
}
},
else => return packet.asError(c.capabilities),
else => return packet.asError(),
}
}
}
Expand Down
42 changes: 22 additions & 20 deletions src/protocol/generic_response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ const PacketReader = @import("./packet_reader.zig").PacketReader;
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
pub const ErrorPacket = struct {
error_code: u16,
sql_state_marker: ?u8,
sql_state: ?*const [5]u8,
sql_state_marker: u8,
sql_state: *const [5]u8,
error_message: []const u8,

pub fn init(packet: *const Packet, capabilities: u32) ErrorPacket {
pub fn initFirst(packet: *const Packet) ErrorPacket {
var reader = packet.reader();
const header = reader.readByte();
std.debug.assert(header == constants.ERR);

var error_packet: ErrorPacket = undefined;
error_packet.error_code = reader.readInt(u16);
if (capabilities & constants.CLIENT_PROTOCOL_41 > 0) {
error_packet.sql_state_marker = reader.readByte();
error_packet.sql_state = reader.readRefComptime(5);
} else {
error_packet.sql_state_marker = null;
error_packet.sql_state = null;
}
error_packet.error_message = reader.readRefRemaining();
return error_packet;
}

pub fn init(packet: *const Packet) ErrorPacket {
var reader = packet.reader();
const header = reader.readByte();
std.debug.assert(header == constants.ERR);

var error_packet: ErrorPacket = undefined;
error_packet.error_code = reader.readInt(u16);

// CLIENT_PROTOCOL_41
error_packet.sql_state_marker = reader.readByte();
error_packet.sql_state = reader.readRefComptime(5);

error_packet.error_message = reader.readRefRemaining();
return error_packet;
}
Expand Down Expand Up @@ -57,16 +66,9 @@ pub const OkPacket = struct {
ok_packet.affected_rows = reader.readLengthEncodedInteger();
ok_packet.last_insert_id = reader.readLengthEncodedInteger();

if (capabilities & constants.CLIENT_PROTOCOL_41 > 0) {
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = reader.readInt(u16);
} else if (capabilities & constants.CLIENT_TRANSACTIONS > 0) {
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = null;
} else {
ok_packet.status_flags = null;
ok_packet.warnings = null;
}
// CLIENT_PROTOCOL_41
ok_packet.status_flags = reader.readInt(u16);
ok_packet.warnings = reader.readInt(u16);

ok_packet.session_state_info = null;
if (capabilities & constants.CLIENT_SESSION_TRACK > 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/handshake_response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const PacketWriter = @import("./packet_writer.zig").PacketWriter;

pub const HandshakeResponse41 = struct {
client_flag: u32, // capabilities
max_packet_size: u32 = 0,
max_packet_size: u32 = 0, // TODO: support configurable max packet size
character_set: u8,
username: [:0]const u8,
auth_response: []const u8,
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/packet.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ pub const Packet = struct {
return .{ .sequence_id = sequence_id, .payload = payload };
}

pub fn asError(packet: *const Packet, capabilities: u32) error{ UnexpectedPacket, ErrorPacket } {
pub fn asError(packet: *const Packet) error{ UnexpectedPacket, ErrorPacket } {
if (packet.payload[0] == constants.ERR) {
return ErrorPacket.init(packet, capabilities).asError();
return ErrorPacket.init(packet).asError();
}
std.log.warn("unexpected packet: {any}", .{packet});
return error.UnexpectedPacket;
Expand Down
14 changes: 7 additions & 7 deletions src/result.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub const QueryResult = union(enum) {
pub fn init(packet: *const Packet, capabilities: u32) !QueryResult {
return switch (packet.payload[0]) {
constants.OK => .{ .ok = OkPacket.init(packet, capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(packet, capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(packet) },
constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"),
else => {
std.log.warn(
Expand Down Expand Up @@ -64,9 +64,9 @@ pub fn QueryResultRows(comptime T: type) type {
\\Unexpected OkPacket: {any}\n,
\\If your query is not expecting a result set, use QueryResult instead.
, .{OkPacket.init(&packet, c.capabilities)});
return packet.asError(c.capabilities);
return packet.asError();
},
constants.ERR => .{ .err = ErrorPacket.init(&packet, c.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&packet) },
constants.LOCAL_INFILE_REQUEST => _ = @panic("not implemented"),
else => .{ .rows = try ResultSet(T).init(c, &packet) },
};
Expand Down Expand Up @@ -254,7 +254,7 @@ pub fn ResultRow(comptime T: type) type {
fn init(conn: *Conn, col_defs: []const ColumnDefinition41) !ResultRow(T) {
const packet = try conn.readPacket();
return switch (packet.payload[0]) {
constants.ERR => .{ .err = ErrorPacket.init(&packet, conn.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&packet) },
constants.EOF => .{ .ok = OkPacket.init(&packet, conn.capabilities) },
else => .{ .row = .{ .packet = packet, .col_defs = col_defs } },
};
Expand Down Expand Up @@ -299,7 +299,7 @@ fn collectAllRowsPacketUntilEof(conn: *Conn, allocator: std.mem.Allocator) !std.
while (true) {
const packet = try conn.readPacket();
return switch (packet.payload[0]) {
constants.ERR => ErrorPacket.init(&packet, conn.capabilities).asError(),
constants.ERR => ErrorPacket.init(&packet).asError(),
constants.EOF => {
_ = OkPacket.init(&packet, conn.capabilities);
return packet_list;
Expand All @@ -320,9 +320,9 @@ pub const PrepareResult = union(enum) {
pub fn init(c: *Conn, allocator: std.mem.Allocator) !PrepareResult {
const response_packet = try c.readPacket();
return switch (response_packet.payload[0]) {
constants.ERR => .{ .err = ErrorPacket.init(&response_packet, c.capabilities) },
constants.ERR => .{ .err = ErrorPacket.init(&response_packet) },
constants.OK => .{ .stmt = try PreparedStatement.init(&response_packet, c, allocator) },
else => return response_packet.asError(c.capabilities),
else => return response_packet.asError(),
};
}

Expand Down
Loading