From f825a60aab3dc3e88a27d6cc1057e4ac8859d89b Mon Sep 17 00:00:00 2001 From: Erik van den Brink Date: Tue, 19 Apr 2022 15:31:35 +0200 Subject: [PATCH] Add GetTransactionSigners https://github.com/neo-project/neo/pull/2685 --- neo3/contracts/native/ledger.py | 8 +++ neo3/network/payloads/verification.py | 81 +++++++++++++++++++++++++-- neo3/network/protocol.py | 2 +- tests/network/test_payloads.py | 1 + tests/network/test_witnessrules.py | 10 ++++ 5 files changed, 95 insertions(+), 7 deletions(-) diff --git a/neo3/contracts/native/ledger.py b/neo3/contracts/native/ledger.py index 9441fffa..a7d5867a 100644 --- a/neo3/contracts/native/ledger.py +++ b/neo3/contracts/native/ledger.py @@ -58,6 +58,14 @@ def get_tx_for_contract(self, snapshot: storage.Snapshot, hash_: types.UInt256) return None return tx + @register("getTransactionSigners", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) + def get_tx_signers(self, snapshot: storage.Snapshot, hash_: types.UInt256) -> Optional[payloads.Signer]: + tx = snapshot.transactions.try_get(hash_) + if tx is None: + return None + else: + return tx.signers + @register("getTransactionVMState", contracts.CallFlags.READ_STATES, cpu_price=1 << 15) def get_tx_vmstate(self, snapshot: storage.Snapshot, hash_: types.UInt256) -> vm.VMState: tx = snapshot.transactions.try_get(hash_, read_only=True) diff --git a/neo3/network/payloads/verification.py b/neo3/network/payloads/verification.py index cbc0964d..df894d82 100644 --- a/neo3/network/payloads/verification.py +++ b/neo3/network/payloads/verification.py @@ -3,13 +3,13 @@ import abc import base64 from enum import IntFlag, IntEnum -from neo3.core import serialization, utils, types, cryptography, Size as s, IJson +from neo3.core import serialization, utils, types, cryptography, Size as s, IJson, IInteroperable from neo3.network import payloads -from neo3 import storage, contracts -from typing import List, Dict, Any, no_type_check, Iterator +from neo3 import storage, contracts, vm +from typing import List, Dict, Any, no_type_check, Iterator, cast -class Signer(serialization.ISerializable, IJson): +class Signer(serialization.ISerializable, IJson, IInteroperable): """ A class that specifies who can pass CheckWitness() verifications in a smart contract. """ @@ -137,6 +137,26 @@ def to_json(self) -> dict: return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + items: List[vm.StackItem] = [] + items.append(vm.ByteStringStackItem(self.to_array())) + items.append(vm.ByteStringStackItem(self.account.to_array())) + items.append(vm.IntegerStackItem(self.scope.value)) + + contracts_ = vm.ArrayStackItem(reference_counter, + list(map(lambda c: vm.ByteStringStackItem(c.to_array()), self.allowed_contracts)) + ) + items.append(contracts_) + + groups_ = vm.ArrayStackItem(reference_counter, + list(map(lambda g: vm.ByteStringStackItem(g.to_array()), self.allowed_groups))) + items.append(groups_) + + rules_ = vm.ArrayStackItem(reference_counter, + list(map(lambda r: r.to_stack_item(reference_counter), self.rules))) + items.append(rules_) + return vm.ArrayStackItem(reference_counter, items) + @classmethod def from_json(cls, json: dict): """ Create object from JSON """ @@ -306,7 +326,7 @@ def to_csharp_string(self) -> str: return self.name.title() -class WitnessCondition(serialization.ISerializable, IJson): +class WitnessCondition(serialization.ISerializable, IJson, IInteroperable): MAX_SUB_ITEMS = 16 MAX_NESTING_DEPTH = 2 @@ -364,6 +384,11 @@ def to_json(self) -> dict: def from_json(cls, json: dict): raise NotImplementedError() + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + arr = vm.ArrayStackItem(reference_counter) + arr.append(vm.IntegerStackItem(self.type.value)) + return arr + class ConditionAnd(WitnessCondition): _type = WitnessConditionType.AND @@ -397,6 +422,14 @@ def to_json(self) -> dict: json['expressions'] = list(map(lambda exp: exp.to_json(), self.expressions)) return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionAnd, self).to_stack_item(reference_counter) + expressions = list(map(lambda exp: exp.to_stack_item(reference_counter), self.expressions)) + array = vm.ArrayStackItem(reference_counter, expressions) + base = cast(vm.ArrayStackItem, base) + base.append(array) + return base + @classmethod def _serializable_init(cls): return cls([]) @@ -430,6 +463,12 @@ def to_json(self) -> dict: json['expression'] = self.value return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionBool, self).to_stack_item(reference_counter) + base = cast(vm.ArrayStackItem, base) + base.append(vm.BooleanStackItem(self.value)) + return base + @classmethod def _serializable_init(cls): return cls(False) @@ -465,6 +504,12 @@ def to_json(self) -> dict: json["expression"] = self.expression.to_json() return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionNot, self).to_stack_item(reference_counter) + base = cast(vm.ArrayStackItem, base) + base.append(self.expression.to_stack_item(reference_counter)) + return base + @classmethod def _serializable_init(cls): return cls(ConditionBool(False)) @@ -502,6 +547,14 @@ def to_json(self) -> dict: json['expressions'] = list(map(lambda exp: exp.to_json(), self.expressions)) return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionOr, self).to_stack_item(reference_counter) + expressions = list(map(lambda exp: exp.to_stack_item(reference_counter), self.expressions)) + array = vm.ArrayStackItem(reference_counter, expressions) + base = cast(vm.ArrayStackItem, base) + base.append(array) + return base + @classmethod def _serializable_init(cls): return cls([]) @@ -530,6 +583,12 @@ def to_json(self) -> dict: json["hash"] = f"0x{self.hash_}" return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionCalledByContract, self).to_stack_item(reference_counter) + base = cast(vm.ArrayStackItem, base) + base.append(vm.ByteStringStackItem(self.hash_.to_array())) + return base + @classmethod def _serializable_init(cls): return cls(types.UInt160.zero()) @@ -585,6 +644,12 @@ def to_json(self) -> dict: json["group"] = str(self.group) return json + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + base = super(ConditionCalledByGroup, self).to_stack_item(reference_counter) + base = cast(vm.ArrayStackItem, base) + base.append(vm.ByteStringStackItem(self.group.to_array())) + return base + @classmethod def _serializable_init(cls): return cls(cryptography.ECPoint.deserialize_from_bytes(b'\x00')) @@ -606,7 +671,7 @@ def match(self, engine: contracts.ApplicationEngine) -> bool: return engine.current_scripthash == self.hash_ -class WitnessRule(serialization.ISerializable, IJson): +class WitnessRule(serialization.ISerializable, IJson, IInteroperable): def __init__(self, action: WitnessRuleAction, condition: WitnessCondition): self.action = action self.condition = condition @@ -628,6 +693,10 @@ def to_json(self) -> dict: 'condition': self.condition.to_json() } + def to_stack_item(self, reference_counter: vm.ReferenceCounter) -> vm.StackItem: + return vm.ArrayStackItem(reference_counter, [vm.IntegerStackItem(self.action.value), + self.condition.to_stack_item(reference_counter)]) + @classmethod def from_json(cls, json: dict): raise NotImplementedError() diff --git a/neo3/network/protocol.py b/neo3/network/protocol.py index a0248546..fff29ea3 100644 --- a/neo3/network/protocol.py +++ b/neo3/network/protocol.py @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs): def connection_made(self, transport: asyncio.transports.BaseTransport) -> None: super().connection_made(transport) - self._stream_writer = StreamWriter(transport, self, self._stream_reader_orig, self._loop) + self._stream_writer = StreamWriter(transport, self, self._stream_reader_orig, self._loop) # type: ignore if self.client: asyncio.create_task(self.client.connection_made(transport)) diff --git a/tests/network/test_payloads.py b/tests/network/test_payloads.py index 7847b900..e87c75c9 100644 --- a/tests/network/test_payloads.py +++ b/tests/network/test_payloads.py @@ -262,6 +262,7 @@ def setUpClass(cls) -> None: ECPoint p = ECPoint.Parse("026241e7e26b38bb7154b8ad49458b97fb1c4797443dc921c5ca5774f511a2bbfc", ECCurve.Secp256r1); co.AllowedGroups = new ECPoint[] { p }; + co.Rules = new WitnessRule[0]; Console.WriteLine($"{co.Size}"); Console.WriteLine($"{BitConverter.ToString(co.ToArray()).Replace("-","")}"); diff --git a/tests/network/test_witnessrules.py b/tests/network/test_witnessrules.py index ada6131a..c081248b 100644 --- a/tests/network/test_witnessrules.py +++ b/tests/network/test_witnessrules.py @@ -3,6 +3,7 @@ from neo3.network import payloads from neo3.core import types, cryptography + class WitnessRuleTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -222,6 +223,15 @@ def test_by_entry(self): self.assertEqual(c, deserialized_c) def test_called_by_group(self): + """ + var c = new CalledByGroupCondition() + { + Group = ECPoint.Parse("02158c4a4810fa2a6a12f7d33d835680429e1a68ae61161c5b3fbc98c7f1f17765", ECCurve.Secp256r1) + }; + Console.WriteLine(((ISerializable)c).Size); + Console.WriteLine(c.ToArray().ToHexString()); + Console.WriteLine(c.ToJson()); + """ expected_len = 34 expected_data = bytes.fromhex("2902158c4a4810fa2a6a12f7d33d835680429e1a68ae61161c5b3fbc98c7f1f17765") expected_json = {"type":"CalledByGroup","group":"02158c4a4810fa2a6a12f7d33d835680429e1a68ae61161c5b3fbc98c7f1f17765"}