Skip to content

Commit

Permalink
context propagation: appsec, docker, kafka, k8s datasources
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Oct 25, 2024
1 parent d00a6a6 commit d8f15e9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pkg/acquisition/modules/appsec/appsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.
w.logger.Info("Shutting down Appsec server")
// xx let's clean up the appsec runners :)
appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails)
w.server.Shutdown(context.TODO())
w.server.Shutdown(ctx)
return nil
})
return nil
Expand Down
47 changes: 24 additions & 23 deletions pkg/acquisition/modules/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ func (d *DockerSource) SupportedModes() []string {

// OneShotAcquisition reads a set of file and returns when done
func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error {
ctx := context.TODO()
d.logger.Debug("In oneshot")
runningContainer, err := d.Client.ContainerList(context.Background(), dockerTypes.ContainerListOptions{})
runningContainer, err := d.Client.ContainerList(ctx, dockerTypes.ContainerListOptions{})
if err != nil {
return err
}
Expand All @@ -298,10 +299,10 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er
d.logger.Debugf("container with id %s is already being read from", container.ID)
continue
}
if containerConfig := d.EvalContainer(container); containerConfig != nil {
if containerConfig := d.EvalContainer(ctx, container); containerConfig != nil {
d.logger.Infof("reading logs from container %s", containerConfig.Name)
d.logger.Debugf("logs options: %+v", *d.containerLogsOptions)
dockerReader, err := d.Client.ContainerLogs(context.Background(), containerConfig.ID, *d.containerLogsOptions)
dockerReader, err := d.Client.ContainerLogs(ctx, containerConfig.ID, *d.containerLogsOptions)
if err != nil {
d.logger.Errorf("unable to read logs from container: %+v", err)
return err
Expand Down Expand Up @@ -372,26 +373,26 @@ func (d *DockerSource) CanRun() error {
return nil
}

func (d *DockerSource) getContainerTTY(containerId string) bool {
containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId)
func (d *DockerSource) getContainerTTY(ctx context.Context, containerId string) bool {
containerDetails, err := d.Client.ContainerInspect(ctx, containerId)
if err != nil {
return false
}
return containerDetails.Config.Tty
}

func (d *DockerSource) getContainerLabels(containerId string) map[string]interface{} {
containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId)
func (d *DockerSource) getContainerLabels(ctx context.Context, containerId string) map[string]interface{} {
containerDetails, err := d.Client.ContainerInspect(ctx, containerId)
if err != nil {
return map[string]interface{}{}
}
return parseLabels(containerDetails.Config.Labels)
}

func (d *DockerSource) EvalContainer(container dockerTypes.Container) *ContainerConfig {
func (d *DockerSource) EvalContainer(ctx context.Context, container dockerTypes.Container) *ContainerConfig {
for _, containerID := range d.Config.ContainerID {
if containerID == container.ID {
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)}
}
}

Expand All @@ -401,27 +402,27 @@ func (d *DockerSource) EvalContainer(container dockerTypes.Container) *Container
name = name[1:]
}
if name == containerName {
return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}
return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)}
}
}
}

for _, cont := range d.compiledContainerID {
if matched := cont.MatchString(container.ID); matched {
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)}
}
}

for _, cont := range d.compiledContainerName {
for _, name := range container.Names {
if matched := cont.MatchString(name); matched {
return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)}
return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)}
}
}
}

if d.Config.UseContainerLabels {
parsedLabels := d.getContainerLabels(container.ID)
parsedLabels := d.getContainerLabels(ctx, container.ID)
if len(parsedLabels) == 0 {
d.logger.Tracef("container has no 'crowdsec' labels set, ignoring container: %s", container.ID)
return nil
Expand Down Expand Up @@ -458,13 +459,13 @@ func (d *DockerSource) EvalContainer(container dockerTypes.Container) *Container
}
d.logger.Errorf("label %s is not a string", k)
}
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(container.ID)}
return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(ctx, container.ID)}
}

return nil
}

func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error {
func (d *DockerSource) WatchContainer(ctx context.Context, monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error {
ticker := time.NewTicker(d.CheckIntervalDuration)
d.logger.Infof("Container watcher started, interval: %s", d.CheckIntervalDuration.String())
for {
Expand All @@ -475,7 +476,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha
case <-ticker.C:
// to track for garbage collection
runningContainersID := make(map[string]bool)
runningContainer, err := d.Client.ContainerList(context.Background(), dockerTypes.ContainerListOptions{})
runningContainer, err := d.Client.ContainerList(ctx, dockerTypes.ContainerListOptions{})
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "cannot connect to the docker daemon at") {
for idx, container := range d.runningContainerState {
Expand All @@ -501,7 +502,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha
if _, ok := d.runningContainerState[container.ID]; ok {
continue
}
if containerConfig := d.EvalContainer(container); containerConfig != nil {
if containerConfig := d.EvalContainer(ctx, container); containerConfig != nil {
monitChan <- containerConfig
}
}
Expand All @@ -524,10 +525,10 @@ func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.
deleteChan := make(chan *ContainerConfig)
d.logger.Infof("Starting docker acquisition")
t.Go(func() error {
return d.DockerManager(monitChan, deleteChan, out)
return d.DockerManager(ctx, monitChan, deleteChan, out)
})

return d.WatchContainer(monitChan, deleteChan)
return d.WatchContainer(ctx, monitChan, deleteChan)
}

func (d *DockerSource) Dump() interface{} {
Expand All @@ -541,9 +542,9 @@ func ReadTailScanner(scanner *bufio.Scanner, out chan string, t *tomb.Tomb) erro
return scanner.Err()
}

func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types.Event, deleteChan chan *ContainerConfig) error {
func (d *DockerSource) TailDocker(ctx context.Context, container *ContainerConfig, outChan chan types.Event, deleteChan chan *ContainerConfig) error {
container.logger.Infof("start tail for container %s", container.Name)
dockerReader, err := d.Client.ContainerLogs(context.Background(), container.ID, *d.containerLogsOptions)
dockerReader, err := d.Client.ContainerLogs(ctx, container.ID, *d.containerLogsOptions)
if err != nil {
container.logger.Errorf("unable to read logs from container: %+v", err)
return err
Expand Down Expand Up @@ -601,7 +602,7 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types
}
}

func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan *ContainerConfig, outChan chan types.Event) error {
func (d *DockerSource) DockerManager(ctx context.Context, in chan *ContainerConfig, deleteChan chan *ContainerConfig, outChan chan types.Event) error {
d.logger.Info("DockerSource Manager started")
for {
select {
Expand All @@ -610,7 +611,7 @@ func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan *
newContainer.t = &tomb.Tomb{}
newContainer.logger = d.logger.WithField("container_name", newContainer.Name)
newContainer.t.Go(func() error {
return d.TailDocker(newContainer, outChan, deleteChan)
return d.TailDocker(ctx, newContainer, outChan, deleteChan)
})
d.runningContainerState[newContainer.ID] = newContainer
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/acquisition/modules/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ func (k *KafkaSource) Dump() interface{} {
return k
}

func (k *KafkaSource) ReadMessage(out chan types.Event) error {
func (k *KafkaSource) ReadMessage(ctx context.Context, out chan types.Event) error {
// Start processing from latest Offset
k.Reader.SetOffsetAt(context.Background(), time.Now())
k.Reader.SetOffsetAt(ctx, time.Now())
for {
k.logger.Tracef("reading message from topic '%s'", k.Config.Topic)
m, err := k.Reader.ReadMessage(context.Background())
m, err := k.Reader.ReadMessage(ctx)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
Expand Down Expand Up @@ -184,10 +184,10 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error {
}
}

func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error {
func (k *KafkaSource) RunReader(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
k.logger.Debugf("starting %s datasource reader goroutine with configuration %+v", dataSourceName, k.Config)
t.Go(func() error {
return k.ReadMessage(out)
return k.ReadMessage(ctx, out)
})
//nolint //fp
for {
Expand All @@ -207,7 +207,7 @@ func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.E

t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/kafka/live")
return k.RunReader(out, t)
return k.RunReader(ctx, out, t)
})

return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/acquisition/modules/kubernetesaudit/k8s_audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out c
})
<-t.Dying()
ka.logger.Infof("Stopping k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath)
ka.server.Shutdown(context.TODO())
ka.server.Shutdown(ctx)
return nil
})
return nil
Expand Down

0 comments on commit d8f15e9

Please sign in to comment.