diff --git a/vyper/builtin_functions/functions.py b/vyper/builtin_functions/functions.py index 183694d57c..1b7c68ca23 100644 --- a/vyper/builtin_functions/functions.py +++ b/vyper/builtin_functions/functions.py @@ -2199,69 +2199,47 @@ class ISqrt(BuiltinFunction): @process_inputs def build_IR(self, expr, args, kwargs, context): - # TODO check out this import - from vyper.builtin_functions.utils import generate_inline_function + # calculate isqrt using the babylonian method + y, z = "y", "z" arg = args[0] - sqrt_code = """ -y: uint256 = x -z: uint256 = 181 -if y >= 2**(128 + 8): - y = unsafe_div(y, 2**128) - z = unsafe_mul(z, 2**64) -if y >= 2**(64 + 8): - y = unsafe_div(y, 2**64) - z = unsafe_mul(z, 2**32) -if y >= 2**(32 + 8): - y = unsafe_div(y, 2**32) - z = unsafe_mul(z, 2**16) -if y >= 2**(16 + 8): - y = unsafe_div(y, 2**16) - z = unsafe_mul(z, 2**8) - -z = unsafe_div(unsafe_mul(z, unsafe_add(y, 65536)), 2**18) - -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) -z = unsafe_div(unsafe_add(unsafe_div(x, z), z), 2) - -# Performance note: If ``x+1`` is a perfect square, then the Babylonian -# algorithm oscillates between floor(sqrt(x)) and ceil(sqrt(x)) in -# consecutive iterations. ``isqrt`` has a final check that returns -# the floor value always, but this increases costs by approximately 10% : - -z = min(z, unsafe_div(x, z)) - """ + with arg.cache_when_complex("x") as (b1, x): + ret = [ + "seq", + [ + "if", + ["ge", y, 2 ** (128 + 8)], + ["seq", ["set", y, shr(128, y)], ["set", z, shl(64, z)]], + ], + [ + "if", + ["ge", y, 2 ** (64 + 8)], + ["seq", ["set", y, shr(64, y)], ["set", z, shl(32, z)]], + ], + [ + "if", + ["ge", y, 2 ** (32 + 8)], + ["seq", ["set", y, shr(32, y)], ["set", z, shl(16, z)]], + ], + [ + "if", + ["ge", y, 2 ** (16 + 8)], + ["seq", ["set", y, shr(16, y)], ["set", z, shl(8, z)]], + ], + ] + ret.append(["set", z, ["div", ["mul", z, ["add", y, 2 ** 16]], 2 ** 18]]) - x_type = BaseType("uint256") - placeholder_copy = ["pass"] - # Steal current position if variable is already allocated. - if arg.value == "mload": - new_var_pos = arg.args[0] - # Other locations need to be copied. - else: - new_var_pos = context.new_internal_variable(x_type) - placeholder_copy = ["mstore", new_var_pos, arg] - # Create input variables. - variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)} - # Dictionary to update new (i.e. typecheck) namespace - variables_2 = {"x": Uint256Definition()} - # Generate inline IR. - new_ctx, sqrt_ir = generate_inline_function( - code=sqrt_code, - variables=variables, - variables_2=variables_2, - memory_allocator=context.memory_allocator, - ) - return IRnode.from_list( - ["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable - typ=BaseType("uint256"), - location=MEMORY, - ) + for _ in range(7): + ret.append(["set", z, ["div", ["add", ["div", x, z], z], 2]]) + + # note: If ``x+1`` is a perfect square, then the Babylonian + # algorithm oscillates between floor(sqrt(x)) and ceil(sqrt(x)) in + # consecutive iterations. return the floor value always. + + ret.append(["with", "t", ["div", x, z], ["select", ["lt", z, "t"], z, "t"]]) + + ret = ["with", y, x, ["with", z, 181, ret]] + return b1.resolve(IRnode.from_list(ret, typ=BaseType("uint256"))) class Empty(BuiltinFunction):