diff --git a/internal/solver/bench_test.go b/internal/solver/bench_test.go index b82d746..ebce225 100644 --- a/internal/solver/bench_test.go +++ b/internal/solver/bench_test.go @@ -68,20 +68,11 @@ var BenchmarkInput = func() []deppy.Variable { func BenchmarkSolve(b *testing.B) { for i := 0; i < b.N; i++ { - s, err := NewSolver(WithInput(BenchmarkInput)) + s, err := New() if err != nil { b.Fatalf("failed to initialize solver: %s", err) } - _, err = s.Solve() - if err != nil { - b.Fatalf("failed to initialize solver: %s", err) - } - } -} - -func BenchmarkNewInput(b *testing.B) { - for i := 0; i < b.N; i++ { - _, err := NewSolver(WithInput(BenchmarkInput)) + _, err = s.Solve(BenchmarkInput) if err != nil { b.Fatalf("failed to initialize solver: %s", err) } diff --git a/internal/solver/search.go b/internal/solver/search.go index f374465..ba8e1c9 100644 --- a/internal/solver/search.go +++ b/internal/solver/search.go @@ -1,8 +1,6 @@ package solver import ( - "context" - "github.com/go-air/gini/inter" "github.com/go-air/gini/z" @@ -157,7 +155,7 @@ func (h *search) Lits() []z.Lit { return result } -func (h *search) Do(_ context.Context, anchors []z.Lit) (int, []z.Lit, map[z.Lit]struct{}) { +func (h *search) Do(anchors []z.Lit) (int, []z.Lit, map[z.Lit]struct{}) { for _, m := range anchors { h.PushChoiceBack(choice{candidates: []z.Lit{m}}) } diff --git a/internal/solver/search_test.go b/internal/solver/search_test.go index 48dfef5..5fda92f 100644 --- a/internal/solver/search_test.go +++ b/internal/solver/search_test.go @@ -3,7 +3,6 @@ package solver import ( - "context" "testing" "github.com/go-air/gini/inter" @@ -96,7 +95,7 @@ func TestSearch(t *testing.T) { anchors = append(anchors, h.lits.LitOf(id)) } - result, ms, _ := h.Do(context.Background(), anchors) + result, ms, _ := h.Do(anchors) assert.Equal(tt.Result, result) var ids []deppy.Identifier diff --git a/internal/solver/solve.go b/internal/solver/solve.go index d34aaf1..c7f6267 100644 --- a/internal/solver/solve.go +++ b/internal/solver/solve.go @@ -1,7 +1,6 @@ package solver import ( - "context" "errors" "fmt" @@ -12,17 +11,8 @@ import ( "github.com/operator-framework/deppy/pkg/deppy" ) -var ErrIncomplete = errors.New("cancelled before a solution could be found") - -type Solver interface { - Solve() ([]deppy.Variable, error) -} - -type solver struct { - g inter.S - litMap *litMapping +type Solver struct { tracer deppy.Tracer - buffer []z.Lit } const ( @@ -33,81 +23,87 @@ const ( // Solve takes a slice containing all Variables and returns a slice // containing only those Variables that were selected for -// installation. If no solution is possible, or if the provided -// Context times out or is cancelled, an error is returned. -func (s *solver) Solve() ([]deppy.Variable, error) { - result, err := s.solve() +// installation. If no solution is possible an error is returned. +func (s *Solver) Solve(input []deppy.Variable) ([]deppy.Variable, error) { + giniSolver := gini.New() + litMap, err := newLitMapping(input) + if err != nil { + return nil, err + } + + result, err := s.solve(giniSolver, litMap) // This likely indicates a bug, so discard whatever // return values were produced. - if derr := s.litMap.Error(); derr != nil { + if derr := litMap.Error(); derr != nil { return nil, derr } return result, err } -func (s *solver) solve() ([]deppy.Variable, error) { +func (s *Solver) solve(giniSolver inter.S, litMap *litMapping) ([]deppy.Variable, error) { // teach all constraints to the solver - s.litMap.AddConstraints(s.g) + litMap.AddConstraints(giniSolver) // collect literals of all mandatory variables to assume as a baseline - anchors := s.litMap.AnchorIdentifiers() + anchors := litMap.AnchorIdentifiers() assumptions := make([]z.Lit, len(anchors)) for i := range anchors { - assumptions[i] = s.litMap.LitOf(anchors[i]) + assumptions[i] = litMap.LitOf(anchors[i]) } // assume that all constraints hold - s.litMap.AssumeConstraints(s.g) - s.g.Assume(assumptions...) + litMap.AssumeConstraints(giniSolver) + giniSolver.Assume(assumptions...) + var buffer []z.Lit var aset map[z.Lit]struct{} // push a new test scope with the baseline assumptions, to prevent them from being cleared during search - outcome, _ := s.g.Test(nil) + outcome, _ := giniSolver.Test(nil) if outcome != satisfiable && outcome != unsatisfiable { // searcher for solutions in input Order, so that preferences - // can be taken into acount (i.e. prefer one catalog to another) - outcome, assumptions, aset = (&search{s: s.g, lits: s.litMap, tracer: s.tracer}).Do(context.Background(), assumptions) + // can be taken into account (i.e. prefer one catalog to another) + outcome, assumptions, aset = (&search{s: giniSolver, lits: litMap, tracer: s.tracer}).Do(assumptions) } switch outcome { case satisfiable: - s.buffer = s.litMap.Lits(s.buffer) + buffer = litMap.Lits(buffer) var extras, excluded []z.Lit - for _, m := range s.buffer { + for _, m := range buffer { if _, ok := aset[m]; ok { continue } - if !s.g.Value(m) { + if !giniSolver.Value(m) { excluded = append(excluded, m.Not()) continue } extras = append(extras, m) } - s.g.Untest() - cs := s.litMap.CardinalityConstrainer(s.g, extras) - s.g.Assume(assumptions...) - s.g.Assume(excluded...) - s.litMap.AssumeConstraints(s.g) - _, s.buffer = s.g.Test(s.buffer) + giniSolver.Untest() + cs := litMap.CardinalityConstrainer(giniSolver, extras) + giniSolver.Assume(assumptions...) + giniSolver.Assume(excluded...) + litMap.AssumeConstraints(giniSolver) + giniSolver.Test(nil) for w := 0; w <= cs.N(); w++ { - s.g.Assume(cs.Leq(w)) - if s.g.Solve() == satisfiable { - return s.litMap.Variables(s.g), nil + giniSolver.Assume(cs.Leq(w)) + if giniSolver.Solve() == satisfiable { + return litMap.Variables(giniSolver), nil } } // Something is wrong if we can't find a model anymore // after optimizing for cardinality. return nil, fmt.Errorf("unexpected internal error") case unsatisfiable: - return nil, deppy.NotSatisfiable(s.litMap.Conflicts(s.g)) + return nil, deppy.NotSatisfiable(litMap.Conflicts(giniSolver)) } - return nil, ErrIncomplete + return nil, errors.New("cancelled before a solution could be found") } -func NewSolver(options ...Option) (Solver, error) { - s := solver{g: gini.New()} +func New(options ...Option) (*Solver, error) { + s := Solver{} for _, option := range append(options, defaults...) { if err := option(&s); err != nil { return nil, err @@ -116,33 +112,17 @@ func NewSolver(options ...Option) (Solver, error) { return &s, nil } -type Option func(s *solver) error - -func WithInput(input []deppy.Variable) Option { - return func(s *solver) error { - var err error - s.litMap, err = newLitMapping(input) - return err - } -} +type Option func(s *Solver) error func WithTracer(t deppy.Tracer) Option { - return func(s *solver) error { + return func(s *Solver) error { s.tracer = t return nil } } var defaults = []Option{ - func(s *solver) error { - if s.litMap == nil { - var err error - s.litMap, err = newLitMapping(nil) - return err - } - return nil - }, - func(s *solver) error { + func(s *Solver) error { if s.tracer == nil { s.tracer = DefaultTracer{} } diff --git a/internal/solver/solve_test.go b/internal/solver/solve_test.go index 2fd4f61..ac9e0f4 100644 --- a/internal/solver/solve_test.go +++ b/internal/solver/solve_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/operator-framework/deppy/pkg/deppy/constraint" @@ -290,12 +291,12 @@ func TestSolve(t *testing.T) { assert := assert.New(t) var traces bytes.Buffer - s, err := NewSolver(WithInput(tt.Variables), WithTracer(LoggingTracer{Writer: &traces})) + s, err := New(WithTracer(LoggingTracer{Writer: &traces})) if err != nil { t.Fatalf("failed to initialize solver: %s", err) } - installed, err := s.Solve() + installed, err := s.Solve(tt.Variables) var ids []deppy.Identifier for _, variable := range installed { @@ -312,9 +313,12 @@ func TestSolve(t *testing.T) { } func TestDuplicateIdentifier(t *testing.T) { - _, err := NewSolver(WithInput([]deppy.Variable{ + s, err := New() + require.NoError(t, err) + + _, err = s.Solve([]deppy.Variable{ variable("a"), variable("a"), - })) + }) assert.Equal(t, DuplicateIdentifier("a"), err) } diff --git a/pkg/deppy/solver/solver.go b/pkg/deppy/solver/solver.go index b2fb8d6..785254c 100644 --- a/pkg/deppy/solver/solver.go +++ b/pkg/deppy/solver/solver.go @@ -46,12 +46,12 @@ func NewDeppySolver() *DeppySolver { } func (d DeppySolver) Solve(vars []deppy.Variable) (*Solution, error) { - satSolver, err := solver.NewSolver(solver.WithInput(vars)) + satSolver, err := solver.New() if err != nil { return nil, err } - selection, err := satSolver.Solve() + selection, err := satSolver.Solve(vars) if err != nil && !errors.As(err, &deppy.NotSatisfiable{}) { return nil, err }