Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: code cleanup for Location type and some tests. #13

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 61 additions & 68 deletions src/avl.zig
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,6 @@ fn makePtrLocationType(comptime K: type, comptime V: type, comptime Tags: type)
fn setParent(self: *Self, p: ?Self) void {
self.ptr.*.parent = p;
}

fn recalcHeight(self: *Self) bool {
var h: u8 = 0;
if (self.ptr.*.left) |l| {
h = 1 + l.ptr.*.data.h;
}
if (self.ptr.*.right) |r| {
h = @max(h, 1 + r.ptr.*.data.h);
}
return self.data().setHeight(h);
}

fn balance(self: *const Self) i8 {
var b: i8 = 0;
if (self.ptr.*.right) |right| {
b += 1 + @as(i8, @intCast(right.ptr.*.data.h));
}
if (self.ptr.*.left) |left| {
b -= 1 + @as(i8, @intCast(left.ptr.*.data.h));
}
return b;
}
};
}

Expand Down Expand Up @@ -174,6 +152,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
else
struct {};

const KeyType = K;
const ValueType = V;
const Cache = locationCache(K, V, Tags);
const Location = Cache.Location;
const Comparer = Cmp;
Expand Down Expand Up @@ -327,10 +307,10 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:

fn recalcCounts(loc: Location) void {
var count: u32 = 0;
if (loc.ptr.*.left) |left| {
if (loc.child(.left)) |left| {
count += 1 + left.data().tags.childrenCount;
}
if (loc.ptr.*.right) |right| {
if (loc.child(.right)) |right| {
count += 1 + right.data().tags.childrenCount;
}
loc.data().tags.childrenCount = count;
Expand All @@ -346,12 +326,34 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
}

fn leftCount(loc: Location) usize {
if (loc.ptr.*.left) |left| {
if (loc.child(.left)) |left| {
return 1 + left.data().tags.childrenCount;
}
return 0;
}

fn recalcHeight(loc: Location) bool {
var h: u8 = 0;
if (loc.child(.left)) |l| {
h = 1 + l.data().h;
}
if (loc.child(.right)) |r| {
h = @max(h, 1 + r.data().h);
}
return loc.data().setHeight(h);
}

fn balance(loc: Location) i8 {
var b: i8 = 0;
if (loc.child(.right)) |right| {
b += 1 + @as(i8, @intCast(right.data().h));
}
if (loc.child(.left)) |left| {
b -= 1 + @as(i8, @intCast(left.data().h));
}
return b;
}

pub const Iterator = struct {
tree: *Self,
loc: ?Location,
Expand Down Expand Up @@ -508,7 +510,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
} else if (where.dir == .right and l.eq(self.max.?)) {
self.max = new_loc;
}
if (l.recalcHeight()) {
if (recalcHeight(l)) {
if (options.countChildren) {
recalcCounts(l);
}
Expand Down Expand Up @@ -601,7 +603,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
if (right) |r| {
// Russell A. Brown, Optimized Deletion From an AVL Tree.
// https://arxiv.org/pdf/2406.05162v5
if (loc.balance() <= 0) {
if (balance(loc) <= 0) {
return goRight(l);
}
return goLeft(r);
Expand Down Expand Up @@ -735,18 +737,18 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
var mutLoc = loc;
while (true) {
var l = mutLoc orelse break;
const heightChanged = l.recalcHeight();
const heightChanged = recalcHeight(l);
const parent = l.parent();
switch (l.balance()) {
switch (balance(l)) {
-2 => {
switch (l.child(.left).?.balance()) {
switch (balance(l.child(.left).?)) {
-1, 0 => self.rr(l),
1 => self.lr(l),
else => unreachable,
}
},
2 => {
switch (l.child(.right).?.balance()) {
switch (balance(l.child(.right).?)) {
-1 => self.rl(l),
0, 1 => self.ll(l),
else => unreachable,
Expand Down Expand Up @@ -782,8 +784,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .left, left_right);
reparent(left, .right, l);

_ = l.recalcHeight();
_ = left.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(left);

if (options.countChildren) {
recalcCounts(l);
Expand All @@ -810,9 +812,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .left, left_right_right);
reparent(left, .right, left_right_left);

_ = l.recalcHeight();
_ = left.recalcHeight();
_ = left_right.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(left);
_ = recalcHeight(left_right);

if (options.countChildren) {
recalcCounts(l);
Expand Down Expand Up @@ -841,9 +843,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .right, right_left_left);
reparent(right, .left, right_left_right);

_ = l.recalcHeight();
_ = right.recalcHeight();
_ = right_left.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(right);
_ = recalcHeight(right_left);

if (options.countChildren) {
recalcCounts(l);
Expand All @@ -866,8 +868,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .right, right_left);
reparent(right, .left, l);

_ = l.recalcHeight();
_ = right.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(right);

if (options.countChildren) {
recalcCounts(l);
Expand Down Expand Up @@ -990,10 +992,7 @@ test "tree getOrEmplace" {
try std.testing.expect(ir.inserted);
try std.testing.expectEqual(i, ir.v.*);
try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);
i += 1;
Expand Down Expand Up @@ -1039,10 +1038,7 @@ test "tree insert" {
try std.testing.expectEqual(i, max.?.v.*);

try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);

Expand All @@ -1063,10 +1059,7 @@ test "tree insert" {
try std.testing.expect(!ir.inserted);
try std.testing.expectEqual(i * 2, ir.v.*);
try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);
i -= 1;
Expand All @@ -1092,15 +1085,15 @@ test "tree delete" {
try std.testing.expect(ir.inserted);
var exp: i64 = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);

ir = try t.insert(0, 0);
try std.testing.expect(ir.inserted);
ir = try t.insert(-1, -1);
try std.testing.expect(ir.inserted);
exp_len = 2;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
exp = -1;
Expand All @@ -1114,13 +1107,13 @@ test "tree delete" {
try std.testing.expect(ir.inserted);
exp_len = 2;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 1;
try std.testing.expectEqual(exp, t.delete(1).?);
exp_len = 1;
try std.testing.expectEqual(exp_len, t.len());
try std.testing.expectEqual(@as(?i64, null), t.delete(-1));
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
exp_len = 0;
Expand All @@ -1134,10 +1127,10 @@ test "tree delete" {
try std.testing.expectEqual(exp, t.delete(0).?);
exp_len = 1;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 1;
try std.testing.expectEqual(exp, t.delete(1).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp_len = 0;
try std.testing.expectEqual(exp_len, t.len());

Expand All @@ -1151,7 +1144,7 @@ test "tree delete" {
i = 128;
while (i >= 0) {
try std.testing.expectEqual(i, t.delete(i).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
i -= 1;
}
}
Expand Down Expand Up @@ -1405,20 +1398,20 @@ test "tree random" {
const ir = try t.insert(val, val);
try std.testing.expect(ir.inserted);
try std.testing.expectEqual(val, ir.v.*);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
}
r.random().shuffle(i64, arr);
for (arr) |val| {
try std.testing.expectEqual(val, t.delete(val).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
}
try std.testing.expectEqual(exp_len, t.len());
i += 1;
}
}

fn checkHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !void {
_ = try recalcHeightAndBalance(K, V, L, Cmp, loc);
fn checkHeightAndBalance(comptime T: type, loc: ?T.Location) !void {
_ = try recalcHeightAndBalance(T, loc);
}

const recalcResult = struct {
Expand All @@ -1435,21 +1428,21 @@ const recalcResult = struct {
}
};

fn recalcHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !recalcResult {
fn recalcHeightAndBalance(comptime T: type, loc: ?T.Location) !recalcResult {
var result = recalcResult.init();
var l = loc orelse return result;
if (l.child(.left) != null) {
const lRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.left));
const lRes = try recalcHeightAndBalance(T, l.child(.left));
result.height = 1 + lRes.height;
result.l_count = lRes.l_count + lRes.r_count + 1;
}
if (l.child(.right) != null) {
const rRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.right));
const rRes = try recalcHeightAndBalance(T, l.child(.right));
result.height = @max(result.height, 1 + rRes.height);
result.r_count = rRes.r_count + rRes.l_count + 1;
}
try std.testing.expectEqual(result.height, l.data().h);
if (l.balance() < -1 or l.balance() > 1) {
if (T.balance(l) < -1 or T.balance(l) > 1) {
return error{
InvalidBalance,
}.InvalidBalance;
Expand Down
Loading