Skip to content

Commit

Permalink
src: code cleanup for Location type and some tests. (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
avdva authored Sep 24, 2024
1 parent 1113ef5 commit 4785db8
Showing 1 changed file with 61 additions and 68 deletions.
129 changes: 61 additions & 68 deletions src/avl.zig
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,6 @@ fn makePtrLocationType(comptime K: type, comptime V: type, comptime Tags: type)
fn setParent(self: *Self, p: ?Self) void {
self.ptr.*.parent = p;
}

fn recalcHeight(self: *Self) bool {
var h: u8 = 0;
if (self.ptr.*.left) |l| {
h = 1 + l.ptr.*.data.h;
}
if (self.ptr.*.right) |r| {
h = @max(h, 1 + r.ptr.*.data.h);
}
return self.data().setHeight(h);
}

fn balance(self: *const Self) i8 {
var b: i8 = 0;
if (self.ptr.*.right) |right| {
b += 1 + @as(i8, @intCast(right.ptr.*.data.h));
}
if (self.ptr.*.left) |left| {
b -= 1 + @as(i8, @intCast(left.ptr.*.data.h));
}
return b;
}
};
}

Expand Down Expand Up @@ -174,6 +152,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
else
struct {};

const KeyType = K;
const ValueType = V;
const Cache = locationCache(K, V, Tags);
const Location = Cache.Location;
const Comparer = Cmp;
Expand Down Expand Up @@ -327,10 +307,10 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:

fn recalcCounts(loc: Location) void {
var count: u32 = 0;
if (loc.ptr.*.left) |left| {
if (loc.child(.left)) |left| {
count += 1 + left.data().tags.childrenCount;
}
if (loc.ptr.*.right) |right| {
if (loc.child(.right)) |right| {
count += 1 + right.data().tags.childrenCount;
}
loc.data().tags.childrenCount = count;
Expand All @@ -346,12 +326,34 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
}

fn leftCount(loc: Location) usize {
if (loc.ptr.*.left) |left| {
if (loc.child(.left)) |left| {
return 1 + left.data().tags.childrenCount;
}
return 0;
}

fn recalcHeight(loc: Location) bool {
var h: u8 = 0;
if (loc.child(.left)) |l| {
h = 1 + l.data().h;
}
if (loc.child(.right)) |r| {
h = @max(h, 1 + r.data().h);
}
return loc.data().setHeight(h);
}

fn balance(loc: Location) i8 {
var b: i8 = 0;
if (loc.child(.right)) |right| {
b += 1 + @as(i8, @intCast(right.data().h));
}
if (loc.child(.left)) |left| {
b -= 1 + @as(i8, @intCast(left.data().h));
}
return b;
}

pub const Iterator = struct {
tree: *Self,
loc: ?Location,
Expand Down Expand Up @@ -508,7 +510,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
} else if (where.dir == .right and l.eq(self.max.?)) {
self.max = new_loc;
}
if (l.recalcHeight()) {
if (recalcHeight(l)) {
if (options.countChildren) {
recalcCounts(l);
}
Expand Down Expand Up @@ -601,7 +603,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
if (right) |r| {
// Russell A. Brown, Optimized Deletion From an AVL Tree.
// https://arxiv.org/pdf/2406.05162v5
if (loc.balance() <= 0) {
if (balance(loc) <= 0) {
return goRight(l);
}
return goLeft(r);
Expand Down Expand Up @@ -735,18 +737,18 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
var mutLoc = loc;
while (true) {
var l = mutLoc orelse break;
const heightChanged = l.recalcHeight();
const heightChanged = recalcHeight(l);
const parent = l.parent();
switch (l.balance()) {
switch (balance(l)) {
-2 => {
switch (l.child(.left).?.balance()) {
switch (balance(l.child(.left).?)) {
-1, 0 => self.rr(l),
1 => self.lr(l),
else => unreachable,
}
},
2 => {
switch (l.child(.right).?.balance()) {
switch (balance(l.child(.right).?)) {
-1 => self.rl(l),
0, 1 => self.ll(l),
else => unreachable,
Expand Down Expand Up @@ -782,8 +784,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .left, left_right);
reparent(left, .right, l);

_ = l.recalcHeight();
_ = left.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(left);

if (options.countChildren) {
recalcCounts(l);
Expand All @@ -810,9 +812,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .left, left_right_right);
reparent(left, .right, left_right_left);

_ = l.recalcHeight();
_ = left.recalcHeight();
_ = left_right.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(left);
_ = recalcHeight(left_right);

if (options.countChildren) {
recalcCounts(l);
Expand Down Expand Up @@ -841,9 +843,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .right, right_left_left);
reparent(right, .left, right_left_right);

_ = l.recalcHeight();
_ = right.recalcHeight();
_ = right_left.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(right);
_ = recalcHeight(right_left);

if (options.countChildren) {
recalcCounts(l);
Expand All @@ -866,8 +868,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a:
reparent(l, .right, right_left);
reparent(right, .left, l);

_ = l.recalcHeight();
_ = right.recalcHeight();
_ = recalcHeight(l);
_ = recalcHeight(right);

if (options.countChildren) {
recalcCounts(l);
Expand Down Expand Up @@ -990,10 +992,7 @@ test "tree getOrEmplace" {
try std.testing.expect(ir.inserted);
try std.testing.expectEqual(i, ir.v.*);
try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);
i += 1;
Expand Down Expand Up @@ -1039,10 +1038,7 @@ test "tree insert" {
try std.testing.expectEqual(i, max.?.v.*);

try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);

Expand All @@ -1063,10 +1059,7 @@ test "tree insert" {
try std.testing.expect(!ir.inserted);
try std.testing.expectEqual(i * 2, ir.v.*);
try checkHeightAndBalance(
i64,
i64,
TreeType.Location,
TreeType.Comparer,
TreeType,
t.root,
);
i -= 1;
Expand All @@ -1092,15 +1085,15 @@ test "tree delete" {
try std.testing.expect(ir.inserted);
var exp: i64 = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);

ir = try t.insert(0, 0);
try std.testing.expect(ir.inserted);
ir = try t.insert(-1, -1);
try std.testing.expect(ir.inserted);
exp_len = 2;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
exp = -1;
Expand All @@ -1114,13 +1107,13 @@ test "tree delete" {
try std.testing.expect(ir.inserted);
exp_len = 2;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 1;
try std.testing.expectEqual(exp, t.delete(1).?);
exp_len = 1;
try std.testing.expectEqual(exp_len, t.len());
try std.testing.expectEqual(@as(?i64, null), t.delete(-1));
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 0;
try std.testing.expectEqual(exp, t.delete(0).?);
exp_len = 0;
Expand All @@ -1134,10 +1127,10 @@ test "tree delete" {
try std.testing.expectEqual(exp, t.delete(0).?);
exp_len = 1;
try std.testing.expectEqual(exp_len, t.len());
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp = 1;
try std.testing.expectEqual(exp, t.delete(1).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
exp_len = 0;
try std.testing.expectEqual(exp_len, t.len());

Expand All @@ -1151,7 +1144,7 @@ test "tree delete" {
i = 128;
while (i >= 0) {
try std.testing.expectEqual(i, t.delete(i).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
i -= 1;
}
}
Expand Down Expand Up @@ -1405,20 +1398,20 @@ test "tree random" {
const ir = try t.insert(val, val);
try std.testing.expect(ir.inserted);
try std.testing.expectEqual(val, ir.v.*);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
}
r.random().shuffle(i64, arr);
for (arr) |val| {
try std.testing.expectEqual(val, t.delete(val).?);
try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root);
try checkHeightAndBalance(TreeType, t.root);
}
try std.testing.expectEqual(exp_len, t.len());
i += 1;
}
}

fn checkHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !void {
_ = try recalcHeightAndBalance(K, V, L, Cmp, loc);
fn checkHeightAndBalance(comptime T: type, loc: ?T.Location) !void {
_ = try recalcHeightAndBalance(T, loc);
}

const recalcResult = struct {
Expand All @@ -1435,21 +1428,21 @@ const recalcResult = struct {
}
};

fn recalcHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !recalcResult {
fn recalcHeightAndBalance(comptime T: type, loc: ?T.Location) !recalcResult {
var result = recalcResult.init();
var l = loc orelse return result;
if (l.child(.left) != null) {
const lRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.left));
const lRes = try recalcHeightAndBalance(T, l.child(.left));
result.height = 1 + lRes.height;
result.l_count = lRes.l_count + lRes.r_count + 1;
}
if (l.child(.right) != null) {
const rRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.right));
const rRes = try recalcHeightAndBalance(T, l.child(.right));
result.height = @max(result.height, 1 + rRes.height);
result.r_count = rRes.r_count + rRes.l_count + 1;
}
try std.testing.expectEqual(result.height, l.data().h);
if (l.balance() < -1 or l.balance() > 1) {
if (T.balance(l) < -1 or T.balance(l) > 1) {
return error{
InvalidBalance,
}.InvalidBalance;
Expand Down

0 comments on commit 4785db8

Please sign in to comment.