Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Account to support token and user/password auth #219

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 11 additions & 119 deletions controllers/jetstream/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/nats-io/jsm.go"
Expand Down Expand Up @@ -49,70 +46,9 @@ func (c *Controller) processConsumerObject(cns *apis.Consumer, jsm jsmClientFunc
spec := cns.Spec
ifc := c.ji.Consumers(ns)

var (
remoteClientCert string
remoteClientKey string
remoteRootCA string
accServers []string
accUserCreds string
)
if spec.Account != "" && c.opts.CRDConnect {
// Lookup the account using the REST client.
ctx, done := context.WithTimeout(context.Background(), 5*time.Second)
defer done()
acc, err := c.ji.Accounts(ns).Get(ctx, spec.Account, k8smeta.GetOptions{})
if err != nil {
return err
}

accServers = acc.Spec.Servers

// Lookup the TLS secrets
if acc.Spec.TLS != nil && acc.Spec.TLS.Secret != nil {
secretName := acc.Spec.TLS.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return err
}

// Write this to the cacheDir
accDir := filepath.Join(c.cacheDir, ns, spec.Account)
if err := os.MkdirAll(accDir, 0755); err != nil {
return err
}

remoteClientCert = filepath.Join(accDir, acc.Spec.TLS.ClientCert)
remoteClientKey = filepath.Join(accDir, acc.Spec.TLS.ClientKey)
remoteRootCA = filepath.Join(accDir, acc.Spec.TLS.RootCAs)

for k, v := range secret.Data {
if err := os.WriteFile(filepath.Join(accDir, k), v, 0o644); err != nil {
return err
}
}
}
// Lookup the UserCredentials.
if acc.Spec.Creds != nil {
secretName := acc.Spec.Creds.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return err
}

// Write the user credentials to the cache dir.
accDir := filepath.Join(c.cacheDir, ns, spec.Account)
if err := os.MkdirAll(accDir, 0755); err != nil {
return err
}
for k, v := range secret.Data {
if k == acc.Spec.Creds.File {
accUserCreds = filepath.Join(c.cacheDir, ns, spec.Account, k)
if err := os.WriteFile(filepath.Join(accDir, k), v, 0o644); err != nil {
return err
}
}
}
}
acc, err := c.getAccountOverrides(spec.Account, ns)
if err != nil {
return err
}

defer func() {
Expand All @@ -128,58 +64,14 @@ func (c *Controller) processConsumerObject(cns *apis.Consumer, jsm jsmClientFunc
type operator func(ctx context.Context, c jsmClient, spec apis.ConsumerSpec) (err error)

natsClientUtil := func(op operator) error {
servers := spec.Servers
if c.opts.CRDConnect {
// Create a new client
natsCtx := &natsContext{}
// Use JWT/NKEYS based credentials if present.
if spec.Creds != "" {
natsCtx.Credentials = spec.Creds
} else if spec.Nkey != "" {
natsCtx.Nkey = spec.Nkey
}
if spec.TLS.ClientCert != "" && spec.TLS.ClientKey != "" {
natsCtx.TLSCert = spec.TLS.ClientCert
natsCtx.TLSKey = spec.TLS.ClientKey
}

// Use fetched secrets for the account and server if defined.
if remoteClientCert != "" && remoteClientKey != "" {
natsCtx.TLSCert = remoteClientCert
natsCtx.TLSKey = remoteClientKey
}
if remoteRootCA != "" {
natsCtx.TLSCAs = []string{remoteRootCA}
}
if accUserCreds != "" {
natsCtx.Credentials = accUserCreds
}
if len(spec.TLS.RootCAs) > 0 {
natsCtx.TLSCAs = spec.TLS.RootCAs
}

natsServers := strings.Join(append(servers, accServers...), ",")
natsCtx.URL = natsServers
c.normalEvent(cns, "Connecting", "Connecting to new nats-servers")
jsmc, err := jsm(natsCtx)
if err != nil {
return err
}
defer jsmc.Close()

if err := op(c.ctx, jsmc, spec); err != nil {
return err
}
} else {
jsmc, err := jsm(&natsContext{})
if err != nil {
return err
}
if err := op(c.ctx, jsmc, spec); err != nil {
return err
}
}
return nil
return c.runWithJsmc(jsm, acc, &jsmcSpecOverrides{
servers: spec.Servers,
tls: spec.TLS,
creds: spec.Creds,
nkey: spec.Nkey,
vavsab marked this conversation as resolved.
Show resolved Hide resolved
}, cns, func(jsmc jsmClient) error {
return op(c.ctx, jsmc, spec)
})
}

deleteOK := cns.GetDeletionTimestamp() != nil
Expand Down
200 changes: 200 additions & 0 deletions controllers/jetstream/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -414,6 +415,205 @@ func (c *Controller) warningEvent(o runtime.Object, reason, message string) {
}
}

type accountOverrides struct {
remoteClientCert string
remoteClientKey string
remoteRootCA string
servers []string
userCreds string
user string
password string
token string
}

func (c *Controller) getAccountOverrides(account string, ns string) (*accountOverrides, error) {
overrides := &accountOverrides{}

if account == "" || !c.opts.CRDConnect {
return overrides, nil
}

// Lookup the account using the REST client.
ctx, done := context.WithTimeout(context.Background(), 5*time.Second)
defer done()
acc, err := c.ji.Accounts(ns).Get(ctx, account, k8smeta.GetOptions{})
if err != nil {
return nil, err
}

overrides.servers = acc.Spec.Servers

// Lookup the TLS secrets
if acc.Spec.TLS != nil && acc.Spec.TLS.Secret != nil {
secretName := acc.Spec.TLS.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return nil, err
}

// Write this to the cacheDir.
accDir := filepath.Join(c.cacheDir, ns, account)
if err := os.MkdirAll(accDir, 0o755); err != nil {
return nil, err
}

filesToWrite := make(map[string]string)

getSecretValue := func(key string) string {
value, ok := secret.Data[key]
if !ok {
return ""
}
return string(value)
}

remoteClientCertValue := getSecretValue(acc.Spec.TLS.ClientCert)
remoteClientKeyValue := getSecretValue(acc.Spec.TLS.ClientKey)
if remoteClientCertValue != "" && remoteClientKeyValue != "" {
overrides.remoteClientCert = filepath.Join(accDir, acc.Spec.TLS.ClientCert)
overrides.remoteClientKey = filepath.Join(accDir, acc.Spec.TLS.ClientKey)

filesToWrite[acc.Spec.TLS.ClientCert] = remoteClientCertValue
filesToWrite[acc.Spec.TLS.ClientKey] = remoteClientKeyValue
}

remoteRootCAValue := getSecretValue(acc.Spec.TLS.RootCAs)
if remoteRootCAValue != "" {
overrides.remoteRootCA = filepath.Join(accDir, acc.Spec.TLS.RootCAs)
filesToWrite[acc.Spec.TLS.RootCAs] = remoteRootCAValue
}

for file, v := range filesToWrite {
if err := os.WriteFile(filepath.Join(accDir, file), []byte(v), 0o644); err != nil {
return nil, err
}
}
}
// Lookup the UserCredentials.
if acc.Spec.Creds != nil {
secretName := acc.Spec.Creds.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return nil, err
}

// Write the user credentials to the cache dir.
accDir := filepath.Join(c.cacheDir, ns, account)
if err := os.MkdirAll(accDir, 0o755); err != nil {
return nil, err
}
for k, v := range secret.Data {
if k == acc.Spec.Creds.File {
overrides.userCreds = filepath.Join(c.cacheDir, ns, account, k)
if err := os.WriteFile(filepath.Join(accDir, k), v, 0o644); err != nil {
return nil, err
}
}
}
}

// Lookup the Token.
if acc.Spec.Token != nil {
secretName := acc.Spec.Token.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return nil, err
}

for k, v := range secret.Data {
if k == acc.Spec.Token.Token {
overrides.token = string(v)
}
}
}

// Lookup the User.
if acc.Spec.User != nil {
secretName := acc.Spec.User.Secret.Name
secret, err := c.ki.Secrets(ns).Get(c.ctx, secretName, k8smeta.GetOptions{})
if err != nil {
return nil, err
}

for k, v := range secret.Data {
if k == acc.Spec.User.User {
overrides.user = string(v)
}
if k == acc.Spec.User.Password {
overrides.password = string(v)
}
}
}

return overrides, nil
}

type jsmcSpecOverrides struct {
servers []string
tls apis.TLS
creds string
nkey string
}

func (c *Controller) runWithJsmc(jsm jsmClientFunc, acc *accountOverrides, spec *jsmcSpecOverrides, o runtime.Object, op func(jsmClient) error) error {
if !c.opts.CRDConnect {
jsmc, err := jsm(&natsContext{})
if err != nil {
return err
}

return op(jsmc)
}

// Create a new client
natsCtx := &natsContext{}
// Use JWT/NKEYS/user-password/token based credentials if present.
if spec.creds != "" {
natsCtx.Credentials = spec.creds
} else if spec.nkey != "" {
natsCtx.Nkey = spec.nkey
}
if spec.tls.ClientCert != "" && spec.tls.ClientKey != "" {
natsCtx.TLSCert = spec.tls.ClientCert
natsCtx.TLSKey = spec.tls.ClientKey
}

// Use fetched secrets for the account and server if defined.
if acc.remoteClientCert != "" && acc.remoteClientKey != "" {
natsCtx.TLSCert = acc.remoteClientCert
natsCtx.TLSKey = acc.remoteClientKey
}
if acc.remoteRootCA != "" {
natsCtx.TLSCAs = []string{acc.remoteRootCA}
}
if acc.userCreds != "" {
natsCtx.Credentials = acc.userCreds
}

if acc.user != "" && acc.password != "" {
natsCtx.Username = acc.user
natsCtx.Password = acc.password
} else if acc.token != "" {
natsCtx.Token = acc.token
}

if len(spec.tls.RootCAs) > 0 {
natsCtx.TLSCAs = spec.tls.RootCAs
}

natsServers := strings.Join(append(spec.servers, acc.servers...), ",")
natsCtx.URL = natsServers
c.normalEvent(o, "Connecting", "Connecting to new nats-servers")
jsmc, err := jsm(natsCtx)
if err != nil {
return fmt.Errorf("failed to connect to nats-servers(%s): %w", natsServers, err)
}

defer jsmc.Close()

return op(jsmc)
}

func splitNamespaceName(item interface{}) (ns string, name string, err error) {
defer func() {
if err != nil {
Expand Down
Loading