Skip to content

Commit

Permalink
update handling for ports in proxy commands (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaWilkes authored Nov 4, 2024
1 parent a64c6b7 commit cebcab0
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 29 deletions.
4 changes: 2 additions & 2 deletions pkg/granted/eks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func OpenKubeConfig() (*api.Config, string, error) {
return config, kubeConfigPath, nil
}

func AddContextToConfig(ensureAccessOutput *proxy.EnsureAccessOutput[*accessv1alpha1.AWSEKSProxyOutput], port string) error {
func AddContextToConfig(ensureAccessOutput *proxy.EnsureAccessOutput[*accessv1alpha1.AWSEKSProxyOutput], port int) error {

kc, kubeConfigPath, err := OpenKubeConfig()
if err != nil {
Expand All @@ -59,7 +59,7 @@ func AddContextToConfig(ensureAccessOutput *proxy.EnsureAccessOutput[*accessv1al
delete(kc.AuthInfos, username)

newCluster := api.NewCluster()
newCluster.Server = fmt.Sprintf("http://localhost:%s", port)
newCluster.Server = fmt.Sprintf("http://localhost:%d", port)
newCluster.InsecureSkipTLSVerify = true
//add the new cluster and context back in
kc.Clusters[clusterName] = newCluster
Expand Down
6 changes: 3 additions & 3 deletions pkg/granted/proxy/initiateconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
type InitiateSessionConnectionInput struct {
GrantID string
RequestURL string
LocalPort string
LocalPort int
}

// InitiateSessionConnection starts a new tcp connection to through the SSM port forward and completes a handshake with the proxy server
Expand All @@ -24,8 +24,8 @@ func InitiateSessionConnection(cfg *config.Context, input InitiateSessionConnect
// First dial the local SSM portforward, which will be running on a randomly chosen port
// or the local proxy server instance if it's local dev mode
// this establishes the initial connection to the Proxy server
clio.Debugw("dialing proxy server", "host", "localhost:"+input.LocalPort)
rawServerConn, err := net.Dial("tcp", "localhost:"+input.LocalPort)
clio.Debugw("dialing proxy server", "host", fmt.Sprintf("localhost:%d", input.LocalPort))
rawServerConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", input.LocalPort))
if err != nil {
return nil, nil, clierr.New("failed to establish a connection to the remote proxy server", clierr.Error(err), clierr.Infof("Your grant may have expired, you can check the status here: %s and retry connecting", input.RequestURL))
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/granted/proxy/listenandproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (

// ListenAndProxy will listen for new client connections and start a stream over the established proxy server session.
// if the proxy server terminates the session, like when a grant expires, this listener will detect it and terminate the CLI commmand with an error explaining what happened
func ListenAndProxy(ctx context.Context, yamuxStreamConnection *yamux.Session, clientConnectionPort string, requestURL string) error {
ln, err := net.Listen("tcp", "localhost:"+clientConnectionPort)
func ListenAndProxy(ctx context.Context, yamuxStreamConnection *yamux.Session, clientConnectionPort int, requestURL string) error {
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", clientConnectionPort))
if err != nil {
return fmt.Errorf("failed to start listening for connections on port: %s. %w", clientConnectionPort, err)
return fmt.Errorf("failed to start listening for connections on port: %d. %w", clientConnectionPort, err)
}
defer ln.Close()

Expand Down
17 changes: 8 additions & 9 deletions pkg/granted/proxy/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,36 @@ package proxy

import (
"net"
"strconv"
)

// Returns the proxy port to connect to and a local port to send client connections to
// in production, an SSM portforward process is running which is used to connect to the proxy server
// and over the top of this connection, a handshake process takes place and connection multiplexing is used to handle multiple database clients
func Ports(isLocalMode bool) (serverPort, localPort string, err error) {
func Ports(isLocalMode bool) (serverPort, localPort int, err error) {
// in local mode the SSM port forward is not used can skip using ssm and just use a local port forward instead
if isLocalMode {
return "7070", "7070", nil
return 7070, 7070, nil
}
// find an unused local port to use for the ssm server
// the user doesn't directly connect to this, they connect through our local proxy
// which adds authentication
ssmPortforwardLocalPort, err := GrabUnusedPort()
if err != nil {
return "", "", err
return 0, 0, err
}
return "8080", ssmPortforwardLocalPort, nil
return 8080, ssmPortforwardLocalPort, nil
}

func GrabUnusedPort() (string, error) {
func GrabUnusedPort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return "", err
return 0, err
}

port := listener.Addr().(*net.TCPAddr).Port
err = listener.Close()
if err != nil {
return "", err
return 0, err
}
return strconv.Itoa(port), nil
return port, nil
}
11 changes: 6 additions & 5 deletions pkg/granted/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strconv"
"time"

awsConfig "github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -39,8 +40,8 @@ type AWSConfig struct {
NoCache bool
}
type ConnectionOpts struct {
ServerPort string
LocalPort string
ServerPort int
LocalPort int
}
type WaitForSSMConnectionToProxyServerOpts struct {
AWSConfig AWSConfig
Expand Down Expand Up @@ -89,8 +90,8 @@ func WaitForSSMConnectionToProxyServer(ctx context.Context, opts WaitForSSMConne
Target: &opts.AWSConfig.SSMSessionTarget,
DocumentName: &documentName,
Parameters: map[string][]string{
"portNumber": {opts.ConnectionOpts.ServerPort},
"localPortNumber": {opts.ConnectionOpts.LocalPort},
"portNumber": {strconv.Itoa(opts.ConnectionOpts.ServerPort)},
"localPortNumber": {strconv.Itoa(opts.ConnectionOpts.LocalPort)},
},
Reason: grab.Ptr(fmt.Sprintf("Session started for Granted %s connection with Common Fate. GrantID: %s, AccessRequestID: %s", opts.DisplayOpts.SessionType, opts.GrantID, opts.RequestID)),
}
Expand All @@ -109,7 +110,7 @@ func WaitForSSMConnectionToProxyServer(ctx context.Context, opts WaitForSSMConne
SessionId: *sessionOutput.SessionId,
TokenValue: *sessionOutput.TokenValue,
IsAwsCliUpgradeNeeded: false,
Endpoint: "localhost:" + opts.ConnectionOpts.LocalPort,
Endpoint: fmt.Sprintf("localhost:%d", opts.ConnectionOpts.LocalPort),
DataChannel: &datachannel.DataChannel{},
ClientId: clientId,
}
Expand Down
16 changes: 9 additions & 7 deletions pkg/granted/rds/rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ var proxyCommand = cli.Command{
return err
}

printConnectionParameters(connectionString, cliString, clientConnectionPort, ensuredAccess.GrantOutput.RdsDatabase.Engine)
printConnectionParameters(connectionString, cliString, ensuredAccess.GrantOutput.RdsDatabase.Engine, clientConnectionPort)

return proxy.ListenAndProxy(ctx, yamuxStreamConnection, clientConnectionPort, requestURL)
},
Expand Down Expand Up @@ -210,36 +210,38 @@ func promptForDatabaseAndUser(ctx context.Context, cfg *config.Context) (*access
return selectorVal.(*accessv1alpha1.Entitlement), nil
}

func clientConnectionParameters(c *cli.Context, ensuredAccess *proxy.EnsureAccessOutput[*accessv1alpha1.AWSRDSOutput]) (connectionString, cliString, port string, err error) {
func clientConnectionParameters(c *cli.Context, ensuredAccess *proxy.EnsureAccessOutput[*accessv1alpha1.AWSRDSOutput]) (connectionString, cliString string, port int, err error) {
// Print the connection information to the user based on the database they are connecting to
// the passwords are always 'password' while the username and database will match that of the target being connected to
yellow := color.New(color.FgYellow)
switch ensuredAccess.GrantOutput.RdsDatabase.Engine {
case "postgres", "aurora-postgresql":
port := getLocalPort(getLocalPortInput{
port = getLocalPort(getLocalPortInput{
OverrideFlag: c.Int("port"),
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort),
Fallback: 5432,
})

connectionString = yellow.Sprintf("postgresql://%s:password@127.0.0.1:%d/%s?sslmode=disable", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`psql "postgresql://%s:password@127.0.0.1:%d/%s?sslmode=disable"`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
case "mysql", "aurora-mysql":
port := getLocalPort(getLocalPortInput{
port = getLocalPort(getLocalPortInput{
OverrideFlag: c.Int("port"),
DefaultFromServer: int(ensuredAccess.GrantOutput.DefaultLocalPort),
Fallback: 3306,
})

connectionString = yellow.Sprintf("%s:password@tcp(127.0.0.1:%d)/%s", ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
cliString = yellow.Sprintf(`mysql -u %s -p'password' -h 127.0.0.1 -P %d %s`, ensuredAccess.GrantOutput.User.Username, port, ensuredAccess.GrantOutput.RdsDatabase.Database)
default:
return "", "", "", fmt.Errorf("unsupported database engine: %s, maybe you need to update your `cf` cli", ensuredAccess.GrantOutput.RdsDatabase.Engine)
return "", "", 0, fmt.Errorf("unsupported database engine: %s, maybe you need to update your `cf` cli", ensuredAccess.GrantOutput.RdsDatabase.Engine)
}
return
}

func printConnectionParameters(connectionString, cliString, port, engine string) {
func printConnectionParameters(connectionString, cliString, engine string, port int) {
clio.NewLine()
clio.Infof("Database proxy ready for connections on 127.0.0.1:%s", port)
clio.Infof("Database proxy ready for connections on 127.0.0.1:%d", port)
clio.NewLine()

clio.Infof("You can connect now using this connection string: %s", connectionString)
Expand Down

0 comments on commit cebcab0

Please sign in to comment.