diff --git a/v2/pkg/stratus/runner/runner.go b/v2/pkg/stratus/runner/runner.go index ecaaa2b28..7c2d3d6c0 100644 --- a/v2/pkg/stratus/runner/runner.go +++ b/v2/pkg/stratus/runner/runner.go @@ -14,7 +14,7 @@ import ( const StratusRunnerForce = true const StratusRunnerNoForce = false -type Runner struct { +type runnerImpl struct { Technique *stratus.AttackTechnique TechniqueState stratus.AttackTechniqueState TerraformDir string @@ -25,10 +25,21 @@ type Runner struct { UniqueCorrelationID uuid.UUID } +type Runner interface { + WarmUp() (map[string]string, error) + Detonate() error + Revert() error + CleanUp() error + GetState() stratus.AttackTechniqueState + GetUniqueExecutionId() string +} + +var _ Runner = &runnerImpl{} + func NewRunner(technique *stratus.AttackTechnique, force bool) Runner { stateManager := state.NewFileSystemStateManager(technique) uuid := uuid.New() - runner := Runner{ + runner := &runnerImpl{ Technique: technique, ShouldForce: force, StateManager: stateManager, @@ -42,7 +53,7 @@ func NewRunner(technique *stratus.AttackTechnique, force bool) Runner { return runner } -func (m *Runner) initialize() { +func (m *runnerImpl) initialize() { m.TerraformDir = filepath.Join(m.StateManager.GetRootDirectory(), m.Technique.ID) m.TechniqueState = m.StateManager.GetTechniqueState() if m.TechniqueState == "" { @@ -51,7 +62,7 @@ func (m *Runner) initialize() { m.ProviderFactory = stratus.CloudProvidersImpl{UniqueCorrelationID: m.UniqueCorrelationID} } -func (m *Runner) WarmUp() (map[string]string, error) { +func (m *runnerImpl) WarmUp() (map[string]string, error) { // No prerequisites to spin-up if m.Technique.PrerequisitesTerraformCode == nil { return map[string]string{}, nil @@ -100,7 +111,7 @@ func (m *Runner) WarmUp() (map[string]string, error) { return outputs, err } -func (m *Runner) Detonate() error { +func (m *runnerImpl) Detonate() error { willWarmUp := true var err error var outputs map[string]string @@ -137,7 +148,7 @@ func (m *Runner) Detonate() error { return nil } -func (m *Runner) Revert() error { +func (m *runnerImpl) Revert() error { if m.GetState() != stratus.AttackTechniqueStatusDetonated && !m.ShouldForce { return errors.New(m.Technique.ID + " is not in DETONATED state and should not need to be reverted, use --force to force") } @@ -161,7 +172,7 @@ func (m *Runner) Revert() error { return nil } -func (m *Runner) CleanUp() error { +func (m *runnerImpl) CleanUp() error { // Has the technique already been cleaned up? if m.TechniqueState == stratus.AttackTechniqueStatusCold && !m.ShouldForce { return errors.New(m.Technique.ID + " is already COLD and should already be clean, use --force to force cleanup") @@ -201,11 +212,11 @@ func (m *Runner) CleanUp() error { return nil } -func (m *Runner) GetState() stratus.AttackTechniqueState { +func (m *runnerImpl) GetState() stratus.AttackTechniqueState { return m.TechniqueState } -func (m *Runner) setState(state stratus.AttackTechniqueState) { +func (m *runnerImpl) setState(state stratus.AttackTechniqueState) { err := m.StateManager.SetTechniqueState(state) if err != nil { log.Println("Warning: unable to set technique state: " + err.Error()) @@ -214,7 +225,7 @@ func (m *Runner) setState(state stratus.AttackTechniqueState) { } // GetUniqueExecutionId returns an unique execution ID, unique for each runner instance -func (m *Runner) GetUniqueExecutionId() string { +func (m *runnerImpl) GetUniqueExecutionId() string { return m.UniqueCorrelationID.String() } diff --git a/v2/pkg/stratus/runner/runner_test.go b/v2/pkg/stratus/runner/runner_test.go index ee4ba0fc5..91c2ba0a3 100644 --- a/v2/pkg/stratus/runner/runner_test.go +++ b/v2/pkg/stratus/runner/runner_test.go @@ -119,7 +119,7 @@ func TestRunnerWarmUp(t *testing.T) { state.On("WriteTerraformOutputs", mock.Anything).Return(nil) state.On("SetTechniqueState", mock.Anything).Return(nil) - runner := Runner{ + runner := runnerImpl{ Technique: scenario[i].Technique, ShouldForce: scenario[i].ShouldForce, TerraformManager: terraform, @@ -211,7 +211,7 @@ func TestRunnerDetonate(t *testing.T) { state.On("SetTechniqueState", mock.Anything).Return(nil) var wasDetonated = false - runner := Runner{ + runner := runnerImpl{ Technique: &stratus.AttackTechnique{ ID: "sample-technique", Detonate: func(map[string]string, stratus.CloudProviders) error { @@ -300,7 +300,7 @@ func TestRunnerRevert(t *testing.T) { state.On("SetTechniqueState", mock.Anything).Return(nil) var wasReverted = false - runner := Runner{ + runner := runnerImpl{ Technique: &stratus.AttackTechnique{ ID: "foo", Detonate: func(map[string]string, stratus.CloudProviders) error { return nil }, @@ -446,7 +446,7 @@ func TestRunnerCleanup(t *testing.T) { return errors.New("nope") } } - runner := Runner{ + runner := runnerImpl{ Technique: scenario[i].Technique, ShouldForce: scenario[i].ShouldForce, TerraformManager: terraform,