diff --git a/src/sema/expression/arithmetic.rs b/src/sema/expression/arithmetic.rs index 8150adc73..16883688a 100644 --- a/src/sema/expression/arithmetic.rs +++ b/src/sema/expression/arithmetic.rs @@ -227,12 +227,19 @@ pub(super) fn shift_left( let (right_length, _) = type_bits_and_sign(&right.ty(), &r.loc(), false, ns, diagnostics)?; let left_type = left.ty().deref_any().clone(); + let right_type = right.ty().deref_any().clone(); Ok(Expression::ShiftLeft { loc: *loc, ty: left_type.clone(), left: Box::new(left.cast(loc, &left_type, true, ns, diagnostics)?), - right: Box::new(cast_shift_arg(loc, right, right_length, &left_type, ns)), + right: Box::new(cast_shift_arg( + loc, + right.cast(loc, &right_type, true, ns, diagnostics)?, + right_length, + &left_type, + ns, + )), }) } @@ -252,6 +259,7 @@ pub(super) fn shift_right( check_var_usage_expression(ns, &left, &right, symtable); let left_type = left.ty().deref_any().clone(); + let right_type = right.ty().deref_any().clone(); // left hand side may be bytes/int/uint // right hand size may be int/uint let _ = type_bits_and_sign(&left_type, &l.loc(), true, ns, diagnostics)?; @@ -261,7 +269,13 @@ pub(super) fn shift_right( loc: *loc, ty: left_type.clone(), left: Box::new(left.cast(loc, &left_type, true, ns, diagnostics)?), - right: Box::new(cast_shift_arg(loc, right, right_length, &left_type, ns)), + right: Box::new(cast_shift_arg( + loc, + right.cast(loc, &right_type, true, ns, diagnostics)?, + right_length, + &left_type, + ns, + )), sign: left_type.is_signed_int(ns), }) } diff --git a/tests/codegen_testcases/solidity/struct_member_shift.sol b/tests/codegen_testcases/solidity/struct_member_shift.sol new file mode 100644 index 000000000..5506b7abc --- /dev/null +++ b/tests/codegen_testcases/solidity/struct_member_shift.sol @@ -0,0 +1,29 @@ +// RUN: --target polkadot --emit cfg +contract c { + struct S { + uint256 a; + } + function test1(S memory s) public pure returns (uint256) { +// CHECK: ty:uint256 %b = (uint256 2 << (load (struct (arg #0) field 0))) + uint256 b = 2 << s.a; + return b; + } + + function test2(S memory s) public pure returns (uint256) { +// CHECK: ty:uint256 %b = ((load (struct (arg #0) field 0)) << uint256 2) + uint256 b = s.a << 2; + return b; + } + + function test3(S memory s) public pure returns (uint256) { +// CHECK: ty:uint256 %b = (uint256 2 >> (load (struct (arg #0) field 0))) + uint256 b = 2 >> s.a; + return b; + } + + function test4(S memory s) public pure returns (uint256) { +// CHECK: ty:uint256 %b = ((load (struct (arg #0) field 0)) >> uint256 2) + uint256 b = s.a >> 2; + return b; + } +}