Skip to content

Commit

Permalink
fix mixins of custom tasks (I hope)
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Nov 25, 2024
1 parent 09a4ab4 commit 5defd46
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 31 deletions.
17 changes: 4 additions & 13 deletions hbw/tasks/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
CalibratorsMixin,
ProducersMixin,
MLModelTrainingMixin,
MLModelMixin,
MLModelsMixin,
MLModelDataMixin,
SelectorStepsMixin,
Expand Down Expand Up @@ -403,13 +402,9 @@ def workflow_run(self):


class MLEvaluationSingleFold(
# NOTE: this should probably be a MLModelTrainingMixin, but I'll postpone this until the MultiConfigTask
# is implemented
# NOTE: mixins might need fixing, needs to be checked
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down Expand Up @@ -541,13 +536,9 @@ def run(self):


class PlotMLResultsSingleFold(
# NOTE: this should probably be a MLModelTrainingMixin, but I'll postpone this until the MultiConfigTask
# is implemented
# NOTE: mixins might need fixing, needs to be checked
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down
29 changes: 13 additions & 16 deletions hbw/tasks/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

from columnflow.tasks.framework.base import Requirements
from columnflow.tasks.framework.mixins import (
SelectorMixin,
CalibratorsMixin,
ProducersMixin,
MLModelMixin,
# SelectorMixin,
# CalibratorsMixin,
# ProducersMixin,
# MLModelMixin,
MLModelTrainingMixin,
)
# from columnflow.tasks.framework.remote import RemoteWorkflow
from columnflow.util import DotDict
Expand Down Expand Up @@ -97,12 +98,11 @@ def run(self):


class Optimizer(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
"""
Workflow that runs optimization. Needs to be run from within the sandbox
Expand Down Expand Up @@ -191,12 +191,11 @@ def run(self):


class Objective(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
"""
Objective to optimize.
Expand Down Expand Up @@ -281,11 +280,9 @@ def run(self):


class DummyObjective(
# NOTE: mixins might need fixing, needs to be tested
HBWTask,
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
MLModelTrainingMixin,
law.LocalWorkflow,
# RemoteWorkflow,
):
Expand Down
6 changes: 4 additions & 2 deletions hbw/tasks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import law
import order as od

from columnflow.tasks.framework.base import Requirements, ShiftTask
from columnflow.tasks.framework.base import Requirements, MultiConfigTask
from columnflow.tasks.framework.mixins import (
CalibratorsMixin, SelectorStepsMixin, ProducersMixin, MLModelsMixin,
CategoriesMixin,
)
from columnflow.tasks.framework.plotting import (
PlotBase, PlotBase1D, ProcessPlotSettingMixin, VariablePlotSettingMixin,
PlotShiftMixin,
)
from columnflow.tasks.framework.decorators import view_output_plots
from columnflow.tasks.framework.remote import RemoteWorkflow
Expand Down Expand Up @@ -108,7 +109,7 @@ def plot_multi_weight_producer(

class PlotVariablesMultiWeightProducer(
HBWTask,
ShiftTask,
PlotShiftMixin,
VariablePlotSettingMixin,
ProcessPlotSettingMixin,
PlotBase1D,
Expand All @@ -117,6 +118,7 @@ class PlotVariablesMultiWeightProducer(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
MultiConfigTask,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand Down
3 changes: 3 additions & 0 deletions hbw/tasks/postfit_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import luigi
import order as od

from columnflow.tasks.framework.base import ConfigTask
from columnflow.tasks.framework.mixins import (
InferenceModelMixin, MLModelsMixin, ProducersMixin, SelectorStepsMixin,
CalibratorsMixin,
Expand Down Expand Up @@ -107,6 +108,7 @@ def plot_postfit_shapes(


class PlotPostfitShapes(
# NOTE: mixins might be wrong and could (should?) be extended to MultiConfigTask
HBWTask,
PlotBase1D,
# to correctly setup our InferenceModel, we need all these mixins, but hopefully, all these
Expand All @@ -116,6 +118,7 @@ class PlotPostfitShapes(
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
ConfigTask,
):
"""
Task that creates Postfit shape plots based on a fit_diagnostics file.
Expand Down

0 comments on commit 5defd46

Please sign in to comment.