Skip to content

Commit

Permalink
context propagation: cscli {capi,lapi,papi} (#3228)
Browse files Browse the repository at this point in the history
* context propagation: lapi status, capi status, papi status

* context propagation: lapi register, capi register

* lint
  • Loading branch information
mmetc authored Sep 12, 2024
1 parent 6810b41 commit 8a74fae
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 48 deletions.
20 changes: 10 additions & 10 deletions cmd/crowdsec-cli/clicapi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (cli *cliCapi) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error {
func (cli *cliCapi) register(ctx context.Context, capiUserPrefix string, outputFile string) error {
cfg := cli.cfg()

capiUser, err := idgen.GenerateMachineID(capiUserPrefix)
Expand All @@ -73,7 +73,7 @@ func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error {
return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err)
}

_, err = apiclient.RegisterClient(&apiclient.Config{
_, err = apiclient.RegisterClient(ctx, &apiclient.Config{
MachineID: capiUser,
Password: password,
URL: apiurl,
Expand Down Expand Up @@ -134,8 +134,8 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command {
Short: "Register to Central API (CAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.register(capiUserPrefix, outputFile)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.register(cmd.Context(), capiUserPrefix, outputFile)
},
}

Expand All @@ -148,7 +148,7 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command {
}

// queryCAPIStatus checks if the Central API is reachable, and if the credentials are correct. It then checks if the instance is enrolle in the console.
func queryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) {
func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) {
apiURL, err := url.Parse(credURL)
if err != nil {
return false, false, err
Expand Down Expand Up @@ -186,7 +186,7 @@ func queryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri
Scenarios: itemsForAPI,
}

authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), t)
authResp, _, err := client.Auth.AuthenticateWatcher(ctx, t)
if err != nil {
return false, false, err
}
Expand All @@ -200,7 +200,7 @@ func queryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri
return true, false, nil
}

func (cli *cliCapi) Status(out io.Writer, hub *cwhub.Hub) error {
func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error {
cfg := cli.cfg()

if err := require.CAPIRegistered(cfg); err != nil {
Expand All @@ -212,7 +212,7 @@ func (cli *cliCapi) Status(out io.Writer, hub *cwhub.Hub) error {
fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath)
fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL)

auth, enrolled, err := queryCAPIStatus(hub, cred.URL, cred.Login, cred.Password)
auth, enrolled, err := queryCAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password)
if err != nil {
return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err)
}
Expand All @@ -234,13 +234,13 @@ func (cli *cliCapi) newStatusCmd() *cobra.Command {
Short: "Check status with the Central API (CAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
hub, err := require.Hub(cli.cfg(), nil, nil)
if err != nil {
return err
}

return cli.Status(color.Output, hub)
return cli.Status(cmd.Context(), color.Output, hub)
},
}

Expand Down
30 changes: 15 additions & 15 deletions cmd/crowdsec-cli/clilapi/lapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func New(cfg configGetter) *cliLapi {
}

// queryLAPIStatus checks if the Local API is reachable, and if the credentials are correct.
func queryLAPIStatus(hub *cwhub.Hub, credURL string, login string, password string) (bool, error) {
func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, error) {
apiURL, err := url.Parse(credURL)
if err != nil {
return false, err
Expand Down Expand Up @@ -76,15 +76,15 @@ func queryLAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri
return true, nil
}

func (cli *cliLapi) Status(out io.Writer, hub *cwhub.Hub) error {
func (cli *cliLapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error {
cfg := cli.cfg()

cred := cfg.API.Client.Credentials

fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Client.CredentialsFilePath)
fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL)

_, err := queryLAPIStatus(hub, cred.URL, cred.Login, cred.Password)
_, err := queryLAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password)
if err != nil {
return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err)
}
Expand All @@ -94,7 +94,7 @@ func (cli *cliLapi) Status(out io.Writer, hub *cwhub.Hub) error {
return nil
}

func (cli *cliLapi) register(apiURL string, outputFile string, machine string, token string) error {
func (cli *cliLapi) register(ctx context.Context, apiURL string, outputFile string, machine string, token string) error {
var err error

lapiUser := machine
Expand All @@ -114,7 +114,7 @@ func (cli *cliLapi) register(apiURL string, outputFile string, machine string, t
return fmt.Errorf("parsing api url: %w", err)
}

_, err = apiclient.RegisterClient(&apiclient.Config{
_, err = apiclient.RegisterClient(ctx, &apiclient.Config{
MachineID: lapiUser,
Password: password,
RegistrationToken: token,
Expand Down Expand Up @@ -195,13 +195,13 @@ func (cli *cliLapi) newStatusCmd() *cobra.Command {
Short: "Check authentication to Local API (LAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
hub, err := require.Hub(cli.cfg(), nil, nil)
if err != nil {
return err
}

return cli.Status(color.Output, hub)
return cli.Status(cmd.Context(), color.Output, hub)
},
}

Expand All @@ -223,8 +223,8 @@ func (cli *cliLapi) newRegisterCmd() *cobra.Command {
Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`,
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.register(apiURL, outputFile, machine, token)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.register(cmd.Context(), apiURL, outputFile, machine, token)
},
}

Expand Down Expand Up @@ -513,14 +513,14 @@ func detectStaticField(grokStatics []parser.ExtraField) []string {

for _, static := range grokStatics {
if static.Parsed != "" {
fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed)
fieldName := "evt.Parsed." + static.Parsed
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
}

if static.Meta != "" {
fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta)
fieldName := "evt.Meta." + static.Meta
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
Expand All @@ -546,7 +546,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string {

if node.Grok.RunTimeRegexp != nil {
for _, capturedField := range node.Grok.RunTimeRegexp.Names() {
fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField)
fieldName := "evt.Parsed." + capturedField
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
Expand All @@ -558,7 +558,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string {
// ignore error (parser does not exist?)
if err == nil {
for _, capturedField := range grokCompiled.Names() {
fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField)
fieldName := "evt.Parsed." + capturedField
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
Expand Down Expand Up @@ -593,7 +593,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string {
for _, subnode := range node.LeavesNodes {
if subnode.Grok.RunTimeRegexp != nil {
for _, capturedField := range subnode.Grok.RunTimeRegexp.Names() {
fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField)
fieldName := "evt.Parsed." + capturedField
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
Expand All @@ -605,7 +605,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string {
if err == nil {
// ignore error (parser does not exist?)
for _, capturedField := range grokCompiled.Names() {
fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField)
fieldName := "evt.Parsed." + capturedField
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
Expand Down
11 changes: 7 additions & 4 deletions cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clipapi

import (
"context"
"fmt"
"io"
"time"
Expand Down Expand Up @@ -55,7 +56,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliPapi) Status(out io.Writer, db *database.Client) error {
func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error {
cfg := cli.cfg()

apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
Expand All @@ -68,7 +69,7 @@ func (cli *cliPapi) Status(out io.Writer, db *database.Client) error {
return fmt.Errorf("unable to initialize PAPI client: %w", err)
}

perms, err := papi.GetPermissions()
perms, err := papi.GetPermissions(ctx)
if err != nil {
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}
Expand Down Expand Up @@ -103,12 +104,14 @@ func (cli *cliPapi) newStatusCmd() *cobra.Command {
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg()
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
ctx := cmd.Context()

db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil {
return err
}

return cli.Status(color.Output, db)
return cli.Status(ctx, color.Output, db)
},
}

Expand Down
18 changes: 9 additions & 9 deletions cmd/crowdsec-cli/clisupport/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
return nil
}

func (cli *cliSupport) dumpLAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error {
func (cli *cliSupport) dumpLAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error {
log.Info("Collecting LAPI status")

out := new(bytes.Buffer)
cl := clilapi.New(cli.cfg)

err := cl.Status(out, hub)
err := cl.Status(ctx, out, hub)
if err != nil {
fmt.Fprintf(out, "%s\n", err)
}
Expand All @@ -249,13 +249,13 @@ func (cli *cliSupport) dumpLAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error {
return nil
}

func (cli *cliSupport) dumpCAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error {
func (cli *cliSupport) dumpCAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error {
log.Info("Collecting CAPI status")

out := new(bytes.Buffer)
cc := clicapi.New(cli.cfg)

err := cc.Status(out, hub)
err := cc.Status(ctx, out, hub)
if err != nil {
fmt.Fprintf(out, "%s\n", err)
}
Expand All @@ -267,13 +267,13 @@ func (cli *cliSupport) dumpCAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error {
return nil
}

func (cli *cliSupport) dumpPAPIStatus(zw *zip.Writer, db *database.Client) error {
func (cli *cliSupport) dumpPAPIStatus(ctx context.Context, zw *zip.Writer, db *database.Client) error {
log.Info("Collecting PAPI status")

out := new(bytes.Buffer)
cp := clipapi.New(cli.cfg)

err := cp.Status(out, db)
err := cp.Status(ctx, out, db)
if err != nil {
fmt.Fprintf(out, "%s\n", err)
}
Expand Down Expand Up @@ -534,17 +534,17 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
}

if !skipCAPI {
if err = cli.dumpCAPIStatus(zipWriter, hub); err != nil {
if err = cli.dumpCAPIStatus(ctx, zipWriter, hub); err != nil {
log.Warnf("could not collect CAPI status: %s", err)
}

if err = cli.dumpPAPIStatus(zipWriter, db); err != nil {
if err = cli.dumpPAPIStatus(ctx, zipWriter, db); err != nil {
log.Warnf("could not collect PAPI status: %s", err)
}
}

if !skipLAPI {
if err = cli.dumpLAPIStatus(zipWriter, hub); err != nil {
if err = cli.dumpLAPIStatus(ctx, zipWriter, hub); err != nil {
log.Warnf("could not collect LAPI status: %s", err)
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/apiclient/auth_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ func TestWatcherRegister(t *testing.T) {
VersionPrefix: "v1",
}

client, err := RegisterClient(&clientconfig, &http.Client{})
ctx := context.Background()

client, err := RegisterClient(ctx, &clientconfig, &http.Client{})
require.NoError(t, err)

log.Printf("->%T", client)
Expand All @@ -102,7 +104,7 @@ func TestWatcherRegister(t *testing.T) {
for _, errorCodeToTest := range errorCodesToTest {
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)

client, err = RegisterClient(&clientconfig, &http.Client{})
client, err = RegisterClient(ctx, &clientconfig, &http.Client{})
require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest)
require.Error(t, err, "error expected for the response code %d", errorCodeToTest)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
return c, nil
}

func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
func RegisterClient(ctx context.Context, config *Config, client *http.Client) (*ApiClient, error) {
transport, baseURL := createTransport(config.URL)

if client == nil {
Expand Down Expand Up @@ -199,7 +199,7 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
c.Alerts = (*AlertsService)(&c.common)
c.Auth = (*AuthService)(&c.common)

resp, err := c.Auth.RegisterWatcher(context.Background(), models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken})
resp, err := c.Auth.RegisterWatcher(ctx, models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken})
if err != nil {
/*if we have http status, return it*/
if resp != nil && resp.Response != nil {
Expand Down
16 changes: 12 additions & 4 deletions pkg/apiclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/")
require.NoError(t, err)

_, err = RegisterClient(&Config{
ctx := context.Background()

_, err = RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -272,7 +274,9 @@ func TestNewClientRegisterOK(t *testing.T) {
apiURL, err := url.Parse(urlx + "/")
require.NoError(t, err)

client, err := RegisterClient(&Config{
ctx := context.Background()

client, err := RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -304,7 +308,9 @@ func TestNewClientRegisterOK_UnixSocket(t *testing.T) {
t.Fatalf("parsing api url: %s", apiURL)
}

client, err := RegisterClient(&Config{
ctx := context.Background()

client, err := RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -333,7 +339,9 @@ func TestNewClientBadAnswer(t *testing.T) {
apiURL, err := url.Parse(urlx + "/")
require.NoError(t, err)

_, err = RegisterClient(&Config{
ctx := context.Background()

_, err = RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down
Loading

0 comments on commit 8a74fae

Please sign in to comment.