Skip to content

Commit

Permalink
Merge pull request #54 from ZKStats/fix/wrong-return-value-when-more-…
Browse files Browse the repository at this point in the history
…than-one-outputs

fix: outputs order is wrong when more than 1 outputs
  • Loading branch information
mhchia authored Jun 21, 2024
2 parents dd9c57e + 13bcab6 commit f0a28ea
Showing 1 changed file with 22 additions and 32 deletions.
54 changes: 22 additions & 32 deletions zkstats/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item()]
self.op_dict[op_class_str] = 1
else:
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Median):
Expand All @@ -177,7 +177,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
if op_class_str not in self.op_dict:
self.precal_witness[op_class_str+"_0"] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str] = 1
else:
else:
self.precal_witness[op_class_str+"_"+str(self.op_dict[op_class_str])] = [op.result.data.item(), op.data_mean.data.item()]
self.op_dict[op_class_str]+=1
elif isinstance(op, Covariance):
Expand Down Expand Up @@ -229,7 +229,7 @@ def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[tor
current_op_index = self.current_op_index
# Sanity check that current op index is not out of bound
len_ops = len(self.ops)
if current_op_index >= len(self.ops):
if current_op_index >= len_ops:
raise Exception(f"current_op_index out of bound: {current_op_index=} >= {len_ops=}")

op = self.ops[current_op_index]
Expand All @@ -245,28 +245,12 @@ def is_precise() -> IsResultPrecise:
return op.ezkl(x)
self.bools.append(is_precise)

# If this is the last operation, aggregate all `is_precise` in `self.bools`, and return (is_precise_aggregated, result)
# else, return only result

if current_op_index == len_ops - 1:
# print('final op: ', op)
# Sanity check for length of self.ops and self.bools
len_bools = len(self.bools)
if len_ops != len_bools:
raise Exception(f"length mismatch: {len_ops=} != {len_bools=}")
is_precise_aggregated = torch.tensor(1.0)
for i in range(len_bools):
res = self.bools[i]()
is_precise_aggregated = torch.logical_and(is_precise_aggregated, res)
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return is_precise_aggregated, op.result+(x[0]-x[0])[0][0]

elif current_op_index > len_ops - 1:
if current_op_index > len_ops - 1:
# Sanity check that current op index does not exceed the length of ops
raise Exception(f"current_op_index out of bound: {current_op_index=} > {len_ops=}")
else:
return op.result+(x[0]-x[0])[0][0]
if self.isProver:
json.dump(self.precal_witness, open(self.precal_witness_path, 'w'))
return op.result+(x[0]-x[0])[0][0]


class IModel(nn.Module):
Expand Down Expand Up @@ -296,23 +280,29 @@ def computation_to_model(computation: TComputation, precal_witness_path:str, isP
State is a container for intermediate results of computation, which can be useful when debugging.
"""
state = State(error)
# if it's verifier


state.precal_witness_path= precal_witness_path
state.isProver = isProver

class Model(IModel):
def preprocess(self, x: list[torch.Tensor]) -> None:
"""
Calculate the witnesses of the computation and store them in the state.
"""
# In the preprocess step, the operations are calculated and the results are stored in the state.
# So we don't need to get the returned result
computation(state, x)
state.set_ready_for_exporting_onnx()

def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
# print('x sy: ')
result = computation(state, x)
if len(result) ==1:
return (x[0]-x[0])[0][0]+torch.tensor(1.0), result
else:
return result
# print('state:: ', state.aggregate_witness_path)
"""
Called by torch.onnx.export.
"""
result = computation(state, x)
is_computation_result_accurate = state.bools[0]()
for op_precise_check in state.bools[1:]:
is_op_result_accurate = op_precise_check()
is_computation_result_accurate = torch.logical_and(is_computation_result_accurate, is_op_result_accurate)
return is_computation_result_accurate, result
return state, Model

0 comments on commit f0a28ea

Please sign in to comment.