Skip to content

Commit

Permalink
Hid Wire's members. All uses changed to use setters/getters.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Aug 4, 2023
1 parent 554fbb1 commit 9eff652
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 213 deletions.
16 changes: 9 additions & 7 deletions circuit/stream_garble.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewStreaming(key []byte, inputs []Wire, conn *p2p.Conn) (
r: r,
}

stream.ensureWires(inputs)
stream.ensureWires(maxWire(0, inputs))

// Assing all input wires.
for i := 0; i < len(inputs); i++ {
Expand All @@ -72,17 +72,20 @@ func NewStreaming(key []byte, inputs []Wire, conn *p2p.Conn) (
return stream, nil
}

func (stream *Streaming) ensureWires(wires []Wire) {
// Verify that wires is big enough.
var max Wire
func maxWire(max Wire, wires []Wire) Wire {
for _, w := range wires {
if w > max {
max = w
}
}
return max
}

func (stream *Streaming) ensureWires(max Wire) {
// Verify that wires is big enough.
if len(stream.wires) <= int(max) {
var i int
for i = 1024; i <= int(max); i <<= 1 {
for i = 65536; i <= int(max); i <<= 1 {
}
n := make([]ot.Wire, i)
copy(n, stream.wires)
Expand All @@ -91,8 +94,7 @@ func (stream *Streaming) ensureWires(wires []Wire) {
}

func (stream *Streaming) initCircuit(c *Circuit, in, out []Wire) {
stream.ensureWires(in)
stream.ensureWires(out)
stream.ensureWires(maxWire(maxWire(0, in), out))

if len(stream.tmp) < c.NumWires {
stream.tmp = make([]ot.Wire, c.NumWires)
Expand Down
4 changes: 2 additions & 2 deletions compiler/circuits/circuits_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//
// circuits_test.go
//
// Copyright (c) 2019 Markku Rossi
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
Expand Down Expand Up @@ -29,7 +29,7 @@ func makeWires(count int, output bool) []*Wire {
var result []*Wire
for i := 0; i < count; i++ {
w := NewWire()
w.Output = output
w.SetOutput(output)
result = append(result, w)
}
return result
Expand Down
190 changes: 44 additions & 146 deletions compiler/circuits/compiler.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright (c) 2019-2022 Markku Rossi
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
Expand All @@ -8,12 +8,10 @@ package circuits

import (
"fmt"
"math"
"time"

"github.com/markkurossi/mpc/circuit"
"github.com/markkurossi/mpc/compiler/utils"
"github.com/markkurossi/mpc/types"
)

// Builtin implements a buitin circuit that uses input wires a and b
Expand Down Expand Up @@ -72,7 +70,7 @@ func (c *Compiler) ZeroWire() *Wire {
c.zeroWire = NewWire()
c.AddGate(NewBinary(circuit.AND, c.InputWires[0], c.InvI0Wire(),
c.zeroWire))
c.zeroWire.Value = Zero
c.zeroWire.SetValue(Zero)
}
return c.zeroWire
}
Expand All @@ -83,7 +81,7 @@ func (c *Compiler) OneWire() *Wire {
c.oneWire = NewWire()
c.AddGate(NewBinary(circuit.OR, c.InputWires[0], c.InvI0Wire(),
c.oneWire))
c.oneWire.Value = One
c.oneWire.SetValue(One)
}
return c.oneWire
}
Expand Down Expand Up @@ -178,94 +176,94 @@ func (c *Compiler) ConstPropagate() {
for _, g := range c.Gates {
switch g.Op {
case circuit.XOR:
if (g.A.Value == Zero && g.B.Value == Zero) ||
(g.A.Value == One && g.B.Value == One) {
g.O.Value = Zero
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
(g.A.Value() == One && g.B.Value() == One) {
g.O.SetValue(Zero)
stats[g.Op]++
} else if (g.A.Value == Zero && g.B.Value == One) ||
(g.A.Value == One && g.B.Value == Zero) {
g.O.Value = One
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
(g.A.Value() == One && g.B.Value() == Zero) {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value == Zero {
} else if g.A.Value() == Zero {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value == Zero {
} else if g.B.Value() == Zero {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}

case circuit.XNOR:
if (g.A.Value == Zero && g.B.Value == Zero) ||
(g.A.Value == One && g.B.Value == One) {
g.O.Value = One
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
(g.A.Value() == One && g.B.Value() == One) {
g.O.SetValue(One)
stats[g.Op]++
} else if (g.A.Value == Zero && g.B.Value == One) ||
(g.A.Value == One && g.B.Value == Zero) {
g.O.Value = Zero
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
(g.A.Value() == One && g.B.Value() == Zero) {
g.O.SetValue(Zero)
stats[g.Op]++
}

case circuit.AND:
if g.A.Value == Zero || g.B.Value == Zero {
g.O.Value = Zero
if g.A.Value() == Zero || g.B.Value() == Zero {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value == One && g.B.Value == One {
g.O.Value = One
} else if g.A.Value() == One && g.B.Value() == One {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value == One {
} else if g.A.Value() == One {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value == One {
} else if g.B.Value() == One {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}

case circuit.OR:
if g.A.Value == One || g.B.Value == One {
g.O.Value = One
if g.A.Value() == One || g.B.Value() == One {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value == Zero && g.B.Value == Zero {
g.O.Value = Zero
} else if g.A.Value() == Zero && g.B.Value() == Zero {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value == Zero {
} else if g.A.Value() == Zero {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value == Zero {
} else if g.B.Value() == Zero {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}

case circuit.INV:
if g.A.Value == One {
g.O.Value = Zero
if g.A.Value() == One {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value == Zero {
g.O.Value = One
} else if g.A.Value() == Zero {
g.O.SetValue(One)
stats[g.Op]++
}
}

if g.A.Value == Zero {
if g.A.Value() == Zero {
g.A.RemoveOutput(g)
g.A = c.ZeroWire()
g.A.AddOutput(g)
} else if g.A.Value == One {
} else if g.A.Value() == One {
g.A.RemoveOutput(g)
g.A = c.OneWire()
g.A.AddOutput(g)
}
if g.B != nil {
if g.B.Value == Zero {
if g.B.Value() == Zero {
g.B.RemoveOutput(g)
g.B = c.ZeroWire()
g.B.AddOutput(g)
} else if g.B.Value == One {
} else if g.B.Value() == One {
g.B.RemoveOutput(g)
g.B = c.OneWire()
g.B.AddOutput(g)
Expand Down Expand Up @@ -293,20 +291,20 @@ func (c *Compiler) ShortCircuitXORZero() {
if g.Op != circuit.XOR {
continue
}
if g.A.Value == Zero && !g.B.IsInput() &&
len(g.B.Input.O.Outputs) == 1 {
if g.A.Value() == Zero && !g.B.IsInput() &&
g.B.Input().O.NumOutputs() == 1 {

g.B.Input.ResetOutput(g.O)
g.B.Input().ResetOutput(g.O)

// Disconnect gate's output wire.
g.O = NewWire()

stats[g.Op]++
}
if g.B.Value == Zero && !g.A.IsInput() &&
len(g.A.Input.O.Outputs) == 1 {
if g.B.Value() == Zero && !g.A.IsInput() &&
g.A.Input().O.NumOutputs() == 1 {

g.A.Input.ResetOutput(g.O)
g.A.Input().ResetOutput(g.O)

// Disconnect gate's output wire.
g.O = NewWire()
Expand Down Expand Up @@ -372,7 +370,7 @@ func (c *Compiler) Compile() *circuit.Circuit {
panic("Output already assigned")
}
} else {
w.ID = c.NextWireID()
w.SetID(c.NextWireID())
}
}

Expand All @@ -397,103 +395,3 @@ func (c *Compiler) Compile() *circuit.Circuit {

return result
}

const (
// UnassignedID identifies an unassigned wire ID.
UnassignedID uint32 = math.MaxUint32
)

// Wire implements a wire connecting binary gates.
type Wire struct {
Output bool
Value WireValue
ID uint32
NumOutputs uint32
Input *Gate
Outputs []*Gate
}

// WireValue defines wire values.
type WireValue uint8

// Possible wire values.
const (
Unknown WireValue = iota
Zero
One
)

func (v WireValue) String() string {
switch v {
case Zero:
return "0"
case One:
return "1"
default:
return "?"
}
}

// Assigned tests if the wire is assigned with an unique ID.
func (w *Wire) Assigned() bool {
return w.ID != UnassignedID
}

// NewWire creates an unassigned wire.
func NewWire() *Wire {
return &Wire{
ID: UnassignedID,
Outputs: make([]*Gate, 0, 1),
}
}

// MakeWires creates bits number of wires.
func MakeWires(bits types.Size) []*Wire {
result := make([]*Wire, bits)
for i := 0; i < int(bits); i++ {
result[i] = NewWire()
}
return result
}

func (w *Wire) String() string {
return fmt.Sprintf("Wire{%x, Input:%s, Value:%s, Outputs:%v, Output=%v}",
w.ID, w.Input, w.Value, w.Outputs, w.Output)
}

// IsInput tests if the wire is an input wire.
func (w *Wire) IsInput() bool {
return w.Input == nil
}

// Assign assings wire ID.
func (w *Wire) Assign(c *Compiler) {
if w.Output {
return
}
if !w.Assigned() {
w.ID = c.NextWireID()
}
for _, output := range w.Outputs {
output.Visit(c)
}
}

// SetInput sets the wire's input gate.
func (w *Wire) SetInput(gate *Gate) {
if w.Input != nil {
panic("Input gate already set")
}
w.Input = gate
}

// AddOutput adds gate to the wire's output gates.
func (w *Wire) AddOutput(gate *Gate) {
w.Outputs = append(w.Outputs, gate)
w.NumOutputs++
}

// RemoveOutput removes gate from the wire's output gates.
func (w *Wire) RemoveOutput(gate *Gate) {
w.NumOutputs--
}
Loading

0 comments on commit 9eff652

Please sign in to comment.