Skip to content

Commit

Permalink
feat: authentication success for password
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Oct 13, 2023
1 parent 45adec5 commit e350532
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 45 deletions.
29 changes: 29 additions & 0 deletions src/auth_plugin.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
const std = @import("std");

pub const AuthPlugin = enum {
unspecified,
mysql_native_password,
sha256_password,
caching_sha2_password,
mysql_clear_password,
unknown,

pub fn fromName(name: []const u8) AuthPlugin {
if (std.mem.eql(u8, name, "mysql_native_password")) {
return .mysql_native_password;
} else if (std.mem.eql(u8, name, "sha256_password")) {
return .sha256_password;
} else if (std.mem.eql(u8, name, "caching_sha2_password")) {
return .caching_sha2_password;
} else if (std.mem.eql(u8, name, "mysql_clear_password")) {
return .mysql_clear_password;
} else {
return .unknown;
}
}
};

pub const caching_sha2_password_public_key_request = 0x01;
pub const caching_sha2_password_public_key_response = 0x02;
pub const caching_sha2_password_scramble_success = 0x03;
pub const caching_sha2_password_scramble_failure = 0x04;
85 changes: 50 additions & 35 deletions src/conn.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const std = @import("std");
const Config = @import("./config.zig").Config;
const constants = @import("./constants.zig");
const auth_plugin = @import("./auth_plugin.zig");
const AuthPlugin = auth_plugin.AuthPlugin;
const protocol = @import("./protocol.zig");
const HandshakeV10 = protocol.handshake_v10.HandshakeV10;
const ErrorPacket = protocol.generic_response.ErrorPacket;
Expand Down Expand Up @@ -28,7 +30,7 @@ pub const Conn = struct {
reader: stream_buffered.Reader = undefined,
writer: stream_buffered.Writer = undefined,
server_capabilities: u32 = 0,
current_sequence_id: u8 = 0,
sequence_id: u8 = 0,

pub fn close(conn: *Conn) void {
switch (conn.state) {
Expand All @@ -52,13 +54,13 @@ pub const Conn = struct {
}

fn updateSequenceId(conn: *Conn, packet: Packet) !void {
std.debug.assert(packet.sequence_id == conn.current_sequence_id);
conn.current_sequence_id += 1;
std.debug.assert(packet.sequence_id == conn.sequence_id);
conn.sequence_id += 1;
}

fn generateSequenceId(conn: *Conn) u8 {
const id = conn.current_sequence_id;
conn.current_sequence_id += 1;
const id = conn.sequence_id;
conn.sequence_id += 1;
return id;
}

Expand All @@ -67,7 +69,7 @@ pub const Conn = struct {
try conn.dial(config.address);
errdefer conn.close();

var auth_plugin_name: FixedBytes(32) = .{};
var auth: AuthPlugin = undefined;
{
const packet = try conn.readPacket(allocator);
defer packet.deinit(allocator);
Expand All @@ -77,15 +79,17 @@ pub const Conn = struct {
else => return packet.asError(config.capability_flags()),
};
conn.server_capabilities = handshake_v10.capability_flags();
if (handshake_v10.auth_plugin_name) |p| {
try auth_plugin_name.set(p);
}
auth = handshake_v10.get_auth_plugin();

// TODO: TLS handshake if enabled

// send handshake response to server
if (conn.hasCapability(constants.CLIENT_PROTOCOL_41)) {
try conn.sendHandshakeResponse41(handshake_v10, config);
try conn.sendHandshakeResponse41(
auth,
&handshake_v10.get_auth_data(),
config,
);
} else {
// TODO: handle older protocol
@panic("not implemented");
Expand All @@ -103,20 +107,29 @@ pub const Conn = struct {
},
constants.AUTH_SWITCH => {
const auth_switch = AuthSwitchRequest.initFromPacket(&packet);
try auth_plugin_name.set(auth_switch.plugin_name);
auth = AuthPlugin.fromName(auth_switch.plugin_name);
try conn.sendAuthSwitchResponse(
auth_switch.plugin_name,
auth_switch.plugin_name,
auth,
auth_switch.plugin_data,
config,
);
},
constants.AUTH_MORE_DATA => {
// more auth exchange based on auth_method
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods.html
const more_data = packet.payload[1..];
try conn.sendAuthSwitchResponse(
auth_plugin_name.get(),
more_data,
config,
);
switch (auth) {
.caching_sha2_password => {
switch (more_data[0]) {
auth_plugin.caching_sha2_password_scramble_success => {},
auth_plugin.caching_sha2_password_scramble_failure => {
return error.NotImplemented;
},
else => return error.UnsupportedCachingSha2PasswordMoreData,
}
},
else => {},
}
},
else => return packet.asError(config.capability_flags()),
}
Expand All @@ -127,25 +140,25 @@ pub const Conn = struct {

fn sendAuthSwitchResponse(
conn: *Conn,
plugin_name: []const u8,
auth: AuthPlugin,
plugin_data: []const u8,
config: *const Config,
) !void {
var auth_response: FixedBytes(32) = .{};
try generate_auth_response(
plugin_name,
auth,
plugin_data,
config.password,
&auth_response,
);
try conn.sendAndFlushAsPacket(auth_response.get());
}

fn sendHandshakeResponse41(conn: *Conn, handshake_v10: HandshakeV10, config: *const Config) !void {
fn sendHandshakeResponse41(conn: *Conn, auth: AuthPlugin, auth_data: []const u8, config: *const Config) !void {
var auth_response: FixedBytes(32) = .{};
try generate_auth_response(
handshake_v10.get_auth_plugin_name(),
&handshake_v10.get_auth_data(),
auth,
auth_data,
config.password,
&auth_response,
);
Expand All @@ -167,7 +180,7 @@ pub const Conn = struct {
}

pub fn ping(conn: *Conn, allocator: std.mem.Allocator, config: *const Config) !void {
conn.current_sequence_id = 0;
conn.sequence_id = 0;
try conn.sendAndFlushAsPacket(&[_]u8{commands.COM_PING});
const packet = try conn.readPacket(allocator);
defer packet.deinit(allocator);
Expand Down Expand Up @@ -196,21 +209,23 @@ pub const Conn = struct {
};

fn generate_auth_response(
auth_plugin_name: []const u8,
auth: AuthPlugin,
auth_data: []const u8,
password: []const u8,
out: *FixedBytes(32),
) !void {
if (std.mem.eql(u8, auth_plugin_name, constants.caching_sha2_password)) {
if (password.len == 0) {
try out.set("");
} else {
try out.set(&scrambleSHA256Password(auth_data, password));
}
} else {
// TODO: support more
std.log.err("Unsupported auth plugin: |{s}|(contribution are welcome!)\n", .{auth_plugin_name});
return error.UnsupportedAuthPlugin;
switch (auth) {
.caching_sha2_password => {
if (password.len == 0) {
try out.set("");
} else {
try out.set(&scrambleSHA256Password(auth_data, password));
}
},
else => {
std.log.err("Unsupported auth plugin: {any}\n", .{auth_plugin});
return error.UnsupportedAuthPlugin;
},
}
}

Expand Down
6 changes: 0 additions & 6 deletions src/constants.zig
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,3 @@ pub const CLIENT_CAPABILITY_EXTENSION: u32 = 1 << 29;
pub const CLIENT_SSL_VERIFY_SERVER_CERT: u32 = 1 << 30;

pub const MAX_CAPABILITIES: u32 = std.math.maxInt(u32);

// plugin names
pub const mysql_native_password = "mysql_native_password";
pub const sha256_password = "sha256_password";
pub const caching_sha2_password = "caching_sha2_password";
pub const mysql_clear_password = "mysql_clear_password";
2 changes: 1 addition & 1 deletion src/protocol/handshake_response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub const HandshakeResponse41 = struct {
username: [:0]const u8,
auth_response: []const u8,
database: [:0]const u8,
client_plugin_name: [:0]const u8 = constants.caching_sha2_password,
client_plugin_name: [:0]const u8 = "caching_sha2_password",
key_values: []const [2][]const u8 = &.{},
zstd_compression_level: u8 = 0,

Expand Down
6 changes: 4 additions & 2 deletions src/protocol/handshake_v10.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");
const Packet = @import("./packet.zig").Packet;
const PacketReader = @import("./packet_reader.zig").PacketReader;
const constants = @import("../constants.zig");
const AuthPlugin = @import("../auth_plugin.zig").AuthPlugin;

pub const HandshakeV10 = struct {
protocol_version: u8,
Expand Down Expand Up @@ -68,8 +69,9 @@ pub const HandshakeV10 = struct {
return f;
}

pub fn get_auth_plugin_name(h: *const HandshakeV10) []const u8 {
return h.auth_plugin_name orelse "mysql_native_password";
pub fn get_auth_plugin(h: *const HandshakeV10) AuthPlugin {
const name = h.auth_plugin_name orelse return .unspecified;
return AuthPlugin.fromName(name);
}

pub fn get_auth_data(h: *const HandshakeV10) [20]u8 {
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/packet.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub const Packet = struct {
if (packet.payload[0] == constants.ERR) {
return ErrorPacket.initFromPacket(false, packet, capabilities).asError();
}
std.log.err("unexpected packet: {}", .{packet.payload[0]});
std.log.err("unexpected packet: {any}", .{packet.payload[0]});
return error.UnexpectedPacket;
}

Expand Down

0 comments on commit e350532

Please sign in to comment.