Skip to content

Commit

Permalink
feat(x): use endpoint instea of dialer for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Jan 30, 2025
1 parent 80b6430 commit c597671
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
14 changes: 7 additions & 7 deletions x/websocket/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ import (
// Websockets, with each Write becoming a separate message. Half-close is supported:
// CloseRead will not close the Websocket connection, while CloseWrite sends a Websocket
// close but continues reading until a close is received from the server.
func NewStreamEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (transport.StreamConn, error), error) {
return newEndpoint(urlStr, sd, func(gc *gorillaConn) transport.StreamConn { return gc }, opts...)
func NewStreamEndpoint(urlStr string, se transport.StreamEndpoint, opts ...Option) (func(context.Context) (transport.StreamConn, error), error) {
return newEndpoint(urlStr, se, func(gc *gorillaConn) transport.StreamConn { return gc }, opts...)
}

// NewPacketEndpoint creates a new Websocket Packet Endpoint. Each packet is exchanged as a Websocket message.
func NewPacketEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (net.Conn, error), error) {
return newEndpoint(urlStr, sd, func(gc *gorillaConn) net.Conn { return gc }, opts...)
func NewPacketEndpoint(urlStr string, se transport.StreamEndpoint, opts ...Option) (func(context.Context) (net.Conn, error), error) {
return newEndpoint(urlStr, se, func(gc *gorillaConn) net.Conn { return gc }, opts...)
}

type options struct {
Expand All @@ -68,7 +68,7 @@ func WithHTTPHeaders(headers http.Header) Option {
}
}

func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, wsToConn func(*gorillaConn) ConnType, opts ...Option) (func(context.Context) (ConnType, error), error) {
func newEndpoint[ConnType net.Conn](urlStr string, se transport.StreamEndpoint, wsToConn func(*gorillaConn) ConnType, opts ...Option) (func(context.Context) (ConnType, error), error) {
_, err := url.Parse(urlStr)
if err != nil {
return nil, fmt.Errorf("url is invalid: %w", err)
Expand All @@ -84,11 +84,11 @@ func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, ws

wsDialer := &websocket.Dialer{
TLSClientConfig: resolvedOpts.tlsConfig,
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
NetDialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
return nil, fmt.Errorf("websocket dialer does not support network type %v", network)
}
return sd.DialStream(ctx, addr)
return se.ConnectStream(ctx)
},
}
return func(ctx context.Context) (ConnType, error) {
Expand Down
6 changes: 4 additions & 2 deletions x/websocket/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func Test_NewStreamEndpoint(t *testing.T) {
// },
// }
client := ts.Client()
connect, err := NewStreamEndpoint("wss"+ts.URL[5:]+"/tcp", &transport.TCPDialer{}, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig))
endpoint := &transport.TCPEndpoint{Address: ts.Listener.Addr().String()}
connect, err := NewStreamEndpoint("wss"+ts.URL[5:]+"/tcp", endpoint, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig))
require.NoError(t, err)
require.NotNil(t, connect)

Expand Down Expand Up @@ -142,7 +143,8 @@ func Test_NewPacketEndpoint(t *testing.T) {
defer ts.Close()

client := ts.Client()
connect, err := NewPacketEndpoint("wss"+ts.URL[5:]+"/udp", &transport.TCPDialer{}, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig))
endpoint := &transport.TCPEndpoint{Address: ts.Listener.Addr().String()}
connect, err := NewPacketEndpoint("wss"+ts.URL[5:]+"/udp", endpoint, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig))
require.NoError(t, err)
require.NotNil(t, connect)

Expand Down

0 comments on commit c597671

Please sign in to comment.