Skip to content

Commit

Permalink
refactor: auth
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Nov 5, 2023
1 parent 5d1fe08 commit a5ee0b0
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 152 deletions.
4 changes: 3 additions & 1 deletion integration_tests/config.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
const Config = @import("../src/config.zig").Config;

// TODO: use a config with password
pub const test_config: Config = .{};
pub const test_config: Config = .{
.password = "password",
};
69 changes: 67 additions & 2 deletions src/auth_plugin.zig → src/auth.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
const std = @import("std");
const FixedBytes = @import("./utils.zig").FixedBytes;

const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n");

pub const AuthPlugin = enum {
Expand Down Expand Up @@ -43,7 +45,7 @@ pub const DecodedPublicKey = struct {
}
};

pub fn decode_public_key(encoded_bytes: []const u8, allocator: std.mem.Allocator) !DecodedPublicKey {
pub fn decodePublicKey(encoded_bytes: []const u8, allocator: std.mem.Allocator) !DecodedPublicKey {
var decoded_pk: DecodedPublicKey = undefined;

const start_marker = "-----BEGIN PUBLIC KEY-----";
Expand Down Expand Up @@ -101,6 +103,69 @@ test "decode public key" {
\\-----END PUBLIC KEY-----
;

const d = try decode_public_key(pk, std.testing.allocator);
const d = try decodePublicKey(pk, std.testing.allocator);
defer d.deinit(std.testing.allocator);
}

pub fn generate_auth_response(auth_plugin: AuthPlugin, auth_data: []const u8, password: []const u8) !FixedBytes(32) {
var result: FixedBytes(32) = .{};
switch (auth_plugin) {
.caching_sha2_password => if (password.len > 0) {
result.set(&scrambleSHA256Password(auth_data, password));
},
else => {
std.log.warn("Unsupported auth plugin: {any}\n", .{auth_plugin});
return error.UnsupportedAuthPlugin;
},
}
return result;
}

// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
fn scrambleSHA256Password(scramble: []const u8, password: []const u8) [32]u8 {
const Sha256 = std.crypto.hash.sha2.Sha256;

var message1 = blk: {
var hasher = Sha256.init(.{});
hasher.update(password);
break :blk hasher.finalResult();
};
const message2 = blk: {
var hasher = Sha256.init(.{});
hasher.update(&message1);
var temp = hasher.finalResult();

hasher = Sha256.init(.{});
hasher.update(&temp);
hasher.update(scramble);
hasher.final(&temp);
break :blk temp;
};
for (&message1, message2) |*m1, m2| {
m1.* ^= m2;
}
return message1;
}

test "scrambleSHA256Password" {
const scramble = [_]u8{ 10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21 };
const tests = [_]struct {
password: []const u8,
expected: [32]u8,
}{
.{
.password = "secret",
.expected = .{ 244, 144, 231, 111, 102, 217, 216, 102, 101, 206, 84, 217, 140, 120, 208, 172, 254, 47, 176, 176, 139, 66, 61, 168, 7, 20, 72, 115, 211, 11, 49, 44 },
},
.{
.password = "secret2",
.expected = .{ 171, 195, 147, 74, 1, 44, 243, 66, 232, 118, 7, 28, 142, 226, 2, 222, 81, 120, 91, 67, 2, 88, 167, 160, 19, 139, 199, 156, 77, 128, 11, 198 },
},
};

for (tests) |t| {
const actual = scrambleSHA256Password(&scramble, t.password);
// std.debug.print("actual: {x}", .{ std.fmt.fmtSliceHexLower(&actual) });
try std.testing.expectEqual(t.expected, actual);
}
}
185 changes: 46 additions & 139 deletions src/conn.zig
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
const std = @import("std");
const Config = @import("./config.zig").Config;
const constants = @import("./constants.zig");
const auth_plugin = @import("./auth_plugin.zig");
const AuthPlugin = auth_plugin.AuthPlugin;
const auth = @import("./auth.zig");
const generate_auth_response = auth.generate_auth_response;
const AuthPlugin = auth.AuthPlugin;
const protocol = @import("./protocol.zig");
const HandshakeV10 = protocol.handshake_v10.HandshakeV10;
const ErrorPacket = protocol.generic_response.ErrorPacket;
Expand Down Expand Up @@ -143,7 +144,8 @@ pub const Conn = struct {
conn.sequence_id = 0;
conn.client_capabilities = config.capability_flags();

var auth: AuthPlugin = undefined;
var auth_plugin: AuthPlugin = undefined;
var auth_data: [20]u8 = undefined;
{
const packet = try conn.readPacket(allocator);
defer packet.deinit(allocator);
Expand All @@ -153,17 +155,27 @@ pub const Conn = struct {
else => return packet.asError(conn.client_capabilities),
};
conn.server_capabilities = handshake_v10.capability_flags();
auth = handshake_v10.get_auth_plugin();

auth_plugin = handshake_v10.get_auth_plugin();
auth_data = handshake_v10.get_auth_data();

// TODO: TLS handshake if enabled

// send handshake response to server
if (conn.hasCapability(constants.CLIENT_PROTOCOL_41)) {
try conn.sendHandshakeResponse41(
auth,
const auth_resp = try generate_auth_response(
handshake_v10.get_auth_plugin(),
&handshake_v10.get_auth_data(),
config,
config.password,
);
const response: HandshakeResponse41 = .{
.database = config.database,
.client_flag = conn.client_capabilities,
.character_set = config.collation,
.username = config.username,
.auth_response = auth_resp.get(),
};
try conn.sendPacketUsingSmallPacketWriter(response);
} else {
// TODO: handle older protocol
@panic("not implemented");
Expand All @@ -175,42 +187,45 @@ pub const Conn = struct {
defer packet.deinit(allocator);

switch (packet.payload[0]) {
constants.OK => {
_ = OkPacket.initFromPacket(&packet, conn.client_capabilities);
return;
},
constants.OK => return,
constants.AUTH_SWITCH => {
const auth_switch = AuthSwitchRequest.initFromPacket(&packet);
auth = AuthPlugin.fromName(auth_switch.plugin_name);
try conn.sendAuthSwitchResponse(
auth,
auth_switch.plugin_data,
config,
);
auth_plugin = AuthPlugin.fromName(auth_switch.plugin_name);
const auth_resp = try generate_auth_response(auth_plugin, auth_switch.plugin_data, config.password);
try conn.sendBytesAsPacket(auth_resp.get());
},
constants.AUTH_MORE_DATA => {
// more auth exchange based on auth_method
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods.html
const more_data = packet.payload[1..];
switch (auth) {
switch (auth_plugin) {
.caching_sha2_password => {
switch (more_data[0]) {
auth_plugin.caching_sha2_password_fast_auth_success => {
// Fast auth success
},
auth_plugin.caching_sha2_password_full_authentication_start => {
auth.caching_sha2_password_fast_auth_success => return, // success (no more action needed)
auth.caching_sha2_password_full_authentication_start => {
// Full Authentication start

// TODO: Implement sending encrypted password with server's public key
// when we can parse, decrypt and ecrypt data with RSA
//
// try conn.sendAndFlushAsPacket(&[_]u8{auth_plugin.caching_sha2_password_public_key_request});
// const public_key_packet = try conn.readPacket(allocator);
// defer public_key_packet.deinit(allocator);
try conn.sendBytesAsPacket(&[_]u8{auth.caching_sha2_password_public_key_request});
const pk_packet = try conn.readPacket(allocator);
defer pk_packet.deinit(allocator);

const pub_key = try auth.decodePublicKey(pk_packet.payload, allocator);
defer pub_key.deinit(allocator);

// if TLS, send password as plain text
// try conn.sendAndFlushAsPacket(config.password);
return error.NotImplemented;
// TODO: support TLS
// // if TLS, send password as plain text
// try conn.sendBytesAsPacket(config.password);
const auth_resp = try generate_auth_response(.sha256_password, &auth_data, config.password);
try conn.sendBytesAsPacket(auth_resp.get());

const resp_packet = try conn.readPacket(allocator);
defer resp_packet.deinit(allocator);

switch (resp_packet.payload[0]) {
constants.OK => _ = OkPacket.initFromPacket(&resp_packet, conn.client_capabilities),
constants.ERR => return ErrorPacket.initFromPacket(false, &resp_packet, conn.client_capabilities).asError(),
else => return resp_packet.asError(conn.client_capabilities),
}
},
else => return error.UnsupportedCachingSha2PasswordMoreData,
}
Expand All @@ -225,44 +240,6 @@ pub const Conn = struct {
// Server ack
}

fn sendAuthSwitchResponse(
conn: *Conn,
auth: AuthPlugin,
plugin_data: []const u8,
config: *const Config,
) !void {
var auth_response: FixedBytes(32) = .{};
try generate_auth_response(
auth,
plugin_data,
config.password,
&auth_response,
);
try conn.sendBytesAsPacket(auth_response.get());
}

fn sendHandshakeResponse41(conn: *Conn, auth: AuthPlugin, auth_data: []const u8, config: *const Config) !void {
var auth_response: FixedBytes(32) = .{};
try generate_auth_response(
auth,
auth_data,
config.password,
&auth_response,
);
// TODO: support CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
// if (password_resp.len > 250) {
// resp_cap_flag |= constants.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA;
// }
const response: HandshakeResponse41 = .{
.database = config.database,
.client_flag = conn.client_capabilities,
.character_set = config.collation,
.username = config.username,
.auth_response = auth_response.get(),
};
try conn.sendPacketUsingSmallPacketWriter(response);
}

fn sendPacketUsingSmallPacketWriter(conn: *Conn, packet: anytype) !void {
std.debug.assert(conn.state == .connected);
var writer = conn.writer;
Expand Down Expand Up @@ -304,73 +281,3 @@ pub const Conn = struct {
return id;
}
};

fn generate_auth_response(
auth: AuthPlugin,
auth_data: []const u8,
password: []const u8,
out: *FixedBytes(32),
) !void {
switch (auth) {
.caching_sha2_password => {
if (password.len == 0) {
try out.set("");
} else {
try out.set(&scrambleSHA256Password(auth_data, password));
}
},
else => {
std.log.warn("Unsupported auth plugin: {any}\n", .{auth_plugin});
return error.UnsupportedAuthPlugin;
},
}
}

// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
fn scrambleSHA256Password(scramble: []const u8, password: []const u8) [32]u8 {
const Sha256 = std.crypto.hash.sha2.Sha256;

var message1 = blk: {
var hasher = Sha256.init(.{});
hasher.update(password);
break :blk hasher.finalResult();
};
const message2 = blk: {
var hasher = Sha256.init(.{});
hasher.update(&message1);
var temp = hasher.finalResult();

hasher = Sha256.init(.{});
hasher.update(&temp);
hasher.update(scramble);
hasher.final(&temp);
break :blk temp;
};
for (&message1, message2) |*m1, m2| {
m1.* ^= m2;
}
return message1;
}

test "scrambleSHA256Password" {
const scramble = [_]u8{ 10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21 };
const tests = [_]struct {
password: []const u8,
expected: [32]u8,
}{
.{
.password = "secret",
.expected = .{ 244, 144, 231, 111, 102, 217, 216, 102, 101, 206, 84, 217, 140, 120, 208, 172, 254, 47, 176, 176, 139, 66, 61, 168, 7, 20, 72, 115, 211, 11, 49, 44 },
},
.{
.password = "secret2",
.expected = .{ 171, 195, 147, 74, 1, 44, 243, 66, 232, 118, 7, 28, 142, 226, 2, 222, 81, 120, 91, 67, 2, 88, 167, 160, 19, 139, 199, 156, 77, 128, 11, 198 },
},
};

for (tests) |t| {
const actual = scrambleSHA256Password(&scramble, t.password);
// std.debug.print("actual: {x}", .{ std.fmt.fmtSliceHexLower(&actual) });
try std.testing.expectEqual(t.expected, actual);
}
}
2 changes: 1 addition & 1 deletion src/protocol/handshake_v10.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const std = @import("std");
const Packet = @import("./packet.zig").Packet;
const PacketReader = @import("./packet_reader.zig").PacketReader;
const constants = @import("../constants.zig");
const AuthPlugin = @import("../auth_plugin.zig").AuthPlugin;
const AuthPlugin = @import("../auth.zig").AuthPlugin;

pub const HandshakeV10 = struct {
server_version: [:0]const u8,
Expand Down
17 changes: 8 additions & 9 deletions src/utils.zig
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
const std = @import("std");

// This is a fixed-size byte array to avoid heap allocation.
pub fn FixedBytes(comptime max: usize) type {
return struct {
buf: [max]u8 = undefined,
Expand All @@ -6,15 +9,11 @@ pub fn FixedBytes(comptime max: usize) type {
pub fn get(self: *const FixedBytes(max)) []const u8 {
return self.buf[0..self.len];
}
pub fn set(self: *FixedBytes(max), s: []const u8) !void {
if (s.len > max) {
return error.SourceTooLarge;
}
self.len = 0;
for (s) |c| {
self.buf[self.len] = c;
self.len += 1;
}

pub fn set(self: *FixedBytes(max), src: []const u8) void {
std.debug.assert(src.len <= max);
var dest = self.buf[0..src.len];
@memcpy(dest, src);
}
};
}

0 comments on commit a5ee0b0

Please sign in to comment.