diff --git a/src/avl.zig b/src/avl.zig index 426f64b..c7d8585 100644 --- a/src/avl.zig +++ b/src/avl.zig @@ -780,34 +780,57 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: } } + fn treeRotated(self: *Self, parent: ?Location, oldRoot: Location, newRoot: Location) void { + if (parent) |p| { + reparent(p, childDir(p, oldRoot), newRoot); + } else { + self.setRoot(newRoot); + } + } + fn checkBalance(self: *Self, loc: ?Location, all_way_up: bool) void { var mutLoc = loc; - while (mutLoc) |*l| { + while (mutLoc) |*mlPtr| { + const l = mlPtr.*; const parent = l.parent(); - switch (balance(l.*)) { + switch (balance(l)) { -2 => { - switch (balance(l.*.child(.left).?)) { - -1, 0 => self.rr(l.*), - 1 => self.lr(l.*), - else => unreachable, - } + const subRoot = blk: { + switch (balance(l.child(.left).?)) { + -1, 0 => { + break :blk rr(l); + }, + 1 => { + break :blk lr(l); + }, + else => unreachable, + } + }; + self.treeRotated(parent, l, subRoot); }, 2 => { - switch (balance(l.*.child(.right).?)) { - -1 => self.rl(l.*), - 0, 1 => self.ll(l.*), - else => unreachable, - } + const subRoot = blk: { + switch (balance(l.child(.right).?)) { + -1 => { + break :blk rl(l); + }, + 0, 1 => { + break :blk ll(l); + }, + else => unreachable, + } + }; + self.treeRotated(parent, l, subRoot); }, else => { - if (!recalcHeight(l.*) and !all_way_up) { + if (!recalcHeight(l) and !all_way_up) { if (options.countChildren) { - updateCounts(l.*); + updateCounts(l); } return; } if (options.countChildren) { - recalcCounts(l.*); + recalcCounts(l); } }, } @@ -815,16 +838,10 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: } } - fn rr(self: *Self, loc: Location) void { + fn rr(loc: Location) Location { var l = loc; - var left = l.child(.left) orelse unreachable; + var left = l.child(.left).?; const left_right = left.child(.right); - const parent = l.parent(); - if (parent) |p| { - reparent(parent, childDir(p, l), left); - } else { - self.setRoot(left); - } reparent(l, .left, left_right); reparent(left, .right, l); @@ -836,18 +853,14 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: recalcCounts(l); recalcCounts(left); } + + return left; } - fn lr(self: *Self, loc: Location) void { + fn lr(loc: Location) Location { var l = loc; - var left = l.child(.left) orelse unreachable; - var left_right = left.child(.right) orelse unreachable; - const parent = l.parent(); - if (parent) |p| { - reparent(parent, childDir(p, l), left_right); - } else { - self.setRoot(left_right); - } + var left = l.child(.left).?; + var left_right = left.child(.right).?; const left_right_right = left_right.child(.right); const left_right_left = left_right.child(.left); @@ -866,18 +879,14 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: recalcCounts(left); recalcCounts(left_right); } + + return left_right; } - fn rl(self: *Self, loc: Location) void { + fn rl(loc: Location) Location { var l = loc; - var right = l.child(.right) orelse unreachable; - var right_left = right.child(.left) orelse unreachable; - const parent = l.parent(); - if (parent) |p| { - reparent(parent, childDir(p, l), right_left); - } else { - self.setRoot(right_left); - } + var right = l.child(.right).?; + var right_left = right.child(.left).?; const right_left_left = right_left.child(.left); const right_left_right = right_left.child(.right); @@ -897,18 +906,14 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: recalcCounts(right); recalcCounts(right_left); } + + return right_left; } - fn ll(self: *Self, loc: Location) void { + fn ll(loc: Location) Location { var l = loc; - var right = l.child(.right) orelse unreachable; + var right = l.child(.right).?; const right_left = right.child(.left); - const parent = l.parent(); - if (parent) |p| { - reparent(parent, childDir(p, l), right); - } else { - self.setRoot(right); - } reparent(l, .right, right_left); reparent(right, .left, l); @@ -920,6 +925,8 @@ pub fn TreeWithOptions(comptime K: type, comptime V: type, comptime Cmp: fn (a: recalcCounts(l); recalcCounts(right); } + + return right; } fn locate(self: *Self, k: K) LocateResult {