Skip to content

Commit

Permalink
feat(strm-1922): store and use zed tokens for all calls (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
trietsch authored May 24, 2023
1 parent c60d4b0 commit 3598600
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 78 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dist/${target}: ${source_files} Makefile
clean:
rm -f dist/${target}

# Make sure the .env containing all `STRM_TEST_*` variables is present in the ./test directory
# godotenv loads the .env file from that directory when running the tests
test: dist/${target}
go clean -testcache
go test ./test -v
Expand Down
3 changes: 2 additions & 1 deletion cmd/strm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strmprivacy/strm/pkg/bootstrap"
"strmprivacy/strm/pkg/common"
"strmprivacy/strm/pkg/context"
"strmprivacy/strm/pkg/user_project"
"strmprivacy/strm/pkg/util"
)

Expand Down Expand Up @@ -78,7 +79,7 @@ func rootCmdPreRun(cmd *cobra.Command, args []string) error {
common.ApiAuthHost = util.GetStringAndErr(cmd.Flags(), auth.ApiAuthUrlFlag)

if auth.Auth.LoadLogin() == nil {
bootstrap.SetupServiceClients(auth.Auth.GetToken())
bootstrap.SetupServiceClients(auth.Auth.GetToken(), user_project.GetZedToken())
splitCommand := strings.Split(cmd.CommandPath(), " ")
if splitCommand[1] != "auth" && !(splitCommand[1] == "create" && splitCommand[2] == "project") {
context.ResolveProject(cmd.Flags())
Expand Down
69 changes: 67 additions & 2 deletions pkg/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package bootstrap

import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"strings"
"strmprivacy/strm/pkg/cmd"
"strmprivacy/strm/pkg/common"
Expand All @@ -30,6 +35,12 @@ import (
"strmprivacy/strm/pkg/entity/user"
"strmprivacy/strm/pkg/logs"
"strmprivacy/strm/pkg/monitor"
"strmprivacy/strm/pkg/user_project"
)

const (
cliVersionHeader = "strm-cli-version"
zedTokenHeader = "strm-zed-token"
)

/*
Expand Down Expand Up @@ -62,8 +73,8 @@ func SetupVerbs(rootCmd *cobra.Command) {
rootCmd.AddCommand(cmd.EvaluateCmd)
}

func SetupServiceClients(accessToken *string) {
clientConnection, ctx := common.SetupGrpc(common.ApiHost, accessToken)
func SetupServiceClients(accessToken *string, zedToken *string) {
clientConnection, ctx := SetupGrpc(common.ApiHost, accessToken, zedToken)

stream.SetupClient(clientConnection, ctx)
kafka_exporter.SetupClient(clientConnection, ctx)
Expand Down Expand Up @@ -146,3 +157,57 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) {
}
})
}

func SetupGrpc(host string, token *string, zedToken *string) (*grpc.ClientConn, context.Context) {

var err error
var creds grpc.DialOption

if strings.Contains(host, ":50051") {
creds = grpc.WithTransportCredentials(insecure.NewCredentials())
} else {
creds = grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""))
}

clientConnection, err := grpc.Dial(host, creds, grpc.WithUnaryInterceptor(clientInterceptor))
common.CliExit(err)

var mdMap = map[string]string{cliVersionHeader: common.Version}

if token != nil {
mdMap["authorization"] = "Bearer " + *token
}
if zedToken != nil {
mdMap[zedTokenHeader] = *zedToken
}

return clientConnection, metadata.NewOutgoingContext(context.Background(), metadata.New(mdMap))
}

func clientInterceptor(
ctx context.Context,
method string,
req interface{},
reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
zedToken := user_project.GetZedToken()

if zedToken != nil {
ctx = metadata.AppendToOutgoingContext(ctx, zedTokenHeader, *zedToken)
}

var header metadata.MD
opts = append(opts, grpc.Header(&header))
err := invoker(ctx, method, req, reply, cc, opts...)

zedTokenValue := header.Get(zedTokenHeader)

if len(zedTokenValue) > 0 {
user_project.SetZedToken(zedTokenValue[0])
}

return err
}
31 changes: 0 additions & 31 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
package common

import (
"context"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"gopkg.in/natefinch/lumberjack.v2"
"os"
"runtime"
"strings"
)

var RootCommandName = "strm"
Expand All @@ -24,31 +18,6 @@ var ApiHost string

var ProjectId string

func SetupGrpc(host string, token *string) (*grpc.ClientConn, context.Context) {

var err error
var creds grpc.DialOption

if strings.Contains(host, ":50051") {
creds = grpc.WithTransportCredentials(insecure.NewCredentials())
} else {
creds = grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""))
}

clientConnection, err := grpc.Dial(host, creds)
CliExit(err)

var md metadata.MD
if token != nil {
md = metadata.New(map[string]string{"authorization": "Bearer " + *token, "strm-cli-version": Version})
} else {
md = metadata.New(map[string]string{"strm-cli-version": Version})
}

ctx := metadata.NewOutgoingContext(context.Background(), md)
return clientConnection, ctx
}

func CliExit(err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
Expand Down
28 changes: 3 additions & 25 deletions pkg/context/project.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
package context

import (
"encoding/json"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"os"
"path"
"strmprivacy/strm/pkg/common"
"strmprivacy/strm/pkg/entity/project"
"strmprivacy/strm/pkg/user_project"
)

const activeProjectFilename = "active_projects.json"

// ResolveProject resolves the project to use and makes its ID globally available.
// The value passed through the flag takes precedence, then the value stored in the config dir, and finally
// a fallback to default project.
func ResolveProject(f *pflag.FlagSet) {

activeProjectFilePath := path.Join(common.ConfigPath(), activeProjectFilename)
projectFlagValue, _ := f.GetString(common.ProjectNameFlag)

if _, err := os.Stat(activeProjectFilePath); os.IsNotExist(err) && projectFlagValue == "" {
if _, err := os.Stat(user_project.ActiveProjectFilepath); os.IsNotExist(err) && projectFlagValue == "" {
initActiveProject()
fmt.Println(fmt.Sprintf("Active project was not yet set, has been set to '%v'. You can set a project "+
"with 'strm context project <project-name>'\n", user_project.GetActiveProject()))
Expand Down Expand Up @@ -54,7 +48,7 @@ func ResolveProject(f *pflag.FlagSet) {

func SetActiveProject(projectName string) {
if len(project.GetProject(projectName).Projects) != 0 {
saveActiveProject(projectName)
user_project.Projects.SetActiveProject(projectName)
message := "Active project set to: " + projectName
log.Infoln(message)
fmt.Println(message)
Expand All @@ -75,21 +69,5 @@ func getFirstProject() string {

func initActiveProject() {
firstProjectName := getFirstProject()
saveActiveProject(firstProjectName)
}

func saveActiveProject(projectName string) {
activeProjectFilepath := path.Join(common.ConfigPath(), activeProjectFilename)
user_project.Projects.SetActiveProject(projectName)
projects, err := json.Marshal(user_project.Projects)
if err != nil {
common.CliExit(err)
}

err = os.WriteFile(
activeProjectFilepath,
projects,
0644,
)
common.CliExit(err)
user_project.Projects.SetActiveProject(firstProjectName)
}
73 changes: 56 additions & 17 deletions pkg/user_project/user_projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@ import (
"strmprivacy/strm/pkg/common"
)

const ActiveProjectFilename = "active_projects.json"
const activeProjectFilename = "active_projects.json"

var Projects UsersProjects
var ActiveProjectFilepath = path.Join(common.ConfigPath(), activeProjectFilename)

// UsersProjects is the printed json format of the different active projects
var Projects *UsersProjectsContext

// UsersProjectsContext is the printed json format of the different active projects
// per past or currently logged-in user
type UsersProjects struct {
Users []UserProject `json:"users"`
type UsersProjectsContext struct {
Users []UserProjectContext `json:"users"`
}

type UserProject struct {
type UserProjectContext struct {
Email string `json:"email"`
ActiveProject string `json:"active_project"`
ZedToken string `json:"zed_token"`
}

func (projects *UsersProjects) GetCurrentProjectByEmail() string {
func (projects *UsersProjectsContext) GetCurrentProjectByEmail() string {
activeProject := ""
email := GetUserEmail()
for _, user := range projects.Users {
Expand All @@ -35,7 +38,7 @@ func (projects *UsersProjects) GetCurrentProjectByEmail() string {
return activeProject
}

func (projects *UsersProjects) SetActiveProject(project string) {
func (projects *UsersProjectsContext) SetActiveProject(project string) {
email := GetUserEmail()
added := false
for index, user := range projects.Users {
Expand All @@ -46,11 +49,13 @@ func (projects *UsersProjects) SetActiveProject(project string) {
}

if !added {
projects.Users = append(projects.Users, UserProject{
projects.Users = append(projects.Users, UserProjectContext{
Email: email,
ActiveProject: project,
})
}

storeUserProjectContext()
}

func GetUserEmail() string {
Expand All @@ -63,19 +68,53 @@ func GetUserEmail() string {
return auth.Auth.Email
}

func LoadActiveProject() {
activeProjectFilePath := path.Join(common.ConfigPath(), ActiveProjectFilename)
func initializeUsersProjectsContext() {
if Projects == nil {
activeProjectFilePath := path.Join(common.ConfigPath(), activeProjectFilename)

bytes, err := os.ReadFile(activeProjectFilePath)
common.CliExit(err)
activeProjects := UsersProjects{}
_ = json.Unmarshal(bytes, &activeProjects)
Projects = activeProjects
bytes, err := os.ReadFile(activeProjectFilePath)
common.CliExit(err)
activeProjects := UsersProjectsContext{}
_ = json.Unmarshal(bytes, &activeProjects)
Projects = &activeProjects
}
}

func GetActiveProject() string {
LoadActiveProject()
initializeUsersProjectsContext()
activeProject := Projects.GetCurrentProjectByEmail()
log.Infoln("Current active project is: " + activeProject)
return activeProject
}

func SetZedToken(zedToken string) {
initializeUsersProjectsContext()
email := GetUserEmail()
for index, user := range Projects.Users {
// If there is no entry for the user, a zed token will be added the next time, when it is present
if user.Email == email {
(*Projects).Users[index].ZedToken = zedToken
}
}

storeUserProjectContext()
}

func GetZedToken() *string {
initializeUsersProjectsContext()
email := GetUserEmail()
for _, user := range Projects.Users {
if user.Email == email && user.ZedToken != "" {
return &user.ZedToken
}
}
return nil
}

func storeUserProjectContext() {
projects, err := json.Marshal(Projects)
common.CliExit(err)

err = os.WriteFile(ActiveProjectFilepath, projects, 0644)
common.CliExit(err)
}
11 changes: 9 additions & 2 deletions test/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ var _testConfig TestConfig

func testConfig() *TestConfig {
if (TestConfig{}) == _testConfig {
_ = godotenv.Load()
err := godotenv.Load()

if err != nil && os.Getenv("GITHUB_ACTION") == "" {
fmt.Fprintf(os.Stderr, "Error loading .env file: %v\n", err)
os.Exit(1)
}

_testConfig = TestConfig{
projectId: os.Getenv("STRM_TEST_PROJECT_ID"),
email: os.Getenv("STRM_TEST_USER_EMAIL"),
Expand Down Expand Up @@ -64,7 +70,7 @@ func newConfigDir() string {
_ = os.Setenv("STRM_API_AUTH_URL", "https://accounts.dev.strmprivacy.io")
_ = os.Setenv("STRM_API_HOST", "api.dev.strmprivacy.io:443")
_ = os.Setenv("STRM_HEADLESS", "true")
_ = os.WriteFile(configDir+"/active_project", []byte("default"), 0644)
_ = os.WriteFile(configDir+"/active_projects.json", []byte(fmt.Sprintf(`{"users":[{"email":"%s","active_project":"default"}]}`, os.Getenv("STRM_TEST_USER_EMAIL"))), 0644)
return configDir
}

Expand Down Expand Up @@ -212,6 +218,7 @@ func ExecuteAndVerify(t *testing.T, expected proto.Message, args ...string) {
out, err := TryLoad(outputMessage, output)
if err != nil {
fmt.Println("Can't execute", args)
fmt.Fprintln(os.Stderr, err)
t.Fail()
}
assertProtoEquals(t, out, expected)
Expand Down

0 comments on commit 3598600

Please sign in to comment.