diff --git a/cmd/root.go b/cmd/root.go index 9588afb..22c7537 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,6 +19,7 @@ var ( configPath string dep c.Dependencies ctx c.Context + config c.Config options cli.Options optionsPrint print.Options err error @@ -26,14 +27,16 @@ var ( Version: Version, Use: "ticker", Short: "Terminal stock ticker and stock gain/loss tracker", - Args: cli.Validate(&ctx, &options, &err), + PreRun: initContext, + Args: cli.Validate(&config, &options, &err), Run: cli.Run(ui.Start(&dep, &ctx)), } printCmd = &cobra.Command{ - Use: "print", - Short: "Prints holdings", - Args: cli.Validate(&ctx, &options, &err), - Run: print.Run(&dep, &ctx, &optionsPrint), + Use: "print", + Short: "Prints holdings", + PreRun: initContext, + Args: cli.Validate(&config, &options, &err), + Run: print.Run(&dep, &ctx, &optionsPrint), } ) @@ -63,14 +66,21 @@ func init() { //nolint: gochecknoinits } func initConfig() { - dep, err = cli.GetDependencies() + + dep = cli.GetDependencies() + + config, err = cli.GetConfig(dep, configPath, options) if err != nil { fmt.Println(err) os.Exit(1) } - ctx, err = cli.GetContext(dep, options, configPath) +} + +func initContext(_ *cobra.Command, _ []string) { + + ctx, err = cli.GetContext(dep, config) if err != nil { fmt.Println(err) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 066ee98..efddf2d 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -50,14 +50,14 @@ func Run(uiStartFn func() error) func(*cobra.Command, []string) { } // Validate checks whether config is valid and returns an error if invalid or if an error was generated earlier -func Validate(ctx *c.Context, options *Options, prevErr *error) func(*cobra.Command, []string) error { +func Validate(config *c.Config, options *Options, prevErr *error) func(*cobra.Command, []string) error { return func(_ *cobra.Command, _ []string) error { if prevErr != nil && *prevErr != nil { return *prevErr } - if len(ctx.Config.Watchlist) == 0 && len(options.Watchlist) == 0 && len(ctx.Config.Lots) == 0 && len(ctx.Config.AssetGroup) == 0 { + if len(config.Watchlist) == 0 && len(options.Watchlist) == 0 && len(config.Lots) == 0 && len(config.AssetGroup) == 0 { return errors.New("invalid config: No watchlist provided") //nolint:goerr113 } @@ -65,41 +65,32 @@ func Validate(ctx *c.Context, options *Options, prevErr *error) func(*cobra.Comm } } -func GetDependencies() (c.Dependencies, error) { - - client := yahooClient.New(resty.New(), resty.New()) - err := yahooClient.RefreshSession(client, resty.New()) - - if err != nil { - return c.Dependencies{}, err - } +func GetDependencies() c.Dependencies { return c.Dependencies{ Fs: afero.NewOsFs(), HttpClients: c.DependenciesHttpClients{ Default: resty.New(), - Yahoo: client, + Yahoo: yahooClient.New(resty.New(), resty.New()), }, - }, nil + } } // GetContext builds the context from the config and reference data -func GetContext(d c.Dependencies, options Options, configPath string) (c.Context, error) { +func GetContext(d c.Dependencies, config c.Config) (c.Context, error) { var ( reference c.Reference - config c.Config groups []c.AssetGroup err error ) - config, err = readConfig(d.Fs, configPath) + err = yahooClient.RefreshSession(d.HttpClients.Yahoo, resty.New()) if err != nil { return c.Context{}, err } - config = getConfig(config, options, &d.HttpClients) groups, err = getGroups(config, *d.HttpClients.Default) if err != nil { @@ -152,15 +143,21 @@ func getReference(config c.Config, assetGroups []c.AssetGroup, client resty.Clie } -func getConfig(config c.Config, options Options, httpClients *c.DependenciesHttpClients) c.Config { +func GetConfig(dep c.Dependencies, configPath string, options Options) (c.Config, error) { + + config, err := readConfig(dep.Fs, configPath) + + if err != nil { + return c.Config{}, err + } if len(options.Watchlist) != 0 { config.Watchlist = strings.Split(strings.ReplaceAll(options.Watchlist, " ", ""), ",") } if len(config.Proxy) > 0 { - httpClients.Default.SetProxy(config.Proxy) - httpClients.Yahoo.SetProxy(config.Proxy) + dep.HttpClients.Default.SetProxy(config.Proxy) + dep.HttpClients.Yahoo.SetProxy(config.Proxy) } config.RefreshInterval = getRefreshInterval(options.RefreshInterval, config.RefreshInterval) @@ -172,7 +169,7 @@ func getConfig(config c.Config, options Options, httpClients *c.DependenciesHttp config.Proxy = getStringOption(options.Proxy, config.Proxy) config.Sort = getStringOption(options.Sort, config.Sort) - return config + return config, nil } func getConfigPath(fs afero.Fs, configPathOption string) (string, error) { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 65d2bc3..8180c2a 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -50,7 +50,6 @@ var _ = Describe("Cli", func() { var ( options Options dep c.Dependencies - ctx c.Context ) BeforeEach(func() { @@ -72,7 +71,6 @@ var _ = Describe("Cli", func() { Yahoo: client, }, } - ctx = c.Context{} http.MockTickerSymbols() http.MockResponseCurrency() @@ -112,7 +110,8 @@ var _ = Describe("Cli", func() { Describe("GetContext", func() { - Context("options and configuration", func() { + Context("watchlist and groups", func() { + type Case struct { InputOptions cli.Options InputConfigFileContents string @@ -121,12 +120,13 @@ var _ = Describe("Cli", func() { AssertionCtx types.GomegaMatcher } - DescribeTable("config values", + DescribeTable("context values", func(c Case) { if c.InputConfigFileContents != "" { writeConfigFile(dep.Fs, c.InputConfigFileContents) } - outputCtx, outputErr := GetContext(dep, c.InputOptions, c.InputConfigFilePath) + outputConfig, outputErr := cli.GetConfig(dep, c.InputConfigFilePath, c.InputOptions) + outputCtx, outputErr := cli.GetContext(dep, outputConfig) Expect(outputErr).To(c.AssertionErr) Expect(outputCtx).To(c.AssertionCtx) }, @@ -285,16 +285,68 @@ var _ = Describe("Cli", func() { }), }), }), + ) + + }) + + When("there is an error getting ticker symbols", func() { + + It("returns the error", func() { + + http.MockTickerSymbolsError() + + _, outputErr := GetContext(dep, c.Config{}) + + Expect(outputErr).ToNot(BeNil()) + + }) + + }) + + When("there is an error getting reference data", func() { + + PIt("returns the error", func() { + + http.MockResponseCurrencyError() + + _, outputErr := GetContext(dep, c.Config{}) + + Expect(outputErr).ToNot(BeNil()) + + }) + + }) + + }) + + Describe("GetConfig", func() { + + Context("options and configuration", func() { + type Case struct { + InputOptions cli.Options + InputConfigFileContents string + InputConfigFilePath string + AssertionErr types.GomegaMatcher + AssertionConfig types.GomegaMatcher + } + + DescribeTable("config values", + func(c Case) { + if c.InputConfigFileContents != "" { + writeConfigFile(dep.Fs, c.InputConfigFileContents) + } + outputConfig, outputErr := cli.GetConfig(dep, c.InputConfigFilePath, c.InputOptions) + Expect(outputErr).To(c.AssertionErr) + Expect(outputConfig).To(c.AssertionConfig) + }, // option: string (proxy, sort) Entry("when proxy is set in config file", Case{ InputOptions: cli.Options{}, InputConfigFileContents: "proxy: http://myproxy.com:4438", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Proxy": Equal("http://myproxy.com:4438"), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Proxy": Equal("http://myproxy.com:4438"), }), }), @@ -302,10 +354,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{Proxy: "http://www.example.org:3128"}, InputConfigFileContents: "", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Proxy": Equal("http://www.example.org:3128"), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Proxy": Equal("http://www.example.org:3128"), }), }), @@ -313,10 +363,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{Proxy: "http://www.example.org:3128"}, InputConfigFileContents: "proxy: http://myproxy.com:4438", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Proxy": Equal("http://www.example.org:3128"), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Proxy": Equal("http://www.example.org:3128"), }), }), @@ -325,10 +373,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{}, InputConfigFileContents: "interval: 8", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "RefreshInterval": Equal(8), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "RefreshInterval": Equal(8), }), }), @@ -336,10 +382,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{RefreshInterval: 7}, InputConfigFileContents: "", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "RefreshInterval": Equal(7), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "RefreshInterval": Equal(7), }), }), @@ -347,10 +391,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{RefreshInterval: 7}, InputConfigFileContents: "interval: 8", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "RefreshInterval": Equal(7), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "RefreshInterval": Equal(7), }), }), @@ -358,10 +400,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{}, InputConfigFileContents: "", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "RefreshInterval": Equal(5), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "RefreshInterval": Equal(5), }), }), @@ -370,10 +410,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{}, InputConfigFileContents: "show-separator: true", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Separate": Equal(true), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Separate": Equal(true), }), }), @@ -381,10 +419,8 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{Separate: true}, InputConfigFileContents: "", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Separate": Equal(true), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Separate": Equal(true), }), }), @@ -392,42 +428,12 @@ var _ = Describe("Cli", func() { InputOptions: cli.Options{Separate: false}, InputConfigFileContents: "show-separator: true", AssertionErr: BeNil(), - AssertionCtx: g.MatchFields(g.IgnoreExtras, g.Fields{ - "Config": g.MatchFields(g.IgnoreExtras, g.Fields{ - "Separate": Equal(true), - }), + AssertionConfig: g.MatchFields(g.IgnoreExtras, g.Fields{ + "Separate": Equal(true), }), }), ) - When("there is an error getting ticker symbols", func() { - - It("returns the error", func() { - - http.MockTickerSymbolsError() - - _, outputErr := GetContext(dep, cli.Options{}, "") - - Expect(outputErr).ToNot(BeNil()) - - }) - - }) - - When("there is an error getting reference data", func() { - - PIt("returns the error", func() { - - http.MockResponseCurrencyError() - - _, outputErr := GetContext(dep, cli.Options{}, "") - - Expect(outputErr).ToNot(BeNil()) - - }) - - }) - }) //nolint:errcheck @@ -450,9 +456,9 @@ var _ = Describe("Cli", func() { When("an explicit config file is set", func() { It("should read the config file from disk", func() { inputConfigPath := ".ticker.yaml" - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) - Expect(outputCtx.Config.Watchlist).To(Equal([]string{"NOK"})) + Expect(outputConfig.Watchlist).To(Equal([]string{"NOK"})) Expect(outputErr).To(BeNil()) }) }) @@ -463,10 +469,10 @@ var _ = Describe("Cli", func() { inputHome, _ := homedir.Dir() inputConfigPath := "" depLocal.Fs.MkdirAll(inputHome, 0755) - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) Expect(outputErr).To(BeNil()) - Expect(outputCtx.Config).To(Equal(c.Config{RefreshInterval: 5})) + Expect(outputConfig).To(Equal(c.Config{RefreshInterval: 5})) }) }) When("there is a config file in the home directory", func() { @@ -476,9 +482,9 @@ var _ = Describe("Cli", func() { depLocal.Fs.MkdirAll(inputHome, 0755) depLocal.Fs.Create(inputHome + "/.ticker.yaml") afero.WriteFile(depLocal.Fs, inputHome+"/.ticker.yaml", []byte("watchlist:\n - AMD"), 0644) - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) - Expect(outputCtx.Config.Watchlist).To(Equal([]string{"AMD"})) + Expect(outputConfig.Watchlist).To(Equal([]string{"AMD"})) Expect(outputErr).To(BeNil()) }) }) @@ -489,9 +495,9 @@ var _ = Describe("Cli", func() { depLocal.Fs.MkdirAll(inputCurrentDirectory, 0755) depLocal.Fs.Create(inputCurrentDirectory + "/.ticker.yaml") afero.WriteFile(depLocal.Fs, inputCurrentDirectory+"/.ticker.yaml", []byte("watchlist:\n - JNJ"), 0644) - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) - Expect(outputCtx.Config.Watchlist).To(Equal([]string{"JNJ"})) + Expect(outputConfig.Watchlist).To(Equal([]string{"JNJ"})) Expect(outputErr).To(BeNil()) }) }) @@ -504,10 +510,10 @@ var _ = Describe("Cli", func() { depLocal.Fs.MkdirAll(inputConfigHome, 0755) depLocal.Fs.Create(inputConfigHome + "/.ticker.yaml") afero.WriteFile(depLocal.Fs, inputConfigHome+"/.ticker.yaml", []byte("watchlist:\n - ABNB"), 0644) - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) os.Unsetenv("XDG_CONFIG_HOME") - Expect(outputCtx.Config.Watchlist).To(Equal([]string{"ABNB"})) + Expect(outputConfig.Watchlist).To(Equal([]string{"ABNB"})) Expect(outputErr).To(BeNil()) }) }) @@ -516,9 +522,9 @@ var _ = Describe("Cli", func() { When("there is an error reading the config file", func() { It("returns the error", func() { inputConfigPath := ".config-file-that-does-not-exist.yaml" - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) - Expect(outputCtx.Config).To(Equal(c.Config{})) + Expect(outputConfig).To(Equal(c.Config{})) Expect(outputErr).To(MatchError("invalid config: open .config-file-that-does-not-exist.yaml: file does not exist")) }) }) @@ -527,9 +533,9 @@ var _ = Describe("Cli", func() { It("returns the error", func() { inputConfigPath := ".ticker.yaml" afero.WriteFile(depLocal.Fs, ".ticker.yaml", []byte("watchlist:\n NOK"), 0644) - outputCtx, outputErr := GetContext(depLocal, cli.Options{}, inputConfigPath) + outputConfig, outputErr := GetConfig(depLocal, inputConfigPath, cli.Options{}) - Expect(outputCtx.Config).To(Equal(c.Config{})) + Expect(outputConfig).To(Equal(c.Config{})) Expect(outputErr).To(MatchError("invalid config: yaml: unmarshal errors:\n line 2: cannot unmarshal !!str `NOK` into []string")) }) @@ -539,10 +545,18 @@ var _ = Describe("Cli", func() { Describe("Validate", func() { + var ( + config c.Config + ) + + BeforeEach(func() { + config = c.Config{} + }) + When("a deferred error is passed in", func() { It("validation fails", func() { inputErr := errors.New("some config error") - outputErr := Validate(&c.Context{}, &cli.Options{}, &inputErr)(&cobra.Command{}, []string{}) + outputErr := Validate(&config, &cli.Options{}, &inputErr)(&cobra.Command{}, []string{}) Expect(outputErr).To(MatchError("some config error")) }) }) @@ -551,7 +565,7 @@ var _ = Describe("Cli", func() { When("there is no watchlist in the config file and no watchlist cli argument", func() { It("should return an error", func() { options.Watchlist = "" - outputErr := Validate(&ctx, &options, nil)(&cobra.Command{}, []string{}) + outputErr := Validate(&config, &options, nil)(&cobra.Command{}, []string{}) Expect(outputErr).To(MatchError("invalid config: No watchlist provided")) }) @@ -559,7 +573,7 @@ var _ = Describe("Cli", func() { It("should not return an error", func() { var prevErr error - outputErr := Validate(&ctx, &options, &prevErr)(&cobra.Command{}, []string{}) + outputErr := Validate(&config, &options, &prevErr)(&cobra.Command{}, []string{}) Expect(outputErr).NotTo(HaveOccurred()) }) @@ -568,7 +582,7 @@ var _ = Describe("Cli", func() { When("there are lots set", func() { It("should not return an error", func() { options.Watchlist = "" - ctx.Config = c.Config{ + config = c.Config{ Lots: []c.Lot{ { Symbol: "SYM", @@ -577,7 +591,7 @@ var _ = Describe("Cli", func() { }, }, } - outputErr := Validate(&ctx, &options, nil)(&cobra.Command{}, []string{}) + outputErr := Validate(&config, &options, nil)(&cobra.Command{}, []string{}) Expect(outputErr).NotTo(HaveOccurred()) }) })