diff --git a/.travis.yml b/.travis.yml index 4b3e1a2433..047629ba81 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ before_install: install: - go get -u github.com/FiloSottile/vendorcheck - - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $GOPATH/bin v1.22.2 + - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $GOPATH/bin v1.27.0 before_script: - ci_scripts/create-ip-aliases.sh diff --git a/cmd/apps/helloworld/helloworld.go b/cmd/apps/helloworld/helloworld.go index a65957e9e5..ab6bce1463 100644 --- a/cmd/apps/helloworld/helloworld.go +++ b/cmd/apps/helloworld/helloworld.go @@ -4,84 +4,165 @@ simple client server app for skywire visor testing package main import ( - "os" + "flag" + "fmt" + "net" + "time" "github.com/SkycoinProject/dmsg/cipher" "github.com/sirupsen/logrus" "github.com/SkycoinProject/skywire-mainnet/pkg/app" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" "github.com/SkycoinProject/skywire-mainnet/pkg/util/buildinfo" ) const ( - netType = appnet.TypeSkynet + modeServer = "server" + modeClient = "client" +) + +var ( + mode = flag.String("mode", modeServer, fmt.Sprintf("mode of operation: %v", []string{modeServer, modeClient})) + network = flag.String("net", string(appnet.TypeSkynet), fmt.Sprintf("network: %v", []appnet.Type{appnet.TypeSkynet, appnet.TypeDmsg})) + remote = flag.String("remote", "", "remote public key to dial to (client mode only)") + port = flag.Uint("port", 1024, "port to either dial to (client mode), or listen from (server mode)") ) var log = logrus.New() func main() { - appC := app.NewClient() + flag.Parse() + + subs := prepareSubscriptions() + appC := app.NewClient(subs) defer appC.Close() if _, err := buildinfo.Get().WriteTo(log.Writer()); err != nil { - log.Printf("Failed to output build info: %v", err) + log.WithError(err).Info("Failed to output build info.") } - if len(os.Args) == 1 { - port := routing.Port(1024) - l, err := appC.Listen(netType, port) - if err != nil { - log.Fatalf("Error listening network %v on port %d: %v\n", netType, port, err) - } - - log.Println("listening for incoming connections") - for { - conn, err := l.Accept() - if err != nil { - log.Fatalf("Failed to accept conn: %v\n", err) - } - - log.Printf("got new connection from: %v\n", conn.RemoteAddr()) - go func() { - buf := make([]byte, 4) - if _, err := conn.Read(buf); err != nil { - log.Printf("Failed to read remote data: %v\n", err) - // TODO: close conn - } - - log.Printf("Message from %s: %s\n", conn.RemoteAddr().String(), string(buf)) - if _, err := conn.Write([]byte("pong")); err != nil { - log.Printf("Failed to write to a remote visor: %v\n", err) - // TODO: close conn - } - }() - } + switch *mode { + case modeServer: + runServer(appC) + case modeClient: + runClient(appC) + default: + log.WithField("mode", *mode).Fatal("Invalid mode.") } +} - remotePK := cipher.PubKey{} - if err := remotePK.UnmarshalText([]byte(os.Args[1])); err != nil { - log.Fatal("Failed to construct PubKey: ", err, os.Args[1]) - } +func prepareSubscriptions() *appevent.Subscriber { + subs := appevent.NewSubscriber() + + subs.OnTCPDial(func(data appevent.TCPDialData) { + log.WithField("event_type", data.Type()). + WithField("event_data", data). + Info("Received event.") + }) - conn, err := appC.Dial(appnet.Addr{ - Net: netType, - PubKey: remotePK, - Port: 10, + subs.OnTCPClose(func(data appevent.TCPCloseData) { + log.WithField("event_type", data.Type()). + WithField("event_data", data). + Info("Received event.") }) + + return subs +} + +func runServer(appC *app.Client) { + log := log. + WithField("network", *network). + WithField("port", *port) + + lis, err := appC.Listen(appnet.Type(*network), routing.Port(*port)) if err != nil { - log.Fatalf("Failed to open remote conn: %v\n", err) + log.WithError(err).Fatal("Failed to listen.") } + log.Info("Listening for incoming connections.") + + for { + conn, err := lis.Accept() + if err != nil { + log.WithError(err).Fatal("Failed to accept connection.") + } + go handleServerConn(log, conn) + } +} - if _, err := conn.Write([]byte("ping")); err != nil { - log.Fatalf("Failed to write to a remote visor: %v\n", err) +func handleServerConn(log logrus.FieldLogger, conn net.Conn) { + log = log.WithField("remote_addr", conn.RemoteAddr()) + log.Info("Serving connection.") + defer func() { + log.WithError(conn.Close()).Debug("Closed connection.") + }() + + for { + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + log.WithField("n", n).WithError(err). + Error("Failed to read from connection.") + return + } + msg := string(buf[:n]) + log.WithField("n", n).WithField("data", msg).Info("Read from connection.") + + n, err = conn.Write([]byte(fmt.Sprintf("I've got your message: %s", msg))) + if err != nil { + log.WithField("n", n).WithError(err). + Error("Failed to write to connection.") + return + } + log.WithField("n", n).Info("Wrote response message.") } +} - buf := make([]byte, 4) - if _, err = conn.Read(buf); err != nil { - log.Fatalf("Failed to read remote data: %v\n", err) +func runClient(appC *app.Client) { + var remotePK cipher.PubKey + if err := remotePK.UnmarshalText([]byte(*remote)); err != nil { + log.WithError(err).Fatal("Invalid remote public key.") } - log.Printf("Message from %s: %s", conn.RemoteAddr().String(), string(buf)) + var conn net.Conn + + for i := 0; true; i++ { + time.Sleep(time.Second * 2) + + if conn != nil { + log.WithError(conn.Close()).Debug("Connection closed.") + conn = nil + } + + var err error + conn, err = appC.Dial(appnet.Addr{ + Net: appnet.Type(*network), + PubKey: remotePK, + Port: routing.Port(*port), + }) + if err != nil { + log.WithError(err).Error("Failed to dial.") + time.Sleep(time.Second) + continue + } + + n, err := conn.Write([]byte(fmt.Sprintf("Hello world! %d", i))) + if err != nil { + log.WithField("n", n).WithError(err). + Error("Failed to write to connection.") + continue + } + + buf := make([]byte, 1024) + n, err = conn.Read(buf) + if err != nil { + log.WithField("n", n).WithError(err). + Error("Failed to read from connection.") + continue + } + msg := string(buf[:n]) + log.WithField("n", n).WithField("data", msg).Info("Read reply from connection.") + } } diff --git a/cmd/apps/skychat/chat.go b/cmd/apps/skychat/chat.go index adc1281b2d..ea601798f9 100644 --- a/cmd/apps/skychat/chat.go +++ b/cmd/apps/skychat/chat.go @@ -42,7 +42,7 @@ var ( ) func main() { - appC = app.NewClient() + appC = app.NewClient(nil) defer appC.Close() if _, err := buildinfo.Get().WriteTo(os.Stdout); err != nil { diff --git a/cmd/apps/skysocks-client/skysocks-client.go b/cmd/apps/skysocks-client/skysocks-client.go index 0cfa709354..bd7d1594ad 100644 --- a/cmd/apps/skysocks-client/skysocks-client.go +++ b/cmd/apps/skysocks-client/skysocks-client.go @@ -49,7 +49,7 @@ func dialServer(appCl *app.Client, pk cipher.PubKey) (net.Conn, error) { } func main() { - appC := app.NewClient() + appC := app.NewClient(nil) defer appC.Close() skysocks.Log = log diff --git a/cmd/apps/skysocks/skysocks.go b/cmd/apps/skysocks/skysocks.go index c52af208df..1e1914da9d 100644 --- a/cmd/apps/skysocks/skysocks.go +++ b/cmd/apps/skysocks/skysocks.go @@ -25,7 +25,7 @@ const ( var log = logrus.New() func main() { - appC := app.NewClient() + appC := app.NewClient(nil) defer appC.Close() skysocks.Log = log diff --git a/cmd/apps/vpn-client/vpn-client.go b/cmd/apps/vpn-client/vpn-client.go index c764ae029d..1b3a8ba501 100644 --- a/cmd/apps/vpn-client/vpn-client.go +++ b/cmd/apps/vpn-client/vpn-client.go @@ -81,7 +81,7 @@ func main() { noiseCreds = vpn.NewNoiseCredentials(localSK, localPK) } - appClient := app.NewClient() + appClient := app.NewClient(nil) defer appClient.Close() log.Infof("Connecting to VPN server %s", serverPK.String()) diff --git a/cmd/apps/vpn-server/vpn-server.go b/cmd/apps/vpn-server/vpn-server.go index 0f7edb89b6..58e3dba5ab 100644 --- a/cmd/apps/vpn-server/vpn-server.go +++ b/cmd/apps/vpn-server/vpn-server.go @@ -59,7 +59,7 @@ func main() { noiseCreds = vpn.NewNoiseCredentials(localSK, localPK) } - appClient := app.NewClient() + appClient := app.NewClient(nil) defer appClient.Close() osSigs := make(chan os.Signal, 2) diff --git a/go.mod b/go.mod index 6cbc1a3ff8..22eab0d52e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/SkycoinProject/skywire-mainnet go 1.13 require ( - github.com/SkycoinProject/dmsg v0.1.1-0.20200420091742-8c1a3d828a49 + github.com/SkycoinProject/dmsg v0.1.1-0.20200523194607-be73f083a729 github.com/SkycoinProject/skycoin v0.27.0 github.com/SkycoinProject/yamux v0.0.0-20191213015001-a36efeefbf6a github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 diff --git a/go.sum b/go.sum index 1c87a43279..b228f3a60e 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/SkycoinProject/dmsg v0.0.0-20200306152741-acee74fa4514/go.mod h1:DzykXMLlx6Fx0fGjZsCIRas/MIvxW8DZpmDA6f2nCRk= -github.com/SkycoinProject/dmsg v0.1.1-0.20200420091742-8c1a3d828a49 h1:rYqmvSRA+rq6LTne/Ge34T0i4yjSHSwkhk0ER6relWU= -github.com/SkycoinProject/dmsg v0.1.1-0.20200420091742-8c1a3d828a49/go.mod h1:MiX+UG/6fl3g+9rS13/fq7BwUQ2eOlg1yOBOnNf6J6A= +github.com/SkycoinProject/dmsg v0.1.1-0.20200523194607-be73f083a729 h1:Edgnt4ido4MGfNTEJUYqNeXt0AlJ4EHlFCWBrKYPvT4= +github.com/SkycoinProject/dmsg v0.1.1-0.20200523194607-be73f083a729/go.mod h1:MiX+UG/6fl3g+9rS13/fq7BwUQ2eOlg1yOBOnNf6J6A= github.com/SkycoinProject/skycoin v0.26.0/go.mod h1:xqPLOKh5B6GBZlGA7B5IJfQmCy7mwimD9NlqxR3gMXo= github.com/SkycoinProject/skycoin v0.27.0 h1:N3IHxj8ossHOcsxLYOYugT+OaELLncYHJHxbbYLPPmY= github.com/SkycoinProject/skycoin v0.27.0/go.mod h1:xqPLOKh5B6GBZlGA7B5IJfQmCy7mwimD9NlqxR3gMXo= diff --git a/internal/vpn/os.go b/internal/vpn/os.go index 3c221c6066..58a9caecfe 100644 --- a/internal/vpn/os.go +++ b/internal/vpn/os.go @@ -16,6 +16,7 @@ func parseCIDR(ipCIDR string) (ipStr, netmask string, err error) { return ip.String(), fmt.Sprintf("%d.%d.%d.%d", net.Mask[0], net.Mask[1], net.Mask[2], net.Mask[3]), nil } +//nolint:unparam func run(bin string, args ...string) error { cmd := exec.Command(bin, args...) //nolint:gosec diff --git a/internal/vpn/os_server_linux.go b/internal/vpn/os_server_linux.go index 46abba7eb3..7444fae670 100644 --- a/internal/vpn/os_server_linux.go +++ b/internal/vpn/os_server_linux.go @@ -75,6 +75,7 @@ func EnableIPv6Forwarding() error { // EnableIPMasquerading enables IP masquerading for the interface with name `ifcName`. func EnableIPMasquerading(ifcName string) error { cmd := fmt.Sprintf(enableIPMasqueradingCMDFmt, ifcName) + //nolint:gosec if err := exec.Command("sh", "-c", cmd).Run(); err != nil { return fmt.Errorf("error running command %s: %w", cmd, err) } @@ -85,6 +86,7 @@ func EnableIPMasquerading(ifcName string) error { // DisableIPMasquerading disables IP masquerading for the interface with name `ifcName`. func DisableIPMasquerading(ifcName string) error { cmd := fmt.Sprintf(disableIPMasqueradingCMDFmt, ifcName) + //nolint:gosec if err := exec.Command("sh", "-c", cmd).Run(); err != nil { return fmt.Errorf("error running command %s: %w", cmd, err) } diff --git a/pkg/app/appcommon/hello.go b/pkg/app/appcommon/hello.go new file mode 100644 index 0000000000..59775c1da1 --- /dev/null +++ b/pkg/app/appcommon/hello.go @@ -0,0 +1,74 @@ +package appcommon + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" +) + +// Hello represents the first JSON object that an app sends the visor. +type Hello struct { + ProcKey ProcKey `json:"proc_key"` // proc key + EgressNet string `json:"egress_net,omitempty"` // network which hosts the appevent.RPCGateway of the app + EgressAddr string `json:"egress_addr,omitempty"` // address which hosts the appevent.RPCGateway of the app + EventSubs map[string]bool `json:"event_subs,omitempty"` // event subscriptions +} + +// String implements fmt.Stringer +func (h *Hello) String() string { + j, err := json.Marshal(h) + if err != nil { + panic(err) // should never happen + } + return string(j) +} + +// AllowsEventType returns true if the hello object contents allow for an event type. +func (h *Hello) AllowsEventType(eventType string) bool { + if h.EventSubs == nil { + return false + } + return h.EventSubs[eventType] +} + +// ReadHello reads in a hello object from the given reader. +func ReadHello(r io.Reader) (Hello, error) { + sizeRaw := make([]byte, 2) + if _, err := io.ReadFull(r, sizeRaw); err != nil { + return Hello{}, fmt.Errorf("failed to read hello size prefix: %w", err) + } + size := binary.BigEndian.Uint16(sizeRaw) + + helloRaw := make([]byte, size) + if _, err := io.ReadFull(r, helloRaw); err != nil { + return Hello{}, fmt.Errorf("failed to read hello data: %w", err) + } + + var hello Hello + if err := json.Unmarshal(helloRaw, &hello); err != nil { + return Hello{}, fmt.Errorf("failed to unmarshal hello data: %w", err) + } + + return hello, nil +} + +// WriteHello writes a hello object into a given writer. +func WriteHello(w io.Writer, hello Hello) error { + helloRaw, err := json.Marshal(hello) + if err != nil { + panic(err) // should never happen + } + + raw := make([]byte, 2+len(helloRaw)) + size := len(helloRaw) + binary.BigEndian.PutUint16(raw[:2], uint16(size)) + if n := copy(raw[2:], helloRaw); n != size { + panic("hello write does not add up") + } + + if _, err := w.Write(raw); err != nil { + return fmt.Errorf("failed to write hello data: %w", err) + } + return nil +} diff --git a/pkg/app/appevent/broadcaster.go b/pkg/app/appevent/broadcaster.go new file mode 100644 index 0000000000..5e262deec2 --- /dev/null +++ b/pkg/app/appevent/broadcaster.go @@ -0,0 +1,110 @@ +package appevent + +import ( + "context" + "sync" + "time" + + "github.com/SkycoinProject/skycoin/src/util/logging" + "github.com/sirupsen/logrus" +) + +// Broadcaster combines multiple RPCClients (which connects to the RPCGateway of the apps). +// It is responsible for broadcasting events to apps (if the app is subscribed to the event type). +type Broadcaster struct { + timeout time.Duration + + log logrus.FieldLogger + clients map[RPCClient]chan error + closed bool + mx sync.Mutex +} + +// NewBroadcaster instantiates a Broadcaster. +func NewBroadcaster(log logrus.FieldLogger, timeout time.Duration) *Broadcaster { + if log == nil { + log = logging.MustGetLogger("event_broadcaster") + } + return &Broadcaster{ + timeout: timeout, + log: log, + clients: make(map[RPCClient]chan error), + closed: false, + } +} + +// AddClient adds a RPCClient. +func (mc *Broadcaster) AddClient(c RPCClient) { + mc.mx.Lock() + if !mc.closed { + mc.clients[c] = make(chan error, 1) + } + mc.mx.Unlock() +} + +// Broadcast broadcasts an event to all subscribed channels of all rpc gateways. +func (mc *Broadcaster) Broadcast(ctx context.Context, e *Event) error { + if mc.timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(mc.timeout)) + defer cancel() + } + + mc.mx.Lock() + defer mc.mx.Unlock() + + if mc.closed { + return ErrSubscriptionsClosed + } + + if len(mc.clients) == 0 { + return nil + } + + // Notify all clients of event (if client is subscribed to the event type). + for client, errCh := range mc.clients { + go notifyClient(ctx, e, client, errCh) + } + + // Delete inactive clients and associated error channels. + for client, errCh := range mc.clients { + if err := <-errCh; err != nil { + mc.log. + WithError(err). + WithField("close_error", client.Close()). + WithField("hello", client.Hello().String()). + Warn("Events RPC client closed due to error.") + + delete(mc.clients, client) + close(errCh) + } + } + + return nil +} + +// notifyClient notifies a client of a given event if client is subscribed to the event type of the event. +func notifyClient(ctx context.Context, e *Event, client RPCClient, errCh chan error) { + var err error + if client.Hello().AllowsEventType(e.Type) { + err = client.Notify(ctx, e) + } + errCh <- err +} + +// Close implements io.Closer +func (mc *Broadcaster) Close() error { + mc.mx.Lock() + defer mc.mx.Unlock() + + if mc.closed { + return ErrSubscriptionsClosed + } + mc.closed = true + + for c, errCh := range mc.clients { + close(errCh) + delete(mc.clients, c) + } + return nil +} diff --git a/pkg/app/appevent/broadcaster_test.go b/pkg/app/appevent/broadcaster_test.go new file mode 100644 index 0000000000..5a66d68920 --- /dev/null +++ b/pkg/app/appevent/broadcaster_test.go @@ -0,0 +1,133 @@ +package appevent + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" +) + +func TestBroadcaster_Broadcast(t *testing.T) { + const timeout = time.Second * 2 + + // makeMockClient creates a mock RPCClient that appends received events to 'gotEvents'. + makeMockClient := func(subs map[string]bool, gotEvents *[]*Event) RPCClient { + mockC := new(MockRPCClient) + mockC.On("Close").Return(nil) + mockC.On("Hello").Return(&appcommon.Hello{ProcKey: appcommon.RandProcKey(), EventSubs: subs}) + mockC.On("Notify", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + *gotEvents = append(*gotEvents, args.Get(1).(*Event)) + }) + return mockC + } + + // makeEvents makes (n) number of random events. + makeEvents := func(n int) []*Event { + evs := make([]*Event, 0, n) + i := 0 + for { + for t := range AllTypes() { + evs = append(evs, NewEvent(t, struct{}{})) + if i++; i == n { + return evs + } + } + } + } + + // extractEvents returns events that are part of the subs. + extractEvents := func(events []*Event, subs map[string]bool) []*Event { + out := make([]*Event, 0, len(events)) + for _, ev := range events { + if subs[ev.Type] { + out = append(out, ev) + } + } + return out + } + + // Ensure Broadcast correctly broadcasts events to the internal RPCClients. + // Arrange: + // - There is a n(C) number of RPCClients within the Broadcaster. + // - All the aforementioned RPCClients are subscribed to all possible event types. + // Act: + // - Broadcast n(E) number of events using Broadcaster.Broadcast. + // Assert: + // - Each of the n(C) RPCClients should receive n(E) event objects. + // - Received event objects should be in the order of sent. + t.Run("broadcast_events", func(t *testing.T) { + // Arrange: constants. + const nClients = 12 + const nEvents = 52 + + // Arrange: prepare broadcaster. + bc := NewBroadcaster(nil, timeout) + defer func() { assert.NoError(t, bc.Close()) }() + + // Arrange: events to broadcast and results slice. + events := makeEvents(nEvents) + results := make([][]*Event, nClients) + for i := 0; i < nClients; i++ { + bc.AddClient(makeMockClient(AllTypes(), &results[i])) + } + + // Act: broadcast events. + for _, ev := range events { + require.NoError(t, bc.Broadcast(context.Background(), ev)) + } + + // Assert: received events of each RPCClient. + for i, r := range results { + assert.Len(t, r, nEvents, i) + assert.Equal(t, events, r, i) + } + }) + + // Ensure Broadcaster only broadcasts an event to a RPCClient if the RPCClient is subscribed to the event type. + // Arrange: + // - There is a RPCClient and a Broadcaster. + // - The RPCClient is only subscribed to one event type. + // Act: + // - Broadcaster broadcasts all event types. + // Assert: + // - The RPCClient should have only received events that are of subscribed types. + t.Run("broadcast_only_subscribed_events", func(t *testing.T) { + // Arrange: constants/variables + const nEvents = 64 + subs := map[string]bool{TCPDial: true} + + // Arrange: events to broadcast and results slice. + events := makeEvents(nEvents) + result := make([]*Event, 0, nEvents) + + // Arrange: prepare RPCClient. + mockC := makeMockClient(subs, &result) + defer func() { assert.NoError(t, mockC.Close()) }() + + // Arrange: prepare broadcaster. + bc := NewBroadcaster(nil, timeout) + bc.AddClient(mockC) + defer func() { assert.NoError(t, bc.Close()) }() + + // Act: broadcast events. + for _, ev := range events { + require.NoError(t, bc.Broadcast(context.TODO(), ev)) + } + + // Assert: resultant events slice outputted from mock client. + expectedEvents := extractEvents(events, subs) + assert.Len(t, result, len(expectedEvents)) + assert.Equal(t, expectedEvents, result) + expJ, err := json.Marshal(expectedEvents) + require.NoError(t, err) + resJ, err := json.Marshal(result) + require.NoError(t, err) + assert.JSONEq(t, string(expJ), string(resJ)) + }) +} diff --git a/pkg/app/appevent/event.go b/pkg/app/appevent/event.go new file mode 100644 index 0000000000..b67d9ef426 --- /dev/null +++ b/pkg/app/appevent/event.go @@ -0,0 +1,37 @@ +package appevent + +import ( + "encoding/json" +) + +// Event represents an event that is to be broadcasted. +type Event struct { + Type string + Data []byte + done chan struct{} // to be closed once event is dealt with +} + +// NewEvent creates a new Event. +func NewEvent(t string, v interface{}) *Event { + data, err := json.Marshal(v) + if err != nil { + panic(err) // should never happen + } + return &Event{Type: t, Data: data} +} + +// Unmarshal unmarshals the event data to a given object. +func (e *Event) Unmarshal(v interface{}) { + if err := json.Unmarshal(e.Data, v); err != nil { + panic(err) // should never happen + } +} + +// InitDone enables the Done/Wait logic. +func (e *Event) InitDone() { e.done = make(chan struct{}) } + +// Done informs that event is handled. +func (e *Event) Done() { close(e.done) } + +// Wait waits until event is handled. +func (e *Event) Wait() { <-e.done } diff --git a/pkg/app/appevent/handshake.go b/pkg/app/appevent/handshake.go new file mode 100644 index 0000000000..9cb1dbea3b --- /dev/null +++ b/pkg/app/appevent/handshake.go @@ -0,0 +1,74 @@ +package appevent + +import ( + "fmt" + "io" + "net" + "net/rpc" + + "github.com/sirupsen/logrus" + + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" +) + +// DoReqHandshake performs a request handshake which is initiated from an app. +// First, it determines whether we need an egress connection (from the app server which sends events) by seeing if +// there are any subscriptions within 'subs'. If so, a listener is started. +// Then we send a hello object to the app server which contains the proc key and egress connection info (if needed). +func DoReqHandshake(conf appcommon.ProcConfig, subs *Subscriber) (net.Conn, []io.Closer, error) { + var closers []io.Closer + hello := appcommon.Hello{ProcKey: conf.ProcKey} + + // configure and serve event channel subscriptions (if any) + if subs != nil && subs.Count() > 0 { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, nil, fmt.Errorf("failed to create listener for RPC egress: %w", err) + } + + log := logrus.New().WithField("src", "events_gateway") + + rpcS := rpc.NewServer() + if err := rpcS.RegisterName(conf.ProcKey.String(), NewRPCGateway(log, subs)); err != nil { + panic(err) // should never happen + } + go rpcS.Accept(lis) + + hello.EgressNet = lis.Addr().Network() + hello.EgressAddr = lis.Addr().String() + hello.EventSubs = subs.Subscriptions() + closers = append(closers, lis) + } + + // dial to app server and send hello JSON object + // sending hello will also advertise event subscriptions endpoint (if needed) + conn, err := net.Dial("tcp", conf.AppSrvAddr) + if err != nil { + return nil, nil, fmt.Errorf("failed to dial to app server: %w", err) + } + if err := appcommon.WriteHello(conn, hello); err != nil { + return nil, nil, fmt.Errorf("failed to send hello to app server: %w", err) + } + + return conn, append(closers, conn), nil +} + +// DoRespHandshake performs a response handshake from the app server side. +// It reads the hello object from the app, and connects the app to the events broadcast (if needed). +func DoRespHandshake(ebc *Broadcaster, conn net.Conn) (*appcommon.Hello, error) { + hello, err := appcommon.ReadHello(conn) + if err != nil { + return nil, fmt.Errorf("failed to read hello object: %w", err) + } + + // connect app to events broadcast (if necessary) + if hello.EgressNet != "" && hello.EgressAddr != "" && len(hello.EventSubs) > 0 { + rpcC, err := NewRPCClient(&hello) + if err != nil { + return nil, fmt.Errorf("failed to connect app to events broadcast: %w", err) + } + ebc.AddClient(rpcC) + } + + return &hello, nil +} diff --git a/pkg/app/appevent/mock_rpc_client.go b/pkg/app/appevent/mock_rpc_client.go new file mode 100644 index 0000000000..9ec94adeec --- /dev/null +++ b/pkg/app/appevent/mock_rpc_client.go @@ -0,0 +1,60 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package appevent + +import ( + context "context" + + appcommon "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" + + mock "github.com/stretchr/testify/mock" +) + +// MockRPCClient is an autogenerated mock type for the RPCClient type +type MockRPCClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *MockRPCClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Hello provides a mock function with given fields: +func (_m *MockRPCClient) Hello() *appcommon.Hello { + ret := _m.Called() + + var r0 *appcommon.Hello + if rf, ok := ret.Get(0).(func() *appcommon.Hello); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*appcommon.Hello) + } + } + + return r0 +} + +// Notify provides a mock function with given fields: ctx, e +func (_m *MockRPCClient) Notify(ctx context.Context, e *Event) error { + ret := _m.Called(ctx, e) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *Event) error); ok { + r0 = rf(ctx, e) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/app/appevent/rpc.go b/pkg/app/appevent/rpc.go new file mode 100644 index 0000000000..b05c1c1f3e --- /dev/null +++ b/pkg/app/appevent/rpc.go @@ -0,0 +1,96 @@ +package appevent + +import ( + "context" + "fmt" + "io" + "net/rpc" + + "github.com/SkycoinProject/skycoin/src/util/logging" + "github.com/sirupsen/logrus" + + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" +) + +// RPCGateway represents the RPC gateway that opens up an app for incoming events from visor. +type RPCGateway struct { + log logrus.FieldLogger + subs *Subscriber +} + +// NewRPCGateway returns a new RPCGateway. +func NewRPCGateway(log logrus.FieldLogger, subs *Subscriber) *RPCGateway { + if log == nil { + log = logging.MustGetLogger("app_rpc_egress_gateway") + } + if subs == nil { + panic("'subs' input cannot be nil") + } + return &RPCGateway{log: log, subs: subs} +} + +// Notify notifies the app about events. +func (g *RPCGateway) Notify(e *Event, _ *struct{}) (err error) { + return PushEvent(g.subs, e) +} + +//go:generate mockery -name RPCClient -case underscore -inpkg + +// RPCClient describes the RPC client interface that communicates the NewRPCGateway. +type RPCClient interface { + io.Closer + Notify(ctx context.Context, e *Event) error + Hello() *appcommon.Hello +} + +// NewRPCClient constructs a new 'rpcClient'. +func NewRPCClient(hello *appcommon.Hello) (RPCClient, error) { + if hello.EgressNet == "" || hello.EgressAddr == "" { + return &rpcClient{rpcC: nil, hello: hello}, nil + } + + rpcC, err := rpc.Dial(hello.EgressNet, hello.EgressAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial RPC: %w", err) + } + return &rpcClient{rpcC: rpcC, hello: hello}, nil +} + +type rpcClient struct { + rpcC *rpc.Client + hello *appcommon.Hello +} + +// Notify sends a notify to the rpc gateway. +func (c *rpcClient) Notify(ctx context.Context, e *Event) error { + if c.rpcC == nil { + return nil + } + + call := c.rpcC.Go(c.formatMethod("Notify"), e, nil, nil) + select { + case <-call.Done: + return call.Error + case <-ctx.Done(): + return ctx.Err() + } +} + +// Hello returns the internal hello object. +func (c *rpcClient) Hello() *appcommon.Hello { + return c.hello +} + +// Close closes the underlying rpc client (if any). +func (c *rpcClient) Close() error { + if c.rpcC == nil { + return nil + } + return c.rpcC.Close() +} + +// formatMethod formats complete RPC method signature. +func (c *rpcClient) formatMethod(method string) string { + const methodFmt = "%s.%s" + return fmt.Sprintf(methodFmt, c.hello.ProcKey.String(), method) +} diff --git a/pkg/app/appevent/subscriber.go b/pkg/app/appevent/subscriber.go new file mode 100644 index 0000000000..5fe816e7d9 --- /dev/null +++ b/pkg/app/appevent/subscriber.go @@ -0,0 +1,129 @@ +package appevent + +import ( + "errors" + "sync" +) + +// subChanSize is used so that incoming events are kept in order +const subChanSize = 5 + +// Errors associated with the Subscriber type. +var ( + ErrSubscriptionsClosed = errors.New("event subscriptions is closed") +) + +// Subscriber is used by apps and contain subscription channels to different event types. +type Subscriber struct { + chanSize int // config: event channel size + + m map[string]chan *Event + mx sync.RWMutex + closed bool +} + +// NewSubscriber returns a new Subscriber struct. +func NewSubscriber() *Subscriber { + return &Subscriber{ + chanSize: subChanSize, + m: make(map[string]chan *Event), + closed: false, + } +} + +// OnTCPDial subscribes to the OnTCPDial event channel (if not already). +// And triggers the contained action func on each subsequent event. +func (s *Subscriber) OnTCPDial(action func(data TCPDialData)) { + evCh := s.ensureEventChan(TCPDial) + + go func() { + for ev := range evCh { + var data TCPDialData + ev.Unmarshal(&data) + action(data) + ev.Done() + } + }() +} + +// OnTCPClose subscribes to the OnTCPClose event channel (if not already). +// And triggers the contained action func on each subsequent event. +func (s *Subscriber) OnTCPClose(action func(data TCPCloseData)) { + evCh := s.ensureEventChan(TCPClose) + + go func() { + for ev := range evCh { + var data TCPCloseData + ev.Unmarshal(&data) + action(data) + ev.Done() + } + }() +} + +func (s *Subscriber) ensureEventChan(eventType string) chan *Event { + s.mx.Lock() + ch, ok := s.m[eventType] + if !ok { + ch = make(chan *Event, s.chanSize) + s.m[eventType] = ch + } + s.mx.Unlock() + + return ch +} + +// Subscriptions returns a map of all subscribed event types. +func (s *Subscriber) Subscriptions() map[string]bool { + s.mx.RLock() + subs := make(map[string]bool, len(s.m)) + for t := range s.m { + subs[t] = true + } + s.mx.RUnlock() + + return subs +} + +// Count returns the number of subscriptions. +func (s *Subscriber) Count() int { + s.mx.RLock() + n := len(s.m) + s.mx.RUnlock() + return n +} + +// Close implements io.Closer +func (s *Subscriber) Close() error { + s.mx.Lock() + defer s.mx.Unlock() + + if s.closed { + return ErrSubscriptionsClosed + } + + for _, ch := range s.m { + close(ch) + } + s.m = nil + + return nil +} + +// PushEvent pushes an event to the relevant subscription channel. +func PushEvent(s *Subscriber, e *Event) error { + s.mx.RLock() + defer s.mx.RUnlock() + + if s.closed { + return ErrSubscriptionsClosed + } + + if ch, ok := s.m[e.Type]; ok { + e.InitDone() + ch <- e + e.Wait() // wait until event is fully handled by app before returning + } + + return nil +} diff --git a/pkg/app/appevent/types.go b/pkg/app/appevent/types.go new file mode 100644 index 0000000000..4bc8f89988 --- /dev/null +++ b/pkg/app/appevent/types.go @@ -0,0 +1,33 @@ +package appevent + +// AllTypes returns all event types. +func AllTypes() map[string]bool { + return map[string]bool{ + TCPDial: true, + TCPClose: true, + } +} + +// TCPDial represents a dial event. +const TCPDial = "tcp_dial" + +// TCPDialData contains net dial event data. +type TCPDialData struct { + RemoteNet string `json:"remote_net"` + RemoteAddr string `json:"remote_addr"` +} + +// Type returns the TCPDial type. +func (TCPDialData) Type() string { return TCPDial } + +// TCPClose represents a close event. +const TCPClose = "tcp_close" + +// TCPCloseData contains net close event data. +type TCPCloseData struct { + RemoteNet string `json:"remote_net"` + RemoteAddr string `json:"remote_addr"` +} + +// Type returns the TCPClose type. +func (TCPCloseData) Type() string { return TCPClose } diff --git a/pkg/app/mock_rpc_client.go b/pkg/app/appserver/mock_rpc_ingress_client.go similarity index 79% rename from pkg/app/mock_rpc_client.go rename to pkg/app/appserver/mock_rpc_ingress_client.go index f3c62f3238..55593bf90d 100644 --- a/pkg/app/mock_rpc_client.go +++ b/pkg/app/appserver/mock_rpc_ingress_client.go @@ -1,6 +1,6 @@ // Code generated by mockery v1.0.0. DO NOT EDIT. -package app +package appserver import ( mock "github.com/stretchr/testify/mock" @@ -12,13 +12,13 @@ import ( time "time" ) -// MockRPCClient is an autogenerated mock type for the RPCClient type -type MockRPCClient struct { +// MockRPCIngressClient is an autogenerated mock type for the RPCIngressClient type +type MockRPCIngressClient struct { mock.Mock } // Accept provides a mock function with given fields: lisID -func (_m *MockRPCClient) Accept(lisID uint16) (uint16, appnet.Addr, error) { +func (_m *MockRPCIngressClient) Accept(lisID uint16) (uint16, appnet.Addr, error) { ret := _m.Called(lisID) var r0 uint16 @@ -46,7 +46,7 @@ func (_m *MockRPCClient) Accept(lisID uint16) (uint16, appnet.Addr, error) { } // CloseConn provides a mock function with given fields: id -func (_m *MockRPCClient) CloseConn(id uint16) error { +func (_m *MockRPCIngressClient) CloseConn(id uint16) error { ret := _m.Called(id) var r0 error @@ -60,7 +60,7 @@ func (_m *MockRPCClient) CloseConn(id uint16) error { } // CloseListener provides a mock function with given fields: id -func (_m *MockRPCClient) CloseListener(id uint16) error { +func (_m *MockRPCIngressClient) CloseListener(id uint16) error { ret := _m.Called(id) var r0 error @@ -74,7 +74,7 @@ func (_m *MockRPCClient) CloseListener(id uint16) error { } // Dial provides a mock function with given fields: remote -func (_m *MockRPCClient) Dial(remote appnet.Addr) (uint16, routing.Port, error) { +func (_m *MockRPCIngressClient) Dial(remote appnet.Addr) (uint16, routing.Port, error) { ret := _m.Called(remote) var r0 uint16 @@ -102,7 +102,7 @@ func (_m *MockRPCClient) Dial(remote appnet.Addr) (uint16, routing.Port, error) } // Listen provides a mock function with given fields: local -func (_m *MockRPCClient) Listen(local appnet.Addr) (uint16, error) { +func (_m *MockRPCIngressClient) Listen(local appnet.Addr) (uint16, error) { ret := _m.Called(local) var r0 uint16 @@ -123,7 +123,7 @@ func (_m *MockRPCClient) Listen(local appnet.Addr) (uint16, error) { } // Read provides a mock function with given fields: connID, b -func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, error) { +func (_m *MockRPCIngressClient) Read(connID uint16, b []byte) (int, error) { ret := _m.Called(connID, b) var r0 int @@ -144,7 +144,7 @@ func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, error) { } // SetDeadline provides a mock function with given fields: connID, d -func (_m *MockRPCClient) SetDeadline(connID uint16, d time.Time) error { +func (_m *MockRPCIngressClient) SetDeadline(connID uint16, d time.Time) error { ret := _m.Called(connID, d) var r0 error @@ -158,7 +158,7 @@ func (_m *MockRPCClient) SetDeadline(connID uint16, d time.Time) error { } // SetReadDeadline provides a mock function with given fields: connID, d -func (_m *MockRPCClient) SetReadDeadline(connID uint16, d time.Time) error { +func (_m *MockRPCIngressClient) SetReadDeadline(connID uint16, d time.Time) error { ret := _m.Called(connID, d) var r0 error @@ -172,7 +172,7 @@ func (_m *MockRPCClient) SetReadDeadline(connID uint16, d time.Time) error { } // SetWriteDeadline provides a mock function with given fields: connID, d -func (_m *MockRPCClient) SetWriteDeadline(connID uint16, d time.Time) error { +func (_m *MockRPCIngressClient) SetWriteDeadline(connID uint16, d time.Time) error { ret := _m.Called(connID, d) var r0 error @@ -186,7 +186,7 @@ func (_m *MockRPCClient) SetWriteDeadline(connID uint16, d time.Time) error { } // Write provides a mock function with given fields: connID, b -func (_m *MockRPCClient) Write(connID uint16, b []byte) (int, error) { +func (_m *MockRPCIngressClient) Write(connID uint16, b []byte) (int, error) { ret := _m.Called(connID, b) var r0 int diff --git a/pkg/app/appserver/proc.go b/pkg/app/appserver/proc.go index 5b69959261..941e3f7f72 100644 --- a/pkg/app/appserver/proc.go +++ b/pkg/app/appserver/proc.go @@ -38,10 +38,10 @@ type Proc struct { waitMx sync.Mutex waitErr error - rpcGW *RPCGateway // gateway shared over 'conn' - introduced AFTER proc is started - conn net.Conn // connection to proc - introduced AFTER proc is started - connCh chan struct{} // push here when conn is received - protected by 'connOnce' - connOnce sync.Once // ensures we only push to 'connCh' once + rpcGW *RPCIngressGateway // gateway shared over 'conn' - introduced AFTER proc is started + conn net.Conn // connection to proc - introduced AFTER proc is started + connCh chan struct{} // push here when conn is received - protected by 'connOnce' + connOnce sync.Once // ensures we only push to 'connCh' once } // NewProc constructs `Proc`. @@ -183,7 +183,7 @@ func (p *Proc) Stop() error { } if p.cmd.Process != nil { - err := p.cmd.Process.Signal(os.Interrupt) //TODO: panic here. + err := p.cmd.Process.Signal(os.Interrupt) if err != nil { return err } diff --git a/pkg/app/appserver/proc_manager.go b/pkg/app/appserver/proc_manager.go index a9cebd9e26..0ed62798db 100644 --- a/pkg/app/appserver/proc_manager.go +++ b/pkg/app/appserver/proc_manager.go @@ -14,6 +14,7 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appdisc" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" ) //go:generate mockery -name ProcManager -case underscore -inpkg @@ -56,18 +57,24 @@ type procManager struct { procs map[string]*Proc procsByKey map[appcommon.ProcKey]*Proc + // event broadcaster: broadcasts events to apps + eb *appevent.Broadcaster + mx sync.RWMutex done chan struct{} } // NewProcManager constructs `ProcManager`. -func NewProcManager(mLog *logging.MasterLogger, discF *appdisc.Factory, addr string) (ProcManager, error) { +func NewProcManager(mLog *logging.MasterLogger, discF *appdisc.Factory, eb *appevent.Broadcaster, addr string) (ProcManager, error) { if mLog == nil { mLog = logging.NewMasterLogger() } if discF == nil { discF = new(appdisc.Factory) } + if eb == nil { + eb = appevent.NewBroadcaster(mLog.PackageLogger("event_broadcaster"), time.Second) + } lis, err := net.Listen("tcp", addr) if err != nil { @@ -82,6 +89,7 @@ func NewProcManager(mLog *logging.MasterLogger, discF *appdisc.Factory, addr str discF: discF, procs: make(map[string]*Proc), procsByKey: make(map[appcommon.ProcKey]*Proc), + eb: eb, done: make(chan struct{}), } @@ -130,21 +138,18 @@ func (m *procManager) handleConn(conn net.Conn) bool { log := m.log.WithField("remote", conn.RemoteAddr()) log.Debug("Accepting proc conn...") - // Read in and check key. - var key appcommon.ProcKey - if n, err := io.ReadFull(conn, key[:]); err != nil { - log.WithError(err). - WithField("n", n). - Warn("Failed to read proc key.") + hello, err := appevent.DoRespHandshake(m.eb, conn) + if err != nil { + log.WithError(err).Error("Failed to do handshake with proc.") return false } - log = log.WithField("proc_key", key.String()) - log.Debug("Read proc key.") + log = log.WithField("hello", hello.String()) + log.Debug("Read hello from proc.") // Push conn to Proc. m.mx.RLock() - proc, ok := m.procsByKey[key] + proc, ok := m.procsByKey[hello.ProcKey] m.mx.RUnlock() if !ok { log.Error("Failed to find proc of given key.") diff --git a/pkg/app/appserver/proc_manager_test.go b/pkg/app/appserver/proc_manager_test.go index 5fd360cf8e..2d6d70cdb5 100644 --- a/pkg/app/appserver/proc_manager_test.go +++ b/pkg/app/appserver/proc_manager_test.go @@ -8,7 +8,7 @@ import ( ) func TestProcManager_ProcByName(t *testing.T) { - mI, err := NewProcManager(nil, nil, ":0") + mI, err := NewProcManager(nil, nil, nil, ":0") require.NoError(t, err) m, ok := mI.(*procManager) @@ -28,7 +28,7 @@ func TestProcManager_ProcByName(t *testing.T) { } func TestProcManager_Range(t *testing.T) { - mI, err := NewProcManager(nil, nil, ":0") + mI, err := NewProcManager(nil, nil, nil, ":0") require.NoError(t, err) m, ok := mI.(*procManager) @@ -57,7 +57,7 @@ func TestProcManager_Range(t *testing.T) { } func TestProcManager_Pop(t *testing.T) { - mI, err := NewProcManager(nil, nil, ":0") + mI, err := NewProcManager(nil, nil, nil, ":0") require.NoError(t, err) m, ok := mI.(*procManager) diff --git a/pkg/app/rpc_client.go b/pkg/app/appserver/rpc_ingress_client.go similarity index 62% rename from pkg/app/rpc_client.go rename to pkg/app/appserver/rpc_ingress_client.go index b36f56b910..c12b96396e 100644 --- a/pkg/app/rpc_client.go +++ b/pkg/app/appserver/rpc_ingress_client.go @@ -1,4 +1,4 @@ -package app +package appserver import ( "fmt" @@ -7,14 +7,13 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" - "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" ) -//go:generate mockery -name RPCClient -case underscore -inpkg +//go:generate mockery -name RPCIngressClient -case underscore -inpkg -// RPCClient describes RPC interface to communicate with the server. -type RPCClient interface { +// RPCIngressClient describes RPC interface to communicate with the server. +type RPCIngressClient interface { Dial(remote appnet.Addr) (connID uint16, localPort routing.Port, err error) Listen(local appnet.Addr) (uint16, error) Accept(lisID uint16) (connID uint16, remote appnet.Addr, err error) @@ -27,23 +26,23 @@ type RPCClient interface { SetWriteDeadline(connID uint16, d time.Time) error } -// rpcClient implements `RPCClient`. -type rpcClient struct { - rpc *rpc.Client - appKey appcommon.ProcKey +// rpcIngressClient implements `RPCIngressClient`. +type rpcIngressClient struct { + rpc *rpc.Client + procKey appcommon.ProcKey } -// NewRPCClient constructs new `rpcClient`. -func NewRPCClient(rpc *rpc.Client, appKey appcommon.ProcKey) RPCClient { - return &rpcClient{ - rpc: rpc, - appKey: appKey, +// NewRPCIngressClient constructs new `rpcIngressClient`. +func NewRPCIngressClient(rpc *rpc.Client, procKey appcommon.ProcKey) RPCIngressClient { + return &rpcIngressClient{ + rpc: rpc, + procKey: procKey, } } // Dial sends `Dial` command to the server. -func (c *rpcClient) Dial(remote appnet.Addr) (connID uint16, localPort routing.Port, err error) { - var resp appserver.DialResp +func (c *rpcIngressClient) Dial(remote appnet.Addr) (connID uint16, localPort routing.Port, err error) { + var resp DialResp if err := c.rpc.Call(c.formatMethod("Dial"), &remote, &resp); err != nil { return 0, 0, err } @@ -52,7 +51,7 @@ func (c *rpcClient) Dial(remote appnet.Addr) (connID uint16, localPort routing.P } // Listen sends `Listen` command to the server. -func (c *rpcClient) Listen(local appnet.Addr) (uint16, error) { +func (c *rpcIngressClient) Listen(local appnet.Addr) (uint16, error) { var lisID uint16 if err := c.rpc.Call(c.formatMethod("Listen"), &local, &lisID); err != nil { return 0, err @@ -62,8 +61,8 @@ func (c *rpcClient) Listen(local appnet.Addr) (uint16, error) { } // Accept sends `Accept` command to the server. -func (c *rpcClient) Accept(lisID uint16) (connID uint16, remote appnet.Addr, err error) { - var acceptResp appserver.AcceptResp +func (c *rpcIngressClient) Accept(lisID uint16) (connID uint16, remote appnet.Addr, err error) { + var acceptResp AcceptResp if err := c.rpc.Call(c.formatMethod("Accept"), &lisID, &acceptResp); err != nil { return 0, appnet.Addr{}, err } @@ -72,13 +71,13 @@ func (c *rpcClient) Accept(lisID uint16) (connID uint16, remote appnet.Addr, err } // Write sends `Write` command to the server. -func (c *rpcClient) Write(connID uint16, b []byte) (int, error) { - req := appserver.WriteReq{ +func (c *rpcIngressClient) Write(connID uint16, b []byte) (int, error) { + req := WriteReq{ ConnID: connID, B: b, } - var resp appserver.WriteResp + var resp WriteResp if err := c.rpc.Call(c.formatMethod("Write"), &req, &resp); err != nil { return 0, err } @@ -87,13 +86,13 @@ func (c *rpcClient) Write(connID uint16, b []byte) (int, error) { } // Read sends `Read` command to the server. -func (c *rpcClient) Read(connID uint16, b []byte) (int, error) { - req := appserver.ReadReq{ +func (c *rpcIngressClient) Read(connID uint16, b []byte) (int, error) { + req := ReadReq{ ConnID: connID, BufLen: len(b), } - var resp appserver.ReadResp + var resp ReadResp if err := c.rpc.Call(c.formatMethod("Read"), &req, &resp); err != nil { return 0, err } @@ -106,18 +105,18 @@ func (c *rpcClient) Read(connID uint16, b []byte) (int, error) { } // CloseConn sends `CloseConn` command to the server. -func (c *rpcClient) CloseConn(id uint16) error { +func (c *rpcIngressClient) CloseConn(id uint16) error { return c.rpc.Call(c.formatMethod("CloseConn"), &id, nil) } // CloseListener sends `CloseListener` command to the server. -func (c *rpcClient) CloseListener(id uint16) error { +func (c *rpcIngressClient) CloseListener(id uint16) error { return c.rpc.Call(c.formatMethod("CloseListener"), &id, nil) } // SetDeadline sends `SetDeadline` command to the server. -func (c *rpcClient) SetDeadline(id uint16, t time.Time) error { - req := appserver.DeadlineReq{ +func (c *rpcIngressClient) SetDeadline(id uint16, t time.Time) error { + req := DeadlineReq{ ConnID: id, Deadline: t, } @@ -126,8 +125,8 @@ func (c *rpcClient) SetDeadline(id uint16, t time.Time) error { } // SetReadDeadline sends `SetReadDeadline` command to the server. -func (c *rpcClient) SetReadDeadline(id uint16, t time.Time) error { - req := appserver.DeadlineReq{ +func (c *rpcIngressClient) SetReadDeadline(id uint16, t time.Time) error { + req := DeadlineReq{ ConnID: id, Deadline: t, } @@ -136,8 +135,8 @@ func (c *rpcClient) SetReadDeadline(id uint16, t time.Time) error { } // SetWriteDeadline sends `SetWriteDeadline` command to the server. -func (c *rpcClient) SetWriteDeadline(id uint16, t time.Time) error { - req := appserver.DeadlineReq{ +func (c *rpcIngressClient) SetWriteDeadline(id uint16, t time.Time) error { + req := DeadlineReq{ ConnID: id, Deadline: t, } @@ -146,7 +145,7 @@ func (c *rpcClient) SetWriteDeadline(id uint16, t time.Time) error { } // formatMethod formats complete RPC method signature. -func (c *rpcClient) formatMethod(method string) string { +func (c *rpcIngressClient) formatMethod(method string) string { const methodFmt = "%s.%s" - return fmt.Sprintf(methodFmt, c.appKey.String(), method) + return fmt.Sprintf(methodFmt, c.procKey.String(), method) } diff --git a/pkg/app/rpc_client_test.go b/pkg/app/appserver/rpc_ingress_client_test.go similarity index 91% rename from pkg/app/rpc_client_test.go rename to pkg/app/appserver/rpc_ingress_client_test.go index b95abe312d..6b7201ee5e 100644 --- a/pkg/app/rpc_client_test.go +++ b/pkg/app/appserver/rpc_ingress_client_test.go @@ -1,4 +1,4 @@ -package app +package appserver import ( "context" @@ -21,7 +21,6 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/routing" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" - "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" ) func TestRPCClient_Dial(t *testing.T) { @@ -30,7 +29,7 @@ func TestRPCClient_Dial(t *testing.T) { rpcL, closeL := prepListener(t) defer closeL() - rpcS := prepRPCServer(t, appserver.NewRPCGateway(nil)) + rpcS := prepRPCServer(t, NewRPCGateway(nil)) go rpcS.Accept(rpcL) rpcC := prepRPCClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) @@ -56,7 +55,7 @@ func TestRPCClient_Dial(t *testing.T) { }) t.Run("dial error", func(t *testing.T) { - s := prepRPCServer(t, appserver.NewRPCGateway(nil)) + s := prepRPCServer(t, NewRPCGateway(nil)) rpcL, lisCleanup := prepListener(t) defer lisCleanup() go s.Accept(rpcL) @@ -86,7 +85,7 @@ func TestRPCClient_Dial(t *testing.T) { func TestRPCClient_Listen(t *testing.T) { t.Run("ok", func(t *testing.T) { - s := prepRPCServer(t, appserver.NewRPCGateway(nil)) + s := prepRPCServer(t, NewRPCGateway(nil)) rpcL, lisCleanup := prepListener(t) defer lisCleanup() go s.Accept(rpcL) @@ -112,7 +111,7 @@ func TestRPCClient_Listen(t *testing.T) { }) t.Run("listen error", func(t *testing.T) { - s := prepRPCServer(t, appserver.NewRPCGateway(nil)) + s := prepRPCServer(t, NewRPCGateway(nil)) rpcL, lisCleanup := prepListener(t) defer lisCleanup() go s.Accept(rpcL) @@ -143,7 +142,7 @@ func TestRPCClient_Accept(t *testing.T) { dmsgLocal, dmsgRemote, local, _ := prepAddrs() t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) lisConn := &appcommon.MockConn{} lisConn.On("LocalAddr").Return(dmsgLocal) @@ -178,7 +177,7 @@ func TestRPCClient_Accept(t *testing.T) { }) t.Run("accept error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) var lisConn net.Conn listenErr := errors.New("accept error") @@ -211,7 +210,7 @@ func TestRPCClient_Write(t *testing.T) { dmsgLocal, dmsgRemote, _, remote := prepAddrs() t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} writeN := 10 @@ -224,7 +223,7 @@ func TestRPCClient_Write(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -241,7 +240,7 @@ func TestRPCClient_Write(t *testing.T) { }) t.Run("write error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} writeN := 0 @@ -254,7 +253,7 @@ func TestRPCClient_Write(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -276,7 +275,7 @@ func TestRPCClient_Read(t *testing.T) { dmsgLocal, dmsgRemote, _, remote := prepAddrs() t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) readBufLen := 10 readBuf := make([]byte, readBufLen) @@ -290,7 +289,7 @@ func TestRPCClient_Read(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -307,7 +306,7 @@ func TestRPCClient_Read(t *testing.T) { }) t.Run("read error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) readBufLen := 10 readBuf := make([]byte, readBufLen) @@ -321,7 +320,7 @@ func TestRPCClient_Read(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -343,7 +342,7 @@ func TestRPCClient_CloseConn(t *testing.T) { dmsgLocal, dmsgRemote, _, remote := prepAddrs() t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) var noErr error @@ -354,7 +353,7 @@ func TestRPCClient_CloseConn(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -370,7 +369,7 @@ func TestRPCClient_CloseConn(t *testing.T) { }) t.Run("close error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) closeErr := errors.New("close error") @@ -381,7 +380,7 @@ func TestRPCClient_CloseConn(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -402,7 +401,7 @@ func TestRPCClient_CloseListener(t *testing.T) { _, _, local, _ := prepAddrs() t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) var noErr error @@ -427,7 +426,7 @@ func TestRPCClient_CloseListener(t *testing.T) { }) t.Run("close error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) closeErr := errors.New("close error") @@ -459,7 +458,7 @@ func TestRPCClient_SetDeadline(t *testing.T) { deadline := time.Now().Add(1 * time.Hour) t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetDeadline", mock.Anything).Return(func(d time.Time) error { @@ -474,7 +473,7 @@ func TestRPCClient_SetDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -490,7 +489,7 @@ func TestRPCClient_SetDeadline(t *testing.T) { }) t.Run("set deadline error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetDeadline", mock.Anything).Return(func(d time.Time) error { @@ -505,7 +504,7 @@ func TestRPCClient_SetDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -528,7 +527,7 @@ func TestRPCClient_SetReadDeadline(t *testing.T) { deadline := time.Now().Add(1 * time.Hour) t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetReadDeadline", mock.Anything).Return(func(d time.Time) error { @@ -543,7 +542,7 @@ func TestRPCClient_SetReadDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -559,7 +558,7 @@ func TestRPCClient_SetReadDeadline(t *testing.T) { }) t.Run("set deadline error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetReadDeadline", mock.Anything).Return(func(d time.Time) error { @@ -574,7 +573,7 @@ func TestRPCClient_SetReadDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -597,7 +596,7 @@ func TestRPCClient_SetWriteDeadline(t *testing.T) { deadline := time.Now().Add(1 * time.Hour) t.Run("ok", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetWriteDeadline", mock.Anything).Return(func(d time.Time) error { @@ -612,7 +611,7 @@ func TestRPCClient_SetWriteDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -628,7 +627,7 @@ func TestRPCClient_SetWriteDeadline(t *testing.T) { }) t.Run("set deadline error", func(t *testing.T) { - gateway := appserver.NewRPCGateway(nil) + gateway := NewRPCGateway(nil) conn := &appcommon.MockConn{} conn.On("SetWriteDeadline", mock.Anything).Return(func(d time.Time) error { @@ -643,7 +642,7 @@ func TestRPCClient_SetWriteDeadline(t *testing.T) { prepNetworkerWithConn(t, conn, remote) - var dialResp appserver.DialResp + var dialResp DialResp err := gateway.Dial(&remote, &dialResp) require.NoError(t, err) @@ -686,7 +685,7 @@ func prepNetworkerWithConn(t *testing.T, conn *appcommon.MockConn, remote appnet // rpcProcKey is shared by prepRPCServer and prepRPCClient var rpcProcKey = appcommon.RandProcKey() -func prepRPCServer(t *testing.T, gateway *appserver.RPCGateway) *rpc.Server { +func prepRPCServer(t *testing.T, gateway *RPCIngressGateway) *rpc.Server { s := rpc.NewServer() err := s.RegisterName(rpcProcKey.String(), gateway) require.NoError(t, err) @@ -694,11 +693,11 @@ func prepRPCServer(t *testing.T, gateway *appserver.RPCGateway) *rpc.Server { return s } -func prepRPCClient(t *testing.T, network, addr string) RPCClient { +func prepRPCClient(t *testing.T, network, addr string) RPCIngressClient { rpcCl, err := rpc.Dial(network, addr) require.NoError(t, err) - return NewRPCClient(rpcCl, rpcProcKey) + return NewRPCIngressClient(rpcCl, rpcProcKey) } func prepListener(t *testing.T) (lis net.Listener, cleanup func()) { diff --git a/pkg/app/appserver/rpc_gateway.go b/pkg/app/appserver/rpc_ingress_gateway.go similarity index 84% rename from pkg/app/appserver/rpc_gateway.go rename to pkg/app/appserver/rpc_ingress_gateway.go index df45b8f49f..fc017f6084 100644 --- a/pkg/app/appserver/rpc_gateway.go +++ b/pkg/app/appserver/rpc_ingress_gateway.go @@ -55,19 +55,19 @@ func (e *RPCIOErr) ToError() error { } } -// RPCGateway is a RPC interface for the app server. -type RPCGateway struct { +// RPCIngressGateway is a RPC interface for the app server. +type RPCIngressGateway struct { lm *idmanager.Manager // contains listeners associated with their IDs cm *idmanager.Manager // contains connections associated with their IDs log *logging.Logger } // NewRPCGateway constructs new server RPC interface. -func NewRPCGateway(log *logging.Logger) *RPCGateway { +func NewRPCGateway(log *logging.Logger) *RPCIngressGateway { if log == nil { - log = logging.MustGetLogger("app_rpc_gateway") + log = logging.MustGetLogger("app_rpc_ingress_gateway") } - return &RPCGateway{ + return &RPCIngressGateway{ lm: idmanager.New(), cm: idmanager.New(), log: log, @@ -81,7 +81,7 @@ type DialResp struct { } // Dial dials to the remote. -func (r *RPCGateway) Dial(remote *appnet.Addr, resp *DialResp) (err error) { +func (r *RPCIngressGateway) Dial(remote *appnet.Addr, resp *DialResp) (err error) { defer rpcutil.LogCall(r.log, "Dial", remote)(resp, &err) reservedConnID, free, err := r.cm.ReserveNextID() @@ -118,7 +118,7 @@ func (r *RPCGateway) Dial(remote *appnet.Addr, resp *DialResp) (err error) { } // Listen starts listening. -func (r *RPCGateway) Listen(local *appnet.Addr, lisID *uint16) (err error) { +func (r *RPCIngressGateway) Listen(local *appnet.Addr, lisID *uint16) (err error) { defer rpcutil.LogCall(r.log, "Listen", local)(lisID, &err) nextLisID, free, err := r.lm.ReserveNextID() @@ -151,7 +151,7 @@ type AcceptResp struct { } // Accept accepts connection from the listener specified by `lisID`. -func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) (err error) { +func (r *RPCIngressGateway) Accept(lisID *uint16, resp *AcceptResp) (err error) { defer rpcutil.LogCall(r.log, "Accept", lisID)(resp, &err) log := r.log.WithField("func", "Accept") @@ -211,7 +211,7 @@ type WriteResp struct { } // Write writes to the connection. -func (r *RPCGateway) Write(req *WriteReq, resp *WriteResp) error { +func (r *RPCIngressGateway) Write(req *WriteReq, resp *WriteResp) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -238,7 +238,7 @@ type ReadResp struct { } // Read reads data from connection specified by `connID`. -func (r *RPCGateway) Read(req *ReadReq, resp *ReadResp) error { +func (r *RPCIngressGateway) Read(req *ReadReq, resp *ReadResp) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -262,7 +262,7 @@ func (r *RPCGateway) Read(req *ReadReq, resp *ReadResp) error { } // CloseConn closes connection specified by `connID`. -func (r *RPCGateway) CloseConn(connID *uint16, _ *struct{}) (err error) { +func (r *RPCIngressGateway) CloseConn(connID *uint16, _ *struct{}) (err error) { defer rpcutil.LogCall(r.log, "CloseConn", connID)(nil, &err) conn, err := r.popConn(*connID) @@ -274,7 +274,7 @@ func (r *RPCGateway) CloseConn(connID *uint16, _ *struct{}) (err error) { } // CloseListener closes listener specified by `lisID`. -func (r *RPCGateway) CloseListener(lisID *uint16, _ *struct{}) (err error) { +func (r *RPCIngressGateway) CloseListener(lisID *uint16, _ *struct{}) (err error) { defer rpcutil.LogCall(r.log, "CloseConn", lisID)(nil, &err) lis, err := r.popListener(*lisID) @@ -292,7 +292,7 @@ type DeadlineReq struct { } // SetDeadline sets deadline for connection specified by `connID`. -func (r *RPCGateway) SetDeadline(req *DeadlineReq, _ *struct{}) error { +func (r *RPCIngressGateway) SetDeadline(req *DeadlineReq, _ *struct{}) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -302,7 +302,7 @@ func (r *RPCGateway) SetDeadline(req *DeadlineReq, _ *struct{}) error { } // SetReadDeadline sets read deadline for connection specified by `connID`. -func (r *RPCGateway) SetReadDeadline(req *DeadlineReq, _ *struct{}) error { +func (r *RPCIngressGateway) SetReadDeadline(req *DeadlineReq, _ *struct{}) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -312,7 +312,7 @@ func (r *RPCGateway) SetReadDeadline(req *DeadlineReq, _ *struct{}) error { } // SetWriteDeadline sets read deadline for connection specified by `connID`. -func (r *RPCGateway) SetWriteDeadline(req *DeadlineReq, _ *struct{}) error { +func (r *RPCIngressGateway) SetWriteDeadline(req *DeadlineReq, _ *struct{}) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -323,7 +323,7 @@ func (r *RPCGateway) SetWriteDeadline(req *DeadlineReq, _ *struct{}) error { // popListener gets listener from the manager by `lisID` and removes it. // Handles type assertion. -func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { +func (r *RPCIngressGateway) popListener(lisID uint16) (net.Listener, error) { lisIfc, err := r.lm.Pop(lisID) if err != nil { return nil, fmt.Errorf("no listener: %v", err) @@ -334,7 +334,7 @@ func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { // popConn gets conn from the manager by `connID` and removes it. // Handles type assertion. -func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { +func (r *RPCIngressGateway) popConn(connID uint16) (net.Conn, error) { connIfc, err := r.cm.Pop(connID) if err != nil { return nil, fmt.Errorf("no conn: %v", err) @@ -344,7 +344,7 @@ func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { } // getListener gets listener from the manager by `lisID`. Handles type assertion. -func (r *RPCGateway) getListener(lisID uint16) (net.Listener, error) { +func (r *RPCIngressGateway) getListener(lisID uint16) (net.Listener, error) { lisIfc, ok := r.lm.Get(lisID) if !ok { return nil, fmt.Errorf("no listener with key %d", lisID) @@ -354,7 +354,7 @@ func (r *RPCGateway) getListener(lisID uint16) (net.Listener, error) { } // getConn gets conn from the manager by `connID`. Handles type assertion. -func (r *RPCGateway) getConn(connID uint16) (net.Conn, error) { +func (r *RPCIngressGateway) getConn(connID uint16) (net.Conn, error) { connIfc, ok := r.cm.Get(connID) if !ok { return nil, fmt.Errorf("no conn with key %d", connID) diff --git a/pkg/app/appserver/rpc_gateway_test.go b/pkg/app/appserver/rpc_ingress_gateway_test.go similarity index 99% rename from pkg/app/appserver/rpc_gateway_test.go rename to pkg/app/appserver/rpc_ingress_gateway_test.go index 36a01b7f38..d88b6100e9 100644 --- a/pkg/app/appserver/rpc_gateway_test.go +++ b/pkg/app/appserver/rpc_ingress_gateway_test.go @@ -1026,7 +1026,7 @@ func prepAddr(nType appnet.Type) appnet.Addr { } } -func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { +func addConn(t *testing.T, rpc *RPCIngressGateway, conn net.Conn) uint16 { connID, _, err := rpc.cm.ReserveNextID() require.NoError(t, err) @@ -1036,7 +1036,7 @@ func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { return *connID } -func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { +func addListener(t *testing.T, rpc *RPCIngressGateway, lis net.Listener) uint16 { lisID, _, err := rpc.lm.ReserveNextID() require.NoError(t, err) diff --git a/pkg/app/client.go b/pkg/app/client.go index 0b9686177b..ceb5816782 100644 --- a/pkg/app/client.go +++ b/pkg/app/client.go @@ -1,7 +1,7 @@ package app import ( - "fmt" + "io" "net" "net/rpc" "strings" @@ -9,29 +9,32 @@ import ( "github.com/sirupsen/logrus" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/app/idmanager" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" ) // Client is used by skywire apps. type Client struct { - log logrus.FieldLogger - conf appcommon.ProcConfig - rpc RPCClient - lm *idmanager.Manager // contains listeners associated with their IDs - cm *idmanager.Manager // contains connections associated with their IDs + log logrus.FieldLogger + conf appcommon.ProcConfig + rpcC appserver.RPCIngressClient + lm *idmanager.Manager // contains listeners associated with their IDs + cm *idmanager.Manager // contains connections associated with their IDs + closers []io.Closer // additional things to close on close } // NewClient creates a new Client, panicking on any error. -func NewClient() *Client { +func NewClient(eventSubs *appevent.Subscriber) *Client { log := logrus.New() conf, err := appcommon.ProcConfigFromEnv() if err != nil { log.WithError(err).Fatal("Failed to obtain proc config.") } - client, err := NewClientFromConfig(log, conf) + client, err := NewClientFromConfig(log, conf, eventSubs) if err != nil { log.WithError(err).Panic("Failed to create app client.") } @@ -39,21 +42,19 @@ func NewClient() *Client { } // NewClientFromConfig creates a new client from a given proc config. -func NewClientFromConfig(log logrus.FieldLogger, conf appcommon.ProcConfig) (*Client, error) { - conn, err := net.Dial("tcp", conf.AppSrvAddr) +func NewClientFromConfig(log logrus.FieldLogger, conf appcommon.ProcConfig, subs *appevent.Subscriber) (*Client, error) { + conn, closers, err := appevent.DoReqHandshake(conf, subs) if err != nil { - return nil, fmt.Errorf("failed to dial to app server: %w", err) - } - if _, err := conn.Write(conf.ProcKey[:]); err != nil { - return nil, fmt.Errorf("failed to send proc key back to app server: %w", err) + return nil, err } return &Client{ - log: log, - conf: conf, - rpc: NewRPCClient(rpc.NewClient(conn), conf.ProcKey), - lm: idmanager.New(), - cm: idmanager.New(), + log: log, + conf: conf, + rpcC: appserver.NewRPCIngressClient(rpc.NewClient(conn), conf.ProcKey), + lm: idmanager.New(), + cm: idmanager.New(), + closers: closers, }, nil } @@ -64,14 +65,14 @@ func (c *Client) Config() appcommon.ProcConfig { // Dial dials the remote visor using `remote`. func (c *Client) Dial(remote appnet.Addr) (net.Conn, error) { - connID, localPort, err := c.rpc.Dial(remote) + connID, localPort, err := c.rpcC.Dial(remote) if err != nil { return nil, err } conn := &Conn{ id: connID, - rpc: c.rpc, + rpc: c.rpcC, local: appnet.Addr{ Net: remote.Net, PubKey: c.conf.VisorPK, @@ -109,7 +110,7 @@ func (c *Client) Listen(n appnet.Type, port routing.Port) (net.Listener, error) Port: port, } - lisID, err := c.rpc.Listen(local) + lisID, err := c.rpcC.Listen(local) if err != nil { return nil, err } @@ -117,7 +118,7 @@ func (c *Client) Listen(n appnet.Type, port routing.Port) (net.Listener, error) listener := &Listener{ log: c.log, id: lisID, - rpc: c.rpc, + rpc: c.rpcC, addr: local, cm: idmanager.New(), } @@ -145,41 +146,45 @@ func (c *Client) Listen(n appnet.Type, port routing.Port) (net.Listener, error) // Close closes client/server communication entirely. It closes all open // listeners and connections. func (c *Client) Close() { - var listeners []net.Listener + var ( + listeners []net.Listener + conns []net.Conn + ) + // Fill listeners and connections. c.lm.DoRange(func(_ uint16, v interface{}) bool { lis, err := idmanager.AssertListener(v) if err != nil { c.log.Error(err) return true } - listeners = append(listeners, lis) return true }) - - var conns []net.Conn - c.cm.DoRange(func(_ uint16, v interface{}) bool { conn, err := idmanager.AssertConn(v) if err != nil { c.log.Error(err) return true } - conns = append(conns, conn) return true }) + // Close everything. for _, lis := range listeners { - if err := lis.Close(); err != nil { + if err := lis.Close(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { c.log.WithError(err).Error("Error closing listener.") } } - for _, conn := range conns { if err := conn.Close(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - c.log.WithError(err).Error("Unexpected error while closing conn.") + c.log.WithError(err).Error("Error closing conn.") + } + } + for _, v := range c.closers { + if err := v.Close(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + c.log.WithError(err).Error("Error closing closer.") } } } diff --git a/pkg/app/client_test.go b/pkg/app/client_test.go index 7faa3e10ce..55a25716de 100644 --- a/pkg/app/client_test.go +++ b/pkg/app/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/app/appcommon" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/app/idmanager" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" ) @@ -31,7 +32,7 @@ func TestClient_Dial(t *testing.T) { dialLocalPort := routing.Port(1) var dialErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) cl := prepClient(l, visorPK, rpc) @@ -75,7 +76,7 @@ func TestClient_Dial(t *testing.T) { var closeErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) rpc.On("CloseConn", dialConnID).Return(closeErr) @@ -96,7 +97,7 @@ func TestClient_Dial(t *testing.T) { closeErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) rpc.On("CloseConn", dialConnID).Return(closeErr) @@ -113,7 +114,7 @@ func TestClient_Dial(t *testing.T) { t.Run("dial error", func(t *testing.T) { dialErr := errors.New("dial error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Dial", remote).Return(uint16(0), routing.Port(0), dialErr) cl := prepClient(l, visorPK, rpc) @@ -139,7 +140,7 @@ func TestClient_Listen(t *testing.T) { listenLisID := uint16(1) var listenErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Listen", local).Return(listenLisID, listenErr) cl := prepClient(l, visorPK, rpc) @@ -168,7 +169,7 @@ func TestClient_Listen(t *testing.T) { var closeErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Listen", local).Return(listenLisID, listenErr) rpc.On("CloseListener", listenLisID).Return(closeErr) @@ -188,7 +189,7 @@ func TestClient_Listen(t *testing.T) { closeErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Listen", local).Return(listenLisID, listenErr) rpc.On("CloseListener", listenLisID).Return(closeErr) @@ -205,7 +206,7 @@ func TestClient_Listen(t *testing.T) { t.Run("listen error", func(t *testing.T) { listenErr := errors.New("listen error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Listen", local).Return(uint16(0), listenErr) cl := prepClient(l, visorPK, rpc) @@ -225,7 +226,7 @@ func TestClient_Close(t *testing.T) { closeErr = errors.New("close error") ) - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} lisID1 := uint16(1) lisID2 := uint16(2) @@ -284,7 +285,7 @@ func TestClient_Close(t *testing.T) { require.False(t, ok) } -func prepClient(l *logging.Logger, visorPK cipher.PubKey, rpc RPCClient) *Client { +func prepClient(l *logging.Logger, visorPK cipher.PubKey, rpc appserver.RPCIngressClient) *Client { var procKey appcommon.ProcKey copy(procKey[:], visorPK[:]) return &Client{ @@ -300,8 +301,8 @@ func prepClient(l *logging.Logger, visorPK cipher.PubKey, rpc RPCClient) *Client BinaryLoc: "", LogDBLoc: "", }, - rpc: rpc, - lm: idmanager.New(), - cm: idmanager.New(), + rpcC: rpc, + lm: idmanager.New(), + cm: idmanager.New(), } } diff --git a/pkg/app/conn.go b/pkg/app/conn.go index ee70225b9e..f86aa48a8d 100644 --- a/pkg/app/conn.go +++ b/pkg/app/conn.go @@ -9,13 +9,14 @@ import ( "time" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" ) // Conn is a connection from app client to the server. // Implements `net.Conn`. type Conn struct { id uint16 - rpc RPCClient + rpc appserver.RPCIngressClient local appnet.Addr remote appnet.Addr freeConn func() bool diff --git a/pkg/app/conn_test.go b/pkg/app/conn_test.go index ab72451380..d245ffea04 100644 --- a/pkg/app/conn_test.go +++ b/pkg/app/conn_test.go @@ -45,7 +45,7 @@ func TestConn_Read(t *testing.T) { for _, tc := range tt { tc := tc t.Run(tc.name, func(t *testing.T) { - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readErr) conn := &Conn{ @@ -84,7 +84,7 @@ func TestConn_Write(t *testing.T) { for _, tc := range tt { tc := tc t.Run(tc.name, func(t *testing.T) { - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Write", connID, tc.writeBuff).Return(tc.writeN, tc.writeErr) conn := &Conn{ @@ -105,7 +105,7 @@ func TestConn_Close(t *testing.T) { var noErr error t.Run("ok", func(t *testing.T) { - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseConn", connID).Return(noErr) conn := &Conn{ @@ -121,7 +121,7 @@ func TestConn_Close(t *testing.T) { t.Run("close error", func(t *testing.T) { closeErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseConn", connID).Return(closeErr) conn := &Conn{ @@ -135,7 +135,7 @@ func TestConn_Close(t *testing.T) { }) t.Run("already closed", func(t *testing.T) { - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseConn", connID).Return(noErr) conn := &Conn{ @@ -252,9 +252,9 @@ func TestConn_TestConn(t *testing.T) { conf: appcommon.ProcConfig{ VisorPK: keys[0].PK, }, - rpc: NewRPCClient(rpcCl1, procKey1), - lm: idmanager.New(), - cm: idmanager.New(), + rpcC: appserver.NewRPCIngressClient(rpcCl1, procKey1), + lm: idmanager.New(), + cm: idmanager.New(), } rpcCl2, err := rpc.Dial(rpcL.Addr().Network(), rpcL.Addr().String()) @@ -267,9 +267,9 @@ func TestConn_TestConn(t *testing.T) { conf: appcommon.ProcConfig{ VisorPK: keys[1].PK, }, - rpc: NewRPCClient(rpcCl2, procKey2), - lm: idmanager.New(), - cm: idmanager.New(), + rpcC: appserver.NewRPCIngressClient(rpcCl2, procKey2), + lm: idmanager.New(), + cm: idmanager.New(), } c1, err := cl1.Dial(a2) diff --git a/pkg/app/listener.go b/pkg/app/listener.go index 404ce7b81d..a96877b44f 100644 --- a/pkg/app/listener.go +++ b/pkg/app/listener.go @@ -8,6 +8,7 @@ import ( "github.com/sirupsen/logrus" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/app/idmanager" ) @@ -16,7 +17,7 @@ import ( type Listener struct { log logrus.FieldLogger id uint16 - rpc RPCClient + rpc appserver.RPCIngressClient addr appnet.Addr cm *idmanager.Manager // contains conns associated with their IDs freeLis func() bool diff --git a/pkg/app/listener_test.go b/pkg/app/listener_test.go index 08cffa450f..4084f7f3f4 100644 --- a/pkg/app/listener_test.go +++ b/pkg/app/listener_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/app/idmanager" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" ) @@ -34,7 +35,7 @@ func TestListener_Accept(t *testing.T) { } var acceptErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) lis := &Listener{ @@ -83,7 +84,7 @@ func TestListener_Accept(t *testing.T) { var closeErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) rpc.On("CloseConn", acceptConnID).Return(closeErr) @@ -115,7 +116,7 @@ func TestListener_Accept(t *testing.T) { closeErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) rpc.On("CloseConn", acceptConnID).Return(closeErr) @@ -140,7 +141,7 @@ func TestListener_Accept(t *testing.T) { acceptRemote := appnet.Addr{} acceptErr := errors.New("accept error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("Accept", lisID).Return(acceptConnID, acceptRemote, acceptErr) lis := &Listener{ @@ -172,7 +173,7 @@ func TestListener_Close(t *testing.T) { var closeNoErr error closeErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseListener", lisID).Return(closeNoErr) cm := idmanager.New() @@ -225,7 +226,7 @@ func TestListener_Close(t *testing.T) { t.Run("close error", func(t *testing.T) { lisCloseErr := errors.New("close error") - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseListener", lisID).Return(lisCloseErr) lis := &Listener{ @@ -244,7 +245,7 @@ func TestListener_Close(t *testing.T) { t.Run("already closed", func(t *testing.T) { var noErr error - rpc := &MockRPCClient{} + rpc := &appserver.MockRPCIngressClient{} rpc.On("CloseListener", lisID).Return(noErr) lis := &Listener{ diff --git a/pkg/snet/network.go b/pkg/snet/network.go index 6dca166d1a..7ec1adbb60 100644 --- a/pkg/snet/network.go +++ b/pkg/snet/network.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" "github.com/SkycoinProject/skywire-mainnet/pkg/snet/stcp" "github.com/SkycoinProject/skycoin/src/util/logging" @@ -81,16 +82,29 @@ type Network struct { } // New creates a network from a config. -func New(conf Config) *Network { +func New(conf Config, eb *appevent.Broadcaster) *Network { var dmsgC *dmsg.Client var stcpC *stcp.Client if conf.Dmsg != nil { - c := &dmsg.Config{ + dmsgConf := &dmsg.Config{ MinSessions: conf.Dmsg.SessionsCount, + Callbacks: &dmsg.ClientCallbacks{ + OnSessionDial: func(network, addr string) error { + data := appevent.TCPDialData{RemoteNet: network, RemoteAddr: addr} + event := appevent.NewEvent(appevent.TCPDial, data) + _ = eb.Broadcast(context.Background(), event) //nolint:errcheck + // @evanlinjin: An error is not returned here as this will cancel the session dial. + return nil + }, + OnSessionDisconnect: func(network, addr string, _ error) { + data := appevent.TCPCloseData{RemoteNet: network, RemoteAddr: addr} + event := appevent.NewEvent(appevent.TCPClose, data) + _ = eb.Broadcast(context.Background(), event) //nolint:errcheck + }, + }, } - - dmsgC = dmsg.NewClient(conf.PubKey, conf.SecKey, disc.NewHTTP(conf.Dmsg.Discovery), c) + dmsgC = dmsg.NewClient(conf.PubKey, conf.SecKey, disc.NewHTTP(conf.Dmsg.Discovery), dmsgConf) dmsgC.SetLogger(logging.MustGetLogger("snet.dmsgC")) } diff --git a/pkg/visor/init.go b/pkg/visor/init.go index 31e4a53919..01eafcbf3c 100644 --- a/pkg/visor/init.go +++ b/pkg/visor/init.go @@ -16,6 +16,7 @@ import ( "github.com/SkycoinProject/skywire-mainnet/internal/utclient" "github.com/SkycoinProject/skywire-mainnet/internal/vpn" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appdisc" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/app/launcher" "github.com/SkycoinProject/skywire-mainnet/pkg/routefinder/rfclient" @@ -31,6 +32,21 @@ import ( type initFunc func(v *Visor) bool +func initStack() []initFunc { + return []initFunc{ + initUpdater, + initEventBroadcaster, + initSNet, + initDmsgpty, + initTransport, + initRouter, + initLauncher, + initCLI, + initHypervisors, + initUptimeTracker, + } +} + func initUpdater(v *Visor) bool { report := v.makeReporter("updater") @@ -45,15 +61,31 @@ func initUpdater(v *Visor) bool { return report(nil) } +func initEventBroadcaster(v *Visor) bool { + report := v.makeReporter("event_broadcaster") + + log := v.MasterLogger().PackageLogger("event_broadcaster") + const ebcTimeout = time.Second + ebc := appevent.NewBroadcaster(log, ebcTimeout) + + v.pushCloseStack("event_broadcaster", func() bool { + return report(ebc.Close()) + }) + + v.ebc = ebc + return report(nil) +} + func initSNet(v *Visor) bool { report := v.makeReporter("snet") - n := snet.New(snet.Config{ + conf := snet.Config{ PubKey: v.conf.PK, SecKey: v.conf.SK, Dmsg: v.conf.Dmsg, STCP: v.conf.STCP, - }) + } + n := snet.New(conf, v.ebc) if err := n.Init(); err != nil { return report(err) } @@ -177,7 +209,7 @@ func initLauncher(v *Visor) bool { } // Prepare proc manager. - procM, err := appserver.NewProcManager(v.MasterLogger(), &factory, conf.ServerAddr) + procM, err := appserver.NewProcManager(v.MasterLogger(), &factory, v.ebc, conf.ServerAddr) if err != nil { return report(fmt.Errorf("failed to start proc_manager: %w", err)) } @@ -220,11 +252,12 @@ func makeVPNEnvs(conf *visorconfig.V1, n *snet.Network) ([]string, error) { r := netutil.NewRetrier(logrus.New(), 1*time.Second, 10*time.Second, 0, 1) err := r.Do(context.Background(), func() error { - envCfg.DmsgServers = n.Dmsg().ConnectedServers() + for _, ses := range n.Dmsg().AllSessions() { + envCfg.DmsgServers = append(envCfg.DmsgServers, ses.LocalTCPAddr().String()) + } if len(envCfg.DmsgServers) == 0 { - return errors.New("no Dmsg servers found") + return errors.New("no dmsg servers found") } - return nil }) if err != nil { @@ -330,7 +363,7 @@ func initUptimeTracker(v *Visor) bool { ut, err := utclient.NewHTTP(conf.Addr, v.conf.PK, v.conf.SK) if err != nil { // TODO(evanlinjin): We should design utclient to retry automatically instead of returning error. - //return report(err) + // return report(err) v.log.WithError(err).Warn("Failed to connect to uptime tracker.") return true } diff --git a/pkg/visor/init_unix.go b/pkg/visor/init_unix.go index afa10df963..e4b3b04f5d 100644 --- a/pkg/visor/init_unix.go +++ b/pkg/visor/init_unix.go @@ -14,24 +14,6 @@ import ( "github.com/SkycoinProject/dmsg/dmsgpty" ) -func initStack() []initFunc { - return []initFunc{ - initUpdater, - initSNet, - initDmsgpty, - initTransport, - initRouter, - initLauncher, - initCLI, - initHypervisors, - initUptimeTracker, - } -} - -type pty struct { - pty *dmsgpty.Host -} - func initDmsgpty(v *Visor) bool { report := v.makeReporter("dmsgpty") conf := v.conf.Dmsgpty @@ -116,6 +98,5 @@ func initDmsgpty(v *Visor) bool { }) } - v.pty.pty = pty return report(nil) } diff --git a/pkg/visor/init_windows.go b/pkg/visor/init_windows.go index 7c93d6a999..617b5bc89c 100644 --- a/pkg/visor/init_windows.go +++ b/pkg/visor/init_windows.go @@ -2,18 +2,15 @@ package visor -type pty struct { -} +func initDmsgpty(v *Visor) bool { + report := v.makeReporter("dmsgpty") + conf := v.conf.Dmsgpty -func initStack() []initFunc { - return []initFunc{ - initUpdater, - initSNet, - initTransport, - initRouter, - initLauncher, - initCLI, - initHypervisors, - initUptimeTracker, + if conf == nil { + v.log.Info("'dmsgpty' is not configured, skipping.") + return report(nil) } + + v.log.Error("dmsgpty is not supported on windows.") + return report(nil) } diff --git a/pkg/visor/visor.go b/pkg/visor/visor.go index 03d04ec796..82b2be7eae 100644 --- a/pkg/visor/visor.go +++ b/pkg/visor/visor.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/SkycoinProject/skywire-mainnet/pkg/app/appevent" "github.com/SkycoinProject/skywire-mainnet/pkg/visor/visorconfig" "github.com/SkycoinProject/dmsg/cipher" @@ -58,8 +59,9 @@ type Visor struct { restartCtx *restart.Context updater *updater.Updater + ebc *appevent.Broadcaster // event broadcaster + net *snet.Network - pty pty tpM *transport.Manager router router.Router diff --git a/pkg/visor/visorconfig/README.md b/pkg/visor/visorconfig/README.md index 0959981e8a..893b7907fa 100644 --- a/pkg/visor/visorconfig/README.md +++ b/pkg/visor/visorconfig/README.md @@ -14,13 +14,17 @@ - `restart_check_delay` (string) -# V1Launcher +# V1UptimeTracker + +- `addr` (string) -- `discovery` (*[V1AppDisc](#V1AppDisc)) -- `apps` ([][AppConfig](#AppConfig)) -- `server_addr` (string) -- `bin_path` (string) -- `local_path` (string) + +# V1Dmsgpty + +- `port` (uint16) +- `authorization_file` (string) +- `cli_network` (string) +- `cli_address` (string) # V1Transport @@ -30,15 +34,13 @@ - `trusted_visors` () -# V1AppDisc - -- `update_interval` (Duration) -- `proxy_discovery_addr` (string) - - -# V1UptimeTracker +# V1Launcher -- `addr` (string) +- `discovery` (*[V1AppDisc](#V1AppDisc)) +- `apps` ([][AppConfig](#AppConfig)) +- `server_addr` (string) +- `bin_path` (string) +- `local_path` (string) # V1Routing @@ -54,20 +56,16 @@ - `location` (string) -# V1Dmsgpty +# V1AppDisc -- `port` (uint16) -- `authorization_file` (string) -- `cli_network` (string) -- `cli_address` (string) +- `update_interval` (Duration) +- `proxy_discovery_addr` (string) -# AppConfig +# STCPConfig -- `name` (string) -- `args` ([]string) -- `auto_start` (bool) -- `port` (Port) +- `pk_table` () +- `local_address` (string) # DmsgConfig @@ -76,7 +74,9 @@ - `sessions_count` (int) -# STCPConfig +# AppConfig -- `pk_table` () -- `local_address` (string) +- `name` (string) +- `args` ([]string) +- `auto_start` (bool) +- `port` (Port) diff --git a/vendor/github.com/SkycoinProject/dmsg/.gitignore b/vendor/github.com/SkycoinProject/dmsg/.gitignore index 3aa72e870b..7084aa6b1c 100644 --- a/vendor/github.com/SkycoinProject/dmsg/.gitignore +++ b/vendor/github.com/SkycoinProject/dmsg/.gitignore @@ -12,8 +12,8 @@ .idea/ bin/ -dmsg-discovery -dmsg-server -dmsgpty-cli -dmsgpty-host -dmsgpty-ui +/dmsg-discovery +/dmsg-server +/dmsgpty-cli +/dmsgpty-host +/dmsgpty-ui diff --git a/vendor/github.com/SkycoinProject/dmsg/client.go b/vendor/github.com/SkycoinProject/dmsg/client.go index 50816efdaa..c7393662c3 100644 --- a/vendor/github.com/SkycoinProject/dmsg/client.go +++ b/vendor/github.com/SkycoinProject/dmsg/client.go @@ -19,9 +19,32 @@ import ( // TODO(evanlinjin): We should implement exponential backoff at some point. const serveWait = time.Second +// SessionDialCallback is triggered BEFORE a session is dialed to. +// If a non-nil error is returned, the session dial is instantly terminated. +type SessionDialCallback func(network, addr string) (err error) + +// SessionDisconnectCallback triggers after a session is closed. +type SessionDisconnectCallback func(network, addr string, err error) + +// ClientCallbacks contains callbacks which a Client uses. +type ClientCallbacks struct { + OnSessionDial SessionDialCallback + OnSessionDisconnect SessionDisconnectCallback +} + +func (sc *ClientCallbacks) ensure() { + if sc.OnSessionDial == nil { + sc.OnSessionDial = func(network, addr string) (err error) { return nil } + } + if sc.OnSessionDisconnect == nil { + sc.OnSessionDisconnect = func(network, addr string, err error) {} + } +} + // Config configures a dmsg client entity. type Config struct { MinSessions int + Callbacks *ClientCallbacks } // PrintWarnings prints warnings with config. @@ -51,11 +74,7 @@ type Client struct { errCh chan error done chan struct{} once sync.Once - sesMx sync.Mutex - - connectedServersMx sync.Mutex - connectedServers map[string]struct{} } // NewClient creates a dmsg client entity. @@ -82,6 +101,10 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf if conf == nil { conf = DefaultConfig() } + if conf.Callbacks == nil { + conf.Callbacks = new(ClientCallbacks) + } + conf.Callbacks.ensure() c.conf = conf c.conf.PrintWarnings(c.log) @@ -89,8 +112,6 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf c.errCh = make(chan error, 10) c.done = make(chan struct{}) - c.connectedServers = make(map[string]struct{}) - return c } @@ -255,15 +276,14 @@ func (ce *Client) AllSessions() []ClientSession { } // ConnectedServers obtains all the servers client is connected to. +// +// Deprecated: we can now obtain the remote TCP address of a session from the ClientSession struct directly. func (ce *Client) ConnectedServers() []string { - ce.connectedServersMx.Lock() - defer ce.connectedServersMx.Unlock() - - addrs := make([]string, 0, len(ce.connectedServers)) - for addr := range ce.connectedServers { - addrs = append(addrs, addr) + sessions := ce.allClientSessions(ce.porter) + addrs := make([]string, len(sessions)) + for i, s := range sessions { + addrs[i] = s.RemoteTCPAddr().String() } - return addrs } @@ -305,13 +325,27 @@ func (ce *Client) ensureSession(ctx context.Context, entry *disc.Entry) error { // It is expected that the session is created and served before the context cancels, otherwise an error will be returned. // NOTE: This should not be called directly as it may lead to session duplicates. // Only `ensureSession` or `EnsureAndObtainSession` should call this function. -func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (ClientSession, error) { +func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs ClientSession, err error) { ce.log.WithField("remote_pk", entry.Static).Info("Dialing session...") - conn, err := net.Dial("tcp", entry.Server.Address) + const network = "tcp" + + // Trigger dial callback. + if err := ce.conf.Callbacks.OnSessionDial(network, entry.Server.Address); err != nil { + return ClientSession{}, fmt.Errorf("session dial is rejected by callback: %w", err) + } + defer func() { + if err != nil { + // Trigger disconnect callback when dial fails. + ce.conf.Callbacks.OnSessionDisconnect(network, entry.Server.Address, err) + } + }() + + conn, err := net.Dial(network, entry.Server.Address) if err != nil { return ClientSession{}, err } + dSes, err := makeClientSession(&ce.EntityCommon, ce.porter, conn, entry.Static) if err != nil { return ClientSession{}, err @@ -321,15 +355,19 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (ClientSes _ = dSes.Close() //nolint:errcheck return ClientSession{}, errors.New("session already exists") } - ce.connectedServersMx.Lock() - ce.connectedServers[entry.Server.Address] = struct{}{} - ce.connectedServersMx.Unlock() + go func() { ce.log.WithField("remote_pk", dSes.RemotePK()).Info("Serving session.") - if err := dSes.serve(); !isClosed(ce.done) { + err := dSes.serve() + if !isClosed(ce.done) { + // We should only report an error when client is not closed. + // Also, when the client is closed, it will automatically delete all sessions. ce.errCh <- fmt.Errorf("failed to serve dialed session to %s: %v", dSes.RemotePK(), err) ce.delSession(ctx, dSes.RemotePK()) } + + // Trigger disconnect callback. + ce.conf.Callbacks.OnSessionDisconnect(network, entry.Server.Address, err) }() return dSes, nil diff --git a/vendor/github.com/SkycoinProject/dmsg/session_common.go b/vendor/github.com/SkycoinProject/dmsg/session_common.go index b4a5698e4d..db8fc12240 100644 --- a/vendor/github.com/SkycoinProject/dmsg/session_common.go +++ b/vendor/github.com/SkycoinProject/dmsg/session_common.go @@ -20,11 +20,12 @@ type SessionCommon struct { entity *EntityCommon // back reference rPK cipher.PubKey // remote pk - ys *yamux.Session - ns *noise.Noise - nMap noise.NonceMap - rMx sync.Mutex - wMx sync.Mutex + netConn net.Conn // underlying net.Conn (TCP connection to the dmsg server) + ys *yamux.Session + ns *noise.Noise + nMap noise.NonceMap + rMx sync.Mutex + wMx sync.Mutex log logrus.FieldLogger } @@ -55,6 +56,7 @@ func (sc *SessionCommon) initClient(entity *EntityCommon, conn net.Conn, rPK cip sc.entity = entity sc.rPK = rPK + sc.netConn = conn sc.ys = ySes sc.ns = ns sc.nMap = make(noise.NonceMap) @@ -87,6 +89,7 @@ func (sc *SessionCommon) initServer(entity *EntityCommon, conn net.Conn) error { sc.entity = entity sc.rPK = ns.RemoteStatic() + sc.netConn = conn sc.ys = ySes sc.ns = ns sc.nMap = make(noise.NonceMap) @@ -134,6 +137,12 @@ func (sc *SessionCommon) LocalPK() cipher.PubKey { return sc.entity.pk } // RemotePK returns the remote public key of the session. func (sc *SessionCommon) RemotePK() cipher.PubKey { return sc.rPK } +// LocalTCPAddr returns the local address of the underlying TCP connection. +func (sc *SessionCommon) LocalTCPAddr() net.Addr { return sc.netConn.LocalAddr() } + +// RemoteTCPAddr returns the remote address of the underlying TCP connection. +func (sc *SessionCommon) RemoteTCPAddr() net.Addr { return sc.netConn.RemoteAddr() } + // Close closes the session. func (sc *SessionCommon) Close() (err error) { if sc != nil { diff --git a/vendor/modules.txt b/vendor/modules.txt index f7da5e9add..74f3c8da23 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/SkycoinProject/dmsg v0.1.1-0.20200420091742-8c1a3d828a49 +# github.com/SkycoinProject/dmsg v0.1.1-0.20200523194607-be73f083a729 github.com/SkycoinProject/dmsg github.com/SkycoinProject/dmsg/cipher github.com/SkycoinProject/dmsg/cmdutil