Skip to content

Commit

Permalink
context propagation: lapi register, capi register
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 12, 2024
1 parent 4f04b70 commit 6fa0b5c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 16 deletions.
8 changes: 4 additions & 4 deletions cmd/crowdsec-cli/clicapi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (cli *cliCapi) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error {
func (cli *cliCapi) register(ctx context.Context, capiUserPrefix string, outputFile string) error {
cfg := cli.cfg()

capiUser, err := idgen.GenerateMachineID(capiUserPrefix)
Expand All @@ -73,7 +73,7 @@ func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error {
return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err)
}

_, err = apiclient.RegisterClient(&apiclient.Config{
_, err = apiclient.RegisterClient(ctx, &apiclient.Config{
MachineID: capiUser,
Password: password,
URL: apiurl,
Expand Down Expand Up @@ -134,8 +134,8 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command {
Short: "Register to Central API (CAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.register(capiUserPrefix, outputFile)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.register(cmd.Context(), capiUserPrefix, outputFile)
},
}

Expand Down
8 changes: 4 additions & 4 deletions cmd/crowdsec-cli/clilapi/lapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (cli *cliLapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) e
return nil
}

func (cli *cliLapi) register(apiURL string, outputFile string, machine string, token string) error {
func (cli *cliLapi) register(ctx context.Context, apiURL string, outputFile string, machine string, token string) error {
var err error

lapiUser := machine
Expand All @@ -114,7 +114,7 @@ func (cli *cliLapi) register(apiURL string, outputFile string, machine string, t
return fmt.Errorf("parsing api url: %w", err)
}

_, err = apiclient.RegisterClient(&apiclient.Config{
_, err = apiclient.RegisterClient(ctx, &apiclient.Config{
MachineID: lapiUser,
Password: password,
RegistrationToken: token,
Expand Down Expand Up @@ -223,8 +223,8 @@ func (cli *cliLapi) newRegisterCmd() *cobra.Command {
Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`,
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.register(apiURL, outputFile, machine, token)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.register(cmd.Context(), apiURL, outputFile, machine, token)
},
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/apiclient/auth_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ func TestWatcherRegister(t *testing.T) {
VersionPrefix: "v1",
}

client, err := RegisterClient(&clientconfig, &http.Client{})
ctx := context.Background()

client, err := RegisterClient(ctx, &clientconfig, &http.Client{})
require.NoError(t, err)

log.Printf("->%T", client)
Expand All @@ -102,7 +104,7 @@ func TestWatcherRegister(t *testing.T) {
for _, errorCodeToTest := range errorCodesToTest {
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)

client, err = RegisterClient(&clientconfig, &http.Client{})
client, err = RegisterClient(ctx, &clientconfig, &http.Client{})
require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest)
require.Error(t, err, "error expected for the response code %d", errorCodeToTest)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
return c, nil
}

func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
func RegisterClient(ctx context.Context, config *Config, client *http.Client) (*ApiClient, error) {
transport, baseURL := createTransport(config.URL)

if client == nil {
Expand Down Expand Up @@ -199,7 +199,7 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
c.Alerts = (*AlertsService)(&c.common)
c.Auth = (*AuthService)(&c.common)

resp, err := c.Auth.RegisterWatcher(context.Background(), models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken})
resp, err := c.Auth.RegisterWatcher(ctx, models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken})
if err != nil {
/*if we have http status, return it*/
if resp != nil && resp.Response != nil {
Expand Down
16 changes: 12 additions & 4 deletions pkg/apiclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/")
require.NoError(t, err)

_, err = RegisterClient(&Config{
ctx := context.Background()

_, err = RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -272,7 +274,9 @@ func TestNewClientRegisterOK(t *testing.T) {
apiURL, err := url.Parse(urlx + "/")
require.NoError(t, err)

client, err := RegisterClient(&Config{
ctx := context.Background()

client, err := RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -304,7 +308,9 @@ func TestNewClientRegisterOK_UnixSocket(t *testing.T) {
t.Fatalf("parsing api url: %s", apiURL)
}

client, err := RegisterClient(&Config{
ctx := context.Background()

client, err := RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down Expand Up @@ -333,7 +339,9 @@ func TestNewClientBadAnswer(t *testing.T) {
apiURL, err := url.Parse(urlx + "/")
require.NoError(t, err)

_, err = RegisterClient(&Config{
ctx := context.Background()

_, err = RegisterClient(ctx, &Config{
MachineID: "test_login",
Password: "test_password",
URL: apiURL,
Expand Down

0 comments on commit 6fa0b5c

Please sign in to comment.