Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go client authentication support #3584

Merged
merged 11 commits into from
Mar 24, 2023
15 changes: 13 additions & 2 deletions go/internal/test_tools/test_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func CheckError(t *testing.T, msg string, err error) {
}

// GetHost returns the host to connect to for the tests.
// By default it is localhost, but can be overriden by setting the DH_HOST environment variable.
// By default it is localhost, but can be overridden by setting the DH_HOST environment variable.
func GetHost() string {
host := os.Getenv("DH_HOST")
if host == "" {
Expand All @@ -79,7 +79,7 @@ func GetHost() string {
}

// GetPort returns the port to connect to for the tests.
// By default it is 10000, but can be overriden by setting the DH_PORT environment variable.
// By default it is 10000, but can be overridden by setting the DH_PORT environment variable.
func GetPort() string {
port := os.Getenv("DH_PORT")
if port == "" {
Expand All @@ -88,3 +88,14 @@ func GetPort() string {
return port
}
}

// GetAuth returns the auth string to connect to for the tests.
// By default it is Anonymous, but can be overridden by setting the DH_AUTH environment variable.
func GetAuth() string {
auth := os.Getenv("DH_AUTH")
if auth == "" {
return "Anonymous"
chipkent marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be Constants.DefaultAuth (or whatever) rather than the literal string "Anonymous"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be using the constant. This is set to match the DH instance that is spun up to test against. We have no guarantee that the instance will be using the default connection mechanism, so I'm not completely comfortable linking them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is weak on this, so if you think it should be changed, I will.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not deeply invested in either choice. I assumed it was an oversight. If not, that's fine.

} else {
return auth
}
}
94 changes: 59 additions & 35 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"sync"

apppb2 "github.com/deephaven/deephaven-core/go/internal/proto/application"
configpb2 "github.com/deephaven/deephaven-core/go/internal/proto/config"
consolepb2 "github.com/deephaven/deephaven-core/go/internal/proto/console"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

// ErrClosedClient is returned as an error when trying to perform a network operation on a client that has been closed.
Expand All @@ -43,54 +43,80 @@ type Client struct {

suppressTableLeakWarning bool // When true, this disables the TableHandle finalizer warning.

sessionStub
consoleStub
flightStub
tableStub
inputTableStub
*sessionStub
*consoleStub
*flightStub
*tableStub
*inputTableStub

appServiceClient apppb2.ApplicationServiceClient
ticketFact ticketFactory
ticketFact *ticketFactory
tokenMgr *tokenManager
}

// NewClient starts a connection to a Deephaven server.
//
// The client should be closed using Close() after it is done being used.
//
// Keepalive messages are sent automatically by the client to the server at a regular interval (~30 seconds)
// Keepalive messages are sent automatically by the client to the server at a regular interval
// so that the connection remains open. The provided context is saved and used to send keepalive messages.
//
// host, port, and auth are used to connect to the Deephaven server. host and port are the Deephaven server host and port.
// auth is the authorization string used to get the first token. Examples:
// - "Anonymous" is used for anonymous authentication.
// - "io.deephaven.authentication.psk.PskAuthenticationHandler <password>" is used for PSK authentication
//
// If auth is set to an empty string, DefaultAuth authentication is used.
// To see what authentication methods are available on the Deephaven server, navigate to: http://<host>:<port>/jsapi/authentication/.
//
// The option arguments can be used to specify other settings for the client.
// See the With<XYZ> methods (e.g. WithConsole) for details on what options are available.
func NewClient(ctx context.Context, host string, port string, options ...ClientOption) (*Client, error) {
func NewClient(ctx context.Context, host string, port string, auth string, options ...ClientOption) (client *Client, err error) {
defer func() {
if err != nil && client != nil {
e := client.Close()
if e != nil {
log.Println("Error when closing failed client: ", e)
}
}
}()

grpcChannel, err := grpc.Dial(host+":"+port, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}

opts := newClientOptions(options...)

client := &Client{grpcChannel: grpcChannel, isOpen: true}
client = &Client{grpcChannel: grpcChannel, isOpen: true}

client.suppressTableLeakWarning = opts.suppressTableLeakWarning

client.ticketFact = newTicketFactory()

client.sessionStub, err = newSessionStub(ctx, client)
client.flightStub, err = newFlightStub(client, host, port)
if err != nil {
client.Close()
return nil, err
}

client.consoleStub, err = newConsoleStub(ctx, client, opts.scriptLanguage)
cfgClient := configpb2.NewConfigServiceClient(grpcChannel)

if auth == "" {
auth = DefaultAuth
}

client.tokenMgr, err = newTokenManager(ctx, client.flightStub, cfgClient, auth)
if err != nil {
client.Close()
return nil, err
}

client.flightStub, err = newFlightStub(client, host, port)
client.sessionStub, err = newSessionStub(client)
if err != nil {
return nil, err
}

client.consoleStub, err = newConsoleStub(ctx, client, opts.scriptLanguage)
if err != nil {
client.Close()
return nil, err
}

Expand Down Expand Up @@ -170,32 +196,30 @@ func (client *Client) Close() error {

client.isOpen = false

client.sessionStub.Close()

if client.grpcChannel != nil {
client.grpcChannel.Close()
client.grpcChannel = nil
if client.tokenMgr != nil {
err := client.tokenMgr.Close()
if err != nil {
log.Println("unable to close client:", err.Error())
return err
}
}

// This is logged because most of the time this method is used with defer,
// which will discard the error value.
err := client.flightStub.Close()
if err != nil {
log.Println("unable to close client:", err.Error())
if client.flightStub != nil {
err := client.flightStub.Close()
if err != nil {
log.Println("unable to close client:", err.Error())
return err
}
}

return err
}

// withToken attaches the current session token to a context as metadata.
func (client *Client) withToken(ctx context.Context) (context.Context, error) {
tok, err := client.getToken()

if err != nil {
return nil, err
if client.grpcChannel != nil {
client.grpcChannel.Close()
client.grpcChannel = nil
}

return metadata.NewOutgoingContext(ctx, metadata.Pairs("authorization", string(tok))), nil
return nil
}

// RunScript executes a script on the deephaven server.
Expand All @@ -207,7 +231,7 @@ func (client *Client) RunScript(ctx context.Context, script string) error {
return ErrNoConsole
}

ctx, err := client.consoleStub.client.withToken(ctx)
ctx, err := client.consoleStub.client.tokenMgr.withToken(ctx)
if err != nil {
return err
}
Expand Down
25 changes: 17 additions & 8 deletions go/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@ import (
func TestConnectError(t *testing.T) {
ctx := context.Background()

_, err := client.NewClient(ctx, "foobar", "1234")
_, err := client.NewClient(ctx, "foobar", "1234", test_tools.GetAuth())
if err == nil {
t.Fatalf("client did not fail to connect")
}
}

func TestAuthError(t *testing.T) {
ctx := context.Background()

_, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), "garbage in")
if err == nil {
t.Fatalf("client did not fail to connect")
}
Expand All @@ -21,7 +30,7 @@ func TestConnectError(t *testing.T) {
func TestClosedClient(t *testing.T) {
ctx := context.Background()

c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
if err != nil {
t.Fatalf("NewClient err %s", err.Error())
}
Expand All @@ -40,7 +49,7 @@ func TestClosedClient(t *testing.T) {
func TestMismatchedScript(t *testing.T) {
ctx := context.Background()

_, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), client.WithConsole("groovy"))
_, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth(), client.WithConsole("groovy"))
if err == nil {
t.Fatalf("client did not fail to connect")
}
Expand All @@ -52,7 +61,7 @@ func TestEmptyTable(t *testing.T) {

ctx := context.Background()

c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
if err != nil {
t.Fatalf("NewClient err %s", err.Error())
}
Expand Down Expand Up @@ -83,7 +92,7 @@ func TestEmptyTable(t *testing.T) {
func TestTimeTable(t *testing.T) {
ctx := context.Background()

c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
if err != nil {
t.Fatalf("NewClient err %s", err.Error())
}
Expand All @@ -110,7 +119,7 @@ func TestTableUpload(t *testing.T) {
defer r.Release()

ctx := context.Background()
s, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
s, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
if err != nil {
t.Fatalf("NewClient err %s", err.Error())
return
Expand Down Expand Up @@ -196,7 +205,7 @@ func waitForTable(ctx context.Context, cl *client.Client, names []string, timeou
func TestFieldSync(t *testing.T) {
ctx := context.Background()

client1, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), client.WithConsole("python"))
client1, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth(), client.WithConsole("python"))
test_tools.CheckError(t, "NewClient", err)
defer client1.Close()

Expand All @@ -206,7 +215,7 @@ gotesttable1 = None
`)
test_tools.CheckError(t, "RunScript", err)

client2, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), client.WithConsole("python"))
client2, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth(), client.WithConsole("python"))
test_tools.CheckError(t, "NewClient", err)
defer client2.Close()

Expand Down
12 changes: 6 additions & 6 deletions go/pkg/client/console_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ type consoleStub struct {
//
// If sessionType is non-empty, it will start a console for use with scripts.
// The sessionType determines what language the scripts will use. It can be either "python" or "groovy" and must match the server language.
func newConsoleStub(ctx context.Context, client *Client, sessionType string) (consoleStub, error) {
ctx, err := client.withToken(ctx)
func newConsoleStub(ctx context.Context, client *Client, sessionType string) (*consoleStub, error) {
ctx, err := client.tokenMgr.withToken(ctx)
if err != nil {
return consoleStub{}, err
return nil, err
}

stub := consolepb2.NewConsoleServiceClient(client.grpcChannel)
Expand All @@ -41,20 +41,20 @@ func newConsoleStub(ctx context.Context, client *Client, sessionType string) (co
req := consolepb2.StartConsoleRequest{ResultId: &reqTicket, SessionType: sessionType}
resp, err := stub.StartConsole(ctx, &req)
if err != nil {
return consoleStub{}, err
return nil, err
}

consoleId = resp.ResultId
}

return consoleStub{client: client, stub: stub, consoleId: consoleId}, nil
return &consoleStub{client: client, stub: stub, consoleId: consoleId}, nil
}

// BindToVariable binds a table reference to a given name on the server so that it can be referenced by other clients or the web UI.
//
// If WithConsole was not passed when creating the client, this will return ErrNoConsole.
func (console *consoleStub) BindToVariable(ctx context.Context, name string, table *TableHandle) error {
ctx, err := console.client.withToken(ctx)
ctx, err := console.client.tokenMgr.withToken(ctx)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions go/pkg/client/console_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func TestOpenTable(t *testing.T) {
ctx := context.Background()

c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), client.WithConsole("python"))
c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth(), client.WithConsole("python"))
test_tools.CheckError(t, "NewClient", err)

err = c.RunScript(ctx,
Expand All @@ -38,7 +38,7 @@ gotesttable = empty_table(42)
func TestNoConsole(t *testing.T) {
ctx := context.Background()

c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
c, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
test_tools.CheckError(t, "NewClient", err)
defer c.Close()

Expand Down
7 changes: 7 additions & 0 deletions go/pkg/client/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package client

// DefaultAuth is the default authentication method.
const DefaultAuth = "Anonymous"

// TokenTimeoutConfigConstant is the configuration constant specifying the token timeout interval.
const TokenTimeoutConfigConstant = "http.session.durationMs"
4 changes: 2 additions & 2 deletions go/pkg/client/example_fetch_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func Example_fetchTable() {

// Let's start a client connection using python as the script language ("groovy" is the other option).
// Note that the client language must match the language the server was started with.
cl, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), client.WithConsole("python"))
cl, err := client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth(), client.WithConsole("python"))
if err != nil {
fmt.Println("error when connecting to server:", err.Error())
return
Expand All @@ -44,7 +44,7 @@ func Example_fetchTable() {
cl.Close()

// Now let's make a new connection, completely unrelated to the old one.
cl, err = client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort())
cl, err = client.NewClient(ctx, test_tools.GetHost(), test_tools.GetPort(), test_tools.GetAuth())
if err != nil {
fmt.Println("error when connecting to localhost port 10000:", err.Error())
return
Expand Down
Loading