From 444870548f1008b9c6b42723ae8fdcb367acbf97 Mon Sep 17 00:00:00 2001 From: vub Date: Mon, 23 Jan 2017 00:50:57 -0500 Subject: [PATCH] Added basic units support --- README.md | 18 +- parser.py | 594 ++++++++++++++++++++++++++++++----------------- test_invalids.py | 169 ++++++++++++++ test_parser.py | 32 +-- 4 files changed, 581 insertions(+), 232 deletions(-) diff --git a/README.md b/README.md index 178911911a..dbed9704fb 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Viper is an experimental programming language that aims to provide the following * Bounds and overflow checking, both on array accesses and on arithmetic * Support for signed integers and decimal fixed point numbers * Decidability - it's possible to compute a precise upper bound on the gas consumption of any function call -* Strong typing +* Strong typing, including limited support for units (eg. timestamp, timedelta) * Maximally small and understandable compiler code size * Limited support for pure functions - anything marked constant is NOT allowed to change the state @@ -63,6 +63,10 @@ Note that not all programs that satisfy the following are valid; for example, th * `num`: a signed integer strictly between -2\*\*128 and 2\*\*128 * `decimal`: a decimal fixed point value with the integer component being a signed integer strictly between -2\*\*128 and 2\*\*128 and the fractional component being ten decimal places +* `timestamp`: a timestamp value +* `timedelta`: a number of seconds (note: two timedeltas can be added together, as can a timedelta and a timestamp, but not two timestamps) +* `wei_value`: an amount of wei +* `currency_value`: an amount of currency * `address`: an address * `bytes32`: 32 bytes * `bool`: true or false @@ -87,20 +91,20 @@ Code examples can be found in the `test_parser.py` file. * A mini-language for handling num256 and signed256 values and directly / unsafely using opcodes; will be useful for high-performance code segments * Support for sha3, sha256, ecrecover, etc * Smart optimizations, including compile-time computation of arithmetic and clamps, intelligently computing realistic variable ranges, etc -* Special data types for timestamps, timedeltas, currency, etc; support for "units" +* More advanced support for units, including support for "x per y", "x * y", etc types ### Code example - funders = {sender: address, value: num}[num] + funders = {sender: address, value: wei_value}[num] nextFunderIndex = num beneficiary = address - deadline = num - goal = num + deadline = timestamp + goal = wei_value refundIndex = num - timelimit = num + timelimit = timedelta # Setup global variables - def __init__(_beneficiary: address, _goal: num, _timelimit: num): + def __init__(_beneficiary: address, _goal: wei_value, _timelimit: timedelta): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit self.timelimit = _timelimit diff --git a/parser.py b/parser.py index 392db3deb9..f6e32f1e3b 100644 --- a/parser.py +++ b/parser.py @@ -8,6 +8,7 @@ import ast, tokenize, binascii from io import BytesIO from opcodes import opcodes, pseudo_opcodes +import copy try: x = ast.AnnAssign @@ -35,13 +36,107 @@ def hex_to_int(inp): o = o * 256 + b return o +# Available base types +base_types = ['num', 'decimal', 'bytes32', 'num256', 'signed256', 'bool', 'address'] + +# Available types that functions can have as outputs +allowed_func_output_types = ['num', 'bool', 'num256', 'signed256', 'address'] + + +# Data structure for a type +class NodeType(): + pass + +def print_unit(unit): + if unit is None: + return '*' + pos = '' + for k in sorted([x for x in unit.keys() if unit[x] > 0]): + if unit[k] > 1: + pos += '*' + k + '^' + str(unit[k]) + else: + pos += '*' + k + neg = '' + for k in sorted([x for x in unit.keys() if unit[x] < 0]): + if unit[k] < -1: + neg += '/' + k + '^' + str(-unit[k]) + else: + neg += '/' + k + if pos and neg: + return pos[1:] + neg + elif neg: + return '1' + neg + else: + return pos[1:] + +def combine_units(unit1, unit2, div=True): + o = {k: v for k, v in (unit1 or {}).items()} + for k, v in (unit2 or {}).items(): + o[k] = o.get(k, 0) + v * (-1 if div else 1) + return o + +class BaseType(NodeType): + def __init__(self, typ, unit=False, positional=False): + self.typ = typ + self.unit = {} if unit is False else unit + self.positional = positional + + def __eq__(self, other): + return other.__class__ == BaseType and self.typ == other.typ and self.unit == other.unit and self.positional == other.positional + + def __repr__(self): + return '<' + str(self.typ) + ('>' if self.unit == {} else '> (' + print_unit(self.unit) + ')') + (' (positional) ' * self.positional) + +class ListType(NodeType): + def __init__(self, subtype, count): + self.subtype = subtype + self.count = count + + def __eq__(self, other): + return other.__class__ == ListType and other.subtype == self.subtype and other.count == self.count + + def __repr__(self): + return repr(self.subtype) + '[' + str(self.count) + ']' + +class MappingType(NodeType): + def __init__(self, keytype, valuetype): + if not isinstance(keytype, BaseType): + raise Exception("Dictionary keys must be a base type") + self.keytype = keytype + self.valuetype = valuetype + + def __eq__(self, other): + return other.__class__ == MappingType and other.keytype == self.keytype and other.valuetype == self.valuetype + + def __repr__(self): + return repr(self.valuetype) + '[' + repr(self.keytype) + ']' + +class StructType(NodeType): + def __init__(self, members): + self.members = copy.copy(members) + + def __eq__(self, other): + return other.__class__ == StructType and other.members == self.members + + def __repr__(self): + return '{' + ', '.join([k + ': ' + repr(v) for k, v in self.members.items()]) + '}' + +class MixedType(NodeType): + def __eq__(self, other): + return other.__class__ == MixedType + +class NullType(NodeType): + def __eq__(self, other): + return other.__class__ == NullType + # Data structure for LLL parse tree class LLLnode(): - def __init__(self, value, args=[], typ=None, annotation=None): + def __init__(self, value, args=[], typ=None, location=None): self.value = value self.args = args self.typ = typ - self.annotation = annotation + assert isinstance(self.typ, NodeType) or self.typ is None, repr(self.typ) + self.location = location # Determine this node's valency (1 if it pushes a value on the stack, # 0 otherwise) and checks to make sure the number and valencies of # children are correct @@ -102,7 +197,7 @@ def __init__(self, value, args=[], typ=None, annotation=None): # Variables else: self.valency = 1 - elif self.value is None and self.typ == 'null': + elif self.value is None and isinstance(self.typ, NullType): self.valency = 1 else: raise Exception("Invalid value for LLL AST node: %r" % self.value) @@ -125,22 +220,15 @@ def __repr__(self): return self.repr() @classmethod - def from_list(cls, obj, typ=None, annotation=None): + def from_list(cls, obj, typ=None, location=None): + if isinstance(typ, str): + typ = BaseType(typ) if isinstance(obj, LLLnode): return obj elif not isinstance(obj, list): - return cls(obj, [], typ, annotation) + return cls(obj, [], typ, location) else: - return cls(obj[0], [cls.from_list(o) for o in obj[1:]], typ, annotation) - -# Available base types -types = ['num', 'decimal', 'bytes32', 'num256', 'signed256', 'bool', 'address'] - -# Available types that functions can have as inputs -allowed_func_input_types = ['num', 'bool', 'num256', 'signed256', 'address'] - -# Available types that functions can have as outputs -allowed_func_output_types = ['num', 'bool', 'num256', 'signed256', 'address'] + return cls(obj[0], [cls.from_list(o) for o in obj[1:]], typ, location) # A decimal value can store multiples of 1/DECIMAL_DIVISOR DECIMAL_DIVISOR = 10000000000 @@ -155,6 +243,9 @@ def from_list(cls, obj, typ=None, annotation=None): # Convert type into common form used in ABI def canonicalize_type(t): + if not isinstance(t, BaseType): + raise Exception("Cannot canonicalize non-base type: %r" % t) + t = t.typ if t == 'num': return 'int128' elif t == 'bool': @@ -179,7 +270,7 @@ def canonicalize_type(t): # Is a variable name valid? def is_varname_valid(varname): - if varname.lower() in types: + if varname.lower() in base_types: return False if varname.lower() in reserved_words: return False @@ -204,14 +295,30 @@ class StructureException(Exception): class ConstancyViolationException(Exception): pass + +# Special types +special_types = { + 'timestamp': BaseType('num', {'sec': 1}, True), + 'timedelta': BaseType('num', {'sec': 1}, False), + 'currency_value': BaseType('decimal', {'currency': 1}, False), + 'currency1_value': BaseType('decimal', {'currency1': 1}, False), + 'currency2_value': BaseType('decimal', {'currency2': 1}, False), + 'wei_value': BaseType('num', {'wei': 1}, False), +} + +valid_units = ['currency', 'wei', 'currency1', 'currency2', 'sec'] + # Parses an expression representing a type. Annotation refers to whether # the type is to be located in memory or storage -def parse_type(item, annotation): +def parse_type(item, location): # Base types, eg. uint if isinstance(item, ast.Name): - if item.id not in types: + if item.id in base_types: + return BaseType(item.id) + elif item.id in special_types: + return special_types[item.id] + else: raise InvalidTypeException("Invalid type: "+item.id) - return item.id # Subscripts elif isinstance(item, ast.Subscript): if 'value' not in vars(item.slice): @@ -220,12 +327,15 @@ def parse_type(item, annotation): elif isinstance(item.slice.value, ast.Num): if not isinstance(item.slice.value.n, int) or item.slice.value.n <= 0: raise InvalidTypeException("Arrays must have a positive integral number of elements") - return [parse_type(item.value, annotation), item.slice.value.n] + return ListType(parse_type(item.value, location), item.slice.value.n) # Mappings, eg. num[address] - elif isinstance(item.slice.value, ast.Name) and item.slice.value.id in types: - if annotation == 'memory': + elif isinstance(item.slice.value, ast.Name): + if location == 'memory': raise InvalidTypeException("No mappings allowed for in-memory types, only fixed-size arrays") - return [parse_type(item.value, annotation), item.slice.value.id] + keytype = parse_type(item.slice.value, None) + if not isinstance(keytype, BaseType): + raise Exception("Mapping keys must be base types") + return MappingType(keytype, parse_type(item.value, location)) else: raise InvalidTypeException("Arrays must be of the format type[num_of_elements] or type[key_type]") # Dicts, used to represent mappings, eg. {uint: uint}. Key must be a base type @@ -234,22 +344,23 @@ def parse_type(item, annotation): for key, value in zip(item.keys, item.values): if not isinstance(key, ast.Name) or not is_varname_valid(key.id): raise InvalidTypeException("Invalid member variable for struct: %r" % vars(key).get('id', key)) - o[key.id] = parse_type(value, annotation) - return o + o[key.id] = parse_type(value, location) + return StructType(o) else: raise InvalidTypeException("Invalid type: %r" % ast.dump(item)) # Gets the number of memory or storage keys needed to represent a given type def get_size_of_type(typ): - if not isinstance(typ, (list, dict)): + if isinstance(typ, BaseType): return 1 - if isinstance(typ, list): - if isinstance(typ[1], int): - return get_size_of_type(typ[0]) * typ[1] - else: - raise Exception("Type size infinite!") - elif isinstance(typ, dict): - return sum([get_size_of_type(v) for v in typ.values()]) + if isinstance(typ, ListType): + return get_size_of_type(typ.subtype) * typ.count + elif isinstance(typ, MappingType): + raise Exception("Type size infinite!") + elif isinstance(typ, StructType): + return sum([get_size_of_type(v) for v in typ.members.values()]) + else: + raise Exception("Unexpected type: %r" % repr(typ)) # Parse top-level functions and variables def get_defs_and_globals(code): @@ -279,7 +390,7 @@ def mk_initial(): ['mstore', MINNUM_POS, -2**128 + 1], ['mstore', MAXDECIMAL_POS, (2**128 - 1) * DECIMAL_DIVISOR], ['mstore', MINDECIMAL_POS, (-2**128 + 1) * DECIMAL_DIVISOR], - ], typ='null') + ], typ=None) # Get function details def get_func_details(code): @@ -292,18 +403,14 @@ def get_func_details(code): raise VariableDeclarationException("Argument name invalid") if not typ: raise InvalidTypeException("Argument must have type") - if not isinstance(typ, ast.Name) or typ.id not in allowed_func_input_types: - raise InvalidTypeException("Argument type invalid or unsupported") if not is_varname_valid(arg.arg): raise VariableDeclarationException("Argument name invalid or reserved: "+arg.arg) if arg.arg in (x[0] for x in args): raise VariableDeclarationException("Duplicate function argument name: "+arg.arg) if name == '__init__': - args.append((arg.arg, -32 * len(code.args.args) + 32 * len(args), typ.id)) + args.append((arg.arg, -32 * len(code.args.args) + 32 * len(args), parse_type(typ, None))) else: - args.append((arg.arg, 4 + 32 * len(args), typ.id)) - if typ.id not in allowed_func_input_types: - raise InvalidTypeException("Function input type invalid or unsupported: %r" % typ.id) + args.append((arg.arg, 4 + 32 * len(args), parse_type(typ, None))) # Determine the return type and whether or not it's constant. Expects something # of the form: # def foo(): ... @@ -312,17 +419,19 @@ def get_func_details(code): const = False if not code.returns: output_type = None - elif isinstance(code.returns, ast.Name) and code.returns.id in allowed_func_output_types: - output_type = code.returns.id + elif isinstance(code.returns, ast.Name): + output_type = parse_type(code.returns, None) elif isinstance(code.returns, ast.Call) and isinstance(code.returns.func, ast.Name) and \ - code.returns.func.id in allowed_func_output_types and len(code.returns.args) == 1 and \ - isinstance(code.returns.args[0], ast.Name) and code.returns.args[0].id == 'const': - output_type = code.returns.func.id + len(code.returns.args) == 1 and isinstance(code.returns.args[0], ast.Name) and \ + code.returns.args[0].id == 'const': + output_type = parse_type(code.returns.func, None) const = True else: - raise InvalidTypeException("Output type invalid or unsupported: %r" % code.returns) + raise InvalidTypeException("Output type invalid or unsupported: %r" % parse_type(code.returns, None)) + # Output type can only be base type or none + assert isinstance(output_type, (BaseType, (None).__class__)) # Get the four-byte method id - sig = name + '(' + ','.join([canonicalize_type(arg.annotation.id) for arg in code.args.args]) + ')' + sig = name + '(' + ','.join([canonicalize_type(parse_type(arg.annotation, None)) for arg in code.args.args]) + ')' method_id = fourbytes_to_int(sha3_256(bytes(sig, 'utf-8'))[:4]) return name, args, output_type, const, sig, method_id @@ -363,7 +472,7 @@ def parse_func(code, _globals, _vars=None): return LLLnode.from_list(['if', ['eq', ['mload', 0], method_id], ['seq'] + [parse_body(c, context) for c in code.body] - ], typ='null') + ], typ=None) # Get ABI signature def mk_full_signature(code): @@ -411,55 +520,64 @@ def parse_body(code, context): # Take a value representing a storage location, and descend down to an element or member variable def add_variable_offset(parent, key): - typ, annotation = parent.typ, parent.annotation - if isinstance(typ, dict): + typ, location = parent.typ, parent.location + if isinstance(typ, StructType): if not isinstance(key, str): raise TypeMismatchException("Expecting a member variable access; cannot access element %r" % key) - if key not in typ: + if key not in typ.members: raise TypeMismatchException("Object does not have member variable %s" % key) - subtype = typ[key] - attrs = sorted(typ.keys()) + subtype = typ.members[key] + attrs = sorted(typ.members.keys()) if key not in attrs: raise TypeMismatchException("Member %s not found. Only the following available: %s" % (expr.attr, " ".join(attrs))) index = attrs.index(key) - if annotation == 'storage': + if location == 'storage': return LLLnode.from_list(['add', ['sha3_32', parent], index], typ=subtype, - annotation='storage') - elif annotation == 'memory': + location='storage') + elif location == 'memory': offset = 0 for i in range(index): - offset += 32 * get_size_of_type(typ[attrs[i]]) + offset += 32 * get_size_of_type(typ.members[attrs[i]]) return LLLnode.from_list(['add', offset, parent], - typ=typ[key], - annotation='memory') + typ=typ.members[key], + location='memory') else: raise TypeMismatchException("Not expecting a member variable access") - elif isinstance(typ, list): - subtype = typ[0] - if isinstance(typ[1], int): - length, expected_index_type = typ[1], 'num' - sub = ['uclamplt', base_type_conversion(key, key.typ, expected_index_type), length] - elif typ[1] in types: - expected_index_type = typ[1] - sub = base_type_conversion(key, key.typ, expected_index_type) - if annotation == 'storage': + elif isinstance(typ, (ListType, MappingType)): + if isinstance(typ, ListType): + subtype = typ.subtype + sub = ['uclamplt', base_type_conversion(key, key.typ, BaseType('num')), typ.count] + else: + subtype = typ.valuetype + sub = base_type_conversion(key, key.typ, typ.keytype) + if location == 'storage': return LLLnode.from_list(['add', ['sha3_32', parent], sub], typ=subtype, - annotation='storage') - elif annotation == 'memory': - if not isinstance(typ[1], int): + location='storage') + elif location == 'memory': + if isinstance(typ, MappingType): raise TypeMismatchException("Can only have fixed-side arrays in memory, not mappings") offset = 32 * get_size_of_type(subtype) return LLLnode.from_list(['add', ['mul', offset, sub], parent], typ=subtype, - annotation='memory') + location='memory') else: raise TypeMismatchException("Not expecting an array access") else: raise TypeMismatchException("Cannot access the child of a constant variable!") +# Is a type representing a number? +def is_numeric_type(typ): + return isinstance(typ, BaseType) and typ.typ in ('num', 'decimal') + +# Is a type representing some particular base type? +def is_base_type(typ, btypes): + if not isinstance(btypes, tuple): + btypes = (btypes, ) + return isinstance(typ, BaseType) and typ.typ in btypes + # Parse an expression def parse_expr(expr, context): if isinstance(expr, LLLnode): @@ -469,11 +587,11 @@ def parse_expr(expr, context): if isinstance(expr.n, int): if not (-2**127 + 1 <= expr.n <= 2**127 - 1): raise Exception("Number out of range: "+str(expr.n)) - return LLLnode.from_list(expr.n, typ='num') + return LLLnode.from_list(expr.n, typ=BaseType('num', None)) elif isinstance(expr.n, float): if not (-2**127 + 1 <= expr.n <= 2**127 - 1): raise Exception("Number out of range: "+str(expr.n)) - return LLLnode.from_list(int(expr.n * DECIMAL_DIVISOR), typ='decimal') + return LLLnode.from_list(int(expr.n * DECIMAL_DIVISOR), typ=BaseType('decimal', None)) # Addresses and bytes32 objects elif isinstance(expr, ast.Str): if len(expr.s) == 42 and expr.s[:2] == '0x': @@ -482,13 +600,14 @@ def parse_expr(expr, context): return LLLnode.from_list(hex_to_int(expr.s), typ='bytes32') else: raise Exception("Unsupported bytes: "+expr.s) + # True, False, None constants elif isinstance(expr, ast.NameConstant): if expr.value == True: return LLLnode.from_list(1, typ='bool') elif expr.value == False: return LLLnode.from_list(0, typ='bool') elif expr.value == None: - return LLLnode.from_list(None, typ='null') + return LLLnode.from_list(None, typ=NullType()) else: raise Exception("Unknown name constant: %r" % expr.value.value) # Variable names @@ -500,26 +619,26 @@ def parse_expr(expr, context): if expr.id == 'false': return LLLnode.from_list(0, typ='bool') if expr.id == 'null': - return LLLnode.from_list(None, typ='null') + return LLLnode.from_list(None, typ=NullType()) if expr.id in context.args: dataloc, typ = context.args[expr.id] if dataloc >= 0: data_decl = ['calldataload', dataloc] else: data_decl = ['seq', ['codecopy', 192, ['sub', ['codesize'], -dataloc], 32], ['mload', 192]] - if typ == 'num': - return LLLnode.from_list(['clamp', ['mload', MINNUM_POS], data_decl, ['mload', MAXNUM_POS]], typ='num') - elif typ == 'bool': - return LLLnode.from_list(['uclamplt', data_decl, 2], typ='bool') - elif typ == 'address': - return LLLnode.from_list(['uclamplt', data_decl, ['mload', ADDRSIZE_POS]], typ='address') - elif typ == 'num256' or typ == 'signed256' or typ == 'bytes32': + if is_base_type(typ, 'num'): + return LLLnode.from_list(['clamp', ['mload', MINNUM_POS], data_decl, ['mload', MAXNUM_POS]], typ=typ) + elif is_base_type(typ, 'bool'): + return LLLnode.from_list(['uclamplt', data_decl, 2], typ=typ) + elif is_base_type(typ, 'address'): + return LLLnode.from_list(['uclamplt', data_decl, ['mload', ADDRSIZE_POS]], typ=typ) + elif is_base_type(typ, ('num256', 'signed256', 'bytes32')): return LLLnode.from_list(data_decl, typ=typ) else: raise InvalidTypeException("Unsupported type: "+typ) elif expr.id in context.vars: dataloc, typ = context.vars[expr.id] - return LLLnode.from_list(dataloc, typ=typ, annotation='memory') + return LLLnode.from_list(dataloc, typ=typ, location='memory') else: raise VariableDeclarationException("Undeclared variable: "+expr.id) # x.y or x[5] @@ -527,28 +646,28 @@ def parse_expr(expr, context): # x.balance: balance of address x if expr.attr == 'balance': addr = parse_value_expr(expr.value, context) - if addr.typ != 'address': + if not is_base_type(addr.typ, 'address'): raise TypeMismatchException("Type mismatch: balance keyword expects an address as input") - return LLLnode.from_list(['balance', addr], typ='num', annotation=None) + return LLLnode.from_list(['balance', addr], typ=BaseType('num', {'wei': 1}), location=None) # self.x: global attribute elif isinstance(expr.value, ast.Name) and expr.value.id == "self": if expr.attr not in context.globals: raise VariableDeclarationException("Persistent variable undeclared: "+expr.attr) pos, typ = context.globals[expr.attr][0],context.globals[expr.attr][1] - return LLLnode.from_list(pos, typ=typ, annotation='storage') + return LLLnode.from_list(pos, typ=typ, location='storage') # Reserved keywords elif isinstance(expr.value, ast.Name) and expr.value.id in ("msg", "block", "tx"): key = expr.value.id + "." + expr.attr if key == "msg.sender": return LLLnode.from_list(['caller'], typ='address') elif key == "msg.value": - return LLLnode.from_list(['callvalue'], typ='num') + return LLLnode.from_list(['callvalue'], typ=BaseType('num', {'wei': 1})) elif key == "block.difficulty": return LLLnode.from_list(['difficulty'], typ='num') elif key == "block.timestamp": - return LLLnode.from_list(['timestamp'], typ='num') + return LLLnode.from_list(['timestamp'], typ=BaseType('num', {'sec': 1}, True)) elif key == "block.coinbase": - return LLLnode.from_list(['coinbase'], typ='num') + return LLLnode.from_list(['coinbase'], typ='address') elif key == "block.number": return LLLnode.from_list(['number'], typ='num') elif key == "tx.origin": @@ -558,9 +677,9 @@ def parse_expr(expr, context): # Other variables else: sub = parse_variable_location(expr.value, context) - if not isinstance(sub.typ, dict): + if not isinstance(sub.typ, StructType): raise TypeMismatchException("Type mismatch: member variable access not expected: %r" % sub) - attrs = sorted(sub.typ.keys()) + attrs = sorted(sub.typ.members.keys()) if expr.attr not in attrs: raise TypeMismatchException("Member %s not found. Only the following available: %s" % (expr.attr, " ".join(attrs))) return add_variable_offset(sub, expr.attr) @@ -574,56 +693,80 @@ def parse_expr(expr, context): elif isinstance(expr, ast.BinOp): left = parse_value_expr(expr.left, context) right = parse_value_expr(expr.right, context) - for typ in (left.typ, right.typ): - if typ not in ('num', 'decimal'): - raise TypeMismatchException("Unsupported type for arithmetic op: "+typ) + if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): + raise TypeMismatchException("Unsupported type for arithmetic op: "+typ) + ltyp, rtyp = left.typ.typ, right.typ.typ if isinstance(expr.op, (ast.Add, ast.Sub)): + if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None: + raise TypeMismatchException("Unit mismatch: %r %r" % (left.typ.unit, right.typ.unit)) + if left.typ.positional and right.typ.positional: + raise TypeMismatchException("Cannot add or subtract two positional units!") + new_unit = left.typ.unit or right.typ.unit + new_positional = left.typ.positional or right.typ.positional op = 'add' if isinstance(expr.op, ast.Add) else 'sub' - if left.typ == right.typ: - o = LLLnode.from_list([op, left, right], typ=left.typ) - elif left.typ == 'num' and right.typ == 'decimal': - o = LLLnode.from_list([op, ['mul', left, DECIMAL_DIVISOR], right], typ='decimal') - elif left.typ == 'decimal' and right.typ == 'num': - o = LLLnode.from_list([op, left, ['mul', right, DECIMAL_DIVISOR]], typ='decimal') + if ltyp == rtyp: + o = LLLnode.from_list([op, left, right], typ=BaseType(ltyp, new_unit, new_positional)) + elif ltyp == 'num' and rtyp == 'decimal': + o = LLLnode.from_list([op, ['mul', left, DECIMAL_DIVISOR], right], + typ=BaseType('decimal', new_unit, new_positional)) + elif ltyp == 'decimal' and rtyp == 'num': + o = LLLnode.from_list([op, left, ['mul', right, DECIMAL_DIVISOR]], + typ=BaseType('decimal', new_unit, new_positional)) else: - raise Exception("How did I get here? %r %r" % (left.typ, right.typ)) + raise Exception("How did I get here? %r %r" % (ltyp, rtyp)) elif isinstance(expr.op, ast.Mult): - if left.typ == right.typ == 'num': - o = LLLnode.from_list(['mul', left, right], typ='num') - elif left.typ == right.typ == 'decimal': + if left.typ.positional or right.typ.positional: + raise TypeMismatchException("Cannot multiply positional values!") + new_unit = combine_units(left.typ.unit, right.typ.unit) + if ltyp == rtyp == 'num': + o = LLLnode.from_list(['mul', left, right], typ=BaseType('num', new_unit)) + elif ltyp == rtyp == 'decimal': o = LLLnode.from_list(['with', 'r', right, ['with', 'l', left, ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['not', 'l']]], - ['sdiv', 'ans', DECIMAL_DIVISOR]]]]], typ='decimal') - elif (left.typ == 'num' and right.typ == 'decimal') or (left.typ == 'decimal' and right.typ == 'num'): + ['sdiv', 'ans', DECIMAL_DIVISOR]]]]], typ=BaseType('decimal', new_unit)) + elif (ltyp == 'num' and rtyp == 'decimal') or (ltyp == 'decimal' and rtyp == 'num'): o = LLLnode.from_list(['with', 'r', right, ['with', 'l', left, ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['not', 'l']]], - 'ans']]]], typ='decimal') + 'ans']]]], typ=BaseType('decimal', new_unit)) elif isinstance(expr.op, ast.Div): - if right.typ == 'num': - o = LLLnode.from_list(['sdiv', left, ['clamp_nonzero', right]], typ=left.typ) - elif left.typ == right.typ == 'decimal': + if left.typ.positional or right.typ.positional: + raise TypeMismatchException("Cannot divide positional values!") + new_unit = combine_units(left.typ.unit, right.typ.unit, div=True) + if rtyp == 'num': + o = LLLnode.from_list(['sdiv', left, ['clamp_nonzero', right]], typ=BaseType(ltyp, new_unit)) + elif ltyp == rtyp == 'decimal': o = LLLnode.from_list(['with', 'l', left, ['with', 'r', ['clamp_nonzero', right], ['sdiv', ['mul', 'l', DECIMAL_DIVISOR], 'r']]], - typ='decimal') - elif left.typ == 'num' and right.typ == 'decimal': - o = LLLnode.from_list(['sdiv', ['mul', left, DECIMAL_DIVISOR ** 2], ['clamp_nonzero', right]], typ='decimal') + typ=BaseType('decimal', new_unit)) + elif ltyp == 'num' and rtyp == 'decimal': + o = LLLnode.from_list(['sdiv', ['mul', left, DECIMAL_DIVISOR ** 2], ['clamp_nonzero', right]], + typ=BaseType('decimal', new_unit)) elif isinstance(expr.op, ast.Mod): - if left.typ == right.typ: - o = LLLnode.from_list(['smod', left, ['clamp_nonzero', right]], typ=left.typ) - elif left.typ == 'decimal' and right.typ == 'num': - o = LLLnode.from_list(['smod', left, ['mul', ['clamp_nonzero', right], DECIMAL_DIVISOR]], typ='decimal') - elif left.typ == 'num' and right.typ == 'decimal': - o = LLLnode.from_list(['smod', ['mul', left, DECIMAL_DIVISOR], right], typ='decimal') + if left.typ.positional or right.typ.positional: + raise TypeMismatchException("Cannot use positional values as modulus arguments!") + if left.typ.unit != right.typ.unit and left.typ.unit is not None and right.typ.unit is not None: + raise TypeMismatchException("Modulus arguments must have same unit") + new_unit = left.typ.unit or right.typ.unit + if ltyp == rtyp: + o = LLLnode.from_list(['smod', left, ['clamp_nonzero', right]], typ=BaseType(ltyp, new_unit)) + elif ltyp == 'decimal' and rtyp == 'num': + o = LLLnode.from_list(['smod', left, ['mul', ['clamp_nonzero', right], DECIMAL_DIVISOR]], + typ=BaseType('decimal', new_unit)) + elif ltyp == 'num' and rtyp == 'decimal': + o = LLLnode.from_list(['smod', ['mul', left, DECIMAL_DIVISOR], right], + typ=BaseType('decimal', new_unit)) else: raise Exception("Unsupported binop: %r" % expr.op) # Comparison operations elif isinstance(expr, ast.Compare): left = parse_value_expr(expr.left, context) right = parse_value_expr(expr.comparators[0], context) + if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): + raise TypeMismatchException("Can't compare values with different units!") if len(expr.ops) != 1: raise StructureException("Cannot have a comparison with more than two elements") if isinstance(expr.ops[0], ast.Gt): @@ -640,25 +783,25 @@ def parse_expr(expr, context): op = 'ne' else: raise Exception("Unsupported comparison operator") - for typ in (left.typ, right.typ): - if typ not in ('num', 'decimal'): - if op not in ('eq', 'ne'): - raise TypeMismatchException("Invalid type for comparison op: "+typ) - if left.typ == right.typ: + if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): + if op not in ('eq', 'ne'): + raise TypeMismatchException("Invalid type for comparison op: "+typ) + ltyp, rtyp = left.typ.typ, right.typ.typ + if ltyp == rtyp: o = LLLnode.from_list([op, left, right], typ='bool') - elif left.typ == 'decimal' and right.typ == 'num': + elif ltyp == 'decimal' and rtyp == 'num': o = LLLnode.from_list([op, left, ['mul', right, DECIMAL_DIVISOR]], typ='bool') - elif left.typ == 'num' and right.typ == 'decimal': + elif ltyp == 'num' and rtyp == 'decimal': o = LLLnode.from_list([op, ['mul', left, DECIMAL_DIVISOR], right], typ='bool') else: - raise TypeMismatchException("Unsupported types for comparison: %r %r" % (left.typ, right.typ)) + raise TypeMismatchException("Unsupported types for comparison: %r %r" % (ltyp, rtyp)) # Boolean logical operations elif isinstance(expr, ast.BoolOp): if len(expr.values) != 2: raise StructureException("Expected two arguments for a bool op") left = parse_value_expr(expr.values[0], context) right = parse_value_expr(expr.values[1], context) - if left.typ != 'bool' or right.typ != 'bool': + if not is_base_type(left.typ, 'bool') or not is_base_type(right.typ, 'bool'): raise TypeMismatchException("Boolean operations can only be between booleans!") if isinstance(expr.op, ast.And): op = 'and' @@ -675,7 +818,7 @@ def parse_expr(expr, context): # a zero entry represents false, all others represent true o = LLLnode.from_list(["iszero", operand], typ='bool') elif isinstance(expr.op, ast.USub): - if operand.typ not in ('num', 'decimal'): + if not is_numeric_type(operand.typ): raise TypeMismatchException("Unsupported type for negation: %r" % operand.typ) o = LLLnode.from_list(["sub", 0, operand], typ=operand.typ) else: @@ -686,22 +829,28 @@ def parse_expr(expr, context): if len(expr.args) != 1: raise StructureException("Floor expects 1 argument!") sub = parse_value_expr(expr.args[0], context) - if sub.typ in ('num', 'num256', 'signed256'): + if is_base_type(sub.typ, ('num', 'num256', 'signed256')): return sub - elif sub.typ == 'decimal': - return LLLnode.from_list(['sdiv', sub, DECIMAL_DIVISOR], typ='num') + elif is_base_type(sub.typ, 'decimal'): + return LLLnode.from_list(['sdiv', sub, DECIMAL_DIVISOR], typ=BaseType('num', sub.typ.unit, sub.typ.positional)) else: raise TypeMismatchException("Bad type for argument to floor: %r" % sub.typ) elif isinstance(expr.func, ast.Name) and expr.func.id == 'decimal': if len(expr.args) != 1: raise StructureException("Decimal expects 1 argument!") sub = parse_value_expr(expr.args[0], context) - if sub.typ == 'decimal': + if is_base_type(sub.typ, 'decimal'): return sub - elif sub.typ == 'num': - return LLLnode.from_list(['mul', sub, DECIMAL_DIVISOR], typ='decimal') + elif is_base_type(sub.typ, 'num'): + return LLLnode.from_list(['mul', sub, DECIMAL_DIVISOR], typ=BaseType('decimal', sub.typ.unit, sub.typ.positional)) else: raise TypeMismatchException("Bad type for argument to decimal: %r" % sub.typ) + elif isinstance(expr.func, ast.Name) and expr.func.id == "as_number": + sub = parse_value_expr(expr.args[0], context) + if is_base_type(sub.typ, ('num', 'decimal')): + return LLLnode(value=sub.value, args=sub.args, typ=BaseType(sub.typ.typ, {})) + else: + raise TypeMismatchException("as_number only accepts base types") else: raise Exception("Unsupported operator: %r" % ast.dump(expr)) elif isinstance(expr, ast.List): @@ -714,36 +863,36 @@ def parse_expr(expr, context): if not out_type: out_type = o[-1].typ elif len(o) > 1 and o[-1].typ != out_type: - out_type = 'mixed' - return LLLnode.from_list(["multi"] + o, typ=[out_type, len(o)]) + out_type = MixedType() + return LLLnode.from_list(["multi"] + o, typ=ListType(out_type, len(o))) elif isinstance(expr, ast.Dict): o = {} - typ = {} + members = {} for key, value in zip(expr.keys, expr.values): if not isinstance(key, ast.Name) or not is_varname_valid(key.id): raise TypeMismatchException("Invalid member variable for struct: %r" % vars(key).get('id', key)) if key.id in o: raise TypeMismatchException("Member variable duplicated: "+key.id) o[key.id] = parse_expr(value, context) - typ[key.id] = o[key.id].typ - return LLLnode.from_list(["multi"] + [o[key] for key in sorted(list(o.keys()))], typ=typ) + members[key.id] = o[key.id].typ + return LLLnode.from_list(["multi"] + [o[key] for key in sorted(list(o.keys()))], typ=StructType(members)) else: raise Exception("Unsupported operator: %r" % ast.dump(expr)) # Clamp based on variable type - if o.annotation is None and o.typ == 'bool': + if o.location is None and o.typ == 'bool': return o - elif o.annotation is None and o.typ == 'num': + elif o.location is None and o.typ == 'num': return LLLnode.from_list(['clamp', ['mload', MINNUM_POS], o, ['mload', MAXNUM_POS]], typ='num') - elif o.annotation is None and o.typ == 'decimal': + elif o.location is None and o.typ == 'decimal': return LLLnode.from_list(['clamp', ['mload', MINDECIMAL_POS], o, ['mload', MAXDECIMAL_POS]], typ='decimal') else: return o -# Unwrap annotation -def unwrap_annotation(orig): - if orig.annotation == 'memory': +# Unwrap location +def unwrap_location(orig): + if orig.location == 'memory': return LLLnode.from_list(['mload', orig], typ=orig.typ) - elif orig.annotation == 'storage': + elif orig.location == 'storage': return LLLnode.from_list(['sload', orig], typ=orig.typ) else: return orig @@ -751,108 +900,130 @@ def unwrap_annotation(orig): # Parse an expression that represents an address in memory or storage def parse_variable_location(expr, context): o = parse_expr(expr, context) - if not o.annotation: + if not o.location: raise Exception("Looking for a variable location, instead got a value") return o # Parse an expression that results in a value def parse_value_expr(expr, context): - return unwrap_annotation(parse_expr(expr, context)) + return unwrap_location(parse_expr(expr, context)) + +# Checks that the units of frm can be seamlessly converted into the units of to +def are_units_compatible(frm, to): + return frm.unit is None or (frm.unit == to.unit and frm.positional == to.positional) # Convert from one base type to another def base_type_conversion(orig, frm, to): - orig = unwrap_annotation(orig) - if frm == to and isinstance(frm, str): - return orig - elif frm == 'num' and to == 'decimal': - return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ='decimal') - elif frm == 'null': - return LLLnode.from_list(0 if to in ('num', 'bool', 'num256', 'address', 'bytes32') else None, typ=to) + orig = unwrap_location(orig) + if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType): + raise TypeMismatchException("Base type conversion from or to non-base type: %r %r" % (frm, to)) + elif is_base_type(frm, to.typ) and are_units_compatible(frm, to): + return LLLnode(orig.value, orig.args, typ=to) + elif is_base_type(frm, 'num') and is_base_type(to, 'decimal') and are_units_compatible(frm, to): + return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ=BaseType('decimal', to.unit, to.positional)) + elif isinstance(frm, NullType): + if to.typ not in ('num', 'bool', 'num256', 'address', 'bytes32', 'decimal'): + raise TypeMismatchException("Cannot convert null-type object to type %r" % to) + return LLLnode.from_list(0, typ=to) else: raise TypeMismatchException("Typecasting from base type %r to %r unavailable" % (frm, to)) +def set_default_units(typ): + if isinstance(typ, BaseType): + if typ.unit is None: + return BaseType(typ.typ, {}) + else: + return typ + elif isinstance(typ, StructType): + return StructType({k: set_default_units(v) for k, v in typ.members.items()}) + elif isinstance(typ, ListType): + return ListType(set_default_units(typ.subtype), typ.count) + elif isinstance(typ, MappingType): + return MappingType(set_default_units(typ.keytype), set_default_units(typ.valuetype)) + else: + return typ + # Create an x=y statement, where the types may be compound -def make_setter(left, right, annotation): +def make_setter(left, right, location): # Basic types - if isinstance(left.typ, str): + if isinstance(left.typ, BaseType): right = base_type_conversion(right, right.typ, left.typ) - if annotation == 'storage': + if location == 'storage': return LLLnode.from_list(['sstore', left, right], typ=None) - elif annotation == 'memory': + elif location == 'memory': return LLLnode.from_list(['mstore', left, right], typ=None) + # Can't copy mappings + elif isinstance(left.typ, MappingType): + raise TypeMismatchException("Cannot copy mappings; can only copy individual elements") # Arrays - elif isinstance(left.typ, list): + elif isinstance(left.typ, ListType): # Cannot do something like [a, b, c] = [1, 2, 3] if left.value == "multi": raise Exception("Target of set statement must be a single item") - if not isinstance(right.typ, list) and right.typ != 'null': + if not isinstance(right.typ, (ListType, NullType)): raise TypeMismatchException("Setter type mismatch: left side is array, right side is %r" % right.typ) - _, elts = left.typ - left_token = LLLnode.from_list('_L', typ=left.typ, annotation=left.annotation) - if not isinstance(elts, int): - raise TypeMismatchException("Cannot copy mappings; can only copy individual elements") + left_token = LLLnode.from_list('_L', typ=left.typ, location=left.location) # Type checks - if right.typ != 'null': - if not isinstance(right.typ, list): + if not isinstance(right.typ, NullType): + if not isinstance(right.typ, ListType): raise TypeMismatchException("Left side is array, right side is not") - _, elts2 = right.typ - if elts != elts2: + if left.typ.count != right.typ.count: raise TypeMismatchException("Mismatched number of elements") # If the right side is a literal if right.value == "multi": - if len(right.args) != elts: + if len(right.args) != left.typ.count: raise TypeMismatchException("Mismatched number of elements") subs = [] - for i in range(elts): + for i in range(left.typ.count): subs.append(make_setter(add_variable_offset(left_token, LLLnode.from_list(i, typ='num')), - right.args[i], annotation)) + right.args[i], location)) return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None) # If the right side is a null - elif right.typ == 'null': + elif isinstance(right.typ, NullType): subs = [] - for i in range(elts): + for i in range(left.typ.count): subs.append(make_setter(add_variable_offset(left_token, LLLnode.from_list(i, typ='num')), - LLLnode.from_list(None, typ='null'), annotation)) + LLLnode.from_list(None, typ=NullType()), location)) return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None) # If the right side is a variable else: - right_token = LLLnode.from_list('_R', typ=right.typ, annotation=right.annotation) + right_token = LLLnode.from_list('_R', typ=right.typ, location=right.location) subs = [] - for i in range(elts): + for i in range(left.typ.count): subs.append(make_setter(add_variable_offset(left_token, LLLnode.from_list(i, typ='num')), - add_variable_offset(right_token, LLLnode.from_list(i, typ='num')), annotation)) + add_variable_offset(right_token, LLLnode.from_list(i, typ='num')), location)) return LLLnode.from_list(['with', '_L', left, ['with', '_R', right, ['seq'] + subs]], typ=None) # Structs - elif isinstance(left.typ, dict): + elif isinstance(left.typ, StructType): if left.value == "multi": raise Exception("Target of set statement must be a single item") - if right.typ != 'null': - if not isinstance(right.typ, dict): + if not isinstance(right.typ, NullType): + if not isinstance(right.typ, StructType): raise TypeMismatchException("Setter type mismatch: left side is %r, right side is %r" % (left.typ, right.typ)) - if sorted(list(left.typ.keys())) != sorted(list(right.typ.keys())): + if sorted(list(left.typ.members.keys())) != sorted(list(right.typ.members.keys())): raise TypeMismatchException("Keys don't match for structs") - left_token = LLLnode.from_list('_L', typ=left.typ, annotation=left.annotation) + left_token = LLLnode.from_list('_L', typ=left.typ, location=left.location) # If the right side is a literal if right.value == "multi": - if len(right.args) != len(list(left.typ.keys())): + if len(right.args) != len(list(left.typ.members.keys())): raise TypeMismatchException("Mismatched number of elements") subs = [] - for i, typ in enumerate(sorted(list(left.typ.keys()))): - subs.append(make_setter(add_variable_offset(left_token, typ), right.args[i], annotation)) + for i, typ in enumerate(sorted(list(left.typ.members.keys()))): + subs.append(make_setter(add_variable_offset(left_token, typ), right.args[i], location)) return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None) # If the right side is a null - elif right.typ == "null": + elif isinstance(right.typ, NullType): subs = [] - for typ in sorted(list(left.typ.keys())): - subs.append(make_setter(add_variable_offset(left_token, typ), LLLnode.from_list(None, typ='null'), annotation)) + for typ in sorted(list(left.typ.members.keys())): + subs.append(make_setter(add_variable_offset(left_token, typ), LLLnode.from_list(None, typ=NullType()), location)) return LLLnode.from_list(['with', '_L', left, ['seq'] + subs], typ=None) # If the right side is a variable else: - right_token = LLLnode.from_list('_R', typ=right.typ, annotation=right.annotation) + right_token = LLLnode.from_list('_R', typ=right.typ, location=right.location) subs = [] - for typ in sorted(list(left.typ.keys())): - subs.append(make_setter(add_variable_offset(left_token, typ), add_variable_offset(right_token, typ), annotation)) + for typ in sorted(list(left.typ.members.keys())): + subs.append(make_setter(add_variable_offset(left_token, typ), add_variable_offset(right_token, typ), location)) return LLLnode.from_list(['with', '_L', left, ['with', '_R', right, ['seq'] + subs]], typ=None) # Parse a statement (usually one line of code but not always) @@ -862,7 +1033,7 @@ def parse_stmt(stmt, context): elif isinstance(stmt, ast.Pass): return LLLnode.from_list('pass', typ=None) elif isinstance(stmt, ast.AnnAssign): - typ = parse_type(stmt.annotation, annotation='memory') + typ = parse_type(stmt.annotation, location='memory') varname = stmt.target.id pos = context.new_variable(varname, typ) return LLLnode.from_list('pass', typ=None) @@ -872,13 +1043,13 @@ def parse_stmt(stmt, context): raise StructureException("Assignment statement must have one target") sub = parse_expr(stmt.value, context) if isinstance(stmt.targets[0], ast.Name) and stmt.targets[0].id not in context.vars: - pos = context.new_variable(stmt.targets[0].id, sub.typ) - return make_setter(LLLnode.from_list(pos, typ=sub.typ, annotation='memory'), sub, 'memory') + pos = context.new_variable(stmt.targets[0].id, set_default_units(sub.typ)) + return make_setter(LLLnode.from_list(pos, typ=sub.typ, location='memory'), sub, 'memory') else: target = parse_variable_location(stmt.targets[0], context) - if target.annotation == 'storage' and context.is_constant: + if target.location == 'storage' and context.is_constant: raise ConstancyViolationException("Cannot modify storage inside a constant function!") - return make_setter(target, sub, target.annotation) + return make_setter(target, sub, target.location) # If statements elif isinstance(stmt, ast.If): if stmt.orelse: @@ -900,10 +1071,10 @@ def parse_stmt(stmt, context): if len(stmt.args) != 2: raise Exception("Send expects 2 arguments!") to = parse_value_expr(stmt.args[0], context) - if to.typ != "address": + if not is_base_type(to.typ, "address"): raise TypeMismatchException("Expected an address as destination for send") value = parse_value_expr(stmt.args[1], context) - if value.typ != "num" and value.typ != "num256": + if not is_base_type(value.typ, ("num", "num256")): raise TypeMismatchException("Send value must be a number!") else: return LLLnode.from_list(['pop', ['call', 0, to, value, 0, 0, 0, 0]], typ=None) @@ -913,7 +1084,7 @@ def parse_stmt(stmt, context): if context.is_constant: raise ConstancyViolationException("Cannot %s inside a constant function!" % stmt.func.id) sub = parse_value_expr(stmt.args[0], context) - if sub.typ != "address": + if not is_base_type(sub.typ, "address"): raise TypeMismatchException("%s expects an address!" % stmt.func.id) return LLLnode.from_list(['selfdestruct', sub], typ=None) @@ -949,7 +1120,7 @@ def parse_stmt(stmt, context): start = parse_value_expr(stmt.iter.args[0], context) rounds = stmt.iter.args[1].right.n varname = stmt.target.id - pos = context.vars[varname][0] if varname in context.forvars else context.new_variable(varname, 'num') + pos = context.vars[varname][0] if varname in context.forvars else context.new_variable(varname, BaseType('num')) o = LLLnode.from_list(['repeat', pos, start, rounds, parse_body(stmt.body, context)], typ=None) context.forvars[varname] = True return o @@ -960,15 +1131,15 @@ def parse_stmt(stmt, context): sub = base_type_conversion(sub, sub.typ, target.typ) if not isinstance(stmt.op, (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod)): raise Exception("Unsupported operator for augassign") - if target.typ not in types: + if not isinstance(target.typ, BaseType): raise TypeMismatchException("Can only use aug-assign operators with simple types!") - if target.annotation == 'storage': + if target.location == 'storage': if context.is_constant: raise ConstancyViolationException("Cannot modify storage inside a constant function!") o = parse_value_expr(ast.BinOp(left=LLLnode.from_list(['sload', '_addr'], typ=target.typ), right=sub, op=stmt.op), context) return LLLnode.from_list(['with', '_addr', target, ['sstore', '_addr', base_type_conversion(o, o.typ, target.typ)]], typ=None) - elif target.annotation == 'memory': + elif target.location == 'memory': o = parse_value_expr(ast.BinOp(left=LLLnode.from_list(['mload', '_addr'], typ=target.typ), right=sub, op=stmt.op), context) return LLLnode.from_list(['with', '_addr', target, ['mstore', '_addr', base_type_conversion(o, o.typ, target.typ)]], typ=None) @@ -984,13 +1155,18 @@ def parse_stmt(stmt, context): if not stmt.value: raise TypeMismatchException("Expecting to return a value") sub = parse_value_expr(stmt.value, context) - if sub.typ == context.return_type or (sub.typ == 'num' and context.return_type == 'signed256'): + if not isinstance(sub.typ, BaseType): + raise TypeMismatchException("Can only return base type!") + elif not are_units_compatible(sub.typ, context.return_type): + raise TypeMismatchException("Return type units mismatch %r %r" % (sub.typ, context.return_type)) + elif is_base_type(sub.typ, context.return_type.typ) or \ + (is_base_type(sub.typ, 'num') and is_base_type(context.return_type, 'signed256')): return LLLnode.from_list(['seq', ['mstore', 0, sub], ['return', 0, 32]], typ=None) - elif sub.typ == 'num' and context.return_type == 'num256': + elif is_base_type(sub.typ, 'num') and is_base_type(context.return_type, 'num256'): return LLLnode.from_list(['seq', ['mstore', 0, sub], ['assert', ['iszero', ['lt', ['mload', 0], 0]]], ['return', 0, 32]], typ=None) else: - raise TypeMismatchException("Unsupported type conversion: %r %r" % (sub.typ, context.return_type)) + raise TypeMismatchException("Unsupported type conversion: %r to %r" % (sub.typ, context.return_type)) else: raise StructureException("Unsupported statement type") diff --git a/test_invalids.py b/test_invalids.py index a36974d6d0..e4c22e5883 100644 --- a/test_invalids.py +++ b/test_invalids.py @@ -716,3 +716,172 @@ def foo(): self.nom.a[135] = {c: 6} self.nom.b = 9 """) + +must_succeed(""" +def foo(x: timestamp) -> timestamp: + return x +""") + +must_succeed(""" +def foo(x: timestamp) -> num(const): + return 5 +""") + +must_succeed(""" +def foo(x: timestamp) -> timestamp(const): + return x +""") + +must_succeed(""" +def foo(x: timestamp) -> timestamp: + y = x + return y +""") + +must_fail(""" +def foo(x: timestamp) -> num: + return x +""", TypeMismatchException) + +must_fail(""" +def foo(x: timestamp) -> timedelta: + return x +""", TypeMismatchException) + +must_succeed(""" +def foo(x: timestamp, y: timestamp) -> bool: + return y > x +""") + +must_succeed(""" +def foo(x: timedelta, y: timedelta) -> bool: + return y == x +""") + +must_fail(""" +def foo(x: timestamp, y: timedelta) -> bool: + return y < x +""", TypeMismatchException) + +must_succeed(""" +def foo(x: timestamp) -> timestamp: + return x + 50 +""") + +must_succeed(""" +def foo() -> timestamp: + return 720 +""") + +must_succeed(""" +def foo() -> timedelta: + return 720 +""") + +must_succeed(""" +def foo(x: timestamp, y: timedelta) -> timestamp: + return x + y +""") + +must_fail(""" +def foo(x: timestamp, y: timedelta) -> timedelta: + return x + y +""", TypeMismatchException) + +must_fail(""" +def foo(x: timestamp, y: timestamp) -> timestamp: + return x + y +""", TypeMismatchException) + +must_succeed(""" +def foo(x: timedelta, y: timedelta) -> timedelta: + return x + y +""") + +must_succeed(""" +def foo(x: timedelta) -> timedelta: + return x * 2 +""") + +must_fail(""" +def foo(x: timestamp) -> timestamp: + return x * 2 +""", TypeMismatchException) + +must_fail(""" +def foo(x: timedelta, y: timedelta) -> timedelta: + return x * y +""", TypeMismatchException) + +must_succeed(""" +def foo(x: timedelta) -> bool: + return x > 50 +""") + +must_succeed(""" +def foo(x: timestamp) -> bool: + return x > 12894712 +""") + +must_succeed(""" +def foo() -> timestamp: + x: timestamp + x = 30 + return x +""") + +must_fail(""" +def foo() -> timestamp: + x = 30 + y: timestamp + return x + y +""", TypeMismatchException) + +must_succeed(""" +a: timestamp[timestamp] + +def add_record(): + self.a[block.timestamp] = block.timestamp + 20 +""") + +must_fail(""" +a: num[timestamp] + +def add_record(): + self.a[block.timestamp] = block.timestamp + 20 +""", TypeMismatchException) + +must_fail(""" +a: timestamp[num] + +def add_record(): + self.a[block.timestamp] = block.timestamp + 20 +""", TypeMismatchException) + +must_fail(""" +def add_record(): + a = {x: block.timestamp} + b = {y: 5} + a.x = b.y +""", TypeMismatchException) + +must_succeed(""" +def add_record(): + a = {x: block.timestamp} + a.x = 5 +""") + +must_succeed(""" +def foo() -> num: + return as_number(block.timestamp) +""") + +must_fail(""" +def foo() -> address: + return as_number(block.coinbase) +""", TypeMismatchException) + +must_fail(""" +def foo() -> address: + return as_number([1, 2, 3]) +""", TypeMismatchException) diff --git a/test_parser.py b/test_parser.py index 06c8e61945..7637c65ec6 100644 --- a/test_parser.py +++ b/test_parser.py @@ -430,15 +430,15 @@ def returnMoose() -> num: crowdfund = """ -funders: {sender: address, value: num}[num] +funders: {sender: address, value: wei_value}[num] nextFunderIndex: num beneficiary: address -deadline: num -goal: num +deadline: timestamp +goal: wei_value refundIndex: num -timelimit: num +timelimit: timedelta -def __init__(_beneficiary: address, _goal: num, _timelimit: num): +def __init__(_beneficiary: address, _goal: wei_value, _timelimit: timedelta): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit self.timelimit = _timelimit @@ -454,13 +454,13 @@ def participate(): def expired() -> bool(const): return block.timestamp >= self.deadline -def timestamp() -> num(const): +def timestamp() -> timestamp(const): return block.timestamp -def deadline() -> num(const): +def deadline() -> timestamp(const): return self.deadline -def timelimit() -> num(const): +def timelimit() -> timedelta(const): return self.timelimit def reached() -> bool(const): @@ -731,15 +731,15 @@ def foq() -> num: crowdfund2 = """ -funders: {sender: address, value: num}[num] +funders: {sender: address, value: wei_value}[num] nextFunderIndex: num beneficiary: address -deadline: num -goal: num +deadline: timestamp +goal: wei_value refundIndex: num -timelimit: num +timelimit: timedelta -def __init__(_beneficiary: address, _goal: num, _timelimit: num): +def __init__(_beneficiary: address, _goal: wei_value, _timelimit: timedelta): self.beneficiary = _beneficiary self.deadline = block.timestamp + _timelimit self.timelimit = _timelimit @@ -754,13 +754,13 @@ def participate(): def expired() -> bool(const): return block.timestamp >= self.deadline -def timestamp() -> num(const): +def timestamp() -> timestamp(const): return block.timestamp -def deadline() -> num(const): +def deadline() -> timestamp(const): return self.deadline -def timelimit() -> num(const): +def timelimit() -> timedelta(const): return self.timelimit def reached() -> bool(const):