diff --git a/src/avl.zig b/src/avl.zig index 1e27a26..5b44b43 100644 --- a/src/avl.zig +++ b/src/avl.zig @@ -97,28 +97,6 @@ fn makePtrLocationType(comptime K: type, comptime V: type, comptime Tags: type) fn setParent(self: *Self, p: ?Self) void { self.ptr.*.parent = p; } - - fn recalcHeight(self: *Self) bool { - var h: u8 = 0; - if (self.ptr.*.left) |l| { - h = 1 + l.ptr.*.data.h; - } - if (self.ptr.*.right) |r| { - h = @max(h, 1 + r.ptr.*.data.h); - } - return self.data().setHeight(h); - } - - fn balance(self: *const Self) i8 { - var b: i8 = 0; - if (self.ptr.*.right) |right| { - b += 1 + @as(i8, @intCast(right.ptr.*.data.h)); - } - if (self.ptr.*.left) |left| { - b -= 1 + @as(i8, @intCast(left.ptr.*.data.h)); - } - return b; - } }; } @@ -174,6 +152,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: else struct {}; + const KeyType = K; + const ValueType = V; const Cache = locationCache(K, V, Tags); const Location = Cache.Location; const Comparer = Cmp; @@ -327,10 +307,10 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: fn recalcCounts(loc: Location) void { var count: u32 = 0; - if (loc.ptr.*.left) |left| { + if (loc.child(.left)) |left| { count += 1 + left.data().tags.childrenCount; } - if (loc.ptr.*.right) |right| { + if (loc.child(.right)) |right| { count += 1 + right.data().tags.childrenCount; } loc.data().tags.childrenCount = count; @@ -346,12 +326,34 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: } fn leftCount(loc: Location) usize { - if (loc.ptr.*.left) |left| { + if (loc.child(.left)) |left| { return 1 + left.data().tags.childrenCount; } return 0; } + fn recalcHeight(loc: Location) bool { + var h: u8 = 0; + if (loc.child(.left)) |l| { + h = 1 + l.data().h; + } + if (loc.child(.right)) |r| { + h = @max(h, 1 + r.data().h); + } + return loc.data().setHeight(h); + } + + fn balance(loc: Location) i8 { + var b: i8 = 0; + if (loc.child(.right)) |right| { + b += 1 + @as(i8, @intCast(right.data().h)); + } + if (loc.child(.left)) |left| { + b -= 1 + @as(i8, @intCast(left.data().h)); + } + return b; + } + pub const Iterator = struct { tree: *Self, loc: ?Location, @@ -508,7 +510,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: } else if (where.dir == .right and l.eq(self.max.?)) { self.max = new_loc; } - if (l.recalcHeight()) { + if (recalcHeight(l)) { if (options.countChildren) { recalcCounts(l); } @@ -601,7 +603,7 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: if (right) |r| { // Russell A. Brown, Optimized Deletion From an AVL Tree. // https://arxiv.org/pdf/2406.05162v5 - if (loc.balance() <= 0) { + if (balance(loc) <= 0) { return goRight(l); } return goLeft(r); @@ -735,18 +737,18 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: var mutLoc = loc; while (true) { var l = mutLoc orelse break; - const heightChanged = l.recalcHeight(); + const heightChanged = recalcHeight(l); const parent = l.parent(); - switch (l.balance()) { + switch (balance(l)) { -2 => { - switch (l.child(.left).?.balance()) { + switch (balance(l.child(.left).?)) { -1, 0 => self.rr(l), 1 => self.lr(l), else => unreachable, } }, 2 => { - switch (l.child(.right).?.balance()) { + switch (balance(l.child(.right).?)) { -1 => self.rl(l), 0, 1 => self.ll(l), else => unreachable, @@ -782,8 +784,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: reparent(l, .left, left_right); reparent(left, .right, l); - _ = l.recalcHeight(); - _ = left.recalcHeight(); + _ = recalcHeight(l); + _ = recalcHeight(left); if (options.countChildren) { recalcCounts(l); @@ -810,9 +812,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: reparent(l, .left, left_right_right); reparent(left, .right, left_right_left); - _ = l.recalcHeight(); - _ = left.recalcHeight(); - _ = left_right.recalcHeight(); + _ = recalcHeight(l); + _ = recalcHeight(left); + _ = recalcHeight(left_right); if (options.countChildren) { recalcCounts(l); @@ -841,9 +843,9 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: reparent(l, .right, right_left_left); reparent(right, .left, right_left_right); - _ = l.recalcHeight(); - _ = right.recalcHeight(); - _ = right_left.recalcHeight(); + _ = recalcHeight(l); + _ = recalcHeight(right); + _ = recalcHeight(right_left); if (options.countChildren) { recalcCounts(l); @@ -866,8 +868,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: reparent(l, .right, right_left); reparent(right, .left, l); - _ = l.recalcHeight(); - _ = right.recalcHeight(); + _ = recalcHeight(l); + _ = recalcHeight(right); if (options.countChildren) { recalcCounts(l); @@ -990,10 +992,7 @@ test "tree getOrEmplace" { try std.testing.expect(ir.inserted); try std.testing.expectEqual(i, ir.v.*); try checkHeightAndBalance( - i64, - i64, - TreeType.Location, - TreeType.Comparer, + TreeType, t.root, ); i += 1; @@ -1039,10 +1038,7 @@ test "tree insert" { try std.testing.expectEqual(i, max.?.v.*); try checkHeightAndBalance( - i64, - i64, - TreeType.Location, - TreeType.Comparer, + TreeType, t.root, ); @@ -1063,10 +1059,7 @@ test "tree insert" { try std.testing.expect(!ir.inserted); try std.testing.expectEqual(i * 2, ir.v.*); try checkHeightAndBalance( - i64, - i64, - TreeType.Location, - TreeType.Comparer, + TreeType, t.root, ); i -= 1; @@ -1092,7 +1085,7 @@ test "tree delete" { try std.testing.expect(ir.inserted); var exp: i64 = 0; try std.testing.expectEqual(exp, t.delete(0).?); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); ir = try t.insert(0, 0); try std.testing.expect(ir.inserted); @@ -1100,7 +1093,7 @@ test "tree delete" { try std.testing.expect(ir.inserted); exp_len = 2; try std.testing.expectEqual(exp_len, t.len()); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); exp = 0; try std.testing.expectEqual(exp, t.delete(0).?); exp = -1; @@ -1114,13 +1107,13 @@ test "tree delete" { try std.testing.expect(ir.inserted); exp_len = 2; try std.testing.expectEqual(exp_len, t.len()); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); exp = 1; try std.testing.expectEqual(exp, t.delete(1).?); exp_len = 1; try std.testing.expectEqual(exp_len, t.len()); try std.testing.expectEqual(@as(?i64, null), t.delete(-1)); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); exp = 0; try std.testing.expectEqual(exp, t.delete(0).?); exp_len = 0; @@ -1134,10 +1127,10 @@ test "tree delete" { try std.testing.expectEqual(exp, t.delete(0).?); exp_len = 1; try std.testing.expectEqual(exp_len, t.len()); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); exp = 1; try std.testing.expectEqual(exp, t.delete(1).?); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); exp_len = 0; try std.testing.expectEqual(exp_len, t.len()); @@ -1151,7 +1144,7 @@ test "tree delete" { i = 128; while (i >= 0) { try std.testing.expectEqual(i, t.delete(i).?); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); i -= 1; } } @@ -1405,20 +1398,20 @@ test "tree random" { const ir = try t.insert(val, val); try std.testing.expect(ir.inserted); try std.testing.expectEqual(val, ir.v.*); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); } r.random().shuffle(i64, arr); for (arr) |val| { try std.testing.expectEqual(val, t.delete(val).?); - try checkHeightAndBalance(i64, i64, TreeType.Location, TreeType.Comparer, t.root); + try checkHeightAndBalance(TreeType, t.root); } try std.testing.expectEqual(exp_len, t.len()); i += 1; } } -fn checkHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !void { - _ = try recalcHeightAndBalance(K, V, L, Cmp, loc); +fn checkHeightAndBalance(comptime T: type, loc: ?T.Location) !void { + _ = try recalcHeightAndBalance(T, loc); } const recalcResult = struct { @@ -1435,21 +1428,21 @@ const recalcResult = struct { } }; -fn recalcHeightAndBalance(comptime K: type, comptime V: type, comptime L: type, comptime Cmp: fn (a: K, b: K) math.Order, loc: ?L) !recalcResult { +fn recalcHeightAndBalance(comptime T: type, loc: ?T.Location) !recalcResult { var result = recalcResult.init(); var l = loc orelse return result; if (l.child(.left) != null) { - const lRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.left)); + const lRes = try recalcHeightAndBalance(T, l.child(.left)); result.height = 1 + lRes.height; result.l_count = lRes.l_count + lRes.r_count + 1; } if (l.child(.right) != null) { - const rRes = try recalcHeightAndBalance(K, V, L, Cmp, l.child(.right)); + const rRes = try recalcHeightAndBalance(T, l.child(.right)); result.height = @max(result.height, 1 + rRes.height); result.r_count = rRes.r_count + rRes.l_count + 1; } try std.testing.expectEqual(result.height, l.data().h); - if (l.balance() < -1 or l.balance() > 1) { + if (T.balance(l) < -1 or T.balance(l) > 1) { return error{ InvalidBalance, }.InvalidBalance;