Skip to content

Commit

Permalink
Merge pull request #114 from near/viktar/shifts
Browse files Browse the repository at this point in the history
zkasm: shl
  • Loading branch information
MCJOHN974 authored Nov 27, 2023
2 parents 8d469f2 + f576640 commit cc91dd0
Show file tree
Hide file tree
Showing 28 changed files with 811 additions and 163 deletions.
24 changes: 24 additions & 0 deletions cranelift/codegen/src/isa/zkasm/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
(rs1 Reg)
(rs2 Reg))

(Shl64
(rd WritableReg)
(rs1 Reg)
(rs2 Reg))

(Shl32
(rd WritableReg)
(rs1 Reg)
(rs2 Reg))

(MulArith32
(rd WritableReg)
(rs1 Reg)
Expand Down Expand Up @@ -663,6 +673,19 @@
;; RV32M Extension
;; TODO: Enable these instructions only when we have the M extension

(decl zk_shl (XReg XReg) XReg)
(rule (zk_shl rs1 rs2)
(let ((dst WritableXReg (temp_writable_xreg))
(_ Unit (emit (MInst.Shl64 dst rs1 rs2))))
dst))

(decl zk_shl_32 (XReg XReg) XReg)
(rule (zk_shl_32 rs1 rs2)
(let ((dst WritableXReg (temp_writable_xreg))
(_ Unit (emit (MInst.Shl32 dst rs1 rs2))))
dst))


;; Helper for emitting the `mul` instruction.
;; rd ← rs1 × rs2
(decl zk_mul (XReg XReg) XReg)
Expand All @@ -671,6 +694,7 @@
(_ Unit (emit (MInst.MulArith dst rs1 rs2))))
dst))


(decl zk_mul_32 (XReg XReg) XReg)
(rule (zk_mul_32 rs1 rs2)
(let ((dst WritableXReg (temp_writable_xreg))
Expand Down
63 changes: 63 additions & 0 deletions cranelift/codegen/src/isa/zkasm/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,69 @@ impl MachInstEmit for Inst {
sink,
);
}
&Inst::Shl64 { rd, rs1, rs2 } => {
let rs1 = allocs.next(rs1);
let rs2 = allocs.next(rs2);
debug_assert_eq!(rs1, a0());
debug_assert_eq!(rs2, e0());

put_string("A :MSTORE(SP)\n", sink);

put_string("64 => B\n", sink);
put_string("0 => D\n", sink);
put_string("${E / B} => A\n", sink);
put_string("${E % B} => C\n", sink);
put_string("E:ARITH\n", sink);

put_string("$ => A :MLOAD(SP)\n", sink);
put_string("C => E", sink);
// E -- shift amount.
// A -- number.
sink.put_data(format!(";;NEED_INCLUDE: 2-exp\n").as_bytes());
put_string("zkPC + 2 => RR\n", sink);
put_string(" :JMP(@two_power + E)\n", sink);
put_string("0 => D\n", sink);
put_string("0 => C\n", sink);
put_string("$${var _mulShlArith = A * B}\n", sink);
put_string("${_mulShlArith / 18446744073709551616} => D\n", sink);
put_string("${_mulShlArith % 18446744073709551616} => E :ARITH\n", sink);
}
&Inst::Shl32 { rd, rs1, rs2 } => {
let rs1 = allocs.next(rs1);
let rs2 = allocs.next(rs2);
debug_assert_eq!(rs1, a0());
debug_assert_eq!(rs2, e0());

// E /= 2**32
put_string("A :MSTORE(SP)\n", sink);
put_string("0 => D\n", sink);
put_string("4294967296n => B\n", sink);
put_string("${E / B} => A\n", sink);
put_string("${E % B} => C\n", sink);
put_string("E:ARITH\n", sink);
put_string("A => E\n", sink);

// E %= 32
put_string("32 => B\n", sink);
put_string("0 => D\n", sink);
put_string("${E / B} => A\n", sink);
put_string("${E % B} => C\n", sink);
put_string("E:ARITH\n", sink);
put_string("$ => A :MLOAD(SP)\n", sink);
put_string("C => E\n", sink);

// E -- shift amount.
// A -- number.
// E:= (A << E)
sink.put_data(format!(";;NEED_INCLUDE: 2-exp\n").as_bytes());
put_string("zkPC + 2 => RR\n", sink);
put_string(" :JMP(@two_power + E)\n", sink);
put_string("0 => D\n", sink);
put_string("0 => C\n", sink);
put_string("$${var _mulShlArith = A * B}\n", sink);
put_string("${_mulShlArith / 18446744073709551616} => D\n", sink);
put_string("${_mulShlArith % 18446744073709551616} => E :ARITH\n", sink);
}
&Inst::DivArith32 { rd, rs1, rs2 } => {
let rs1 = allocs.next(rs1);
let rs2 = allocs.next(rs2);
Expand Down
32 changes: 32 additions & 0 deletions cranelift/codegen/src/isa/zkasm/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ fn zkasm_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandC
collector.reg_fixed_use(rs2, b0());
collector.reg_def(rd);
}
&Inst::Shl64 { rd, rs1, rs2, .. } => {
collector.reg_fixed_use(rs1, a0());
collector.reg_fixed_use(rs2, e0());
let mut clobbered = PRegSet::empty();
clobbered.add(d0().to_real_reg().unwrap().into());
clobbered.add(c0().to_real_reg().unwrap().into());
clobbered.add(b0().to_real_reg().unwrap().into());
collector.reg_clobbers(clobbered);
collector.reg_fixed_def(rd, e0());
}
&Inst::Shl32 { rd, rs1, rs2, .. } => {
collector.reg_fixed_use(rs1, a0());
collector.reg_fixed_use(rs2, e0());
let mut clobbered = PRegSet::empty();
clobbered.add(d0().to_real_reg().unwrap().into());
clobbered.add(c0().to_real_reg().unwrap().into());
clobbered.add(b0().to_real_reg().unwrap().into());
collector.reg_clobbers(clobbered);
collector.reg_fixed_def(rd, e0());
}
&Inst::MulArith { rd, rs1, rs2, .. } => {
collector.reg_fixed_use(rs1, a0());
collector.reg_fixed_use(rs2, b0());
Expand Down Expand Up @@ -987,6 +1007,18 @@ impl Inst {
}
}
}
&Inst::Shl32 { rd, rs1, rs2 } => {
let rs1_s = format_reg(rs1, allocs);
let rs2_s = format_reg(rs2, allocs);
let rd_s = format_reg(rd.to_reg(), allocs);
format!("Shl32 rd = {}, rs1 = {}, rs2 = {}", rd_s, rs1_s, rs2_s)
}
&Inst::Shl64 { rd, rs1, rs2 } => {
let rs1_s = format_reg(rs1, allocs);
let rs2_s = format_reg(rs2, allocs);
let rd_s = format_reg(rd.to_reg(), allocs);
format!("Shl64 rd = {}, rs1 = {}, rs2 = {}", rd_s, rs1_s, rs2_s)
}
&Inst::MulArith { rd, rs1, rs2 } => {
let rs1_s = format_reg(rs1, allocs);
let rs2_s = format_reg(rs2, allocs);
Expand Down
18 changes: 4 additions & 14 deletions cranelift/codegen/src/isa/zkasm/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@
(rule 1 (lower (has_type $I64 (sextend (has_type $I32 (isub x y)))))
(rv_subw x y))

(rule 1 (lower (has_type $I64 (sextend (has_type $I32 (ishl x y)))))
(rv_sllw x (value_regs_get y 0)))

(rule 1 (lower (has_type $I64 (sextend (has_type $I32 (ushr x y)))))
(rv_srlw x (value_regs_get y 0)))

Expand All @@ -189,18 +186,11 @@

;;;; Rules for `ishl` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; 8/16 bit types need a mask on the shift amount
; (rule 0 (lower (has_type (ty_int (ty_8_or_16 ty)) (ishl x y)))
; (if-let mask (u64_to_imm12 (shift_mask ty)))
; (rv_sllw x (rv_andi (value_regs_get y 0) mask)))

;; Using the 32bit version of `sll` automatically masks the shift amount.
(rule 1 (lower (has_type $I32 (ishl x y)))
(rv_sllw x (value_regs_get y 0)))
(rule (lower (has_type $I64 (ishl x y)))
(zk_shl x y))

;; Similarly, the 64bit version does the right thing.
(rule 1 (lower (has_type $I64 (ishl x y)))
(rv_sll x (value_regs_get y 0)))
(rule (lower (has_type $I32 (ishl x y)))
(zk_shl_32 x y))

;;;; Rules for `ushr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down
38 changes: 33 additions & 5 deletions cranelift/zkasm_data/spectest/i32/generated/shl_1.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,43 @@ start:
function_1:
SP + 1 => SP
RR :MSTORE(SP - 1)
SP + 2 => SP
B :MSTORE(SP - 1)
SP + 4 => SP
C :MSTORE(SP - 1)
D :MSTORE(SP - 2)
E :MSTORE(SP - 3)
B :MSTORE(SP - 4)
4294967296n => A
4294967296n => E
A :MSTORE(SP)
0 => D
4294967296n => B
$ => A :sllw
${E / B} => A
${E % B} => C
E:ARITH
A => E
32 => B
0 => D
${E / B} => A
${E % B} => C
E:ARITH
$ => A :MLOAD(SP)
C => E
;;NEED_INCLUDE: 2-exp
zkPC + 2 => RR
:JMP(@two_power + E)
0 => D
0 => C
$${var _mulShlArith = A * B}
${_mulShlArith / 18446744073709551616} => D
${_mulShlArith % 18446744073709551616} => E :ARITH
E => A
8589934592n => B
B :ASSERT
$ => B :MLOAD(SP - 1)
SP - 2 => SP
$ => C :MLOAD(SP - 1)
$ => D :MLOAD(SP - 2)
$ => E :MLOAD(SP - 3)
$ => B :MLOAD(SP - 4)
SP - 4 => SP
$ => RR :MLOAD(SP - 1)
SP - 1 => SP
:JMP(RR)
Expand Down
40 changes: 34 additions & 6 deletions cranelift/zkasm_data/spectest/i32/generated/shl_10.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,43 @@ start:
function_1:
SP + 1 => SP
RR :MSTORE(SP - 1)
SP + 2 => SP
B :MSTORE(SP - 1)
SP + 4 => SP
C :MSTORE(SP - 1)
D :MSTORE(SP - 2)
E :MSTORE(SP - 3)
B :MSTORE(SP - 4)
4294967296n => A
18446744069414584320n => B
$ => A :sllw
18446744069414584320n => E
A :MSTORE(SP)
0 => D
4294967296n => B
${E / B} => A
${E % B} => C
E:ARITH
A => E
32 => B
0 => D
${E / B} => A
${E % B} => C
E:ARITH
$ => A :MLOAD(SP)
C => E
;;NEED_INCLUDE: 2-exp
zkPC + 2 => RR
:JMP(@two_power + E)
0 => D
0 => C
$${var _mulShlArith = A * B}
${_mulShlArith / 18446744073709551616} => D
${_mulShlArith % 18446744073709551616} => E :ARITH
E => A
9223372036854775808n => B
B :ASSERT
$ => B :MLOAD(SP - 1)
SP - 2 => SP
$ => C :MLOAD(SP - 1)
$ => D :MLOAD(SP - 2)
$ => E :MLOAD(SP - 3)
$ => B :MLOAD(SP - 4)
SP - 4 => SP
$ => RR :MLOAD(SP - 1)
SP - 1 => SP
:JMP(RR)
Expand Down
40 changes: 34 additions & 6 deletions cranelift/zkasm_data/spectest/i32/generated/shl_11.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,43 @@ start:
function_1:
SP + 1 => SP
RR :MSTORE(SP - 1)
SP + 2 => SP
B :MSTORE(SP - 1)
SP + 4 => SP
C :MSTORE(SP - 1)
D :MSTORE(SP - 2)
E :MSTORE(SP - 3)
B :MSTORE(SP - 4)
4294967296n => A
9223372032559808512n => B
$ => A :sllw
9223372032559808512n => E
A :MSTORE(SP)
0 => D
4294967296n => B
${E / B} => A
${E % B} => C
E:ARITH
A => E
32 => B
0 => D
${E / B} => A
${E % B} => C
E:ARITH
$ => A :MLOAD(SP)
C => E
;;NEED_INCLUDE: 2-exp
zkPC + 2 => RR
:JMP(@two_power + E)
0 => D
0 => C
$${var _mulShlArith = A * B}
${_mulShlArith / 18446744073709551616} => D
${_mulShlArith % 18446744073709551616} => E :ARITH
E => A
9223372036854775808n => B
B :ASSERT
$ => B :MLOAD(SP - 1)
SP - 2 => SP
$ => C :MLOAD(SP - 1)
$ => D :MLOAD(SP - 2)
$ => E :MLOAD(SP - 3)
$ => B :MLOAD(SP - 4)
SP - 4 => SP
$ => RR :MLOAD(SP - 1)
SP - 1 => SP
:JMP(RR)
Expand Down
40 changes: 34 additions & 6 deletions cranelift/zkasm_data/spectest/i32/generated/shl_2.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,43 @@ start:
function_1:
SP + 1 => SP
RR :MSTORE(SP - 1)
SP + 2 => SP
B :MSTORE(SP - 1)
SP + 4 => SP
C :MSTORE(SP - 1)
D :MSTORE(SP - 2)
E :MSTORE(SP - 3)
B :MSTORE(SP - 4)
4294967296n => A
0n => B
$ => A :sllw
0n => E
A :MSTORE(SP)
0 => D
4294967296n => B
${E / B} => A
${E % B} => C
E:ARITH
A => E
32 => B
0 => D
${E / B} => A
${E % B} => C
E:ARITH
$ => A :MLOAD(SP)
C => E
;;NEED_INCLUDE: 2-exp
zkPC + 2 => RR
:JMP(@two_power + E)
0 => D
0 => C
$${var _mulShlArith = A * B}
${_mulShlArith / 18446744073709551616} => D
${_mulShlArith % 18446744073709551616} => E :ARITH
E => A
4294967296n => B
B :ASSERT
$ => B :MLOAD(SP - 1)
SP - 2 => SP
$ => C :MLOAD(SP - 1)
$ => D :MLOAD(SP - 2)
$ => E :MLOAD(SP - 3)
$ => B :MLOAD(SP - 4)
SP - 4 => SP
$ => RR :MLOAD(SP - 1)
SP - 1 => SP
:JMP(RR)
Expand Down
Loading

0 comments on commit cc91dd0

Please sign in to comment.