-
Notifications
You must be signed in to change notification settings - Fork 1
/
listener_test.go
185 lines (160 loc) · 4.94 KB
/
listener_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
package tomtp
import (
"crypto/ed25519"
"fmt"
"github.com/stretchr/testify/assert"
"net"
"sync"
"testing"
"time"
)
var (
testPrivateSeed1 = [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
testPrivateSeed2 = [32]byte{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
testPrivateKey1 = ed25519.NewKeyFromSeed(testPrivateSeed1[:])
testPrivateKey2 = ed25519.NewKeyFromSeed(testPrivateSeed2[:])
hexPublicKey1 = fmt.Sprintf("0x%x", testPrivateKey1.Public())
hexPublicKey2 = fmt.Sprintf("0x%x", testPrivateKey2.Public())
)
func TestNewListener(t *testing.T) {
// Test case 1: Create a new listener with a valid address
addr := "localhost:8080"
listener, err := Listen(addr, func(s *Stream) {}, WithSeed(testPrivateSeed1))
defer listener.Close()
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
if listener == nil {
t.Errorf("Expected a listener, but got nil")
}
// Test case 2: Create a new listener with an invalid address
invalidAddr := "localhost:99999"
_, err = Listen(invalidAddr, func(s *Stream) {}, WithSeed(testPrivateSeed1))
if err == nil {
t.Errorf("Expected an error, but got nil")
}
}
func TestNewStream(t *testing.T) {
// Test case 1: Create a new multi-stream with a valid remote address
listener, err := Listen("localhost:9080", func(s *Stream) {}, WithSeed(testPrivateSeed1))
defer listener.Close()
assert.Nil(t, err)
conn, _ := listener.Dial("localhost:9081", hexPublicKey1)
if conn == nil {
t.Errorf("Expected a multi-stream, but got nil")
}
// Test case 2: Create a new multi-stream with an invalid remote address
conn, err = listener.Dial("localhost:99999", hexPublicKey1)
if conn != nil {
t.Errorf("Expected nil, but got a multi-stream")
}
}
func TestClose(t *testing.T) {
// Test case 1: Close a listener with no multi-streams
listener, err := Listen("localhost:9080", func(s *Stream) {}, WithSeed(testPrivateSeed1))
assert.NoError(t, err)
// Test case 2: Close a listener with multi-streams
listener.Dial("localhost:9081", hexPublicKey1)
err, _ = listener.Close()
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}
type ChannelNetworkConn struct {
in chan []byte
out chan *SndSegment[[]byte]
localAddr net.Addr
readDeadline time.Time
messageCounter int // Tracks number of messages sent
cond *sync.Cond // Used to wait for messages
mu sync.Mutex // Protects messageCounter
}
// TestAddr struct implements the Addr interface
type TestAddr struct {
network string
address string
}
// Network returns the network type (e.g., "tcp", "udp")
func (a TestAddr) Network() string {
return a.network
}
// String returns the address in string format
func (a TestAddr) String() string {
return a.address
}
func (c *ChannelNetworkConn) ReadFromUDP(p []byte) (int, net.Addr, error) {
select {
case msg := <-c.in:
copy(p, msg)
return len(msg), TestAddr{
network: "remote-of-" + c.localAddr.Network(),
address: "remote-of-" + c.localAddr.String(),
}, nil
default:
return 0, TestAddr{
network: "remote-of-" + c.localAddr.Network(),
address: "remote-of-" + c.localAddr.String(),
}, nil
}
}
func (c *ChannelNetworkConn) WriteToUDP(p []byte, addr net.Addr) (int, error) {
// Sends the message on the out channel.
c.out <- &SndSegment[[]byte]{data: p}
return len(p), nil
}
func (c *ChannelNetworkConn) Close() error {
close(c.out)
close(c.in)
return nil
}
func (c *ChannelNetworkConn) SetReadDeadline(t time.Time) error {
c.readDeadline = t
return nil
}
func (c *ChannelNetworkConn) LocalAddr() net.Addr {
return c.localAddr
}
// NewTestChannel creates two connected ChannelNetworkConn instances.
func NewTestChannel(localAddr1, localAddr2 net.Addr) (*ChannelNetworkConn, *ChannelNetworkConn) {
// Channels to connect read1-write2 and write1-read2
in1 := make(chan []byte, 1)
out1 := make(chan *SndSegment[[]byte], 1)
in2 := make(chan []byte, 1)
out2 := make(chan *SndSegment[[]byte], 1)
conn1 := &ChannelNetworkConn{
localAddr: localAddr1,
in: in1,
out: out2,
}
conn1.cond = sync.NewCond(&conn1.mu)
conn2 := &ChannelNetworkConn{
localAddr: localAddr2,
in: in2,
out: out1,
}
conn2.cond = sync.NewCond(&conn2.mu)
go forwardMessages(conn1, conn2)
go forwardMessages(conn2, conn1)
return conn1, conn2
}
func forwardMessages(sender, receiver *ChannelNetworkConn) {
for msg := range sender.out {
select {
case receiver.in <- msg.data:
receiver.mu.Lock()
receiver.messageCounter++
receiver.cond.Broadcast()
receiver.mu.Unlock()
default:
// Handle the case where the receiver's input channel is full
// You might want to log this or handle it according to your needs
}
}
}
func (c *ChannelNetworkConn) WaitRcv(nr int) {
c.mu.Lock()
defer c.mu.Unlock()
for c.messageCounter < nr {
c.cond.Wait() // Wait until the desired number of messages is reached
}
}