Skip to content

Commit

Permalink
Create transports with respect to initiator
Browse files Browse the repository at this point in the history
  • Loading branch information
i-hate-nicknames committed Jun 17, 2021
1 parent 8d71486 commit e23cce9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
12 changes: 8 additions & 4 deletions pkg/transport/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ const (
responseInvalidEntry
)

func makeEntryFromTpConn(conn *snet.Conn) Entry {
return MakeEntry(conn.LocalPK(), conn.RemotePK(), conn.Network(), true, LabelUser)
func makeEntryFromTpConn(conn *snet.Conn, isInitiator bool) Entry {
initiator, target := conn.LocalPK(), conn.RemotePK()
if !isInitiator {
initiator, target = target, initiator
}
return MakeEntry(initiator, target, conn.Network(), true, LabelUser)
}

func compareEntries(expected, received *Entry) error {
Expand Down Expand Up @@ -95,7 +99,7 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Co
func MakeSettlementHS(init bool) SettlementHS {
// initiating logic.
initHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, true)

// TODO(evanlinjin): Probably not needed as this is called in mTp already. Need to double check.
//defer func() {
Expand Down Expand Up @@ -135,7 +139,7 @@ func MakeSettlementHS(init bool) SettlementHS {

// responding logic.
respHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, false)

// receive, verify and sign entry.
recvSE, err := receiveAndVerifyEntry(conn, &entry, conn.RemotePK())
Expand Down
8 changes: 6 additions & 2 deletions pkg/transport/managed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,19 @@ type ManagedTransport struct {
}

// NewManagedTransport creates a new ManagedTransport.
func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport {
func NewManagedTransport(conf ManagedTransportConfig, isInitiator bool) *ManagedTransport {
initiator, target := conf.Net.LocalPK(), conf.RemotePK
if !isInitiator {
initiator, target = target, initiator
}
mt := &ManagedTransport{
log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])),
rPK: conf.RemotePK,
netName: conf.NetName,
n: conf.Net,
dc: conf.DC,
ls: conf.LS,
Entry: MakeEntry(conf.Net.LocalPK(), conf.RemotePK, conf.NetName, true, conf.TransportLabel),
Entry: MakeEntry(initiator, target, conf.NetName, true, conf.TransportLabel),
LogEntry: new(LogEntry),
connCh: make(chan struct{}, 1),
done: make(chan struct{}),
Expand Down
11 changes: 6 additions & 5 deletions pkg/transport/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ func (tm *Manager) initTransports(ctx context.Context) {
remote = entry.Entry.RemoteEdge(tm.Conf.PubKey)
tpID = entry.Entry.ID
)
if _, err := tm.saveTransport(remote, tpType, entry.Entry.Label); err != nil {
isInitiator := tm.n.LocalPK() == entry.Entry.Edges[0]
if _, err := tm.saveTransport(remote, isInitiator, tpType, entry.Entry.Label); err != nil {
tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID)
} else {
tm.Logger.Debugf("Successfully initialized TP %v", *entry.Entry)
Expand Down Expand Up @@ -230,7 +231,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro
NetName: lis.Network(),
AfterClosed: tm.afterTPClosed,
TransportLabel: LabelUser,
})
}, false)

go func() {
mTp.Serve(tm.readCh)
Expand Down Expand Up @@ -303,7 +304,7 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy
}

for {
mTp, err := tm.saveTransport(remote, tpType, label)
mTp, err := tm.saveTransport(remote, true, tpType, label)
if err != nil {
return nil, fmt.Errorf("save transport: %w", err)
}
Expand Down Expand Up @@ -346,7 +347,7 @@ func isSTCPTableError(remotePK cipher.PubKey, err error) bool {
return err.Error() == fmt.Sprintf("pk table: entry of %s does not exist", remotePK.String())
}

func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Label) (*ManagedTransport, error) {
func (tm *Manager) saveTransport(remote cipher.PubKey, initiator bool, netName string, label Label) (*ManagedTransport, error) {
tm.mx.Lock()
defer tm.mx.Unlock()
if !snet.IsKnownNetwork(netName) {
Expand All @@ -372,7 +373,7 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Lab
NetName: netName,
AfterClosed: afterTPClosed,
TransportLabel: label,
})
}, initiator)

if mTp.netName == tptypes.STCPR {
ar := mTp.n.Conf().ARClient
Expand Down

0 comments on commit e23cce9

Please sign in to comment.