Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

feat: Improve ScoreAnalysis debug information #105

Merged
merged 16 commits into from
Jul 5, 2024
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ requires = [
"setuptools>=69.1.1",
"stubgenj>=0.2.5",
"JPype1>=1.5.0",
"wheel"
"wheel",
"multipledispatch>=1.0.0"
]
build-backend = "setuptools.build_meta"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def find_stub_files(stub_root: str):
python_requires='>=3.10',
install_requires=[
'JPype1>=1.5.0',
'multipledispatch>=1.0.0'
],
cmdclass={'build_py': FetchDependencies},
package_data={
Expand Down
141 changes: 139 additions & 2 deletions tests/test_solution_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
from timefold.solver.config import *
from timefold.solver.score import *

import inspect
import re

from ai.timefold.solver.core.api.score import ScoreExplanation as JavaScoreExplanation
from ai.timefold.solver.core.api.score.analysis import (
ConstraintAnalysis as JavaConstraintAnalysis,
MatchAnalysis as JavaMatchAnalysis,
ScoreAnalysis as JavaScoreAnalysis)
from ai.timefold.solver.core.api.score.constraint import Indictment as JavaIndictment
from ai.timefold.solver.core.api.score.constraint import (ConstraintRef as JavaConstraintRef,
ConstraintMatch as JavaConstraintMatch,
ConstraintMatchTotal as JavaConstraintMatchTotal)

from dataclasses import dataclass, field
from typing import Annotated, List

Expand All @@ -18,8 +31,8 @@ class Entity:
def my_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('package', 'Maximize Value'),
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('package', 'Maximize Value'),
]


Expand Down Expand Up @@ -127,6 +140,21 @@ def assert_score_analysis(problem: Solution, score_analysis: ScoreAnalysis):
assert_constraint_analysis(problem, constraint_analysis)


def assert_score_analysis_summary(score_analysis: ScoreAnalysis):
summary = score_analysis.summarize
assert "Explanation of score (3):" in summary
assert "Constraint matches:" in summary
assert "3: constraint (Maximize Value) has 3 matches:" in summary
assert "1: justified with" in summary

match = score_analysis.constraint_analyses[0]
match_summary = match.summarize
assert "Explanation of score (3):" in match_summary
assert "Constraint matches:" in match_summary
assert "3: constraint (Maximize Value) has 3 matches:" in match_summary
assert "1: justified with" in match_summary


def assert_solution_manager(solution_manager: SolutionManager[Solution]):
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
assert problem.score is None
Expand All @@ -140,6 +168,9 @@ def assert_solution_manager(solution_manager: SolutionManager[Solution]):
score_analysis = solution_manager.analyze(problem)
assert_score_analysis(problem, score_analysis)

score_analysis = solution_manager.analyze(problem)
assert_score_analysis_summary(score_analysis)


def test_solver_manager_score_manager():
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
Expand All @@ -148,3 +179,109 @@ def test_solver_manager_score_manager():

def test_solver_factory_score_manager():
assert_solution_manager(SolutionManager.create(SolverFactory.create(solver_config)))


def test_score_manager_solution_initialization():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
assert score_analysis.is_solution_initialized

second_problem: Solution = Solution([Entity('A', None), Entity('B', None), Entity('C', None)], [1, 2, 3])
second_score_analysis = solution_manager.analyze(second_problem)
assert not second_score_analysis.is_solution_initialized


def test_score_manager_diff():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
second_problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1), Entity('D', 1)], [1, 2, 3])
second_score_analysis = solution_manager.analyze(second_problem)
diff = score_analysis.diff(second_score_analysis)
assert diff.score.score == -1

constraint_analyses = score_analysis.constraint_analyses
assert len(constraint_analyses) == 1


def test_score_manager_constraint_analysis_map():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
constraints = score_analysis.constraint_analyses
assert len(constraints) == 1

constraint_analysis = score_analysis.constraint_analysis('package', 'Maximize Value')
assert constraint_analysis.constraint_name == 'Maximize Value'

constraint_analysis = score_analysis.constraint_analysis(ConstraintRef('package', 'Maximize Value'))
assert constraint_analysis.constraint_name == 'Maximize Value'
assert constraint_analysis.match_count == 3


def test_score_manager_constraint_ref():
constraint_ref = ConstraintRef.parse_id('package/Maximize Value')

assert constraint_ref.package_name == 'package'
assert constraint_ref.constraint_name == 'Maximize Value'


ignored_java_functions = {
'equals',
'getClass',
'hashCode',
'notify',
'notifyAll',
'toString',
'wait',
'compareTo',
}

ignored_java_functions_per_class = {
'Indictment': {'getJustification'}, # deprecated
'ConstraintRef': {'of', 'packageName', 'constraintName'}, # built-in constructor and properties with @dataclass
'ConstraintMatch': {
'getConstraintRef', # built-in constructor and properties with @dataclass
'getConstraintPackage', # deprecated
'getConstraintName', # deprecated
'getConstraintId', # deprecated
'getJustificationList', # deprecated
'getJustification', # built-in constructor and properties with @dataclass
'getScore', # built-in constructor and properties with @dataclass
}
}


def test_has_all_methods():
missing = []
for python_type, java_type in ((ScoreExplanation, JavaScoreExplanation),
(ScoreAnalysis, JavaScoreAnalysis),
(ConstraintAnalysis, JavaConstraintAnalysis),
(ScoreExplanation, JavaScoreExplanation),
(ConstraintMatch, JavaConstraintMatch),
(ConstraintRef, JavaConstraintRef),
(Indictment, JavaIndictment)):
type_name = python_type.__name__
ignored_java_functions_type = ignored_java_functions_per_class[
type_name] if type_name in ignored_java_functions_per_class else {}

for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
if function_name in ignored_java_functions or function_name in ignored_java_functions_type:
continue

snake_case_name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', function_name)
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
snake_case_name_without_prefix = re.sub('(.)([A-Z][a-z]+)', r'\1_\2',
function_name[3:] if function_name.startswith(
"get") else function_name)
snake_case_name_without_prefix = re.sub('([a-z0-9])([A-Z])', r'\1_\2',
snake_case_name_without_prefix).lower()
if not hasattr(python_type, snake_case_name) and not hasattr(python_type, snake_case_name_without_prefix):
missing.append((java_type, python_type, snake_case_name))

if missing:
assertion_msg = ''
for java_type, python_type, snake_case_name in missing:
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
raise AssertionError(assertion_msg)
Loading
Loading