diff --git a/LICENSE b/LICENSE index 7252472..5018ee3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ BSD 3-Clause License -Copyright (c) 2024, conda +Copyright (c) 2012, Anaconda, Inc. +All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/conda_classic_solver/__init__.py b/conda_classic_solver/__init__.py index 6f13c8d..8acdb4d 100644 --- a/conda_classic_solver/__init__.py +++ b/conda_classic_solver/__init__.py @@ -1,8 +1,8 @@ -# Copyright (C) 2022 Anaconda, Inc +# Copyright (C) 2012 Anaconda, Inc # Copyright (C) 2023 conda # SPDX-License-Identifier: BSD-3-Clause """ The conda_classic_solver package """ -from .solver import ClassicSolver # noqa +from .solve import ClassicSolver # noqa diff --git a/conda_classic_solver/_logic.py b/conda_classic_solver/_logic.py new file mode 100644 index 0000000..35380bb --- /dev/null +++ b/conda_classic_solver/_logic.py @@ -0,0 +1,780 @@ +# Copyright (C) 2012 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause +import sys +from array import array +from itertools import combinations +from logging import DEBUG, getLogger + +from conda.common.constants import TRACE + +log = getLogger(__name__) + + +TRUE = sys.maxsize +FALSE = -TRUE + + +class _ClauseList: + """Storage for the CNF clauses, represented as a list of tuples of ints.""" + + def __init__(self): + self._clause_list = [] + # Methods append and extend are directly bound for performance reasons, + # to avoid call overhead and lookups. + self.append = self._clause_list.append + self.extend = self._clause_list.extend + + def get_clause_count(self): + """Return number of stored clauses.""" + return len(self._clause_list) + + def save_state(self): + """ + Get state information to be able to revert temporary additions of + supplementary clauses. _ClauseList: state is simply the number of clauses. + """ + return len(self._clause_list) + + def restore_state(self, saved_state): + """ + Restore state saved via `save_state`. + Removes clauses that were added after the state has been saved. + """ + len_clauses = saved_state + self._clause_list[len_clauses:] = [] + + def as_list(self): + """Return clauses as a list of tuples of ints.""" + return self._clause_list + + def as_array(self): + """Return clauses as a flat int array, each clause being terminated by 0.""" + clause_array = array("i") + for c in self._clause_list: + clause_array.extend(c) + clause_array.append(0) + return clause_array + + +class _ClauseArray: + """ + Storage for the CNF clauses, represented as a flat int array. + Each clause is terminated by int(0). + """ + + def __init__(self): + self._clause_array = array("i") + # Methods append and extend are directly bound for performance reasons, + # to avoid call overhead and lookups. + self._array_append = self._clause_array.append + self._array_extend = self._clause_array.extend + + def extend(self, clauses): + for clause in clauses: + self.append(clause) + + def append(self, clause): + self._array_extend(clause) + self._array_append(0) + + def get_clause_count(self): + """ + Return number of stored clauses. + This is an O(n) operation since we don't store the number of clauses + explicitly due to performance reasons (Python interpreter overhead in + self.append). + """ + return self._clause_array.count(0) + + def save_state(self): + """ + Get state information to be able to revert temporary additions of + supplementary clauses. _ClauseArray: state is the length of the int + array, NOT number of clauses. + """ + return len(self._clause_array) + + def restore_state(self, saved_state): + """ + Restore state saved via `save_state`. + Removes clauses that were added after the state has been saved. + """ + len_clause_array = saved_state + self._clause_array[len_clause_array:] = array("i") + + def as_list(self): + """Return clauses as a list of tuples of ints.""" + clause = [] + for v in self._clause_array: + if v == 0: + yield tuple(clause) + clause.clear() + else: + clause.append(v) + + def as_array(self): + """Return clauses as a flat int array, each clause being terminated by 0.""" + return self._clause_array + + +class _SatSolver: + """Simple wrapper to call a SAT solver given a _ClauseList/_ClauseArray instance.""" + + def __init__(self, **run_kwargs): + self._run_kwargs = run_kwargs or {} + self._clauses = _ClauseList() + # Bind some methods of _clauses to reduce lookups and call overhead. + self.add_clause = self._clauses.append + self.add_clauses = self._clauses.extend + + def get_clause_count(self): + return self._clauses.get_clause_count() + + def as_list(self): + return self._clauses.as_list() + + def save_state(self): + return self._clauses.save_state() + + def restore_state(self, saved_state): + return self._clauses.restore_state(saved_state) + + def run(self, m, **kwargs): + run_kwargs = self._run_kwargs.copy() + run_kwargs.update(kwargs) + solver = self.setup(m, **run_kwargs) + sat_solution = self.invoke(solver) + solution = self.process_solution(sat_solution) + return solution + + def setup(self, m, **kwargs): + """Create a solver instance, add the clauses to it, and return it.""" + raise NotImplementedError() + + def invoke(self, solver): + """Start the actual SAT solving and return the calculated solution.""" + raise NotImplementedError() + + def process_solution(self, sat_solution): + """ + Process the solution returned by self.invoke. + Returns a list of satisfied variables or None if no solution is found. + """ + raise NotImplementedError() + + +class _PycoSatSolver(_SatSolver): + def setup(self, m, limit=0, **kwargs): + from pycosat import itersolve + + # NOTE: The iterative solving isn't actually used here, we just call + # itersolve to separate setup from the actual run. + return itersolve(self._clauses.as_list(), vars=m, prop_limit=limit) + # If we add support for passing the clauses as an integer stream to the + # solvers, we could also use self._clauses.as_array like this: + # return itersolve(self._clauses.as_array(), vars=m, prop_limit=limit) + + def invoke(self, iter_sol): + try: + sat_solution = next(iter_sol) + except StopIteration: + sat_solution = "UNSAT" + del iter_sol + return sat_solution + + def process_solution(self, sat_solution): + if sat_solution in ("UNSAT", "UNKNOWN"): + return None + return sat_solution + + +class _PyCryptoSatSolver(_SatSolver): + def setup(self, m, threads=1, **kwargs): + from pycryptosat import Solver + + solver = Solver(threads=threads) + solver.add_clauses(self._clauses.as_list()) + return solver + + def invoke(self, solver): + sat, sat_solution = solver.solve() + if not sat: + sat_solution = None + return sat_solution + + def process_solution(self, solution): + if not solution: + return None + # The first element of the solution is always None. + solution = [i for i, b in enumerate(solution) if b] + return solution + + +class _PySatSolver(_SatSolver): + def setup(self, m, **kwargs): + from pysat.solvers import Glucose4 + + solver = Glucose4() + solver.append_formula(self._clauses.as_list()) + return solver + + def invoke(self, solver): + if not solver.solve(): + sat_solution = None + else: + sat_solution = solver.get_model() + solver.delete() + return sat_solution + + def process_solution(self, sat_solution): + if sat_solution is None: + solution = None + else: + solution = sat_solution + return solution + + +_sat_solver_str_to_cls = { + "pycosat": _PycoSatSolver, + "pycryptosat": _PyCryptoSatSolver, + "pysat": _PySatSolver, +} + +_sat_solver_cls_to_str = {cls: string for string, cls in _sat_solver_str_to_cls.items()} + + +# Code that uses special cases (generates no clauses) is in ADTs/FEnv.h in +# minisatp. Code that generates clauses is in Hardware_clausify.cc (and are +# also described in the paper, "Translating Pseudo-Boolean Constraints into +# SAT," Eén and Sörensson). +class Clauses: + def __init__(self, m=0, sat_solver_str=_sat_solver_cls_to_str[_PycoSatSolver]): + self.unsat = False + self.m = m + + try: + sat_solver_cls = _sat_solver_str_to_cls[sat_solver_str] + except KeyError: + raise NotImplementedError(f"Unknown SAT solver: {sat_solver_str}") + self._sat_solver = sat_solver_cls() + + # Bind some methods of _sat_solver to reduce lookups and call overhead. + self.add_clause = self._sat_solver.add_clause + self.add_clauses = self._sat_solver.add_clauses + + def get_clause_count(self): + return self._sat_solver.get_clause_count() + + def as_list(self): + return self._sat_solver.as_list() + + def new_var(self): + m = self.m + 1 + self.m = m + return m + + def assign(self, vals): + if isinstance(vals, tuple): + x = self.new_var() + self.add_clauses((-x,) + y for y in vals[0]) + self.add_clauses((x,) + y for y in vals[1]) + return x + return vals + + def Combine(self, args, polarity): + if any(v == FALSE for v in args): + return FALSE + args = [v for v in args if v != TRUE] + nv = len(args) + if nv == 0: + return TRUE + if nv == 1: + return args[0] + if all(isinstance(v, tuple) for v in args): + return (sum((v[0] for v in args), []), sum((v[1] for v in args), [])) + else: + return self.All(map(self.assign, args), polarity) + + def Eval(self, func, args, polarity): + saved_state = self._sat_solver.save_state() + vals = func(*args, polarity=polarity) + # eval without assignment: + if isinstance(vals, tuple): + self.add_clauses(vals[0]) + self.add_clauses(vals[1]) + elif vals not in {TRUE, FALSE}: + self.add_clause((vals if polarity else -vals,)) + else: + self._sat_solver.restore_state(saved_state) + self.unsat = self.unsat or (vals == TRUE) != polarity + + def Prevent(self, func, *args): + self.Eval(func, args, polarity=False) + + def Require(self, func, *args): + self.Eval(func, args, polarity=True) + + def Not(self, x, polarity=None, add_new_clauses=False): + return -x + + def And(self, f, g, polarity, add_new_clauses=False): + if f == FALSE or g == FALSE: + return FALSE + if f == TRUE: + return g + if g == TRUE: + return f + if f == g: + return f + if f == -g: + return FALSE + if g < f: + f, g = g, f + if add_new_clauses: + # This is equivalent to running self.assign(pval, nval) on + # the (pval, nval) tuple we return below. Duplicating the code here + # is an important performance tweak to avoid the costly generator + # expressions and tuple additions in self.assign. + x = self.new_var() + if polarity in (True, None): + self.add_clauses( + [ + ( + -x, + f, + ), + ( + -x, + g, + ), + ] + ) + if polarity in (False, None): + self.add_clauses([(x, -f, -g)]) + return x + pval = [(f,), (g,)] if polarity in (True, None) else [] + nval = [(-f, -g)] if polarity in (False, None) else [] + return pval, nval + + def Or(self, f, g, polarity, add_new_clauses=False): + if f == TRUE or g == TRUE: + return TRUE + if f == FALSE: + return g + if g == FALSE: + return f + if f == g: + return f + if f == -g: + return TRUE + if g < f: + f, g = g, f + if add_new_clauses: + x = self.new_var() + if polarity in (True, None): + self.add_clauses([(-x, f, g)]) + if polarity in (False, None): + self.add_clauses( + [ + ( + x, + -f, + ), + ( + x, + -g, + ), + ] + ) + return x + pval = [(f, g)] if polarity in (True, None) else [] + nval = [(-f,), (-g,)] if polarity in (False, None) else [] + return pval, nval + + def Xor(self, f, g, polarity, add_new_clauses=False): + if f == FALSE: + return g + if f == TRUE: + return self.Not(g, polarity, add_new_clauses=add_new_clauses) + if g == FALSE: + return f + if g == TRUE: + return -f + if f == g: + return FALSE + if f == -g: + return TRUE + if g < f: + f, g = g, f + if add_new_clauses: + x = self.new_var() + if polarity in (True, None): + self.add_clauses([(-x, f, g), (-x, -f, -g)]) + if polarity in (False, None): + self.add_clauses([(x, -f, g), (x, f, -g)]) + return x + pval = [(f, g), (-f, -g)] if polarity in (True, None) else [] + nval = [(-f, g), (f, -g)] if polarity in (False, None) else [] + return pval, nval + + def ITE(self, c, t, f, polarity, add_new_clauses=False): + if c == TRUE: + return t + if c == FALSE: + return f + if t == TRUE: + return self.Or(c, f, polarity, add_new_clauses=add_new_clauses) + if t == FALSE: + return self.And(-c, f, polarity, add_new_clauses=add_new_clauses) + if f == FALSE: + return self.And(c, t, polarity, add_new_clauses=add_new_clauses) + if f == TRUE: + return self.Or(t, -c, polarity, add_new_clauses=add_new_clauses) + if t == c: + return self.Or(c, f, polarity, add_new_clauses=add_new_clauses) + if t == -c: + return self.And(-c, f, polarity, add_new_clauses=add_new_clauses) + if f == c: + return self.And(c, t, polarity, add_new_clauses=add_new_clauses) + if f == -c: + return self.Or(t, -c, polarity, add_new_clauses=add_new_clauses) + if t == f: + return t + if t == -f: + return self.Xor(c, f, polarity, add_new_clauses=add_new_clauses) + if t < f: + t, f, c = f, t, -c + # Basically, c ? t : f is equivalent to (c AND t) OR (NOT c AND f) + # The third clause in each group is redundant but assists the unit + # propagation in the SAT solver. + if add_new_clauses: + x = self.new_var() + if polarity in (True, None): + self.add_clauses([(-x, -c, t), (-x, c, f), (-x, t, f)]) + if polarity in (False, None): + self.add_clauses([(x, -c, -t), (x, c, -f), (x, -t, -f)]) + return x + pval = [(-c, t), (c, f), (t, f)] if polarity in (True, None) else [] + nval = [(-c, -t), (c, -f), (-t, -f)] if polarity in (False, None) else [] + return pval, nval + + def All(self, iter, polarity=None): + vals = set() + for v in iter: + if v == TRUE: + continue + if v == FALSE or -v in vals: + return FALSE + vals.add(v) + nv = len(vals) + if nv == 0: + return TRUE + elif nv == 1: + return next(v for v in vals) + pval = [(v,) for v in vals] if polarity in (True, None) else [] + nval = [tuple(-v for v in vals)] if polarity in (False, None) else [] + return pval, nval + + def Any(self, iter, polarity): + vals = set() + for v in iter: + if v == FALSE: + continue + elif v == TRUE or -v in vals: + return TRUE + vals.add(v) + nv = len(vals) + if nv == 0: + return FALSE + elif nv == 1: + return next(v for v in vals) + pval = [tuple(vals)] if polarity in (True, None) else [] + nval = [(-v,) for v in vals] if polarity in (False, None) else [] + return pval, nval + + def AtMostOne_NSQ(self, vals, polarity): + combos = [] + for v1, v2 in combinations(map(self.Not, vals), 2): + combos.append(self.Or(v1, v2, polarity)) + return self.Combine(combos, polarity) + + def AtMostOne_BDD(self, vals, polarity=None): + lits = list(vals) + coeffs = [1] * len(lits) + return self.LinearBound(lits, coeffs, 0, 1, True, polarity) + + def ExactlyOne_NSQ(self, vals, polarity): + vals = list(vals) + v1 = self.AtMostOne_NSQ(vals, polarity) + v2 = self.Any(vals, polarity) + return self.Combine((v1, v2), polarity) + + def ExactlyOne_BDD(self, vals, polarity): + lits = list(vals) + coeffs = [1] * len(lits) + return self.LinearBound(lits, coeffs, 1, 1, True, polarity) + + def LB_Preprocess(self, lits, coeffs): + equation = [] + offset = 0 + for coeff, lit in zip(coeffs, lits): + if lit == TRUE: + offset += coeff + continue + if lit == FALSE or coeff == 0: + continue + if coeff < 0: + offset += coeff + coeff, lit = -coeff, -lit + equation.append((coeff, lit)) + coeffs, lits = tuple(zip(*sorted(equation))) or ((), ()) + return lits, coeffs, offset + + def BDD(self, lits, coeffs, nterms, lo, hi, polarity): + # The equation (coeffs x lits) is sorted in + # order of increasing coefficients. + # Then we take advantage of the following recurrence: + # l <= S + cN xN <= u + # => IF xN THEN l - cN <= S <= u - cN + # ELSE l <= S <= u + # we use memoization to prune common subexpressions + total = sum(c for c in coeffs[:nterms]) + target = (nterms - 1, 0, total) + call_stack = [target] + ret = {} + call_stack_append = call_stack.append + call_stack_pop = call_stack.pop + ret_get = ret.get + ITE = self.ITE + + csum = 0 + while call_stack: + ndx, csum, total = call_stack[-1] + lower_limit = lo - csum + upper_limit = hi - csum + if lower_limit <= 0 and upper_limit >= total: + ret[call_stack_pop()] = TRUE + continue + if lower_limit > total or upper_limit < 0: + ret[call_stack_pop()] = FALSE + continue + LA = lits[ndx] + LC = coeffs[ndx] + ndx -= 1 + total -= LC + hi_key = (ndx, csum if LA < 0 else csum + LC, total) + thi = ret_get(hi_key) + if thi is None: + call_stack_append(hi_key) + continue + lo_key = (ndx, csum + LC if LA < 0 else csum, total) + tlo = ret_get(lo_key) + if tlo is None: + call_stack_append(lo_key) + continue + # NOTE: The following ITE call is _the_ hotspot of the Python-side + # computations for the overall minimization run. For performance we + # avoid calling self.assign here via add_new_clauses=True. + # If we want to translate parts of the code to a compiled language, + # self.BDD (+ its downward call stack) is the prime candidate! + ret[call_stack_pop()] = ITE( + abs(LA), thi, tlo, polarity, add_new_clauses=True + ) + return ret[target] + + def LinearBound(self, lits, coeffs, lo, hi, preprocess, polarity): + if preprocess: + lits, coeffs, offset = self.LB_Preprocess(lits, coeffs) + lo -= offset + hi -= offset + nterms = len(coeffs) + if nterms and coeffs[-1] > hi: + nprune = sum(c > hi for c in coeffs) + log.log( + TRACE, "Eliminating %d/%d terms for bound violation", nprune, nterms + ) + nterms -= nprune + else: + nprune = 0 + # Tighten bounds + total = sum(c for c in coeffs[:nterms]) + if preprocess: + lo = max([lo, 0]) + hi = min([hi, total]) + if lo > hi: + return FALSE + if nterms == 0: + res = TRUE if lo == 0 else FALSE + else: + res = self.BDD(lits, coeffs, nterms, lo, hi, polarity) + if nprune: + prune = self.All([-a for a in lits[nterms:]], polarity) + res = self.Combine((res, prune), polarity) + return res + + def _run_sat(self, m, limit=0): + if log.isEnabledFor(DEBUG): + log.debug("Invoking SAT with clause count: %s", self.get_clause_count()) + solution = self._sat_solver.run(m, limit=limit) + return solution + + def sat(self, additional=None, includeIf=False, limit=0): + """ + Calculate a SAT solution for the current clause set. + + Returned is the list of those solutions. When the clauses are + unsatisfiable, an empty list is returned. + + """ + if self.unsat: + return None + if not self.m: + return [] + saved_state = self._sat_solver.save_state() + if additional: + + def preproc(eqs): + def preproc_(cc): + for c in cc: + if c == FALSE: + continue + yield c + if c == TRUE: + break + + for cc in eqs: + cc = tuple(preproc_(cc)) + if not cc: + yield cc + break + if cc[-1] != TRUE: + yield cc + + additional = list(preproc(additional)) + if additional: + if not additional[-1]: + return None + self.add_clauses(additional) + solution = self._run_sat(self.m, limit=limit) + if additional and (solution is None or not includeIf): + self._sat_solver.restore_state(saved_state) + return solution + + def minimize(self, lits, coeffs, bestsol=None, trymax=False): + """ + Minimize the objective function given by (coeff, integer) pairs in + zip(coeffs, lits). + The actual minimization is multiobjective: first, we minimize the + largest active coefficient value, then we minimize the sum. + """ + if bestsol is None or len(bestsol) < self.m: + log.debug("Clauses added, recomputing solution") + bestsol = self.sat() + if bestsol is None or self.unsat: + log.debug("Constraints are unsatisfiable") + return bestsol, sum(abs(c) for c in coeffs) + 1 if coeffs else 1 + if not coeffs: + log.debug("Empty objective, trivial solution") + return bestsol, 0 + + lits, coeffs, offset = self.LB_Preprocess(lits, coeffs) + maxval = max(coeffs) + + def peak_val(sol, objective_dict): + return max(objective_dict.get(s, 0) for s in sol) + + def sum_val(sol, objective_dict): + return sum(objective_dict.get(s, 0) for s in sol) + + lo = 0 + try0 = 0 + for peak in (True, False) if maxval > 1 else (False,): + if peak: + log.log(TRACE, "Beginning peak minimization") + objval = peak_val + else: + log.log(TRACE, "Beginning sum minimization") + objval = sum_val + + objective_dict = {a: c for c, a in zip(coeffs, lits)} + bestval = objval(bestsol, objective_dict) + + # If we got lucky and the initial solution is optimal, we still + # need to generate the constraints at least once + hi = bestval + m_orig = self.m + if log.isEnabledFor(DEBUG): + # This is only used for the log message below. + nz = self.get_clause_count() + saved_state = self._sat_solver.save_state() + if trymax and not peak: + try0 = hi - 1 + + log.log(TRACE, "Initial range (%d,%d)", lo, hi) + while True: + if try0 is None: + mid = (lo + hi) // 2 + else: + mid = try0 + if peak: + prevent = tuple(a for c, a in zip(coeffs, lits) if c > mid) + require = tuple(a for c, a in zip(coeffs, lits) if lo <= c <= mid) + self.Prevent(self.Any, prevent) + if require: + self.Require(self.Any, require) + else: + self.Require(self.LinearBound, lits, coeffs, lo, mid, False) + + if log.isEnabledFor(DEBUG): + log.log( + TRACE, + "Bisection attempt: (%d,%d), (%d+%d) clauses", + lo, + mid, + nz, + self.get_clause_count() - nz, + ) + newsol = self.sat() + if newsol is None: + lo = mid + 1 + log.log(TRACE, "Bisection failure, new range=(%d,%d)", lo, hi) + if lo > hi: + # FIXME: This is not supposed to happen! + # TODO: Investigate and fix the cause. + break + # If this was a failure of the first test after peak minimization, + # then it means that the peak minimizer is "tight" and we don't need + # any further constraints. + else: + done = lo == mid + bestsol = newsol + bestval = objval(newsol, objective_dict) + hi = bestval + log.log(TRACE, "Bisection success, new range=(%d,%d)", lo, hi) + if done: + break + self.m = m_orig + # Since we only ever _add_ clauses and only remove then via + # restore_state, it's fine to test on equality only. + if self._sat_solver.save_state() != saved_state: + self._sat_solver.restore_state(saved_state) + self.unsat = False + try0 = None + + log.debug("Final %s objective: %d" % ("peak" if peak else "sum", bestval)) + if bestval == 0: + break + elif peak: + # Now that we've minimized the peak value, we can drop any terms + # with coefficients larger than this. Furthermore, since we know + # at least one peak will be active, our lower bound for the sum + # equals the peak. + lits = [a for c, a in zip(coeffs, lits) if c <= bestval] + coeffs = [c for c in coeffs if c <= bestval] + try0 = sum_val(bestsol, objective_dict) + lo = bestval + else: + log.debug("New peak objective: %d" % peak_val(bestsol, objective_dict)) + + return bestsol, bestval diff --git a/conda_classic_solver/logic.py b/conda_classic_solver/logic.py new file mode 100644 index 0000000..f378b30 --- /dev/null +++ b/conda_classic_solver/logic.py @@ -0,0 +1,313 @@ +# Copyright (C) 2012 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause +""" +The basic idea to nest logical expressions is instead of trying to denest +things via distribution, we add new variables. So if we have some logical +expression expr, we replace it with x and add expr <-> x to the clauses, +where x is a new variable, and expr <-> x is recursively evaluated in the +same way, so that the final clauses are ORs of atoms. + +To use this, create a new Clauses object with the max var, for instance, if you +already have [[1, 2, -3]], you would use C = Clause(3). All functions return +a new literal, which represents that function, or True or False if the expression +can be resolved fully. They may also add new clauses to C.clauses, which +will then be delivered to the SAT solver. + +All functions take atoms as arguments (an atom is an integer, representing a +literal or a negated literal, or boolean constants True or False; that is, +it is the callers' responsibility to do the conversion of expressions +recursively. This is done because we do not have data structures +representing the various logical classes, only atoms. + +The polarity argument can be set to True or False if you know that the literal +being used will only be used in the positive or the negative, respectively +(e.g., you will only use x, not -x). This will generate fewer clauses. It +is probably best if you do not take advantage of this directly, but rather +through the Require and Prevent functions. + +""" + +from itertools import chain + +from ._logic import FALSE, TRUE +from ._logic import Clauses as _Clauses + +# TODO: We may want to turn the user-facing {TRUE,FALSE} values into an Enum and +# hide the _logic.{TRUE,FALSE} values as an implementation detail. +# We then have to handle the {TRUE,FALSE} -> _logic.{TRUE,FALSE} conversion +# in Clauses._convert and the inverse _logic.{TRUE,FALSE} -> {TRUE,FALSE} +# conversion in Clauses._eval. +TRUE = TRUE +FALSE = FALSE + +PycoSatSolver = "pycosat" +PyCryptoSatSolver = "pycryptosat" +PySatSolver = "pysat" + + +class Clauses: + def __init__(self, m=0, sat_solver=PycoSatSolver): + self.names = {} + self.indices = {} + self._clauses = _Clauses(m=m, sat_solver_str=sat_solver) + + @property + def m(self): + return self._clauses.m + + @property + def unsat(self): + return self._clauses.unsat + + def get_clause_count(self): + return self._clauses.get_clause_count() + + def as_list(self): + return self._clauses.as_list() + + def _check_variable(self, variable): + if 0 < abs(variable) <= self.m: + return variable + raise ValueError(f"SAT variable out of bounds: {variable} (max_var: {self.m})") + + def _check_literal(self, literal): + if literal in {TRUE, FALSE}: + return literal + return self._check_variable(literal) + + def add_clause(self, clause): + self._clauses.add_clause(map(self._check_variable, self._convert(clause))) + + def add_clauses(self, clauses): + for clause in clauses: + self.add_clause(clause) + + def name_var(self, m, name): + self._check_literal(m) + nname = "!" + name + self.names[name] = m + self.names[nname] = -m + if m not in {TRUE, FALSE} and m not in self.indices: + self.indices[m] = name + self.indices[-m] = nname + return m + + def new_var(self, name=None): + m = self._clauses.new_var() + if name: + self.name_var(m, name) + return m + + def from_name(self, name): + return self.names.get(name) + + def from_index(self, m): + return self.indices.get(m) + + def _assign(self, vals, name=None): + x = self._clauses.assign(vals) + if not name: + return x + if vals in {TRUE, FALSE}: + x = self._clauses.new_var() + self._clauses.add_clause((x,) if vals else (-x,)) + return self.name_var(x, name) + + def _convert(self, x): + if isinstance(x, (tuple, list)): + return type(x)(map(self._convert, x)) + if isinstance(x, int): + return self._check_literal(x) + name = x + try: + return self.names[name] + except KeyError: + raise ValueError(f"Unregistered SAT variable name: {name}") + + def _eval(self, func, args, no_literal_args, polarity, name): + args = self._convert(args) + if name is False: + self._clauses.Eval(func, args + no_literal_args, polarity) + return None + vals = func(*(args + no_literal_args), polarity=polarity) + return self._assign(vals, name) + + def Prevent(self, what, *args): + return what.__get__(self, Clauses)(*args, polarity=False, name=False) + + def Require(self, what, *args): + return what.__get__(self, Clauses)(*args, polarity=True, name=False) + + def Not(self, x, polarity=None, name=None): + return self._eval(self._clauses.Not, (x,), (), polarity, name) + + def And(self, f, g, polarity=None, name=None): + return self._eval(self._clauses.And, (f, g), (), polarity, name) + + def Or(self, f, g, polarity=None, name=None): + return self._eval(self._clauses.Or, (f, g), (), polarity, name) + + def Xor(self, f, g, polarity=None, name=None): + return self._eval(self._clauses.Xor, (f, g), (), polarity, name) + + def ITE(self, c, t, f, polarity=None, name=None): + """If c Then t Else f. + + In this function, if any of c, t, or f are True and False the resulting + expression is resolved. + """ + return self._eval(self._clauses.ITE, (c, t, f), (), polarity, name) + + def All(self, iter, polarity=None, name=None): + return self._eval(self._clauses.All, (iter,), (), polarity, name) + + def Any(self, vals, polarity=None, name=None): + return self._eval(self._clauses.Any, (list(vals),), (), polarity, name) + + def AtMostOne_NSQ(self, vals, polarity=None, name=None): + return self._eval( + self._clauses.AtMostOne_NSQ, (list(vals),), (), polarity, name + ) + + def AtMostOne_BDD(self, vals, polarity=None, name=None): + return self._eval( + self._clauses.AtMostOne_BDD, (list(vals),), (), polarity, name + ) + + def AtMostOne(self, vals, polarity=None, name=None): + vals = list(vals) + nv = len(vals) + if nv < 5 - (polarity is not True): + what = self.AtMostOne_NSQ + else: + what = self.AtMostOne_BDD + return self._eval(what, (vals,), (), polarity, name) + + def ExactlyOne_NSQ(self, vals, polarity=None, name=None): + return self._eval( + self._clauses.ExactlyOne_NSQ, (list(vals),), (), polarity, name + ) + + def ExactlyOne_BDD(self, vals, polarity=None, name=None): + return self._eval( + self._clauses.ExactlyOne_BDD, (list(vals),), (), polarity, name + ) + + def ExactlyOne(self, vals, polarity=None, name=None): + vals = list(vals) + nv = len(vals) + if nv < 2: + what = self.ExactlyOne_NSQ + else: + what = self.ExactlyOne_BDD + return self._eval(what, (vals,), (), polarity, name) + + def LinearBound(self, equation, lo, hi, preprocess=True, polarity=None, name=None): + if not isinstance(equation, dict): + # in case of duplicate literal -> coefficient mappings, always take the last one + equation = {named_lit: coeff for coeff, named_lit in equation} + named_literals = list(equation.keys()) + coefficients = list(equation.values()) + return self._eval( + self._clauses.LinearBound, + (named_literals,), + (coefficients, lo, hi, preprocess), + polarity, + name, + ) + + def sat(self, additional=None, includeIf=False, names=False, limit=0): + """ + Calculate a SAT solution for the current clause set. + + Returned is the list of those solutions. When the clauses are + unsatisfiable, an empty list is returned. + + """ + if self.unsat: + return None + if not self.m: + return set() if names else [] + if additional: + additional = (tuple(self.names.get(c, c) for c in cc) for cc in additional) + solution = self._clauses.sat( + additional=additional, includeIf=includeIf, limit=limit + ) + if solution is None: + return None + if names: + return { + nm + for nm in (self.indices.get(s) for s in solution) + if nm and nm[0] != "!" + } + return solution + + def itersolve(self, constraints=None, m=None): + exclude = [] + if m is None: + m = self.m + while True: + # We don't use pycosat.itersolve because it is more + # important to limit the number of terms added to the + # exclusion list, in our experience. Once we update + # pycosat to do this, this can use it. + sol = self.sat(chain(constraints, exclude)) + if sol is None: + return + yield sol + exclude.append([-k for k in sol if -m <= k <= m]) + + def minimize(self, objective, bestsol=None, trymax=False): + if not isinstance(objective, dict): + # in case of duplicate literal -> coefficient mappings, always take the last one + objective = {named_lit: coeff for coeff, named_lit in objective} + literals = self._convert(list(objective.keys())) + coeffs = list(objective.values()) + + return self._clauses.minimize(literals, coeffs, bestsol=bestsol, trymax=trymax) + + +def minimal_unsatisfiable_subset(clauses, sat, explicit_specs): + """ + Given a set of clauses, find a minimal unsatisfiable subset (an + unsatisfiable core) + + A set is a minimal unsatisfiable subset if no proper subset is + unsatisfiable. A set of clauses may have many minimal unsatisfiable + subsets of different sizes. + + sat should be a function that takes a tuple of clauses and returns True if + the clauses are satisfiable and False if they are not. The algorithm will + work with any order-reversing function (reversing the order of subset and + the order False < True), that is, any function where (A <= B) iff (sat(B) + <= sat(A)), where A <= B means A is a subset of B and False < True). + + """ + working_set = set() + found_conflicts = set() + + if sat(explicit_specs, True) is None: + found_conflicts = set(explicit_specs) + else: + # we succeeded, so we'll add the spec to our future constraints + working_set = set(explicit_specs) + + for spec in set(clauses) - working_set: + if ( + sat( + working_set + | { + spec, + }, + True, + ) + is None + ): + found_conflicts.add(spec) + else: + # we succeeded, so we'll add the spec to our future constraints + working_set.add(spec) + + return found_conflicts diff --git a/conda_classic_solver/plugin.py b/conda_classic_solver/plugin.py index 9ab7195..e0947ba 100644 --- a/conda_classic_solver/plugin.py +++ b/conda_classic_solver/plugin.py @@ -9,7 +9,7 @@ from conda.plugins import CondaSolver, hookimpl -from .solver import ClassicSolver +from .solve import ClassicSolver @hookimpl diff --git a/conda_classic_solver/resolve.py b/conda_classic_solver/resolve.py new file mode 100644 index 0000000..d396ca9 --- /dev/null +++ b/conda_classic_solver/resolve.py @@ -0,0 +1,1671 @@ +# Copyright (C) 2012 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause +"""Low-level SAT solver wrapper/interface for the classic solver. + +See conda.core.solver.Solver for the high-level API. +""" + +from __future__ import annotations + +import copy +import itertools +from collections import defaultdict, deque +from functools import lru_cache +from logging import DEBUG, getLogger + +from conda.auxlib.decorators import memoizemethod +from conda.base.constants import MAX_CHANNEL_PRIORITY, ChannelPriority, SatSolverChoice +from conda.base.context import context +from conda.common.compat import on_win +from conda.common.io import dashlist, time_recorder +from conda.common.iterators import groupby_to_dict as groupby +from conda.common.toposort import toposort +from conda.exceptions import ( + CondaDependencyError, + InvalidSpec, + ResolvePackageNotFound, + UnsatisfiableError, +) +from conda.models.channel import Channel, MultiChannel +from conda.models.enums import NoarchType, PackageType +from conda.models.match_spec import MatchSpec +from conda.models.records import PackageRecord +from conda.models.version import VersionOrder +from tqdm import tqdm + +try: + from frozendict import frozendict +except ImportError: + from conda._vendor.frozendict import FrozenOrderedDict as frozendict + +from .logic import ( + TRUE, + Clauses, + PycoSatSolver, + PyCryptoSatSolver, + PySatSolver, + minimal_unsatisfiable_subset, +) + +log = getLogger(__name__) +stdoutlog = getLogger("conda.stdoutlog") + +# used in conda build +Unsatisfiable = UnsatisfiableError +ResolvePackageNotFound = ResolvePackageNotFound + +_sat_solvers = { + SatSolverChoice.PYCOSAT: PycoSatSolver, + SatSolverChoice.PYCRYPTOSAT: PyCryptoSatSolver, + SatSolverChoice.PYSAT: PySatSolver, +} + + +@lru_cache(maxsize=None) +def _get_sat_solver_cls(sat_solver_choice=SatSolverChoice.PYCOSAT): + def try_out_solver(sat_solver): + c = Clauses(sat_solver=sat_solver) + required = {c.new_var(), c.new_var()} + c.Require(c.And, *required) + solution = set(c.sat()) + if not required.issubset(solution): + raise RuntimeError(f"Wrong SAT solution: {solution}. Required: {required}") + + sat_solver = _sat_solvers[sat_solver_choice] + try: + try_out_solver(sat_solver) + except Exception as e: + log.warning( + "Could not run SAT solver through interface '%s'.", sat_solver_choice + ) + log.debug("SAT interface error due to: %s", e, exc_info=True) + else: + log.debug("Using SAT solver interface '%s'.", sat_solver_choice) + return sat_solver + for sat_solver in _sat_solvers.values(): + try: + try_out_solver(sat_solver) + except Exception as e: + log.debug( + "Attempted SAT interface '%s' but unavailable due to: %s", + sat_solver_choice, + e, + ) + else: + log.debug("Falling back to SAT solver interface '%s'.", sat_solver_choice) + return sat_solver + raise CondaDependencyError( + "Cannot run solver. No functioning SAT implementations available." + ) + + +def exactness_and_number_of_deps(resolve_obj, ms): + """Sorting key to emphasize packages that have more strict + requirements. More strict means the reduced index can be reduced + more, so we want to consider these more constrained deps earlier in + reducing the index. + """ + if ms.strictness == 3: + prec = resolve_obj.find_matches(ms) + value = 3 + if prec: + for dep in prec[0].depends: + value += MatchSpec(dep).strictness + else: + value = ms.strictness + return value + + +class Resolve: + def __init__(self, index, processed=False, channels=()): + self.index = index + + self.channels = channels + self._channel_priorities_map = ( + self._make_channel_priorities(channels) if channels else {} + ) + self._channel_priority = context.channel_priority + self._solver_ignore_timestamps = context.solver_ignore_timestamps + + groups = groupby(lambda x: x.name, index.values()) + trackers = defaultdict(list) + + for name in groups: + unmanageable_precs = [prec for prec in groups[name] if prec.is_unmanageable] + if unmanageable_precs: + log.debug("restricting to unmanageable packages: %s", name) + groups[name] = unmanageable_precs + tf_precs = (prec for prec in groups[name] if prec.track_features) + for prec in tf_precs: + for feature_name in prec.track_features: + trackers[feature_name].append(prec) + + self.groups = groups # dict[package_name, list[PackageRecord]] + self.trackers = trackers # dict[track_feature, set[PackageRecord]] + self._cached_find_matches = {} # dict[MatchSpec, set[PackageRecord]] + self.ms_depends_ = {} # dict[PackageRecord, list[MatchSpec]] + self._reduced_index_cache = {} + self._pool_cache = {} + self._strict_channel_cache = {} + + self._system_precs = { + _ + for _ in index + if ( + hasattr(_, "package_type") + and _.package_type == PackageType.VIRTUAL_SYSTEM + ) + } + + # sorting these in reverse order is effectively prioritizing + # constraint behavior from newer packages. It is applying broadening + # reduction based on the latest packages, which may reduce the space + # more, because more modern packages utilize constraints in more sane + # ways (for example, using run_exports in conda-build 3) + for name, group in self.groups.items(): + self.groups[name] = sorted(group, key=self.version_key, reverse=True) + + def __hash__(self): + return ( + super().__hash__() + ^ hash(frozenset(self.channels)) + ^ hash(frozendict(self._channel_priorities_map)) + ^ hash(self._channel_priority) + ^ hash(self._solver_ignore_timestamps) + ^ hash(frozendict((k, tuple(v)) for k, v in self.groups.items())) + ^ hash(frozendict((k, tuple(v)) for k, v in self.trackers.items())) + ^ hash(frozendict((k, tuple(v)) for k, v in self.ms_depends_.items())) + ) + + def default_filter(self, features=None, filter=None): + # TODO: fix this import; this is bad + from conda.core.subdir_data import make_feature_record + + if filter is None: + filter = {} + else: + filter.clear() + + filter.update( + {make_feature_record(fstr): False for fstr in self.trackers.keys()} + ) + if features: + filter.update({make_feature_record(fstr): True for fstr in features}) + return filter + + def valid(self, spec_or_prec, filter, optional=True): + """Tests if a package, MatchSpec, or a list of both has satisfiable + dependencies, assuming cyclic dependencies are always valid. + + Args: + spec_or_prec: a package record, a MatchSpec, or an iterable of these. + filter: a dictionary of (fkey,valid) pairs, used to consider a subset + of dependencies, and to eliminate repeated searches. + optional: if True (default), do not enforce optional specifications + when considering validity. If False, enforce them. + + Returns: + True if the full set of dependencies can be satisfied; False otherwise. + If filter is supplied and update is True, it will be updated with the + search results. + """ + + def v_(spec): + return v_ms_(spec) if isinstance(spec, MatchSpec) else v_fkey_(spec) + + def v_ms_(ms): + return ( + optional + and ms.optional + or any(v_fkey_(fkey) for fkey in self.find_matches(ms)) + ) + + def v_fkey_(prec): + val = filter.get(prec) + if val is None: + filter[prec] = True + try: + depends = self.ms_depends(prec) + except InvalidSpec: + val = filter[prec] = False + else: + val = filter[prec] = all(v_ms_(ms) for ms in depends) + return val + + result = v_(spec_or_prec) + return result + + def valid2(self, spec_or_prec, filter_out, optional=True): + def is_valid(_spec_or_prec): + if isinstance(_spec_or_prec, MatchSpec): + return is_valid_spec(_spec_or_prec) + else: + return is_valid_prec(_spec_or_prec) + + @memoizemethod + def is_valid_spec(_spec): + return ( + optional + and _spec.optional + or any(is_valid_prec(_prec) for _prec in self.find_matches(_spec)) + ) + + def is_valid_prec(prec): + val = filter_out.get(prec) + if val is None: + filter_out[prec] = False + try: + has_valid_deps = all( + is_valid_spec(ms) for ms in self.ms_depends(prec) + ) + except InvalidSpec: + val = filter_out[prec] = "invalid dep specs" + else: + val = filter_out[prec] = ( + False if has_valid_deps else "invalid depends specs" + ) + return not val + + return is_valid(spec_or_prec) + + def invalid_chains(self, spec, filter, optional=True): + """Constructs a set of 'dependency chains' for invalid specs. + + A dependency chain is a tuple of MatchSpec objects, starting with + the requested spec, proceeding down the dependency tree, ending at + a specification that cannot be satisfied. + + Args: + spec: a package key or MatchSpec + filter: a dictionary of (prec, valid) pairs to be used when + testing for package validity. + + Returns: + A tuple of tuples, empty if the MatchSpec is valid. + """ + + def chains_(spec, names): + if spec.name in names: + return + names.add(spec.name) + if self.valid(spec, filter, optional): + return + precs = self.find_matches(spec) + found = False + + conflict_deps = set() + for prec in precs: + for m2 in self.ms_depends(prec): + for x in chains_(m2, names): + found = True + yield (spec,) + x + else: + conflict_deps.add(m2) + if not found: + conflict_groups = groupby(lambda x: x.name, conflict_deps) + for group in conflict_groups.values(): + yield (spec,) + MatchSpec.union(group) + + return chains_(spec, set()) + + def verify_specs(self, specs): + """Perform a quick verification that specs and dependencies are reasonable. + + Args: + specs: An iterable of strings or MatchSpec objects to be tested. + + Returns: + Nothing, but if there is a conflict, an error is thrown. + + Note that this does not attempt to resolve circular dependencies. + """ + non_tf_specs = [] + bad_deps = [] + feature_names = set() + for ms in specs: + _feature_names = ms.get_exact_value("track_features") + if _feature_names: + feature_names.update(_feature_names) + else: + non_tf_specs.append(ms) + bad_deps.extend( + (spec,) + for spec in non_tf_specs + if (not spec.optional and not self.find_matches(spec)) + ) + if bad_deps: + raise ResolvePackageNotFound(bad_deps) + return tuple(non_tf_specs), feature_names + + def _classify_bad_deps( + self, bad_deps, specs_to_add, history_specs, strict_channel_priority + ): + classes = { + "python": set(), + "request_conflict_with_history": set(), + "direct": set(), + "virtual_package": set(), + } + specs_to_add = {MatchSpec(_) for _ in specs_to_add or []} + history_specs = {MatchSpec(_) for _ in history_specs or []} + for chain in bad_deps: + # sometimes chains come in as strings + if ( + len(chain) > 1 + and chain[-1].name == "python" + and not any(_.name == "python" for _ in specs_to_add) + and any(_[0] for _ in bad_deps if _[0].name == "python") + ): + python_first_specs = [_[0] for _ in bad_deps if _[0].name == "python"] + if python_first_specs: + python_spec = python_first_specs[0] + if not ( + set(self.find_matches(python_spec)) + & set(self.find_matches(chain[-1])) + ): + classes["python"].add( + ( + tuple([chain[0], chain[-1]]), + str(MatchSpec(python_spec, target=None)), + ) + ) + elif chain[-1].name.startswith("__"): + version = [_ for _ in self._system_precs if _.name == chain[-1].name] + virtual_package_version = ( + version[0].version if version else "not available" + ) + classes["virtual_package"].add((tuple(chain), virtual_package_version)) + elif chain[0] in specs_to_add: + match = False + for spec in history_specs: + if spec.name == chain[-1].name: + classes["request_conflict_with_history"].add( + (tuple(chain), str(MatchSpec(spec, target=None))) + ) + match = True + + if not match: + classes["direct"].add( + (tuple(chain), str(MatchSpec(chain[0], target=None))) + ) + else: + if len(chain) > 1 or any( + len(c) >= 1 and c[0] == chain[0] for c in bad_deps + ): + classes["direct"].add( + (tuple(chain), str(MatchSpec(chain[0], target=None))) + ) + + if classes["python"]: + # filter out plain single-entry python conflicts. The python section explains these. + classes["direct"] = [ + _ + for _ in classes["direct"] + if _[1].startswith("python ") or len(_[0]) > 1 + ] + return classes + + def find_matches_with_strict(self, ms, strict_channel_priority): + matches = self.find_matches(ms) + if not strict_channel_priority: + return matches + sole_source_channel_name = self._get_strict_channel(ms.name) + return tuple(f for f in matches if f.channel.name == sole_source_channel_name) + + def find_conflicts(self, specs, specs_to_add=None, history_specs=None): + if context.unsatisfiable_hints: + if not context.json: + print( + "\nFound conflicts! Looking for incompatible packages.\n" + "This can take several minutes. Press CTRL-C to abort." + ) + bad_deps = self.build_conflict_map(specs, specs_to_add, history_specs) + else: + bad_deps = {} + strict_channel_priority = context.channel_priority == ChannelPriority.STRICT + raise UnsatisfiableError(bad_deps, strict=strict_channel_priority) + + def breadth_first_search_for_dep_graph( + self, root_spec, target_name, dep_graph, num_targets=1 + ): + """Return shorted path from root_spec to target_name""" + queue = [] + queue.append([root_spec]) + visited = [] + target_paths = [] + while queue: + path = queue.pop(0) + node = path[-1] + if node in visited: + continue + visited.append(node) + if node.name == target_name: + if len(target_paths) == 0: + target_paths.append(path) + if len(target_paths[-1]) == len(path): + last_spec = MatchSpec.union((path[-1], target_paths[-1][-1]))[0] + target_paths[-1][-1] = last_spec + else: + target_paths.append(path) + + found_all_targets = len(target_paths) == num_targets and any( + len(_) != len(path) for _ in queue + ) + if len(queue) == 0 or found_all_targets: + return target_paths + sub_graph = dep_graph + for p in path[0:-1]: + sub_graph = sub_graph[p] + children = [_ for _ in sub_graph.get(node, {})] + if children is None: + continue + for adj in children: + if len(target_paths) < num_targets: + new_path = list(path) + new_path.append(adj) + queue.append(new_path) + return target_paths + + def build_graph_of_deps(self, spec): + dep_graph = {spec: {}} + all_deps = set() + queue = [[spec]] + while queue: + path = queue.pop(0) + sub_graph = dep_graph + for p in path: + sub_graph = sub_graph[p] + parent_node = path[-1] + matches = self.find_matches(parent_node) + for mat in matches: + if len(mat.depends) > 0: + for i in mat.depends: + new_node = MatchSpec(i) + sub_graph.update({new_node: {}}) + all_deps.add(new_node) + new_path = list(path) + new_path.append(new_node) + if len(new_path) <= context.unsatisfiable_hints_check_depth: + queue.append(new_path) + return dep_graph, all_deps + + def build_conflict_map(self, specs, specs_to_add=None, history_specs=None): + """Perform a deeper analysis on conflicting specifications, by attempting + to find the common dependencies that might be the cause of conflicts. + + Args: + specs: An iterable of strings or MatchSpec objects to be tested. + It is assumed that the specs conflict. + + Returns: + bad_deps: A list of lists of bad deps + + Strategy: + If we're here, we know that the specs conflict. This could be because: + - One spec conflicts with another; e.g. + ['numpy 1.5*', 'numpy >=1.6'] + - One spec conflicts with a dependency of another; e.g. + ['numpy 1.5*', 'scipy 0.12.0b1'] + - Each spec depends on *the same package* but in a different way; e.g., + ['A', 'B'] where A depends on numpy 1.5, and B on numpy 1.6. + Technically, all three of these cases can be boiled down to the last + one if we treat the spec itself as one of the "dependencies". There + might be more complex reasons for a conflict, but this code only + considers the ones above. + + The purpose of this code, then, is to identify packages (like numpy + above) that all of the specs depend on *but in different ways*. We + then identify the dependency chains that lead to those packages. + """ + # if only a single package matches the spec use the packages depends + # rather than the spec itself + strict_channel_priority = context.channel_priority == ChannelPriority.STRICT + + specs = set(specs) | (specs_to_add or set()) + # Remove virtual packages + specs = {spec for spec in specs if not spec.name.startswith("__")} + if len(specs) == 1: + matches = self.find_matches(next(iter(specs))) + if len(matches) == 1: + specs = set(self.ms_depends(matches[0])) + specs.update({_.to_match_spec() for _ in self._system_precs}) + for spec in specs: + self._get_package_pool((spec,)) + + dep_graph = {} + dep_list = {} + with tqdm( + total=len(specs), + desc="Building graph of deps", + leave=False, + disable=context.json, + ) as t: + for spec in specs: + t.set_description(f"Examining {spec}") + t.update() + dep_graph_for_spec, all_deps_for_spec = self.build_graph_of_deps(spec) + dep_graph.update(dep_graph_for_spec) + if dep_list.get(spec.name): + dep_list[spec.name].append(spec) + else: + dep_list[spec.name] = [spec] + for dep in all_deps_for_spec: + if dep_list.get(dep.name): + dep_list[dep.name].append(spec) + else: + dep_list[dep.name] = [spec] + + chains = [] + conflicting_pkgs_pkgs = {} + for k, v in dep_list.items(): + set_v = frozenset(v) + # Packages probably conflicts if many specs depend on it + if len(set_v) > 1: + if conflicting_pkgs_pkgs.get(set_v) is None: + conflicting_pkgs_pkgs[set_v] = [k] + else: + conflicting_pkgs_pkgs[set_v].append(k) + # Conflict if required virtual package is not present + elif k.startswith("__") and any(s for s in set_v if s.name != k): + conflicting_pkgs_pkgs[set_v] = [k] + + with tqdm( + total=len(specs), + desc="Determining conflicts", + leave=False, + disable=context.json, + ) as t: + for roots, nodes in conflicting_pkgs_pkgs.items(): + t.set_description( + "Examining conflict for {}".format(" ".join(_.name for _ in roots)) + ) + t.update() + lroots = [_ for _ in roots] + current_shortest_chain = [] + shortest_node = None + requested_spec_unsat = frozenset(nodes).intersection( + {_.name for _ in roots} + ) + if requested_spec_unsat: + chains.append([_ for _ in roots if _.name in requested_spec_unsat]) + shortest_node = chains[-1][0] + for root in roots: + if root != chains[0][0]: + search_node = shortest_node.name + num_occurances = dep_list[search_node].count(root) + c = self.breadth_first_search_for_dep_graph( + root, search_node, dep_graph, num_occurances + ) + chains.extend(c) + else: + for node in nodes: + num_occurances = dep_list[node].count(lroots[0]) + chain = self.breadth_first_search_for_dep_graph( + lroots[0], node, dep_graph, num_occurances + ) + chains.extend(chain) + if len(current_shortest_chain) == 0 or len(chain) < len( + current_shortest_chain + ): + current_shortest_chain = chain + shortest_node = node + for root in lroots[1:]: + num_occurances = dep_list[shortest_node].count(root) + c = self.breadth_first_search_for_dep_graph( + root, shortest_node, dep_graph, num_occurances + ) + chains.extend(c) + + bad_deps = self._classify_bad_deps( + chains, specs_to_add, history_specs, strict_channel_priority + ) + return bad_deps + + def _get_strict_channel(self, package_name): + channel_name = None + try: + channel_name = self._strict_channel_cache[package_name] + except KeyError: + if package_name in self.groups: + all_channel_names = { + prec.channel.name for prec in self.groups[package_name] + } + by_cp = { + self._channel_priorities_map.get(cn, 1): cn + for cn in all_channel_names + } + highest_priority = sorted(by_cp)[ + 0 + ] # highest priority is the lowest number + channel_name = self._strict_channel_cache[package_name] = by_cp[ + highest_priority + ] + return channel_name + + @memoizemethod + def _broader(self, ms, specs_by_name): + """Prevent introduction of matchspecs that broaden our selection of choices.""" + if not specs_by_name: + return False + return ms.strictness < specs_by_name[0].strictness + + def _get_package_pool(self, specs): + specs = frozenset(specs) + if specs in self._pool_cache: + pool = self._pool_cache[specs] + else: + pool = self.get_reduced_index(specs) + grouped_pool = groupby(lambda x: x.name, pool) + pool = {k: set(v) for k, v in grouped_pool.items()} + self._pool_cache[specs] = pool + return pool + + @time_recorder(module_name=__name__) + def get_reduced_index( + self, explicit_specs, sort_by_exactness=True, exit_on_conflict=False + ): + # TODO: fix this import; this is bad + from conda.core.subdir_data import make_feature_record + + strict_channel_priority = context.channel_priority == ChannelPriority.STRICT + + cache_key = strict_channel_priority, tuple(explicit_specs) + if cache_key in self._reduced_index_cache: + return self._reduced_index_cache[cache_key] + + if log.isEnabledFor(DEBUG): + log.debug( + "Retrieving packages for: %s", + dashlist(sorted(str(s) for s in explicit_specs)), + ) + + explicit_specs, features = self.verify_specs(explicit_specs) + filter_out = { + prec: False if val else "feature not enabled" + for prec, val in self.default_filter(features).items() + } + snames = set() + top_level_spec = None + cp_filter_applied = set() # values are package names + if sort_by_exactness: + # prioritize specs that are more exact. Exact specs will evaluate to 3, + # constrained specs will evaluate to 2, and name only will be 1 + explicit_specs = sorted( + list(explicit_specs), + key=lambda x: (exactness_and_number_of_deps(self, x), x.dist_str()), + reverse=True, + ) + # tuple because it needs to be hashable + explicit_specs = tuple(explicit_specs) + + explicit_spec_package_pool = {} + for s in explicit_specs: + explicit_spec_package_pool[s.name] = explicit_spec_package_pool.get( + s.name, set() + ) | set(self.find_matches(s)) + + def filter_group(_specs): + # all _specs should be for the same package name + name = next(iter(_specs)).name + group = self.groups.get(name, ()) + + # implement strict channel priority + if group and strict_channel_priority and name not in cp_filter_applied: + sole_source_channel_name = self._get_strict_channel(name) + for prec in group: + if prec.channel.name != sole_source_channel_name: + filter_out[prec] = "removed due to strict channel priority" + cp_filter_applied.add(name) + + # Prune packages that don't match any of the patterns, + # have unsatisfiable dependencies, or conflict with the explicit specs + nold = nnew = 0 + for prec in group: + if not filter_out.setdefault(prec, False): + nold += 1 + if (not self.match_any(_specs, prec)) or ( + explicit_spec_package_pool.get(name) + and prec not in explicit_spec_package_pool[name] + ): + filter_out[prec] = ( + f"incompatible with required spec {top_level_spec}" + ) + continue + unsatisfiable_dep_specs = set() + for ms in self.ms_depends(prec): + if not ms.optional and not any( + rec + for rec in self.find_matches(ms) + if not filter_out.get(rec, False) + ): + unsatisfiable_dep_specs.add(ms) + if unsatisfiable_dep_specs: + filter_out[prec] = "unsatisfiable dependencies {}".format( + " ".join(str(s) for s in unsatisfiable_dep_specs) + ) + continue + filter_out[prec] = False + nnew += 1 + + reduced = nnew < nold + if reduced: + log.debug("%s: pruned from %d -> %d" % (name, nold, nnew)) + if any(ms.optional for ms in _specs): + return reduced + elif nnew == 0: + # Indicates that a conflict was found; we can exit early + return None + + # Perform the same filtering steps on any dependencies shared across + # *all* packages in the group. Even if just one of the packages does + # not have a particular dependency, it must be ignored in this pass. + # Otherwise, we might do more filtering than we should---and it is + # better to have extra packages here than missing ones. + if reduced or name not in snames: + snames.add(name) + + _dep_specs = groupby( + lambda s: s.name, + ( + dep_spec + for prec in group + if not filter_out.get(prec, False) + for dep_spec in self.ms_depends(prec) + if not dep_spec.optional + ), + ) + _dep_specs.pop("*", None) # discard track_features specs + + for deps_name, deps in sorted( + _dep_specs.items(), key=lambda x: any(_.optional for _ in x[1]) + ): + if len(deps) >= nnew: + res = filter_group(set(deps)) + if res: + reduced = True + elif res is None: + # Indicates that a conflict was found; we can exit early + return None + + return reduced + + # Iterate on pruning until no progress is made. We've implemented + # what amounts to "double-elimination" here; packages get one additional + # chance after their first "False" reduction. This catches more instances + # where one package's filter affects another. But we don't have to be + # perfect about this, so performance matters. + pruned_to_zero = set() + for _ in range(2): + snames.clear() + slist = deque(explicit_specs) + while slist: + s = slist.popleft() + if filter_group([s]): + slist.append(s) + else: + pruned_to_zero.add(s) + + if pruned_to_zero and exit_on_conflict: + return {} + + # Determine all valid packages in the dependency graph + reduced_index2 = { + prec: prec for prec in (make_feature_record(fstr) for fstr in features) + } + specs_by_name_seed = {} + for s in explicit_specs: + specs_by_name_seed[s.name] = specs_by_name_seed.get(s.name, []) + [s] + for explicit_spec in explicit_specs: + add_these_precs2 = tuple( + prec + for prec in self.find_matches(explicit_spec) + if prec not in reduced_index2 and self.valid2(prec, filter_out) + ) + + if strict_channel_priority and add_these_precs2: + strict_channel_name = self._get_strict_channel(add_these_precs2[0].name) + + add_these_precs2 = tuple( + prec + for prec in add_these_precs2 + if prec.channel.name == strict_channel_name + ) + reduced_index2.update((prec, prec) for prec in add_these_precs2) + + for pkg in add_these_precs2: + # what we have seen is only relevant within the context of a single package + # that is picked up because of an explicit spec. We don't want the + # broadening check to apply across packages at the explicit level; only + # at the level of deps below that explicit package. + seen_specs = set() + specs_by_name = copy.deepcopy(specs_by_name_seed) + + dep_specs = set(self.ms_depends(pkg)) + for dep in dep_specs: + specs = specs_by_name.get(dep.name, []) + if dep not in specs and ( + not specs or dep.strictness >= specs[0].strictness + ): + specs.insert(0, dep) + specs_by_name[dep.name] = specs + + while dep_specs: + # used for debugging + # size_index = len(reduced_index2) + # specs_added = [] + ms = dep_specs.pop() + seen_specs.add(ms) + for dep_pkg in ( + _ for _ in self.find_matches(ms) if _ not in reduced_index2 + ): + if not self.valid2(dep_pkg, filter_out): + continue + + # expand the reduced index if not using strict channel priority, + # or if using it and this package is in the appropriate channel + if not strict_channel_priority or ( + self._get_strict_channel(dep_pkg.name) + == dep_pkg.channel.name + ): + reduced_index2[dep_pkg] = dep_pkg + + # recurse to deps of this dep + new_specs = set(self.ms_depends(dep_pkg)) - seen_specs + for new_ms in new_specs: + # We do not pull packages into the reduced index due + # to a track_features dependency. Remember, a feature + # specifies a "soft" dependency: it must be in the + # environment, but it is not _pulled_ in. The SAT + # logic doesn't do a perfect job of capturing this + # behavior, but keeping these packags out of the + # reduced index helps. Of course, if _another_ + # package pulls it in by dependency, that's fine. + if "track_features" not in new_ms and not self._broader( + new_ms, + tuple(specs_by_name.get(new_ms.name, ())), + ): + dep_specs.add(new_ms) + # if new_ms not in dep_specs: + # specs_added.append(new_ms) + else: + seen_specs.add(new_ms) + # debugging info - see what specs are bringing in the largest blobs + # if size_index != len(reduced_index2): + # print("MS {} added {} pkgs to index".format(ms, + # len(reduced_index2) - size_index)) + # if specs_added: + # print("MS {} added {} specs to further examination".format(ms, + # specs_added)) + + reduced_index2 = frozendict(reduced_index2) + self._reduced_index_cache[cache_key] = reduced_index2 + return reduced_index2 + + def match_any(self, mss, prec): + return any(ms.match(prec) for ms in mss) + + def find_matches(self, spec: MatchSpec) -> tuple[PackageRecord]: + res = self._cached_find_matches.get(spec, None) + if res is not None: + return res + + spec_name = spec.get_exact_value("name") + if spec_name: + candidate_precs = self.groups.get(spec_name, ()) + elif spec.get_exact_value("track_features"): + feature_names = spec.get_exact_value("track_features") + candidate_precs = itertools.chain.from_iterable( + self.trackers.get(feature_name, ()) for feature_name in feature_names + ) + else: + candidate_precs = self.index.values() + + res = tuple(p for p in candidate_precs if spec.match(p)) + self._cached_find_matches[spec] = res + return res + + def ms_depends(self, prec: PackageRecord) -> list[MatchSpec]: + deps = self.ms_depends_.get(prec) + if deps is None: + deps = [MatchSpec(d) for d in prec.combined_depends] + deps.extend(MatchSpec(track_features=feat) for feat in prec.features) + self.ms_depends_[prec] = deps + return deps + + def version_key(self, prec, vtype=None): + channel = prec.channel + channel_priority = self._channel_priorities_map.get( + channel.name, 1 + ) # TODO: ask @mcg1969 why the default value is 1 here # NOQA + valid = 1 if channel_priority < MAX_CHANNEL_PRIORITY else 0 + version_comparator = VersionOrder(prec.get("version", "")) + build_number = prec.get("build_number", 0) + build_string = prec.get("build") + noarch = -int(prec.subdir == "noarch") + if self._channel_priority != ChannelPriority.DISABLED: + vkey = [valid, -channel_priority, version_comparator, build_number, noarch] + else: + vkey = [valid, version_comparator, -channel_priority, build_number, noarch] + if self._solver_ignore_timestamps: + vkey.append(build_string) + else: + vkey.extend((prec.get("timestamp", 0), build_string)) + return vkey + + @staticmethod + def _make_channel_priorities(channels): + priorities_map = {} + for priority_counter, chn in enumerate( + itertools.chain.from_iterable( + (Channel(cc) for cc in c._channels) + if isinstance(c, MultiChannel) + else (c,) + for c in (Channel(c) for c in channels) + ) + ): + channel_name = chn.name + if channel_name in priorities_map: + continue + priorities_map[channel_name] = min( + priority_counter, MAX_CHANNEL_PRIORITY - 1 + ) + return priorities_map + + def get_pkgs(self, ms, emptyok=False): # pragma: no cover + # legacy method for conda-build + ms = MatchSpec(ms) + precs = self.find_matches(ms) + if not precs and not emptyok: + raise ResolvePackageNotFound([(ms,)]) + return sorted(precs, key=self.version_key) + + @staticmethod + def to_sat_name(val): + # val can be a PackageRecord or MatchSpec + if isinstance(val, PackageRecord): + return val.dist_str() + elif isinstance(val, MatchSpec): + return "@s@" + str(val) + ("?" if val.optional else "") + else: + raise NotImplementedError() + + @staticmethod + def to_feature_metric_id(prec_dist_str, feat): + return f"@fm@{prec_dist_str}@{feat}" + + def push_MatchSpec(self, C, spec): + spec = MatchSpec(spec) + sat_name = self.to_sat_name(spec) + m = C.from_name(sat_name) + if m is not None: + # the spec has already been pushed onto the clauses stack + return sat_name + + simple = spec._is_single() + nm = spec.get_exact_value("name") + tf = frozenset( + _tf + for _tf in (f.strip() for f in spec.get_exact_value("track_features") or ()) + if _tf + ) + + if nm: + tgroup = libs = self.groups.get(nm, []) + elif tf: + assert len(tf) == 1 + k = next(iter(tf)) + tgroup = libs = self.trackers.get(k, []) + else: + tgroup = libs = self.index.keys() + simple = False + if not simple: + libs = [fkey for fkey in tgroup if spec.match(fkey)] + if len(libs) == len(tgroup): + if spec.optional: + m = TRUE + elif not simple: + ms2 = MatchSpec(track_features=tf) if tf else MatchSpec(nm) + m = C.from_name(self.push_MatchSpec(C, ms2)) + if m is None: + sat_names = [self.to_sat_name(prec) for prec in libs] + if spec.optional: + ms2 = MatchSpec(track_features=tf) if tf else MatchSpec(nm) + sat_names.append("!" + self.to_sat_name(ms2)) + m = C.Any(sat_names) + C.name_var(m, sat_name) + return sat_name + + @time_recorder(module_name=__name__) + def gen_clauses(self): + C = Clauses(sat_solver=_get_sat_solver_cls(context.sat_solver)) + for name, group in self.groups.items(): + group = [self.to_sat_name(prec) for prec in group] + # Create one variable for each package + for sat_name in group: + C.new_var(sat_name) + # Create one variable for the group + m = C.new_var(self.to_sat_name(MatchSpec(name))) + + # Exactly one of the package variables, OR + # the negation of the group variable, is true + C.Require(C.ExactlyOne, group + [C.Not(m)]) + + # If a package is installed, its dependencies must be as well + for prec in self.index.values(): + nkey = C.Not(self.to_sat_name(prec)) + for ms in self.ms_depends(prec): + # Virtual packages can't be installed, we ignore them + if not ms.name.startswith("__"): + C.Require(C.Or, nkey, self.push_MatchSpec(C, ms)) + + if log.isEnabledFor(DEBUG): + log.debug( + "gen_clauses returning with clause count: %d", C.get_clause_count() + ) + return C + + def generate_spec_constraints(self, C, specs): + result = [(self.push_MatchSpec(C, ms),) for ms in specs] + if log.isEnabledFor(DEBUG): + log.debug( + "generate_spec_constraints returning with clause count: %d", + C.get_clause_count(), + ) + return result + + def generate_feature_count(self, C): + result = { + self.push_MatchSpec(C, MatchSpec(track_features=name)): 1 + for name in self.trackers.keys() + } + if log.isEnabledFor(DEBUG): + log.debug( + "generate_feature_count returning with clause count: %d", + C.get_clause_count(), + ) + return result + + def generate_update_count(self, C, specs): + return { + "!" + ms.target: 1 for ms in specs if ms.target and C.from_name(ms.target) + } + + def generate_feature_metric(self, C): + eq = {} # a C.minimize() objective: dict[varname, coeff] + # Given a pair (prec, feature), assign a "1" score IF: + # - The prec is installed + # - The prec does NOT require the feature + # - At least one package in the group DOES require the feature + # - A package that tracks the feature is installed + for name, group in self.groups.items(): + prec_feats = {self.to_sat_name(prec): set(prec.features) for prec in group} + active_feats = set.union(*prec_feats.values()).intersection(self.trackers) + for feat in active_feats: + clause_id_for_feature = self.push_MatchSpec( + C, MatchSpec(track_features=feat) + ) + for prec_sat_name, features in prec_feats.items(): + if feat not in features: + feature_metric_id = self.to_feature_metric_id( + prec_sat_name, feat + ) + C.name_var( + C.And(prec_sat_name, clause_id_for_feature), + feature_metric_id, + ) + eq[feature_metric_id] = 1 + return eq + + def generate_removal_count(self, C, specs): + return {"!" + self.push_MatchSpec(C, ms.name): 1 for ms in specs} + + def generate_install_count(self, C, specs): + return {self.push_MatchSpec(C, ms.name): 1 for ms in specs if ms.optional} + + def generate_package_count(self, C, missing): + return {self.push_MatchSpec(C, nm): 1 for nm in missing} + + def generate_version_metrics(self, C, specs, include0=False): + # each of these are weights saying how well packages match the specs + # format for each: a C.minimize() objective: dict[varname, coeff] + eqc = {} # channel + eqv = {} # version + eqb = {} # build number + eqa = {} # arch/noarch + eqt = {} # timestamp + + sdict = {} # dict[package_name, PackageRecord] + + for s in specs: + s = MatchSpec(s) # needed for testing + sdict.setdefault(s.name, []) + # # TODO: this block is important! can't leave it commented out + # rec = sdict.setdefault(s.name, []) + # if s.target: + # dist = Dist(s.target) + # if dist in self.index: + # if self.index[dist].get('priority', 0) < MAX_CHANNEL_PRIORITY: + # rec.append(dist) + + for name, targets in sdict.items(): + pkgs = [(self.version_key(p), p) for p in self.groups.get(name, [])] + pkey = None + # keep in mind that pkgs is already sorted according to version_key (a tuple, + # so composite sort key). Later entries in the list are, by definition, + # greater in some way, so simply comparing with != suffices. + for version_key, prec in pkgs: + if targets and any(prec == t for t in targets): + continue + if pkey is None: + ic = iv = ib = it = ia = 0 + # valid package, channel priority + elif pkey[0] != version_key[0] or pkey[1] != version_key[1]: + ic += 1 + iv = ib = it = ia = 0 + # version + elif pkey[2] != version_key[2]: + iv += 1 + ib = it = ia = 0 + # build number + elif pkey[3] != version_key[3]: + ib += 1 + it = ia = 0 + # arch/noarch + elif pkey[4] != version_key[4]: + ia += 1 + it = 0 + elif not self._solver_ignore_timestamps and pkey[5] != version_key[5]: + it += 1 + + prec_sat_name = self.to_sat_name(prec) + if ic or include0: + eqc[prec_sat_name] = ic + if iv or include0: + eqv[prec_sat_name] = iv + if ib or include0: + eqb[prec_sat_name] = ib + if ia or include0: + eqa[prec_sat_name] = ia + if it or include0: + eqt[prec_sat_name] = it + pkey = version_key + + return eqc, eqv, eqb, eqa, eqt + + def dependency_sort( + self, + must_have: dict[str, PackageRecord], + ) -> list[PackageRecord]: + assert isinstance(must_have, dict) + + digraph = {} # dict[str, set[dependent_package_names]] + for package_name, prec in must_have.items(): + if prec in self.index: + digraph[package_name] = {ms.name for ms in self.ms_depends(prec)} + + # There are currently at least three special cases to be aware of. + # 1. The `toposort()` function, called below, contains special case code to remove + # any circular dependency between python and pip. + # 2. conda/plan.py has special case code for menuinst + # Always link/unlink menuinst first/last on windows in case a subsequent + # package tries to import it to create/remove a shortcut + # 3. On windows, python noarch packages need an implicit dependency on conda added, if + # conda is in the list of packages for the environment. Python noarch packages + # that have entry points use conda's own conda.exe python entry point binary. If conda + # is going to be updated during an operation, the unlink / link order matters. + # See issue #6057. + + if on_win and "conda" in digraph: + for package_name, dist in must_have.items(): + record = self.index.get(prec) + if hasattr(record, "noarch") and record.noarch == NoarchType.python: + digraph[package_name].add("conda") + + sorted_keys = toposort(digraph) + must_have = must_have.copy() + # Take all of the items in the sorted keys + # Don't fail if the key does not exist + result = [must_have.pop(key) for key in sorted_keys if key in must_have] + # Take any key that were not sorted + result.extend(must_have.values()) + return result + + def environment_is_consistent(self, installed): + log.debug("Checking if the current environment is consistent") + if not installed: + return None, [] + sat_name_map = {} # dict[sat_name, PackageRecord] + specs = [] + for prec in installed: + sat_name_map[self.to_sat_name(prec)] = prec + specs.append(MatchSpec(f"{prec.name} {prec.version} {prec.build}")) + r2 = Resolve({prec: prec for prec in installed}, True, channels=self.channels) + C = r2.gen_clauses() + constraints = r2.generate_spec_constraints(C, specs) + solution = C.sat(constraints) + return bool(solution) + + def get_conflicting_specs(self, specs, explicit_specs): + if not specs: + return () + + all_specs = set(specs) | set(explicit_specs) + reduced_index = self.get_reduced_index(all_specs) + + # Check if satisfiable + def mysat(specs, add_if=False): + constraints = r2.generate_spec_constraints(C, specs) + return C.sat(constraints, add_if) + + if reduced_index: + r2 = Resolve(reduced_index, True, channels=self.channels) + C = r2.gen_clauses() + solution = mysat(all_specs, True) + else: + solution = None + + if solution: + final_unsat_specs = () + elif context.unsatisfiable_hints: + r2 = Resolve(self.index, True, channels=self.channels) + C = r2.gen_clauses() + # This first result is just a single unsatisfiable core. There may be several. + final_unsat_specs = tuple( + minimal_unsatisfiable_subset( + specs, sat=mysat, explicit_specs=explicit_specs + ) + ) + else: + final_unsat_specs = None + return final_unsat_specs + + def bad_installed(self, installed, new_specs): + log.debug("Checking if the current environment is consistent") + if not installed: + return None, [] + sat_name_map = {} # dict[sat_name, PackageRecord] + specs = [] + for prec in installed: + sat_name_map[self.to_sat_name(prec)] = prec + specs.append(MatchSpec(f"{prec.name} {prec.version} {prec.build}")) + new_index = {prec: prec for prec in sat_name_map.values()} + name_map = {p.name: p for p in new_index} + if "python" in name_map and "pip" not in name_map: + python_prec = new_index[name_map["python"]] + if "pip" in python_prec.depends: + # strip pip dependency from python if not installed in environment + new_deps = [d for d in python_prec.depends if d != "pip"] + python_prec.depends = new_deps + r2 = Resolve(new_index, True, channels=self.channels) + C = r2.gen_clauses() + constraints = r2.generate_spec_constraints(C, specs) + solution = C.sat(constraints) + limit = xtra = None + if not solution or xtra: + + def get_(name, snames): + if name not in snames: + snames.add(name) + for fn in self.groups.get(name, []): + for ms in self.ms_depends(fn): + get_(ms.name, snames) + + # New addition: find the largest set of installed packages that + # are consistent with each other, and include those in the + # list of packages to maintain consistency with + snames = set() + eq_optional_c = r2.generate_removal_count(C, specs) + solution, _ = C.minimize(eq_optional_c, C.sat()) + snames.update( + sat_name_map[sat_name]["name"] + for sat_name in (C.from_index(s) for s in solution) + if sat_name and sat_name[0] != "!" and "@" not in sat_name + ) + # Existing behavior: keep all specs and their dependencies + for spec in new_specs: + get_(MatchSpec(spec).name, snames) + if len(snames) < len(sat_name_map): + limit = snames + xtra = [ + rec + for sat_name, rec in sat_name_map.items() + if rec["name"] not in snames + ] + log.debug( + "Limiting solver to the following packages: %s", ", ".join(limit) + ) + if xtra: + log.debug("Packages to be preserved: %s", xtra) + return limit, xtra + + def restore_bad(self, pkgs, preserve): + if preserve: + sdict = {prec.name: prec for prec in pkgs} + pkgs.extend(p for p in preserve if p.name not in sdict) + + def install_specs(self, specs, installed, update_deps=True): + specs = list(map(MatchSpec, specs)) + snames = {s.name for s in specs} + log.debug("Checking satisfiability of current install") + limit, preserve = self.bad_installed(installed, specs) + for prec in installed: + if prec not in self.index: + continue + name, version, build = prec.name, prec.version, prec.build + schannel = prec.channel.canonical_name + if name in snames or limit is not None and name not in limit: + continue + # If update_deps=True, set the target package in MatchSpec so that + # the solver can minimize the version change. If update_deps=False, + # fix the version and build so that no change is possible. + if update_deps: + # TODO: fix target here + spec = MatchSpec(name=name, target=prec.dist_str()) + else: + spec = MatchSpec( + name=name, version=version, build=build, channel=schannel + ) + specs.insert(0, spec) + return tuple(specs), preserve + + def install(self, specs, installed=None, update_deps=True, returnall=False): + specs, preserve = self.install_specs(specs, installed or [], update_deps) + pkgs = [] + if specs: + pkgs = self.solve(specs, returnall=returnall, _remove=False) + self.restore_bad(pkgs, preserve) + return pkgs + + def remove_specs(self, specs, installed): + nspecs = [] + # There's an imperfect thing happening here. "specs" nominally contains + # a list of package names or track_feature values to be removed. But + # because of add_defaults_to_specs it may also contain version constraints + # like "python 2.7*", which are *not* asking for python to be removed. + # We need to separate these two kinds of specs here. + for s in map(MatchSpec, specs): + # Since '@' is an illegal version number, this ensures that all of + # these matches will never match an actual package. Combined with + # optional=True, this has the effect of forcing their removal. + if s._is_single(): + nspecs.append(MatchSpec(s, version="@", optional=True)) + else: + nspecs.append(MatchSpec(s, optional=True)) + snames = {s.name for s in nspecs if s.name} + limit, _ = self.bad_installed(installed, nspecs) + preserve = [] + for prec in installed: + nm, ver = prec.name, prec.version + if nm in snames: + continue + elif limit is not None: + preserve.append(prec) + else: + # TODO: fix target here + nspecs.append( + MatchSpec( + name=nm, + version=">=" + ver if ver else None, + optional=True, + target=prec.dist_str(), + ) + ) + return nspecs, preserve + + def remove(self, specs, installed): + specs, preserve = self.remove_specs(specs, installed) + pkgs = self.solve(specs, _remove=True) + self.restore_bad(pkgs, preserve) + return pkgs + + @time_recorder(module_name=__name__) + def solve( + self, + specs: list, + returnall: bool = False, + _remove=False, + specs_to_add=None, + history_specs=None, + should_retry_solve=False, + ) -> list[PackageRecord]: + if specs and not isinstance(specs[0], MatchSpec): + specs = tuple(MatchSpec(_) for _ in specs) + + specs = set(specs) + if log.isEnabledFor(DEBUG): + dlist = dashlist( + str("%i: %s target=%s optional=%s" % (i, s, s.target, s.optional)) + for i, s in enumerate(specs) + ) + log.debug("Solving for: %s", dlist) + + if not specs: + return () + + # Find the compliant packages + log.debug("Solve: Getting reduced index of compliant packages") + len0 = len(specs) + + reduced_index = self.get_reduced_index( + specs, exit_on_conflict=not context.unsatisfiable_hints + ) + if not reduced_index: + # something is intrinsically unsatisfiable - either not found or + # not the right version + not_found_packages = set() + wrong_version_packages = set() + for s in specs: + if not self.find_matches(s): + if s.name in self.groups: + wrong_version_packages.add(s) + else: + not_found_packages.add(s) + if not_found_packages: + raise ResolvePackageNotFound(not_found_packages) + elif wrong_version_packages: + raise UnsatisfiableError( + [[d] for d in wrong_version_packages], chains=False + ) + if should_retry_solve: + # We don't want to call find_conflicts until our last try. + # This jumps back out to conda/cli/install.py, where the + # retries happen + raise UnsatisfiableError({}) + else: + self.find_conflicts(specs, specs_to_add, history_specs) + + # Check if satisfiable + log.debug("Solve: determining satisfiability") + + def mysat(specs, add_if=False): + constraints = r2.generate_spec_constraints(C, specs) + return C.sat(constraints, add_if) + + # Return a solution of packages + def clean(sol): + return [ + q + for q in (C.from_index(s) for s in sol) + if q and q[0] != "!" and "@" not in q + ] + + def is_converged(solution): + """Determine if the SAT problem has converged to a single solution. + + This is determined by testing for a SAT solution with the current + clause set and a clause in which at least one of the packages in + the current solution is excluded. If a solution exists the problem + has not converged as multiple solutions still exist. + """ + psolution = clean(solution) + nclause = tuple(C.Not(C.from_name(q)) for q in psolution) + if C.sat((nclause,), includeIf=False) is None: + return True + return False + + r2 = Resolve(reduced_index, True, channels=self.channels) + C = r2.gen_clauses() + solution = mysat(specs, True) + if not solution: + if should_retry_solve: + # we don't want to call find_conflicts until our last try + raise UnsatisfiableError({}) + else: + self.find_conflicts(specs, specs_to_add, history_specs) + + speco = [] # optional packages + specr = [] # requested packages + speca = [] # all other packages + specm = set(r2.groups) # missing from specs + for k, s in enumerate(specs): + if s.name in specm: + specm.remove(s.name) + if not s.optional: + (speca if s.target or k >= len0 else specr).append(s) + elif any(r2.find_matches(s)): + s = MatchSpec(s.name, optional=True, target=s.target) + speco.append(s) + speca.append(s) + speca.extend(MatchSpec(s) for s in specm) + + if log.isEnabledFor(DEBUG): + log.debug("Requested specs: %s", dashlist(sorted(str(s) for s in specr))) + log.debug("Optional specs: %s", dashlist(sorted(str(s) for s in speco))) + log.debug("All other specs: %s", dashlist(sorted(str(s) for s in speca))) + log.debug("missing specs: %s", dashlist(sorted(str(s) for s in specm))) + + # Removed packages: minimize count + log.debug("Solve: minimize removed packages") + if _remove: + eq_optional_c = r2.generate_removal_count(C, speco) + solution, obj7 = C.minimize(eq_optional_c, solution) + log.debug("Package removal metric: %d", obj7) + + # Requested packages: maximize versions + log.debug("Solve: maximize versions of requested packages") + eq_req_c, eq_req_v, eq_req_b, eq_req_a, eq_req_t = r2.generate_version_metrics( + C, specr + ) + solution, obj3a = C.minimize(eq_req_c, solution) + solution, obj3 = C.minimize(eq_req_v, solution) + log.debug("Initial package channel/version metric: %d/%d", obj3a, obj3) + + # Track features: minimize feature count + log.debug("Solve: minimize track_feature count") + eq_feature_count = r2.generate_feature_count(C) + solution, obj1 = C.minimize(eq_feature_count, solution) + log.debug("Track feature count: %d", obj1) + + # Featured packages: minimize number of featureless packages + # installed when a featured alternative is feasible. + # For example, package name foo exists with two built packages. One with + # 'track_features: 'feat1', and one with 'track_features': 'feat2'. + # The previous "Track features" minimization pass has chosen 'feat1' for the + # environment, but not 'feat2'. In this case, the 'feat2' version of foo is + # considered "featureless." + eq_feature_metric = r2.generate_feature_metric(C) + solution, obj2 = C.minimize(eq_feature_metric, solution) + log.debug("Package misfeature count: %d", obj2) + + # Requested packages: maximize builds + log.debug("Solve: maximize build numbers of requested packages") + solution, obj4 = C.minimize(eq_req_b, solution) + log.debug("Initial package build metric: %d", obj4) + + # prefer arch packages where available for requested specs + log.debug("Solve: prefer arch over noarch for requested packages") + solution, noarch_obj = C.minimize(eq_req_a, solution) + log.debug("Noarch metric: %d", noarch_obj) + + # Optional installations: minimize count + if not _remove: + log.debug("Solve: minimize number of optional installations") + eq_optional_install = r2.generate_install_count(C, speco) + solution, obj49 = C.minimize(eq_optional_install, solution) + log.debug("Optional package install metric: %d", obj49) + + # Dependencies: minimize the number of packages that need upgrading + log.debug("Solve: minimize number of necessary upgrades") + eq_u = r2.generate_update_count(C, speca) + solution, obj50 = C.minimize(eq_u, solution) + log.debug("Dependency update count: %d", obj50) + + # Remaining packages: maximize versions, then builds + log.debug( + "Solve: maximize versions and builds of indirect dependencies. " + "Prefer arch over noarch where equivalent." + ) + eq_c, eq_v, eq_b, eq_a, eq_t = r2.generate_version_metrics(C, speca) + solution, obj5a = C.minimize(eq_c, solution) + solution, obj5 = C.minimize(eq_v, solution) + solution, obj6 = C.minimize(eq_b, solution) + solution, obj6a = C.minimize(eq_a, solution) + log.debug( + "Additional package channel/version/build/noarch metrics: %d/%d/%d/%d", + obj5a, + obj5, + obj6, + obj6a, + ) + + # Prune unnecessary packages + log.debug("Solve: prune unnecessary packages") + eq_c = r2.generate_package_count(C, specm) + solution, obj7 = C.minimize(eq_c, solution, trymax=True) + log.debug("Weak dependency count: %d", obj7) + + if not is_converged(solution): + # Maximize timestamps + eq_t.update(eq_req_t) + solution, obj6t = C.minimize(eq_t, solution) + log.debug("Timestamp metric: %d", obj6t) + + log.debug("Looking for alternate solutions") + nsol = 1 + psolutions = [] + psolution = clean(solution) + psolutions.append(psolution) + while True: + nclause = tuple(C.Not(C.from_name(q)) for q in psolution) + solution = C.sat((nclause,), True) + if solution is None: + break + nsol += 1 + if nsol > 10: + log.debug("Too many solutions; terminating") + break + psolution = clean(solution) + psolutions.append(psolution) + + if nsol > 1: + psols2 = list(map(set, psolutions)) + common = set.intersection(*psols2) + diffs = [sorted(set(sol) - common) for sol in psols2] + if not context.json: + stdoutlog.info( + "\nWarning: {} possible package resolutions " + "(only showing differing packages):{}{}".format( + ">10" if nsol > 10 else nsol, + dashlist(", ".join(diff) for diff in diffs), + "\n ... and others" if nsol > 10 else "", + ) + ) + + # def stripfeat(sol): + # return sol.split('[')[0] + + new_index = {self.to_sat_name(prec): prec for prec in self.index.values()} + + if returnall: + if len(psolutions) > 1: + raise RuntimeError() + # TODO: clean up this mess + # return [sorted(Dist(stripfeat(dname)) for dname in psol) for psol in psolutions] + # return [sorted((new_index[sat_name] for sat_name in psol), key=lambda x: x.name) + # for psol in psolutions] + + # return sorted(Dist(stripfeat(dname)) for dname in psolutions[0]) + return sorted( + (new_index[sat_name] for sat_name in psolutions[0]), key=lambda x: x.name + ) diff --git a/conda_classic_solver/solve.py b/conda_classic_solver/solve.py new file mode 100644 index 0000000..2ce3737 --- /dev/null +++ b/conda_classic_solver/solve.py @@ -0,0 +1,1225 @@ +# Copyright (C) 2012 Anaconda, Inc +# Copyright (C) 2023 conda +# SPDX-License-Identifier: BSD-3-Clause +"""The classic solver implementation.""" + +from __future__ import annotations + +import copy +import sys +from itertools import chain +from logging import DEBUG, getLogger +from textwrap import dedent +from typing import TYPE_CHECKING + +from boltons.setutils import IndexedSet +from conda import CondaError +from conda import __version__ as CONDA_VERSION +from conda.auxlib.decorators import memoizedproperty +from conda.auxlib.ish import dals +from conda.base.constants import ( + REPODATA_FN, + UNKNOWN_CHANNEL, + DepsModifier, + UpdateModifier, +) +from conda.base.context import context +from conda.common.constants import NULL, TRACE +from conda.common.io import Spinner, dashlist, time_recorder +from conda.common.iterators import groupby_to_dict as groupby +from conda.common.path import get_major_minor_version, paths_equal +from conda.core.index import _supplement_index_with_system, get_reduced_index +from conda.core.prefix_data import PrefixData +from conda.core.solve import Solver, get_pinned_specs +from conda.core.subdir_data import SubdirData +from conda.exceptions import ( + PackagesNotFoundError, + SpecsConfigurationConflictError, + UnsatisfiableError, +) +from conda.history import History +from conda.models.channel import Channel +from conda.models.match_spec import MatchSpec +from conda.models.prefix_graph import PrefixGraph +from conda.models.version import VersionOrder + +try: + from frozendict import frozendict +except ImportError: + from conda.auxlib.collection import frozendict + +from .resolve import Resolve + +if TYPE_CHECKING: + from conda.models.records import PackageRecord + +log = getLogger(__name__) + + +class ClassicSolver(Solver): + """ + High-level logic for the 'classic' (pycosat) solver in conda. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._index = None + self._r = None + self._prepared = False + self._pool_cache = {} + + def solve_final_state( + self, + update_modifier=NULL, + deps_modifier=NULL, + prune=NULL, + ignore_pinned=NULL, + force_remove=NULL, + should_retry_solve=False, + ) -> tuple[PackageRecord]: + """Gives the final, solved state of the environment. + + Args: + update_modifier (UpdateModifier): + An optional flag directing how updates are handled regarding packages already + existing in the environment. + + deps_modifier (DepsModifier): + An optional flag indicating special solver handling for dependencies. The + default solver behavior is to be as conservative as possible with dependency + updates (in the case the dependency already exists in the environment), while + still ensuring all dependencies are satisfied. Options include + * NO_DEPS + * ONLY_DEPS + * UPDATE_DEPS + * UPDATE_DEPS_ONLY_DEPS + * FREEZE_INSTALLED + prune (bool): + If ``True``, the solution will not contain packages that were + previously brought into the environment as dependencies but are no longer + required as dependencies and are not user-requested. + ignore_pinned (bool): + If ``True``, the solution will ignore pinned package configuration + for the prefix. + force_remove (bool): + Forces removal of a package without removing packages that depend on it. + should_retry_solve (bool): + Indicates whether this solve will be retried. This allows us to control + whether to call find_conflicts (slow) in ssc.r.solve + + Returns: + tuple[PackageRecord]: + In sorted dependency order from roots to leaves, the package references for + the solved state of the environment. + + """ + if prune and update_modifier == UpdateModifier.FREEZE_INSTALLED: + update_modifier = NULL + if update_modifier is NULL: + update_modifier = context.update_modifier + else: + update_modifier = UpdateModifier(str(update_modifier).lower()) + if deps_modifier is NULL: + deps_modifier = context.deps_modifier + else: + deps_modifier = DepsModifier(str(deps_modifier).lower()) + ignore_pinned = ( + context.ignore_pinned if ignore_pinned is NULL else ignore_pinned + ) + force_remove = context.force_remove if force_remove is NULL else force_remove + + log.debug( + "solving prefix %s\n" + " specs_to_remove: %s\n" + " specs_to_add: %s\n" + " prune: %s", + self.prefix, + self.specs_to_remove, + self.specs_to_add, + prune, + ) + + retrying = hasattr(self, "ssc") + + if not retrying: + ssc = SolverStateContainer( + self.prefix, + update_modifier, + deps_modifier, + prune, + ignore_pinned, + force_remove, + should_retry_solve, + ) + self.ssc = ssc + else: + ssc = self.ssc + ssc.update_modifier = update_modifier + ssc.deps_modifier = deps_modifier + ssc.should_retry_solve = should_retry_solve + + # force_remove is a special case where we return early + if self.specs_to_remove and force_remove: + if self.specs_to_add: + raise NotImplementedError() + solution = tuple( + prec + for prec in ssc.solution_precs + if not any(spec.match(prec) for spec in self.specs_to_remove) + ) + return IndexedSet(PrefixGraph(solution).graph) + + # Check if specs are satisfied by current environment. If they are, exit early. + if ( + update_modifier == UpdateModifier.SPECS_SATISFIED_SKIP_SOLVE + and not self.specs_to_remove + and not prune + ): + for spec in self.specs_to_add: + if not next(ssc.prefix_data.query(spec), None): + break + else: + # All specs match a package in the current environment. + # Return early, with a solution that should just be PrefixData().iter_records() + return IndexedSet(PrefixGraph(ssc.solution_precs).graph) + + if not ssc.r: + with Spinner( + f"Collecting package metadata ({self._repodata_fn})", + not context.verbose and not context.quiet and not retrying, + context.json, + ): + ssc = self._collect_all_metadata(ssc) + + if should_retry_solve and update_modifier == UpdateModifier.FREEZE_INSTALLED: + fail_message = ( + "unsuccessful initial attempt using frozen solve. Retrying" + " with flexible solve.\n" + ) + elif self._repodata_fn != REPODATA_FN: + fail_message = ( + f"unsuccessful attempt using repodata from {self._repodata_fn}, retrying" + " with next repodata source.\n" + ) + else: + fail_message = "failed\n" + + with Spinner( + "Solving environment", + not context.verbose and not context.quiet, + context.json, + fail_message=fail_message, + ): + ssc = self._remove_specs(ssc) + ssc = self._add_specs(ssc) + solution_precs = copy.copy(ssc.solution_precs) + + pre_packages = self.get_request_package_in_solution( + ssc.solution_precs, ssc.specs_map + ) + ssc = self._find_inconsistent_packages(ssc) + # this will prune precs that are deps of precs that get removed due to conflicts + ssc = self._run_sat(ssc) + post_packages = self.get_request_package_in_solution( + ssc.solution_precs, ssc.specs_map + ) + + if ssc.update_modifier == UpdateModifier.UPDATE_SPECS: + constrained = self.get_constrained_packages( + pre_packages, post_packages, ssc.index.keys() + ) + if len(constrained) > 0: + for spec in constrained: + self.determine_constricting_specs(spec, ssc.solution_precs) + + # if there were any conflicts, we need to add their orphaned deps back in + if ssc.add_back_map: + orphan_precs = ( + set(solution_precs) + - set(ssc.solution_precs) + - set(ssc.add_back_map) + ) + solution_prec_names = [_.name for _ in ssc.solution_precs] + ssc.solution_precs.extend( + [ + _ + for _ in orphan_precs + if _.name not in ssc.specs_map + and _.name not in solution_prec_names + ] + ) + + ssc = self._post_sat_handling(ssc) + + time_recorder.log_totals() + + ssc.solution_precs = IndexedSet(PrefixGraph(ssc.solution_precs).graph) + log.debug( + "solved prefix %s\n solved_linked_dists:\n %s\n", + self.prefix, + "\n ".join(prec.dist_str() for prec in ssc.solution_precs), + ) + + return ssc.solution_precs + + def determine_constricting_specs(self, spec, solution_precs): + highest_version = [ + VersionOrder(sp.version) for sp in solution_precs if sp.name == spec.name + ][0] + constricting = [] + for prec in solution_precs: + if any(j for j in prec.depends if spec.name in j): + for dep in prec.depends: + m_dep = MatchSpec(dep) + if ( + m_dep.name == spec.name + and m_dep.version is not None + and (m_dep.version.exact_value or "<" in m_dep.version.spec) + ): + if "," in m_dep.version.spec: + constricting.extend( + [ + (prec.name, MatchSpec(f"{m_dep.name} {v}")) + for v in m_dep.version.tup + if "<" in v.spec + ] + ) + else: + constricting.append((prec.name, m_dep)) + + hard_constricting = [ + i for i in constricting if i[1].version.matcher_vo <= highest_version + ] + if len(hard_constricting) == 0: + return None + + print(f"\n\nUpdating {spec.name} is constricted by \n") + for const in hard_constricting: + print(f"{const[0]} -> requires {const[1]}") + print( + "\nIf you are sure you want an update of your package either try " + "`conda update --all` or install a specific version of the " + "package you want using `conda install =`\n" + ) + return hard_constricting + + def get_request_package_in_solution(self, solution_precs, specs_map): + requested_packages = {} + for pkg in self.specs_to_add: + update_pkg_request = pkg.name + + requested_packages[update_pkg_request] = [ + (i.name, str(i.version)) + for i in solution_precs + if i.name == update_pkg_request and i.version is not None + ] + requested_packages[update_pkg_request].extend( + [ + (v.name, str(v.version)) + for k, v in specs_map.items() + if k == update_pkg_request and v.version is not None + ] + ) + + return requested_packages + + def get_constrained_packages(self, pre_packages, post_packages, index_keys): + update_constrained = set() + + def empty_package_list(pkg): + for k, v in pkg.items(): + if len(v) == 0: + return True + return False + + if empty_package_list(pre_packages) or empty_package_list(post_packages): + return update_constrained + + for pkg in self.specs_to_add: + if pkg.name.startswith("__"): # ignore virtual packages + continue + current_version = max(i[1] for i in pre_packages[pkg.name]) + if current_version == max( + i.version for i in index_keys if i.name == pkg.name + ): + continue + else: + if post_packages == pre_packages: + update_constrained = update_constrained | {pkg} + return update_constrained + + @time_recorder(module_name=__name__) + def _collect_all_metadata(self, ssc): + if ssc.prune: + # When pruning DO NOT consider history of already installed packages when solving. + prepared_specs = {*self.specs_to_remove, *self.specs_to_add} + else: + # add in historically-requested specs + ssc.specs_map.update(ssc.specs_from_history_map) + + # these are things that we want to keep even if they're not explicitly specified. This + # is to compensate for older installers not recording these appropriately for them + # to be preserved. + for pkg_name in ( + "anaconda", + "conda", + "conda-build", + "python.app", + "console_shortcut", + "powershell_shortcut", + ): + if pkg_name not in ssc.specs_map and ssc.prefix_data.get( + pkg_name, None + ): + ssc.specs_map[pkg_name] = MatchSpec(pkg_name) + + # Add virtual packages so they are taken into account by the solver + virtual_pkg_index = {} + _supplement_index_with_system(virtual_pkg_index) + virtual_pkgs = [p.name for p in virtual_pkg_index.keys()] + for virtual_pkgs_name in virtual_pkgs: + if virtual_pkgs_name not in ssc.specs_map: + ssc.specs_map[virtual_pkgs_name] = MatchSpec(virtual_pkgs_name) + + for prec in ssc.prefix_data.iter_records(): + # first check: add everything if we have no history to work with. + # This happens with "update --all", for example. + # + # second check: add in aggressively updated packages + # + # third check: add in foreign stuff (e.g. from pip) into the specs + # map. We add it so that it can be left alone more. This is a + # declaration that it is manually installed, much like the + # history map. It may still be replaced if it is in conflict, + # but it is not just an indirect dep that can be pruned. + if ( + not ssc.specs_from_history_map + or MatchSpec(prec.name) in context.aggressive_update_packages + or prec.subdir == "pypi" + ): + ssc.specs_map.update({prec.name: MatchSpec(prec.name)}) + + prepared_specs = { + *self.specs_to_remove, + *self.specs_to_add, + *ssc.specs_from_history_map.values(), + } + + index, r = self._prepare(prepared_specs) + ssc.set_repository_metadata(index, r) + return ssc + + def _remove_specs(self, ssc): + if self.specs_to_remove: + # In a previous implementation, we invoked SAT here via `r.remove()` to help with + # spec removal, and then later invoking SAT again via `r.solve()`. Rather than invoking + # SAT for spec removal determination, we can use the PrefixGraph and simple tree + # traversal if we're careful about how we handle features. We still invoke sat via + # `r.solve()` later. + _track_fts_specs = ( + spec for spec in self.specs_to_remove if "track_features" in spec + ) + feature_names = set( + chain.from_iterable( + spec.get_raw_value("track_features") for spec in _track_fts_specs + ) + ) + graph = PrefixGraph(ssc.solution_precs, ssc.specs_map.values()) + + all_removed_records = [] + no_removed_records_specs = [] + for spec in self.specs_to_remove: + # If the spec was a track_features spec, then we need to also remove every + # package with a feature that matches the track_feature. The + # `graph.remove_spec()` method handles that for us. + log.log(TRACE, "using PrefixGraph to remove records for %s", spec) + removed_records = graph.remove_spec(spec) + if removed_records: + all_removed_records.extend(removed_records) + else: + no_removed_records_specs.append(spec) + + # ensure that each spec in specs_to_remove is actually associated with removed records + unmatched_specs_to_remove = tuple( + spec + for spec in no_removed_records_specs + if not any(spec.match(rec) for rec in all_removed_records) + ) + if unmatched_specs_to_remove: + raise PackagesNotFoundError( + tuple(sorted(str(s) for s in unmatched_specs_to_remove)) + ) + + for rec in all_removed_records: + # We keep specs (minus the feature part) for the non provides_features packages + # if they're in the history specs. Otherwise, we pop them from the specs_map. + rec_has_a_feature = set(rec.features or ()) & feature_names + if rec_has_a_feature and rec.name in ssc.specs_from_history_map: + spec = ssc.specs_map.get(rec.name, MatchSpec(rec.name)) + spec._match_components = frozendict( + { + key: value + for key, value in spec._match_components.items() + if key != "features" + } + ) + ssc.specs_map[spec.name] = spec + else: + ssc.specs_map.pop(rec.name, None) + + ssc.solution_precs = tuple(graph.graph) + return ssc + + @time_recorder(module_name=__name__) + def _find_inconsistent_packages(self, ssc): + # We handle as best as possible environments in inconsistent states. To do this, + # we remove now from consideration the set of packages causing inconsistencies, + # and then we add them back in following the main SAT call. + _, inconsistent_precs = ssc.r.bad_installed(ssc.solution_precs, ()) + if inconsistent_precs: + # It is possible that the package metadata is incorrect, for example when + # un-patched metadata from the Miniconda or Anaconda installer is present, see: + # https://github.com/conda/conda/issues/8076 + # Update the metadata with information from the index and see if that makes the + # environment consistent. + ssc.solution_precs = tuple(ssc.index.get(k, k) for k in ssc.solution_precs) + _, inconsistent_precs = ssc.r.bad_installed(ssc.solution_precs, ()) + if log.isEnabledFor(DEBUG): + log.debug( + "inconsistent precs: %s", + dashlist(inconsistent_precs) if inconsistent_precs else "None", + ) + if inconsistent_precs: + print( + dedent( + """ + The environment is inconsistent, please check the package plan carefully + The following packages are causing the inconsistency:""" + ), + file=sys.stderr, + ) + print(dashlist(inconsistent_precs), file=sys.stderr) + for prec in inconsistent_precs: + # pop and save matching spec in specs_map + spec = ssc.specs_map.pop(prec.name, None) + ssc.add_back_map[prec.name] = (prec, spec) + # let the package float. This is essential to keep the package's dependencies + # in the solution + ssc.specs_map[prec.name] = MatchSpec(prec.name, target=prec.dist_str()) + # inconsistent environments should maintain the python version + # unless explicitly requested by the user. This along with the logic in + # _add_specs maintains the major.minor version + if prec.name == "python" and spec: + ssc.specs_map["python"] = spec + ssc.solution_precs = tuple( + prec for prec in ssc.solution_precs if prec not in inconsistent_precs + ) + return ssc + + def _package_has_updates(self, ssc, spec, installed_pool): + installed_prec = installed_pool.get(spec.name) + has_update = False + + if installed_prec: + installed_prec = installed_prec[0] + for prec in ssc.r.groups.get(spec.name, []): + if prec.version > installed_prec.version: + has_update = True + break + elif ( + prec.version == installed_prec.version + and prec.build_number > installed_prec.build_number + ): + has_update = True + break + # let conda determine the latest version by just adding a name spec + return ( + MatchSpec(spec.name, version=prec.version, build_number=prec.build_number) + if has_update + else spec + ) + + def _should_freeze( + self, ssc, target_prec, conflict_specs, explicit_pool, installed_pool + ): + # never, ever freeze anything if we have no history. + if not ssc.specs_from_history_map: + return False + # never freeze if not in FREEZE_INSTALLED mode + if ssc.update_modifier != UpdateModifier.FREEZE_INSTALLED: + return False + + # if all package specs have overlapping package choices (satisfiable in at least one way) + pkg_name = target_prec.name + no_conflict = pkg_name not in conflict_specs and ( + pkg_name not in explicit_pool or target_prec in explicit_pool[pkg_name] + ) + + return no_conflict + + def _add_specs(self, ssc): + # For the remaining specs in specs_map, add target to each spec. `target` is a reference + # to the package currently existing in the environment. Setting target instructs the + # solver to not disturb that package if it's not necessary. + # If the spec.name is being modified by inclusion in specs_to_add, we don't set `target`, + # since we *want* the solver to modify/update that package. + # + # TLDR: when working with MatchSpec objects, + # - to minimize the version change, set MatchSpec(name=name, target=prec.dist_str()) + # - to freeze the package, set all the components of MatchSpec individually + + installed_pool = groupby(lambda x: x.name, ssc.prefix_data.iter_records()) + + # the only things we should consider freezing are things that don't conflict with the new + # specs being added. + explicit_pool = ssc.r._get_package_pool(self.specs_to_add) + if ssc.prune: + # Ignore installed specs on prune. + installed_specs = () + else: + installed_specs = [ + record.to_match_spec() for record in ssc.prefix_data.iter_records() + ] + + conflict_specs = ( + ssc.r.get_conflicting_specs(installed_specs, self.specs_to_add) or tuple() + ) + conflict_specs = {spec.name for spec in conflict_specs} + + for pkg_name, spec in ssc.specs_map.items(): + matches_for_spec = tuple( + prec for prec in ssc.solution_precs if spec.match(prec) + ) + if matches_for_spec: + if len(matches_for_spec) != 1: + raise CondaError( + dals( + """ + Conda encountered an error with your environment. Please report an issue + at https://github.com/conda/conda/issues. In your report, please include + the output of 'conda info' and 'conda list' for the active environment, along + with the command you invoked that resulted in this error. + pkg_name: %s + spec: %s + matches_for_spec: %s + """ + ) + % ( + pkg_name, + spec, + dashlist((str(s) for s in matches_for_spec), indent=4), + ) + ) + target_prec = matches_for_spec[0] + if target_prec.is_unmanageable: + ssc.specs_map[pkg_name] = target_prec.to_match_spec() + elif MatchSpec(pkg_name) in context.aggressive_update_packages: + ssc.specs_map[pkg_name] = MatchSpec(pkg_name) + elif self._should_freeze( + ssc, target_prec, conflict_specs, explicit_pool, installed_pool + ): + ssc.specs_map[pkg_name] = target_prec.to_match_spec() + elif pkg_name in ssc.specs_from_history_map: + ssc.specs_map[pkg_name] = MatchSpec( + ssc.specs_from_history_map[pkg_name], + target=target_prec.dist_str(), + ) + else: + ssc.specs_map[pkg_name] = MatchSpec( + pkg_name, target=target_prec.dist_str() + ) + + pin_overrides = set() + for s in ssc.pinned_specs: + if s.name in explicit_pool: + if s.name not in self.specs_to_add_names and not ssc.ignore_pinned: + ssc.specs_map[s.name] = MatchSpec(s, optional=False) + elif explicit_pool[s.name] & ssc.r._get_package_pool([s]).get( + s.name, set() + ): + ssc.specs_map[s.name] = MatchSpec(s, optional=False) + pin_overrides.add(s.name) + else: + log.warning( + "pinned spec %s conflicts with explicit specs. " + "Overriding pinned spec.", + s, + ) + + # we want to freeze any packages in the env that are not conflicts, so that the + # solve goes faster. This is kind of like an iterative solve, except rather + # than just providing a starting place, we are preventing some solutions. + # A true iterative solve would probably be better in terms of reaching the + # optimal output all the time. It would probably also get rid of the need + # to retry with an unfrozen (UPDATE_SPECS) solve. + if ssc.update_modifier == UpdateModifier.FREEZE_INSTALLED: + precs = [ + _ for _ in ssc.prefix_data.iter_records() if _.name not in ssc.specs_map + ] + for prec in precs: + if prec.name not in conflict_specs: + ssc.specs_map[prec.name] = prec.to_match_spec() + else: + ssc.specs_map[prec.name] = MatchSpec( + prec.name, target=prec.to_match_spec(), optional=True + ) + log.debug("specs_map with targets: %s", ssc.specs_map) + + # If we're in UPDATE_ALL mode, we need to drop all the constraints attached to specs, + # so they can all float and the solver can find the most up-to-date solution. In the case + # of UPDATE_ALL, `specs_map` wasn't initialized with packages from the current environment, + # but *only* historically-requested specs. This lets UPDATE_ALL drop dependencies if + # they're no longer needed, and their presence would otherwise prevent the updated solution + # the user most likely wants. + if ssc.update_modifier == UpdateModifier.UPDATE_ALL: + # history is preferable because it has explicitly installed stuff in it. + # that simplifies our solution. + if ssc.specs_from_history_map: + ssc.specs_map = dict( + (spec, MatchSpec(spec)) + if MatchSpec(spec).name not in (_.name for _ in ssc.pinned_specs) + else (MatchSpec(spec).name, ssc.specs_map[MatchSpec(spec).name]) + for spec in ssc.specs_from_history_map + ) + for prec in ssc.prefix_data.iter_records(): + # treat pip-installed stuff as explicitly installed, too. + if prec.subdir == "pypi": + ssc.specs_map.update({prec.name: MatchSpec(prec.name)}) + else: + ssc.specs_map = { + prec.name: ( + MatchSpec(prec.name) + if prec.name not in (_.name for _ in ssc.pinned_specs) + else ssc.specs_map[prec.name] + ) + for prec in ssc.prefix_data.iter_records() + } + + # ensure that our self.specs_to_add are not being held back by packages in the env. + # This factors in pins and also ignores specs from the history. It is unfreezing only + # for the indirect specs that otherwise conflict with update of the immediate request + elif ssc.update_modifier == UpdateModifier.UPDATE_SPECS: + skip = lambda x: ( + ( + x.name not in pin_overrides + and any(x.name == _.name for _ in ssc.pinned_specs) + and not ssc.ignore_pinned + ) + or x.name in ssc.specs_from_history_map + ) + + specs_to_add = tuple( + self._package_has_updates(ssc, _, installed_pool) + for _ in self.specs_to_add + if not skip(_) + ) + # the index is sorted, so the first record here gives us what we want. + conflicts = ssc.r.get_conflicting_specs( + tuple(MatchSpec(_) for _ in ssc.specs_map.values()), specs_to_add + ) + for conflict in conflicts or (): + # neuter the spec due to a conflict + if ( + conflict.name in ssc.specs_map + and ( + # add optional because any pinned specs will include it + MatchSpec(conflict, optional=True) not in ssc.pinned_specs + or ssc.ignore_pinned + ) + and conflict.name not in ssc.specs_from_history_map + ): + ssc.specs_map[conflict.name] = MatchSpec(conflict.name) + + # As a business rule, we never want to update python beyond the current minor version, + # unless that's requested explicitly by the user (which we actively discourage). + py_in_prefix = any(_.name == "python" for _ in ssc.solution_precs) + py_requested_explicitly = any(s.name == "python" for s in self.specs_to_add) + if py_in_prefix and not py_requested_explicitly: + python_prefix_rec = ssc.prefix_data.get("python") + freeze_installed = ssc.update_modifier == UpdateModifier.FREEZE_INSTALLED + if "python" not in conflict_specs and freeze_installed: + ssc.specs_map["python"] = python_prefix_rec.to_match_spec() + else: + # will our prefix record conflict with any explicit spec? If so, don't add + # anything here - let python float when it hasn't been explicitly specified + python_spec = ssc.specs_map.get("python", MatchSpec("python")) + if not python_spec.get("version"): + pinned_version = ( + get_major_minor_version(python_prefix_rec.version) + ".*" + ) + python_spec = MatchSpec(python_spec, version=pinned_version) + + spec_set = (python_spec,) + tuple(self.specs_to_add) + if ssc.r.get_conflicting_specs(spec_set, self.specs_to_add): + if self._command != "install" or ( + self._repodata_fn == REPODATA_FN + and (not ssc.should_retry_solve or not freeze_installed) + ): + # raises a hopefully helpful error message + ssc.r.find_conflicts(spec_set) + else: + raise UnsatisfiableError({}) + ssc.specs_map["python"] = python_spec + + # For the aggressive_update_packages configuration parameter, we strip any target + # that's been set. + if not context.offline: + for spec in context.aggressive_update_packages: + if spec.name in ssc.specs_map: + ssc.specs_map[spec.name] = spec + + # add in explicitly requested specs from specs_to_add + # this overrides any name-matching spec already in the spec map + ssc.specs_map.update( + (s.name, s) for s in self.specs_to_add if s.name not in pin_overrides + ) + + # As a business rule, we never want to downgrade conda below the current version, + # unless that's requested explicitly by the user (which we actively discourage). + if "conda" in ssc.specs_map and paths_equal(self.prefix, context.conda_prefix): + conda_prefix_rec = ssc.prefix_data.get("conda") + if conda_prefix_rec: + version_req = f">={conda_prefix_rec.version}" + conda_requested_explicitly = any( + s.name == "conda" for s in self.specs_to_add + ) + conda_spec = ssc.specs_map["conda"] + conda_in_specs_to_add_version = ssc.specs_map.get("conda", {}).get( + "version" + ) + if not conda_in_specs_to_add_version: + conda_spec = MatchSpec(conda_spec, version=version_req) + if context.auto_update_conda and not conda_requested_explicitly: + conda_spec = MatchSpec("conda", version=version_req, target=None) + ssc.specs_map["conda"] = conda_spec + + return ssc + + @time_recorder(module_name=__name__) + def _run_sat(self, ssc): + final_environment_specs = IndexedSet( + ( + *ssc.specs_map.values(), + *ssc.track_features_specs, + # pinned specs removed here - added to specs_map in _add_specs instead + ) + ) + + absent_specs = [s for s in ssc.specs_map.values() if not ssc.r.find_matches(s)] + if absent_specs: + raise PackagesNotFoundError(absent_specs) + + # We've previously checked `solution` for consistency (which at that point was the + # pre-solve state of the environment). Now we check our compiled set of + # `final_environment_specs` for the possibility of a solution. If there are conflicts, + # we can often avoid them by neutering specs that have a target (e.g. removing version + # constraint) and also making them optional. The result here will be less cases of + # `UnsatisfiableError` handed to users, at the cost of more packages being modified + # or removed from the environment. + # + # get_conflicting_specs() returns a "minimal unsatisfiable subset" which + # may not be the only unsatisfiable subset. We may have to call get_conflicting_specs() + # several times, each time making modifications to loosen constraints. + + conflicting_specs = set( + ssc.r.get_conflicting_specs( + tuple(final_environment_specs), self.specs_to_add + ) + or [] + ) + while conflicting_specs: + specs_modified = False + if log.isEnabledFor(DEBUG): + log.debug( + "conflicting specs: %s", + dashlist(s.target or s for s in conflicting_specs), + ) + + # Are all conflicting specs in specs_map? If not, that means they're in + # track_features_specs or pinned_specs, which we should raise an error on. + specs_map_set = set(ssc.specs_map.values()) + grouped_specs = groupby(lambda s: s in specs_map_set, conflicting_specs) + # force optional to true. This is what it is originally in + # pinned_specs, but we override that in _add_specs to make it + # non-optional when there's a name match in the explicit package + # pool + conflicting_pinned_specs = groupby( + lambda s: MatchSpec(s, optional=True) in ssc.pinned_specs, + conflicting_specs, + ) + + if conflicting_pinned_specs.get(True): + in_specs_map = grouped_specs.get(True, ()) + pinned_conflicts = conflicting_pinned_specs.get(True, ()) + in_specs_map_or_specs_to_add = ( + set(in_specs_map) | set(self.specs_to_add) + ) - set(pinned_conflicts) + + raise SpecsConfigurationConflictError( + sorted(s.__str__() for s in in_specs_map_or_specs_to_add), + sorted(s.__str__() for s in {s for s in pinned_conflicts}), + self.prefix, + ) + for spec in conflicting_specs: + if spec.target and not spec.optional: + specs_modified = True + final_environment_specs.remove(spec) + if spec.get("version"): + neutered_spec = MatchSpec(spec.name, version=spec.version) + else: + neutered_spec = MatchSpec(spec.name) + final_environment_specs.add(neutered_spec) + ssc.specs_map[spec.name] = neutered_spec + if specs_modified: + conflicting_specs = set( + ssc.r.get_conflicting_specs( + tuple(final_environment_specs), self.specs_to_add + ) + ) + else: + # Let r.solve() use r.find_conflicts() to report conflict chains. + break + + # Finally! We get to call SAT. + if log.isEnabledFor(DEBUG): + log.debug( + "final specs to add: %s", + dashlist(sorted(str(s) for s in final_environment_specs)), + ) + + # this will raise for unsatisfiable stuff. We can + if not conflicting_specs or context.unsatisfiable_hints: + ssc.solution_precs = ssc.r.solve( + tuple(final_environment_specs), + specs_to_add=self.specs_to_add, + history_specs=ssc.specs_from_history_map, + should_retry_solve=ssc.should_retry_solve, + ) + else: + # shortcut to raise an unsat error without needing another solve step when + # unsatisfiable_hints is off + raise UnsatisfiableError({}) + + self.neutered_specs = tuple( + v + for k, v in ssc.specs_map.items() + if k in ssc.specs_from_history_map + and v.strictness < ssc.specs_from_history_map[k].strictness + ) + + # add back inconsistent packages to solution + if ssc.add_back_map: + for name, (prec, spec) in ssc.add_back_map.items(): + # spec here will only be set if the conflicting prec was in the original specs_map + # if it isn't there, then we restore the conflict. If it is there, though, + # we keep the new, consistent solution + if not spec: + # filter out solution precs and reinsert the conflict. Any resolution + # of the conflict should be explicit (i.e. it must be in ssc.specs_map) + ssc.solution_precs = [ + _ for _ in ssc.solution_precs if _.name != name + ] + ssc.solution_precs.append(prec) + final_environment_specs.add(spec) + + ssc.final_environment_specs = final_environment_specs + return ssc + + def _post_sat_handling(self, ssc): + # Special case handling for various DepsModifier flags. + final_environment_specs = ssc.final_environment_specs + if ssc.deps_modifier == DepsModifier.NO_DEPS: + # In the NO_DEPS case, we need to start with the original list of packages in the + # environment, and then only modify packages that match specs_to_add or + # specs_to_remove. + # + # Help information notes that use of NO_DEPS is expected to lead to broken + # environments. + _no_deps_solution = IndexedSet(ssc.prefix_data.iter_records()) + only_remove_these = { + prec + for spec in self.specs_to_remove + for prec in _no_deps_solution + if spec.match(prec) + } + _no_deps_solution -= only_remove_these + + only_add_these = { + prec + for spec in self.specs_to_add + for prec in ssc.solution_precs + if spec.match(prec) + } + remove_before_adding_back = {prec.name for prec in only_add_these} + _no_deps_solution = IndexedSet( + prec + for prec in _no_deps_solution + if prec.name not in remove_before_adding_back + ) + _no_deps_solution |= only_add_these + ssc.solution_precs = _no_deps_solution + + # TODO: check if solution is satisfiable, and emit warning if it's not + + elif ( + ssc.deps_modifier == DepsModifier.ONLY_DEPS + and ssc.update_modifier != UpdateModifier.UPDATE_DEPS + ): + # Using a special instance of PrefixGraph to remove youngest child nodes that match + # the original specs_to_add. It's important to remove only the *youngest* child nodes, + # because a typical use might be `conda install --only-deps python=2 flask`, and in + # that case we'd want to keep python. + # + # What are we supposed to do if flask was already in the environment? + # We can't be removing stuff here that's already in the environment. + # + # What should be recorded for the user-requested specs in this case? Probably all + # direct dependencies of flask. + graph = PrefixGraph(ssc.solution_precs, self.specs_to_add) + removed_nodes = graph.remove_youngest_descendant_nodes_with_specs() + self.specs_to_add = set(self.specs_to_add) + for prec in removed_nodes: + for dep in prec.depends: + dep = MatchSpec(dep) + if dep.name not in ssc.specs_map: + self.specs_to_add.add(dep) + # unfreeze + self.specs_to_add = frozenset(self.specs_to_add) + + # Add back packages that are already in the prefix. + specs_to_remove_names = {spec.name for spec in self.specs_to_remove} + add_back = tuple( + ssc.prefix_data.get(node.name, None) + for node in removed_nodes + if node.name not in specs_to_remove_names + ) + ssc.solution_precs = tuple( + PrefixGraph((*graph.graph, *filter(None, add_back))).graph + ) + + # TODO: check if solution is satisfiable, and emit warning if it's not + + elif ssc.update_modifier == UpdateModifier.UPDATE_DEPS: + # Here we have to SAT solve again :( It's only now that we know the dependency + # chain of specs_to_add. + # + # UPDATE_DEPS is effectively making each spec in the dependency chain a user-requested + # spec. We don't modify pinned_specs, track_features_specs, or specs_to_add. For + # all other specs, we drop all information but name, drop target, and add them to + # the specs_to_add that gets recorded in the history file. + # + # It's like UPDATE_ALL, but only for certain dependency chains. + graph = PrefixGraph(ssc.solution_precs) + update_names = set() + for spec in self.specs_to_add: + node = graph.get_node_by_name(spec.name) + update_names.update( + ancest_rec.name for ancest_rec in graph.all_ancestors(node) + ) + specs_map = {name: MatchSpec(name) for name in update_names} + + # Remove pinned_specs and any python spec (due to major-minor pinning business rule). + # Add in the original specs_to_add on top. + for spec in ssc.pinned_specs: + specs_map.pop(spec.name, None) + if "python" in specs_map: + python_rec = ssc.prefix_data.get("python") + py_ver = ".".join(python_rec.version.split(".")[:2]) + ".*" + specs_map["python"] = MatchSpec(name="python", version=py_ver) + specs_map.update({spec.name: spec for spec in self.specs_to_add}) + new_specs_to_add = tuple(specs_map.values()) + + # It feels wrong/unsafe to modify this instance, but I guess let's go with it for now. + self.specs_to_add = new_specs_to_add + ssc.solution_precs = self.solve_final_state( + update_modifier=UpdateModifier.UPDATE_SPECS, + deps_modifier=ssc.deps_modifier, + prune=ssc.prune, + ignore_pinned=ssc.ignore_pinned, + force_remove=ssc.force_remove, + ) + ssc.prune = False + + if ssc.prune: + graph = PrefixGraph(ssc.solution_precs, final_environment_specs) + graph.prune() + ssc.solution_precs = tuple(graph.graph) + + return ssc + + def _notify_conda_outdated(self, link_precs): + if not context.notify_outdated_conda or context.quiet: + return + current_conda_prefix_rec = PrefixData(context.conda_prefix).get("conda", None) + if current_conda_prefix_rec: + channel_name = current_conda_prefix_rec.channel.canonical_name + if channel_name == UNKNOWN_CHANNEL: + channel_name = "defaults" + + # only look for a newer conda in the channel conda is currently installed from + conda_newer_spec = MatchSpec(f"{channel_name}::conda>{CONDA_VERSION}") + + if paths_equal(self.prefix, context.conda_prefix): + if any(conda_newer_spec.match(prec) for prec in link_precs): + return + + conda_newer_precs = sorted( + SubdirData.query_all( + conda_newer_spec, + self.channels, + self.subdirs, + repodata_fn=self._repodata_fn, + ), + key=lambda x: VersionOrder(x.version), + # VersionOrder is fine here rather than r.version_key because all precs + # should come from the same channel + ) + if conda_newer_precs: + latest_version = conda_newer_precs[-1].version + # If conda comes from defaults, ensure we're giving instructions to users + # that should resolve release timing issues between defaults and conda-forge. + print( + dedent( + f""" + + ==> WARNING: A newer version of conda exists. <== + current version: {CONDA_VERSION} + latest version: {latest_version} + + Please update conda by running + + $ conda update -n base -c {channel_name} conda + + Or to minimize the number of packages updated during conda update use + + conda install conda={latest_version} + + """ + ), + file=sys.stderr, + ) + + def _prepare(self, prepared_specs): + # All of this _prepare() method is hidden away down here. Someday we may want to further + # abstract away the use of `index` or the Resolve object. + + if self._prepared and prepared_specs == self._prepared_specs: + return self._index, self._r + + if hasattr(self, "_index") and self._index: + # added in install_actions for conda-build back-compat + self._prepared_specs = prepared_specs + _supplement_index_with_system(self._index) + self._r = Resolve(self._index, channels=self.channels) + else: + # add in required channels that aren't explicitly given in the channels list + # For correctness, we should probably add to additional_channels any channel that + # is given by PrefixData(self.prefix).all_subdir_urls(). However that causes + # usability problems with bad / expired tokens. + + additional_channels = set() + for spec in self.specs_to_add: + # TODO: correct handling for subdir isn't yet done + channel = spec.get_exact_value("channel") + if channel: + additional_channels.add(Channel(channel)) + + self.channels.update(additional_channels) + + reduced_index = get_reduced_index( + self.prefix, + self.channels, + self.subdirs, + prepared_specs, + self._repodata_fn, + ) + _supplement_index_with_system(reduced_index) + + self._prepared_specs = prepared_specs + self._index = reduced_index + self._r = Resolve(reduced_index, channels=self.channels) + + self._prepared = True + return self._index, self._r + + +class SolverStateContainer: + """ + A mutable container with defined attributes to help keep method signatures clean + and also keep track of important state variables. + """ + + def __init__( + self, + prefix, + update_modifier, + deps_modifier, + prune, + ignore_pinned, + force_remove, + should_retry_solve, + ): + # prefix, channels, subdirs, specs_to_add, specs_to_remove + # self.prefix = prefix + # self.channels = channels + # self.subdirs = subdirs + # self.specs_to_add = specs_to_add + # self.specs_to_remove = specs_to_remove + + # Group 1. Behavior flags + self.update_modifier = update_modifier + self.deps_modifier = deps_modifier + self.prune = prune + self.ignore_pinned = ignore_pinned + self.force_remove = force_remove + self.should_retry_solve = should_retry_solve + + # Group 2. System state + self.prefix = prefix + # self.prefix_data = None + # self.specs_from_history_map = None + # self.track_features_specs = None + # self.pinned_specs = None + + # Group 3. Repository metadata + self.index = None + self.r = None + + # Group 4. Mutable working containers + self.specs_map = {} + self.solution_precs = None + self._init_solution_precs() + self.add_back_map = {} # name: (prec, spec) + self.final_environment_specs = None + + @memoizedproperty + def prefix_data(self): + return PrefixData(self.prefix) + + @memoizedproperty + def specs_from_history_map(self): + return History(self.prefix).get_requested_specs_map() + + @memoizedproperty + def track_features_specs(self): + return tuple(MatchSpec(x + "@") for x in context.track_features) + + @memoizedproperty + def pinned_specs(self): + return () if self.ignore_pinned else get_pinned_specs(self.prefix) + + def set_repository_metadata(self, index, r): + self.index, self.r = index, r + + def _init_solution_precs(self): + if self.prune: + # DO NOT add existing prefix data to solution on prune + self.solution_precs = tuple() + else: + self.solution_precs = tuple(self.prefix_data.iter_records()) + + def working_state_reset(self): + self.specs_map = {} + self._init_solution_precs() + self.add_back_map = {} # name: (prec, spec) + self.final_environment_specs = None diff --git a/conda_classic_solver/solver.py b/conda_classic_solver/solver.py deleted file mode 100644 index f6a3871..0000000 --- a/conda_classic_solver/solver.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (C) 2022 Anaconda, Inc -# Copyright (C) 2023 conda -# SPDX-License-Identifier: BSD-3-Clause -""" -The classic solver implementation. -""" - -from conda.core.solve import Solver - - -class ClassicSolver(Solver): - pass diff --git a/tests/test_logic.py b/tests/test_logic.py new file mode 100644 index 0000000..85d8f08 --- /dev/null +++ b/tests/test_logic.py @@ -0,0 +1,421 @@ +# Copyright (C) 2012 Anaconda, Inc +# SPDX-License-Identifier: BSD-3-Clause +from itertools import chain, combinations, permutations, product + +import pytest +from conda.testing.helpers import raises + +from conda_classic_solver.logic import FALSE, TRUE, Clauses, minimal_unsatisfiable_subset + +# These routines implement logical tests with short-circuiting +# and propagation of unknown values: +# - positive integers are variables +# - negative integers are negations of positive variables +# - lowercase True and False are fixed values +# - None reprents an indeterminate value +# If a fixed result is not determinable, the result is None, which +# propagates through the result. +# +# To ensure correctness, the only logic functions we have implemented +# directly are NOT and OR. The rest are implemented in terms of these. +# Performance is not an issue. + + +def my_NOT(x): + if isinstance(x, int): + return -x + if isinstance(x, str): + return x[1:] if x[0] == "!" else "!" + x + return None + + +def my_ABS(x): + if isinstance(x, int): + return abs(x) + if isinstance(x, str): + return x[1:] if x[0] == "!" else x + return None + + +def my_OR(*args): + """Implements a logical OR according to the logic: + - positive integers are variables + - negative integers are negations of positive variables + - TRUE and FALSE are fixed values + - None is an unknown value + TRUE OR x -> TRUE + FALSE OR x -> FALSE + None OR x -> None + x OR y -> None""" + if any(v == TRUE for v in args): + return TRUE + args = {v for v in args if v != FALSE} + if len(args) == 0: + return FALSE + if len(args) == 1: + return next(v for v in args) + if len({v if v is None else my_ABS(v) for v in args}) < len(args): + return TRUE + return None + + +def my_AND(*args): + args = list(map(my_NOT, args)) + return my_NOT(my_OR(*args)) + + +def my_XOR(i, j): + return my_OR(my_AND(i, my_NOT(j)), my_AND(my_NOT(i), j)) + + +def my_ITE(c, t, f): + return my_OR(my_AND(c, t), my_AND(my_NOT(c), f)) + + +def my_AMONE(*args): + args = [my_NOT(v) for v in args] + return my_AND(*[my_OR(v1, v2) for v1, v2 in combinations(args, 2)]) + + +def my_XONE(*args): + return my_AND(my_OR(*args), my_AMONE(*args)) + + +def my_SOL(ij, sol): + return (TRUE if v in sol or v == TRUE else FALSE for v in ij) + + +def _evaluate_eq(eq, sol): + if not isinstance(eq, dict): + eq = {c: v for v, c in eq if c not in {TRUE, FALSE}} + return sum(eq.get(s, 0) for s in sol if s not in {TRUE, FALSE}) + + +def my_EVAL(eq, sol): + # _evaluate_eq doesn't handle TRUE/FALSE entries + return _evaluate_eq(eq, sol) + sum(c for c, a in eq if a == TRUE) + + +# Testing strategy: mechanically construct a all possible permutations of +# True, False, variables from 1 to m, and their negations, in order to exercise +# all logical branches of the function. Test negative, positive, and full +# polarities for each. + + +def my_TEST(Mfunc, Cfunc, mmin, mmax, is_iter): + for m in range(mmin, mmax + 1): + if m == 0: + ijprod = [()] + else: + ijprod = (TRUE, FALSE) + sum(((k, my_NOT(k)) for k in range(1, m + 1)), ()) + ijprod = product(ijprod, repeat=m) + for ij in ijprod: + C = Clauses() + Cpos = Clauses() + Cneg = Clauses() + for k in range(1, m + 1): + nm = "x%d" % k + C.new_var(nm) + Cpos.new_var(nm) + Cneg.new_var(nm) + ij2 = tuple( + C.from_index(k) if isinstance(k, int) and k not in {TRUE, FALSE} else k + for k in ij + ) + if is_iter: + x = Cfunc.__get__(C, Clauses)(ij2) + Cpos.Require(Cfunc.__get__(Cpos, Clauses), ij) + Cneg.Prevent(Cfunc.__get__(Cneg, Clauses), ij) + else: + x = Cfunc.__get__(C, Clauses)(*ij2) + Cpos.Require(Cfunc.__get__(Cpos, Clauses), *ij) + Cneg.Prevent(Cfunc.__get__(Cneg, Clauses), *ij) + tsol = Mfunc(*ij) + if tsol in {TRUE, FALSE}: + assert x == tsol, (ij2, Cfunc.__name__, C.as_list()) + assert Cpos.unsat == (tsol != TRUE) and not Cpos.as_list(), ( + ij, + "Require(%s)", + ) + assert Cneg.unsat == (tsol == TRUE) and not Cneg.as_list(), ( + ij, + "Prevent(%s)", + ) + continue + for sol in C.itersolve([(x,)]): + qsol = Mfunc(*my_SOL(ij, sol)) + assert qsol == TRUE, (ij2, sol, Cfunc.__name__, C.as_list()) + for sol in Cpos.itersolve([]): + qsol = Mfunc(*my_SOL(ij, sol)) + assert qsol == TRUE, ( + ij, + sol, + f"Require({Cfunc.__name__})", + Cpos.as_list(), + ) + for sol in C.itersolve([(C.Not(x),)]): + qsol = Mfunc(*my_SOL(ij, sol)) + assert qsol == FALSE, (ij2, sol, Cfunc.__name__, C.as_list()) + for sol in Cneg.itersolve([]): + qsol = Mfunc(*my_SOL(ij, sol)) + assert qsol == FALSE, ( + ij, + sol, + f"Prevent({Cfunc.__name__})", + Cneg.as_list(), + ) + + +def test_NOT(): + my_TEST(my_NOT, Clauses.Not, 1, 1, False) + + +def test_AND(): + my_TEST(my_AND, Clauses.And, 2, 2, False) + + +@pytest.mark.integration # only because this test is slow +def test_ALL(): + my_TEST(my_AND, Clauses.All, 0, 4, True) + + +def test_OR(): + my_TEST(my_OR, Clauses.Or, 2, 2, False) + + +@pytest.mark.integration # only because this test is slow +def test_ANY(): + my_TEST(my_OR, Clauses.Any, 0, 4, True) + + +def test_XOR(): + my_TEST(my_XOR, Clauses.Xor, 2, 2, False) + + +def test_ITE(): + my_TEST(my_ITE, Clauses.ITE, 3, 3, False) + + +def test_AMONE(): + my_TEST(my_AMONE, Clauses.AtMostOne_NSQ, 0, 3, True) + my_TEST(my_AMONE, Clauses.AtMostOne_BDD, 0, 3, True) + my_TEST(my_AMONE, Clauses.AtMostOne, 0, 3, True) + C1 = Clauses(10) + x1 = C1.AtMostOne_BDD(tuple(range(1, 11))) + C2 = Clauses(10) + x2 = C2.AtMostOne(tuple(range(1, 11))) + assert x1 == x2 and C1.as_list() == C2.as_list() + + +@pytest.mark.integration # only because this test is slow +def test_XONE(): + my_TEST(my_XONE, Clauses.ExactlyOne_NSQ, 0, 3, True) + my_TEST(my_XONE, Clauses.ExactlyOne_BDD, 0, 3, True) + my_TEST(my_XONE, Clauses.ExactlyOne, 0, 3, True) + + +@pytest.mark.integration # only because this test is slow +def test_LinearBound(): + L = [ + ([], [0, 1], 10), + ([], [1, 2], 10), + ({"x1": 2, "x2": 2}, [3, 3], 10), + ({"x1": 2, "x2": 2}, [0, 1], 1000), + ({"x1": 1, "x2": 2}, [0, 2], 1000), + ({"x1": 2, "!x2": 2}, [0, 2], 1000), + ([(1, 1), (2, 2), (3, 3)], [3, 3], 1000), + ([(0, 1), (1, 2), (2, 3), (0, 4), (1, 5), (0, 6), (1, 7)], [0, 2], 1000), + ( + [ + (0, 1), + (1, 2), + (2, 3), + (0, 4), + (1, 5), + (0, 6), + (1, 7), + (3, FALSE), + (2, TRUE), + ], + [2, 4], + 1000, + ), + ( + [ + (1, 15), + (2, 16), + (3, 17), + (4, 18), + (5, 6), + (5, 19), + (6, 7), + (6, 20), + (7, 8), + (7, 21), + (7, 28), + (8, 9), + (8, 22), + (8, 29), + (8, 41), + (9, 10), + (9, 23), + (9, 30), + (9, 42), + (10, 1), + (10, 11), + (10, 24), + (10, 31), + (10, 34), + (10, 37), + (10, 43), + (10, 46), + (10, 50), + (11, 2), + (11, 12), + (11, 25), + (11, 32), + (11, 35), + (11, 38), + (11, 44), + (11, 47), + (11, 51), + (12, 3), + (12, 4), + (12, 5), + (12, 13), + (12, 14), + (12, 26), + (12, 27), + (12, 33), + (12, 36), + (12, 39), + (12, 40), + (12, 45), + (12, 48), + (12, 49), + (12, 52), + (12, 53), + (12, 54), + ], + [192, 204], + 100, + ), + ] + for eq, rhs, max_iter in L: + if isinstance(eq, dict): + N = len(eq) + else: + N = max([0] + [a for c, a in eq if a != TRUE and a != FALSE]) + C = Clauses(N) + Cpos = Clauses(N) + Cneg = Clauses(N) + if isinstance(eq, dict): + for k in range(1, N + 1): + nm = "x%d" % k + C.name_var(k, nm) + Cpos.name_var(k, nm) + Cneg.name_var(k, nm) + eq2 = [(v, C.from_name(c)) for c, v in eq.items()] + else: + eq2 = eq + x = C.LinearBound(eq, rhs[0], rhs[1]) + Cpos.Require(Cpos.LinearBound, eq, rhs[0], rhs[1]) + Cneg.Prevent(Cneg.LinearBound, eq, rhs[0], rhs[1]) + if x != FALSE: + for _, sol in zip( + range(max_iter), C.itersolve([] if x == TRUE else [(x,)], N) + ): + assert rhs[0] <= my_EVAL(eq2, sol) <= rhs[1], C.as_list() + if x != TRUE: + for _, sol in zip( + range(max_iter), C.itersolve([] if x == TRUE else [(C.Not(x),)], N) + ): + assert not (rhs[0] <= my_EVAL(eq2, sol) <= rhs[1]), C.as_list() + for _, sol in zip(range(max_iter), Cpos.itersolve([], N)): + assert rhs[0] <= my_EVAL(eq2, sol) <= rhs[1], ("Cpos", Cpos.as_list()) + for _, sol in zip(range(max_iter), Cneg.itersolve([], N)): + assert not (rhs[0] <= my_EVAL(eq2, sol) <= rhs[1]), ("Cneg", Cneg.as_list()) + + +def test_sat(): + C = Clauses() + C.new_var("x1") + C.new_var("x2") + assert C.sat() is not None + assert C.sat([]) is not None + assert C.sat([()]) is None + assert C.sat([(FALSE,)]) is None + assert C.sat([(TRUE,), ()]) is None + assert C.sat([(TRUE, FALSE, -1)]) is not None + assert C.sat([(+1, FALSE), (+2,), (TRUE,)], names=True) == {"x1", "x2"} + assert C.sat([(-1, FALSE), (TRUE,), (+2,)], names=True) == {"x2"} + assert C.sat([(TRUE,), (-1,), (-2, FALSE)], names=True) == set() + assert C.sat([(+1,), (-1, FALSE)], names=True) is None + C._clauses.unsat = True + assert C.sat() is None + assert C.sat([]) is None + assert C.sat([(TRUE,)]) is None + assert len(Clauses(10).sat([[1]])) == 10 + + +def test_minimize(): + # minimize x1 + 2 x2 + 3 x3 + 4 x4 + 5 x5 + # subject to x1 + x2 + x3 + x4 + x5 == 1 + C = Clauses(15) + C.Require(C.ExactlyOne, range(1, 6)) + sol = C.sat() + C._clauses.unsat = True + # Unsatisfiable constraints + assert C.minimize([(k, k) for k in range(1, 6)], sol)[1] == 16 + C._clauses.unsat = False + sol, sval = C.minimize([(k, k) for k in range(1, 6)], sol) + assert sval == 1 + C.Require(C.ExactlyOne, range(6, 11)) + # Supply an initial vector that is too short, forcing recalculation + sol, sval = C.minimize([(k, k) for k in range(6, 11)], sol) + assert sval == 6 + C.Require(C.ExactlyOne, range(11, 16)) + # Don't supply an initial vector + sol, sval = C.minimize([(k, k) for k in range(11, 16)]) + assert sval == 11 + + +@pytest.mark.xfail( + reason="Broke this with reworking minimal_unsatisfiable_set. Not sure how to fix. minimal_unsatisfiable_subset function is otherwise working well." +) +def test_minimal_unsatisfiable_subset(): + def sat(val): + return Clauses(max(abs(v) for v in chain(*val))).sat(val) + + assert raises(ValueError, lambda: minimal_unsatisfiable_subset([[1]], sat)) + + clauses = [ + [-10], + [1], + [5], + [2, 3], + [3, 4], + [5, 2], + [-7], + [2], + [3], + [-2, -3, 5], + [7, 8, 9, 10], + [-8], + [-9], + ] + res = minimal_unsatisfiable_subset(clauses, sat) + assert sorted(res) == [[-10], [-9], [-8], [-7], [7, 8, 9, 10]] + assert not sat(res) + + clauses = [[1, 3], [2, 3], [-1], [4], [3], [-3]] + for perm in permutations(clauses): + res = minimal_unsatisfiable_subset(clauses, sat) + assert sorted(res) == [[-3], [3]] + assert not sat(res) + + clauses = [[1], [-1], [2], [-2], [3, 4], [4]] + for perm in permutations(clauses): + res = minimal_unsatisfiable_subset(perm, sat) + assert sorted(res) in ([[-1], [1]], [[-2], [2]]) + assert not sat(res) diff --git a/tests/test_solvers.py b/tests/test_solvers.py new file mode 100644 index 0000000..bbe6e58 --- /dev/null +++ b/tests/test_solvers.py @@ -0,0 +1,46 @@ +# Copyright (C) 2012 Anaconda, Inc +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING + +from conda.testing.solver_helpers import SolverTests + +from conda_classic_solver.solve import ClassicSolver + +if TYPE_CHECKING: + from conda.core.solve import Solver + + +class TestClassicSolver(SolverTests): + @property + def solver_class(self) -> type[Solver]: + return ClassicSolver + + +class TestLibMambaSolver(SolverTests): + @property + def solver_class(self) -> type[Solver]: + from conda_libmamba_solver.solver import LibMambaSolver + + return LibMambaSolver + + @property + def tests_to_skip(self): + return { + "conda-libmamba-solver does not support features": [ + "test_iopro_mkl", + "test_iopro_nomkl", + "test_mkl", + "test_accelerate", + "test_scipy_mkl", + "test_pseudo_boolean", + "test_no_features", + "test_surplus_features_1", + "test_surplus_features_2", + "test_remove", + # this one below only fails reliably on windows; + # it passes Linux on CI, but not locally? + "test_unintentional_feature_downgrade", + ], + }