Skip to content

Commit

Permalink
authclient: add support for service accounts (#164)
Browse files Browse the repository at this point in the history
* authclient: add support for service accounts

* fix test name
  • Loading branch information
calebdoxsey authored Nov 16, 2022
1 parent e3b967e commit bd8b1b9
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 60 deletions.
2 changes: 1 addition & 1 deletion api/listener_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *server) connectTunnelLocked(id string) (net.Addr, error) {
return nil, errNotFound
}

tun, listenAddr, err := newTunnel(rec.GetConn(), s.browserCmd)
tun, listenAddr, err := newTunnel(rec.GetConn(), s.browserCmd, s.serviceAccount, s.serviceAccountFile)
if err != nil {
return nil, err
}
Expand Down
24 changes: 20 additions & 4 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sync"

"github.com/golang/groupcache/lru"

pb "github.com/pomerium/cli/proto"
"github.com/pomerium/cli/tcptunnel"
)
Expand All @@ -20,8 +21,7 @@ type ConfigProvider interface {
Save([]byte) error
}

type Config interface {
}
type Config interface{}

// ListenerStatus marks individual records as locked
type ListenerStatus interface {
Expand Down Expand Up @@ -52,8 +52,10 @@ type server struct {
EventBroadcaster
ListenerStatus
*config
browserCmd string
certInfo *lru.Cache
browserCmd string
serviceAccount string
serviceAccountFile string
certInfo *lru.Cache
}

var (
Expand Down Expand Up @@ -113,6 +115,20 @@ func WithBrowserCommand(cmd string) ServerOption {
}
}

func WithServiceAccount(serviceAccount string) ServerOption {
return func(s *server) error {
s.serviceAccount = serviceAccount
return nil
}
}

func WithServiceAccountFile(serviceAccountFile string) ServerOption {
return func(s *server) error {
s.serviceAccountFile = serviceAccountFile
return nil
}
}

// MemCP is in-memory config provider
type MemCP struct {
data []byte
Expand Down
4 changes: 3 additions & 1 deletion api/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"github.com/pomerium/cli/tcptunnel"
)

func newTunnel(conn *pb.Connection, browserCmd string) (Tunnel, string, error) {
func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFile string) (Tunnel, string, error) {
listenAddr := "127.0.0.1:0"
if conn.ListenAddr != nil {
listenAddr = *conn.ListenAddr
Expand All @@ -42,6 +42,8 @@ func newTunnel(conn *pb.Connection, browserCmd string) (Tunnel, string, error) {
return tcptunnel.New(
tcptunnel.WithDestinationHost(conn.GetRemoteAddr()),
tcptunnel.WithProxyHost(pxy.Host),
tcptunnel.WithServiceAccount(serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountFile),
tcptunnel.WithTLSConfig(tlsCfg),
tcptunnel.WithBrowserCommand(browserCmd),
), listenAddr, nil
Expand Down
13 changes: 13 additions & 0 deletions authclient/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"os"
"strings"
"time"

"golang.org/x/sync/errgroup"
Expand All @@ -28,6 +29,18 @@ func New(options ...Option) *AuthClient {

// GetJWT retrieves a JWT from Pomerium.
func (client *AuthClient) GetJWT(ctx context.Context, serverURL *url.URL, onOpenBrowser func(string)) (rawJWT string, err error) {
if client.cfg.serviceAccount != "" {
return client.cfg.serviceAccount, nil
}

if client.cfg.serviceAccountFile != "" {
rawJWTBytes, err := os.ReadFile(client.cfg.serviceAccountFile)
if err != nil {
return "", err
}
return strings.TrimSpace(string(rawJWTBytes)), nil
}

li, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", fmt.Errorf("failed to start listener: %w", err)
Expand Down
128 changes: 83 additions & 45 deletions authclient/authclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,104 @@ import (
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"testing"
"time"

"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAuthClient(t *testing.T) {
t.Parallel()

ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
t.Cleanup(clearTimeout)

li, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li.Close() }()
t.Run("browser", func(t *testing.T) {
t.Parallel()

go func() {
h := chi.NewMux()
h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri")))
})
srv := &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: h,
li, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
_ = srv.Serve(li)
}()
t.Cleanup(func() { li.Close() })

ac := New()
ac.cfg.open = func(input string) error {
u, err := url.Parse(input)
if err != nil {
return err
}
u = u.ResolveReference(&url.URL{
RawQuery: url.Values{
"pomerium_jwt": {"TEST"},
}.Encode(),
})
go func() {
h := chi.NewMux()
h.Get("/.pomerium/api/v1/login", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(r.FormValue("pomerium_redirect_uri")))
})
srv := &http.Server{
BaseContext: func(li net.Listener) context.Context {
return ctx
},
Handler: h,
}
_ = srv.Serve(li)
}()

req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return err
}
ac := New()
ac.cfg.open = func(input string) error {
u, err := url.Parse(input)
if err != nil {
return err
}
u = u.ResolveReference(&url.URL{
RawQuery: url.Values{
"pomerium_jwt": {"TEST"},
}.Encode(),
})

req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return err
}

res, err := http.DefaultClient.Do(req)
if err != nil {
return err
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
_ = res.Body.Close()
return nil
}
_ = res.Body.Close()
return nil
}

rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: li.Addr().String(),
}, func(_ string) {})
assert.NoError(t, err)
assert.Equal(t, "TEST", rawJWT)
rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: li.Addr().String(),
}, func(_ string) {})
assert.NoError(t, err)
assert.Equal(t, "TEST", rawJWT)
})

t.Run("service account", func(t *testing.T) {
t.Parallel()

ac := New(WithServiceAccount("SERVICE_ACCOUNT"))
rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: "example.com",
}, func(_ string) {})
assert.NoError(t, err)
assert.Equal(t, "SERVICE_ACCOUNT", rawJWT)
})

t.Run("service account file", func(t *testing.T) {
t.Parallel()

dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "service-account"), []byte(`
SERVICE_ACCOUNT
`), 0o600)
require.NoError(t, err)
ac := New(WithServiceAccountFile(filepath.Join(dir, "service-account")))
rawJWT, err := ac.GetJWT(ctx, &url.URL{
Scheme: "http",
Host: "example.com",
}, func(_ string) {})
assert.NoError(t, err)
assert.Equal(t, "SERVICE_ACCOUNT", rawJWT)
})
}
20 changes: 18 additions & 2 deletions authclient/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
)

type config struct {
open func(rawURL string) error
tlsConfig *tls.Config
open func(rawURL string) error
serviceAccount string
serviceAccountFile string
tlsConfig *tls.Config
}

func getConfig(options ...Option) *config {
Expand Down Expand Up @@ -36,6 +38,20 @@ func WithBrowserCommand(browserCommand string) Option {
}
}

// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount string) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}

// WithServiceAccountFile sets the service account file in the config.
func WithServiceAccountFile(file string) Option {
return func(cfg *config) {
cfg.serviceAccountFile = file
}
}

// WithTLSConfig returns an option to configure the tls config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
Expand Down
5 changes: 4 additions & 1 deletion cmd/pomerium-cli/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func apiCommand() *cobra.Command {
if err == nil {
cfgDir = path.Join(cfgDir, "PomeriumDesktop", "config.json")
}
addServiceAccountFlags(&cmd.Command)
flags := cmd.Flags()
flags.StringVar(&cmd.jsonRPCAddr, "json-addr", "127.0.0.1:8900", "address json api server should listen to")
flags.StringVar(&cmd.grpcAddr, "grpc-addr", "127.0.0.1:8800", "address json api server should listen to")
Expand All @@ -59,7 +60,7 @@ func (cmd *apiCmd) makeConfigPath() error {
return fmt.Errorf("config file path could not be determined")
}

return os.MkdirAll(path.Dir(cmd.configPath), 0700)
return os.MkdirAll(path.Dir(cmd.configPath), 0o700)
}

func (cmd *apiCmd) exec(c *cobra.Command, args []string) error {
Expand Down Expand Up @@ -87,6 +88,8 @@ func (cmd *apiCmd) exec(c *cobra.Command, args []string) error {
srv, err := api.NewServer(ctx,
api.WithConfigProvider(api.FileConfigProvider(cmd.configPath)),
api.WithBrowserCommand(cmd.browserCmd),
api.WithServiceAccount(serviceAccountOptions.serviceAccount),
api.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions cmd/pomerium-cli/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

func init() {
addBrowserFlags(kubernetesExecCredentialCmd)
addServiceAccountFlags(kubernetesExecCredentialCmd)
addTLSFlags(kubernetesExecCredentialCmd)
kubernetesCmd.AddCommand(kubernetesExecCredentialCmd)
kubernetesCmd.AddCommand(kubernetesFlushCredentialsCmd)
Expand Down Expand Up @@ -69,6 +70,8 @@ var kubernetesExecCredentialCmd = &cobra.Command{

ac := authclient.New(
authclient.WithBrowserCommand(browserOptions.command),
authclient.WithServiceAccount(serviceAccountOptions.serviceAccount),
authclient.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
authclient.WithTLSConfig(tlsConfig))
rawJWT, err := ac.GetJWT(context.Background(), serverURL, func(s string) {})
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions cmd/pomerium-cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,16 @@ func addBrowserFlags(cmd *cobra.Command) {
flags.StringVar(&browserOptions.command, "browser-cmd", "",
"custom browser command to run when opening a URL")
}

var serviceAccountOptions struct {
serviceAccount string
serviceAccountFile string
}

func addServiceAccountFlags(cmd *cobra.Command) {
flags := cmd.Flags()
flags.StringVar(&serviceAccountOptions.serviceAccount, "service-account", "",
"the service account JWT to use for authentication")
flags.StringVar(&serviceAccountOptions.serviceAccountFile, "service-account-file", "",
"a file containing the service account JWT to use for authentication")
}
3 changes: 3 additions & 0 deletions cmd/pomerium-cli/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var proxyCmdOptions struct {
}

func init() {
addServiceAccountFlags(proxyCmd)
addTLSFlags(proxyCmd)
flags := proxyCmd.Flags()
flags.StringVar(&proxyCmdOptions.listen, "listen", "127.0.0.1:3128",
Expand Down Expand Up @@ -140,6 +141,8 @@ func newTCPTunnel(dstHost string, specificPomeriumURL string) (*tcptunnel.Tunnel
return tcptunnel.New(
tcptunnel.WithDestinationHost(net.JoinHostPort(dstHostname, dstPort)),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tcptunnel.WithTLSConfig(tlsConfig),
), nil
}
Expand Down
5 changes: 4 additions & 1 deletion cmd/pomerium-cli/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ var tcpCmdOptions struct {
}

func init() {
addTLSFlags(tcpCmd)
addBrowserFlags(tcpCmd)
addServiceAccountFlags(tcpCmd)
addTLSFlags(tcpCmd)
flags := tcpCmd.Flags()
flags.StringVar(&tcpCmdOptions.listen, "listen", "127.0.0.1:0",
"local address to start a listener on")
Expand Down Expand Up @@ -82,6 +83,8 @@ var tcpCmd = &cobra.Command{
tcptunnel.WithBrowserCommand(browserOptions.command),
tcptunnel.WithDestinationHost(dstHost),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tcptunnel.WithTLSConfig(tlsConfig),
)

Expand Down
Loading

0 comments on commit bd8b1b9

Please sign in to comment.