Skip to content

Commit

Permalink
feat(strm-1198): implement context switching for projects and add pro…
Browse files Browse the repository at this point in the history
…jectId to requests
  • Loading branch information
ivan-p92 authored Jun 17, 2022
1 parent 194b3a5 commit f79d4fe
Show file tree
Hide file tree
Showing 33 changed files with 482 additions and 215 deletions.
5 changes: 4 additions & 1 deletion cmd/strm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strmprivacy/strm/pkg/auth"
"strmprivacy/strm/pkg/bootstrap"
"strmprivacy/strm/pkg/common"
"strmprivacy/strm/pkg/context"
"strmprivacy/strm/pkg/kafkaconsumer"
"strmprivacy/strm/pkg/util"
"strmprivacy/strm/pkg/web_socket"
Expand Down Expand Up @@ -73,7 +74,6 @@ func rootCmdPreRun() func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
util.CreateConfigDirAndFileIfNotExists()
err := bootstrap.InitializeConfig(cmd)

log.Infoln(fmt.Sprintf("Executing command: %v", cmd.CommandPath()))
cmd.Flags().Visit(func(flag *pflag.Flag) {
log.Infoln(fmt.Sprintf("flag %v=%v", flag.Name, flag.Value))
Expand All @@ -85,6 +85,8 @@ func rootCmdPreRun() func(cmd *cobra.Command, args []string) error {

if auth.Auth.LoadLogin() == nil {
bootstrap.SetupServiceClients(auth.Auth.GetToken())
context.ResolveProject(cmd.Flags())
log.Infoln("Resolved projectId: " + common.ProjectId)
}

return err
Expand All @@ -103,6 +105,7 @@ func init() {
"Token file that contains an access token (default is $HOME/.config/strmprivacy/credentials-<api-auth-url>.json)")
persistentFlags.String(web_socket.WebSocketUrl, "wss://websocket.strmprivacy.io/ws", "Websocket to receive events from")
persistentFlags.String(kafkaconsumer.KafkaBootstrapHostFlag, "export-bootstrap.kafka.strmprivacy.io:9092", "Kafka bootstrap brokers, separated by comma")
persistentFlags.String(context.ProjectFlag, "", "Project to use (defaults to context-configured project)")
persistentFlags.StringP(common.OutputFormatFlag, common.OutputFormatFlagShort, common.OutputFormatTable, fmt.Sprintf("Output format [%v]", common.OutputFormatFlagAllowedValuesText))

err := RootCmd.RegisterFlagCompletionFunc(common.OutputFormatFlag, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
github.com/spf13/cobra v1.3.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.10.0
github.com/strmprivacy/api-definitions-go/v2 v2.35.0
github.com/strmprivacy/api-definitions-go/v2 v2.38.0
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
google.golang.org/grpc v1.46.0
Expand Down
156 changes: 2 additions & 154 deletions go.sum

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pkg/auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func oAuth2Config() oauth2.Config {
}
}

// Todo: remove (eventually, in STRM-1238), first add projectId from context to most places it's currently used
func (authenticator *Authenticator) BillingId() string {
if authenticator.billingId == nil {
common.MissingIdTokenError()
Expand Down Expand Up @@ -168,6 +169,7 @@ func startBrowserLoginFlow(ready chan string, ctx context.Context) func() error
}
}

// Todo: we can leave this for now for backwards compatibility, to be removed in STRM-1238
func getLegacyBillingId(accessToken string) string {
clientConnection, ctx := common.SetupGrpc(common.ApiHost, &accessToken)
accountClient := account.NewAccountServiceClient(clientConnection)
Expand Down
9 changes: 5 additions & 4 deletions pkg/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"strmprivacy/strm/pkg/cmd"
"strmprivacy/strm/pkg/common"
"strmprivacy/strm/pkg/entity/account"
"strmprivacy/strm/pkg/entity/batch_exporter"
"strmprivacy/strm/pkg/entity/batch_job"
"strmprivacy/strm/pkg/entity/data_connector"
Expand All @@ -19,11 +20,11 @@ import (
"strmprivacy/strm/pkg/entity/kafka_exporter"
"strmprivacy/strm/pkg/entity/kafka_user"
"strmprivacy/strm/pkg/entity/key_stream"
"strmprivacy/strm/pkg/entity/project"
"strmprivacy/strm/pkg/entity/schema"
"strmprivacy/strm/pkg/entity/schema_code"
"strmprivacy/strm/pkg/entity/stream"
"strmprivacy/strm/pkg/entity/usage"
"strmprivacy/strm/pkg/entity/account"
"strmprivacy/strm/pkg/util"
)

Expand Down Expand Up @@ -65,6 +66,7 @@ func SetupServiceClients(accessToken *string) {
usage.SetupClient(clientConnection, ctx)
installation.SetupClient(clientConnection, ctx)
account.SetupClient(clientConnection, ctx)
project.SetupClient(clientConnection, ctx)
}

func ConfigPath() string {
Expand Down Expand Up @@ -115,9 +117,8 @@ func InitializeConfig(cmd *cobra.Command) error {

// When we bind flags to environment variables expect that the
// environment variables are prefixed, e.g. a flag like --number
// binds to an environment variable STING_NUMBER. This helps
// binds to an environment variable STRM_NUMBER. This helps
// avoid conflicts.
// you could set STRM_BILLINGID for instance
viperConfig.SetEnvPrefix(common.EnvPrefix)

// Bind to environment variables
Expand All @@ -137,7 +138,7 @@ func InitializeConfig(cmd *cobra.Command) error {
func bindFlags(cmd *cobra.Command, v *viper.Viper) {
cmd.Flags().VisitAll(func(f *pflag.Flag) {
// Environment variables can't have dashes in them, so bind them to their equivalent
// keys with underscores, e.g. --favorite-color to STING_FAVORITE_COLOR
// keys with underscores, e.g. --favorite-color to STRM_FAVORITE_COLOR
if strings.Contains(f.Name, "-") {
envVarSuffix := strings.ToUpper(strings.ReplaceAll(f.Name, "-", "_"))
err := v.BindEnv(f.Name, fmt.Sprintf("%s_%s", common.EnvPrefix, envVarSuffix))
Expand Down
3 changes: 2 additions & 1 deletion pkg/cmd/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ func init() {
ContextCommand.AddCommand(context.EntityInfo())
ContextCommand.AddCommand(context.BillingIdInfo())
ContextCommand.AddCommand(context.Account())
}
ContextCommand.AddCommand(context.Project())
}
2 changes: 2 additions & 0 deletions pkg/cmd/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strmprivacy/strm/pkg/entity/kafka_exporter"
"strmprivacy/strm/pkg/entity/kafka_user"
"strmprivacy/strm/pkg/entity/key_stream"
"strmprivacy/strm/pkg/entity/project"
"strmprivacy/strm/pkg/entity/schema"
"strmprivacy/strm/pkg/entity/stream"
)
Expand All @@ -34,6 +35,7 @@ func init() {
ListCmd.AddCommand(schema.ListCmd())
ListCmd.AddCommand(event_contract.ListCmd())
ListCmd.AddCommand(installation.ListCmd())
ListCmd.AddCommand(project.ListCmd())

ListCmd.PersistentFlags().BoolP(common.RecursiveFlagName, common.RecursiveFlagShorthand, false, common.RecursiveFlagUsage)
}
2 changes: 2 additions & 0 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ var ApiAuthHost string
var ApiHost string
var EventAuthHost string

var ProjectId string

func SetupGrpc(host string, token *string) (*grpc.ClientConn, context.Context) {

var err error
Expand Down
5 changes: 4 additions & 1 deletion pkg/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,7 @@ var BillingIdOutputFormatFlagAllowedValues = []string{OutputFormatPlain}
var BillingIdOutputFormatFlagAllowedValuesText = strings.Join(BillingIdOutputFormatFlagAllowedValues, ", ")

var AccountOutputFormatFlagAllowedValues = []string{OutputFormatPlain, OutputFormatJsonRaw}
var AccountOutputFormatFlagAllowedValuesText = strings.Join(AccountOutputFormatFlagAllowedValues, ", ")
var AccountOutputFormatFlagAllowedValuesText = strings.Join(AccountOutputFormatFlagAllowedValues, ", ")

var ProjectOutputFormatFlagAllowedValues = []string{OutputFormatPlain}
var ProjectOutputFormatFlagAllowedValuesText = strings.Join(ProjectOutputFormatFlagAllowedValues, ", ")
34 changes: 33 additions & 1 deletion pkg/context/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const (
entityInfoCommandName = "info"
billingIdInfoCommandName = "billing-id"
accountCommandName = "account"
projectCommandName = "project"
)

func Configuration() *cobra.Command {
Expand Down Expand Up @@ -79,7 +80,7 @@ func Account() *cobra.Command {
printer = configurePrinter(cmd)
},
Run: func(cmd *cobra.Command, args []string) {
getHandle()
showAccountDetails()
},
}
cmd.Flags().StringP(
Expand Down Expand Up @@ -126,6 +127,37 @@ func EntityInfo() *cobra.Command {
return entityInfo
}

func Project() *cobra.Command {
cmd := &cobra.Command{
Use: projectCommandName + " [name]",
Short: "Show or set the active project",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
PreRun: func(cmd *cobra.Command, args []string) {
printer = configurePrinter(cmd)
},
Run: func(cmd *cobra.Command, args []string) {
if len(args) > 0 {
SetActiveProject(args[0])
} else {
showActiveProject()
}
},
}
cmd.Flags().StringP(
common.OutputFormatFlag,
common.OutputFormatFlagShort,
common.OutputFormatPlain,
common.ProjectOutputFormatFlagAllowedValuesText,
)
err := cmd.RegisterFlagCompletionFunc(common.OutputFormatFlag, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return common.ProjectOutputFormatFlagAllowedValues, cobra.ShellCompDirectiveNoFileComp
})

common.CliExit(err)
return cmd
}

func savedEntitiesCompletion(cmd *cobra.Command, args []string, complete string) ([]string, cobra.ShellCompDirective) {
return listSavedEntities(path.Join(common.ConfigPath, common.SavedEntitiesDirectory)), cobra.ShellCompDirectiveNoFileComp
}
13 changes: 10 additions & 3 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strmprivacy/strm/pkg/entity/account"
)

const ProjectFlag = "project"

type configuration struct {
ConfigPath string
ConfigFilepath string
Expand Down Expand Up @@ -59,10 +61,11 @@ func entityInfo(args []string) {
printer.Print(entity)
}

func getHandle() {
details := account.GetHandle()
func showAccountDetails() {
details := account.GetAccountDetails()
printer.Print(details)
}

func billingIdInfo() {
b, err := auth.GetBillingId()
if err != nil {
Expand Down Expand Up @@ -108,4 +111,8 @@ func findConfigFile() string {
}

return configFilepath
}
}

func showActiveProject() {
printer.Print("Active project: " + GetActiveProject())
}
14 changes: 11 additions & 3 deletions pkg/context/printers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ func configurePrinter(command *cobra.Command) util.Printer {
allowedValues = common.ConfigOutputFormatFlagAllowedValuesText
case accountCommandName:
allowedValues = common.ConfigOutputFormatFlagAllowedValuesText
case projectCommandName:
allowedValues = common.ProjectOutputFormatFlagAllowedValuesText
}

common.CliExit(errors.New(fmt.Sprintf("Output format '%v' is not supported for '%v'. Allowed values: %v", command.CommandPath(), outputFormat, allowedValues)))
Expand All @@ -46,22 +48,24 @@ func availablePrinters() map[string]util.Printer {
common.OutputFormatJsonRaw + entityInfoCommandName: jsonRawPrinter{},
common.OutputFormatJson + entityInfoCommandName: jsonPrettyPrinter{},
common.OutputFormatFilepath + entityInfoCommandName: filepathPrinter{},
common.OutputFormatPlain + configCommandName: plainPrinter{},
common.OutputFormatPlain + configCommandName: configPlainPrinter{},
common.OutputFormatJson + configCommandName: configJsonPrinter{},
common.OutputFormatPlain + billingIdInfoCommandName: billingIdPrinter{},
common.OutputFormatJsonRaw + accountCommandName: accountJsonPrinter{},
common.OutputFormatPlain + accountCommandName: accountPlainPrinter{},
common.OutputFormatPlain + projectCommandName: projectPrinter{},
}
}

type filepathPrinter struct{}
type jsonRawPrinter struct{}
type jsonPrettyPrinter struct{}
type plainPrinter struct{}
type configPlainPrinter struct{}
type accountJsonPrinter struct{}
type accountPlainPrinter struct{}
type configJsonPrinter struct{}
type billingIdPrinter struct{}
type projectPrinter struct{}

func (p filepathPrinter) Print(data interface{}) {
entity, _ := (data).(savedEntity)
Expand Down Expand Up @@ -97,7 +101,7 @@ func (p accountPlainPrinter) Print(data interface{}) {

}

func (p plainPrinter) Print(data interface{}) {
func (p configPlainPrinter) Print(data interface{}) {
config, _ := (data).(configuration)

fmt.Println(fmt.Sprintf("Configuration directory: %v", config.ConfigPath))
Expand Down Expand Up @@ -127,3 +131,7 @@ func (p configJsonPrinter) Print(data interface{}) {
func (p billingIdPrinter) Print(data interface{}) {
fmt.Println(data)
}

func (p projectPrinter) Print(data interface{}) {
fmt.Println(data)
}
91 changes: 91 additions & 0 deletions pkg/context/project.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package context

import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"io/ioutil"
"os"
"path"
"strmprivacy/strm/pkg/common"
"strmprivacy/strm/pkg/entity/project"
)

const activeProjectFilename = "active_project"

// ResolveProject resolves the project to use and makes its ID globally available.
// The value passed through the flag takes precedence, then the value stored in the config dir, and finally
// a fallback to default project.
func ResolveProject(f *pflag.FlagSet) {

activeProjectFilePath := path.Join(common.ConfigPath, activeProjectFilename)
projectFlagValue, _ := f.GetString(ProjectFlag)

if _, err := os.Stat(activeProjectFilePath); (os.IsNotExist(err) || GetActiveProject() == "") && projectFlagValue == "" {
initActiveProject()
fmt.Println(fmt.Sprintf("Active project was not yet set, has been set to '%v'. You can set a project "+
"with 'strm context project <project-name>'\n", GetActiveProject()))
}

if projectFlagValue != "" {
resolvedProject := project.GetProject(projectFlagValue)
if resolvedProject == nil {
message := fmt.Sprintf("Specified project '%v' does not exist, or you do not have access to it.", projectFlagValue)
common.CliExit(errors.New(message))
}
common.ProjectId = resolvedProject.Id
} else {
activeProject := GetActiveProject()
resolvedProject := project.GetProject(activeProject)
if resolvedProject == nil {
initActiveProject()
common.CliExit(errors.New(fmt.Sprintf("Active project '%v' does not exist, or you do not have access " +
"to it. The following project has been set instead: %v", activeProject, GetActiveProject())))
}
common.ProjectId = resolvedProject.Id
}
}

func SetActiveProject(projectName string) {
if project.GetProject(projectName) != nil {
saveActiveProject(projectName)
message := "Active project set to: " + projectName
log.Infoln(message)
fmt.Println(message)
} else {
message := fmt.Sprintf("No project '%v' found, or you do not have access to it.", projectName)
log.Warnln(message)
common.CliExit(errors.New(message))
}
}

func GetActiveProject() string {
activeProjectFilePath := path.Join(common.ConfigPath, activeProjectFilename)

bytes, err := ioutil.ReadFile(activeProjectFilePath)
common.CliExit(err)
activeProject := string(bytes)
log.Infoln("Current active project is: " + activeProject)
return activeProject
}

func initActiveProject() {
projects := project.ListProjects()
if len(projects.Projects) == 0 {
common.CliExit(errors.New("you do not have access to any projects; create a project first, or ask to be granted access to one"))
}
firstProjectName := projects.Projects[0].Name
saveActiveProject(firstProjectName)
}

func saveActiveProject(projectName string) {
activeProjectFilepath := path.Join(common.ConfigPath, activeProjectFilename)

err := ioutil.WriteFile(
activeProjectFilepath,
[]byte(projectName),
0644,
)
common.CliExit(err)
}
Loading

0 comments on commit f79d4fe

Please sign in to comment.