Skip to content

Commit

Permalink
Fix stage decorator cachers incidentally becoming singleton objects
Browse files Browse the repository at this point in the history
Fixes #109

We fix this by making deep copying everything into a `local_cachers`
list inside the wrapper, rather than taking (and manipulating) the
directly provided `cachers` list.

A related issue fixed in this is that artifact representations weren't
actually keeping a copy of their cacher stored on them correctly.

These issues culminated specifically when re-running from cached values
in DAG mode, where all artifacts are injected Lazy instances, and thus
relying on cachers (but all cachers for a given artifact for a given
stage are actually pointing to the same cacher across all records.
  • Loading branch information
WarmCyan committed Oct 23, 2023
1 parent a82d36a commit b8a961f
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 51 deletions.
135 changes: 97 additions & 38 deletions curifactory/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,28 @@ def wrapper(record: Record, *args, **kwargs):
if outputs is None:
outputs = []

# NOTE: this local_cachers is necessary because otherwise you run
# into the potential for accidental singleton cachers that all
# records running through one stage then incidentally share. I
# believe this happens because there's only technically one "stage"
# instance, so multiple records running the same stage are all using
# the cacher that's defined either directly in the stage header, or
# that is directly replaced in the stage header when given a type
# (that is then initialized below) To mitigate, we create a _copy_
# of any provided cachers, and this works both for cacher types as
# well as already initialized cachers. However, this does mean that
# if in say a notebook, if someone passes in a cacher initialized
# elsewhere, and expects it to have any state that might have been
# manipulated in the cacher's save etc. this will no longer work.
# Instead, they would have to manually get the copy of the cacher
# off of the artifact in manager in order to get the actual cacher
# instance that was used.
local_cachers = None
if cachers is not None:
local_cachers = []
for cacher in cachers:
local_cachers.append(copy.deepcopy(cacher))

record.manager.current_stage_name = name
record.manager.stage_active = True
record.stages.append(name)
Expand All @@ -189,7 +211,7 @@ def wrapper(record: Record, *args, **kwargs):
for output in outputs:
if (
type(output) == Lazy
and cachers is None
and local_cachers is None
and not record.manager.lazy
and not record.manager.ignore_lazy
):
Expand All @@ -208,23 +230,23 @@ def wrapper(record: Record, *args, **kwargs):
# NOTE: since Lazy caching doesn't work without a cacher, we need to ensure
# one if none exists. Pickle is pretty broad, but obviously there are some things
# that don't work, so we need to warn about this
if cachers is None:
if local_cachers is None:
no_cachers = True
logging.warning(
"Stage %s does not have cachers specified, a --lazy run will force caching by applying PickleCachers to anything with none specified, but this can potentially cause errors."
% name
)
cachers = []
local_cachers = []
if no_cachers:
cachers.append(PickleCacher)
local_cachers.append(PickleCacher)
elif record.manager.ignore_lazy:
for index, output in enumerate(outputs):
if type(output) == Lazy:
logging.debug("Disabling lazy cache for '%s'" % output)
outputs[index] = output.name

# check for mismatched amounts of cachers
if cachers is not None and len(cachers) != len(outputs):
if local_cachers is not None and len(local_cachers) != len(outputs):
raise CachersMismatchError(
f"Stage '{name}' - the number of cachers does not match the number of outputs to cache."
)
Expand Down Expand Up @@ -294,24 +316,27 @@ def wrapper(record: Record, *args, **kwargs):
# no, you cannot abstract this into _check_cached_outputs - if you try to reassign to cachers from
# another function, because of the deep voodoo black magic sorcery that is decorators with arguments,
# it considers it different code.
if cachers is not None:
if local_cachers is not None:
# instantiate cachers if not already
for i in range(len(cachers)):
cacher = cachers[i]
for i in range(len(local_cachers)):
cacher = local_cachers[i]
if type(cacher) == type:
cachers[i] = cacher()
local_cachers[i] = cacher()
# set the active record on the cacher as well as provide a default name
# (the name of the output)
cachers[i].set_record(record)
cachers[
local_cachers[i].set_record(record)
local_cachers[
i
].stage = name # set current stage name, so get_path is correct in later stages (particularly for lazy)
if cachers[i].name is None and cachers[i].path_override is None:
if (
local_cachers[i].name is None
and local_cachers[i].path_override is None
):
if type(outputs[i]) == Lazy:
cachers[i].name = outputs[i].name
local_cachers[i].name = outputs[i].name
else:
cachers[i].name = outputs[i]
record.stage_cachers = cachers
local_cachers[i].name = outputs[i]
record.stage_cachers = local_cachers

# at this point we've grabbed all information we would need if we're
# just mapping out the stages, so return at this point.
Expand All @@ -322,7 +347,9 @@ def wrapper(record: Record, *args, **kwargs):
# we need to create pseudo-state-artifact-representations for the outputs. Since
# we obviously can't add the actual artifacts (there are none without running!),
# we just add the string key and a name, so record __repr__ has something to use
_get_output_representations_for_map(record, outputs, cachers, None)
_get_output_representations_for_map(
record, outputs, local_cachers, None
)
record.stage_cachers = None
return record

Expand All @@ -339,7 +366,7 @@ def wrapper(record: Record, *args, **kwargs):
stage_rep = (record.get_record_index(), name)
if stage_rep not in record.manager.map.execution_list:
logging.debug('DAG-indicated stage skip "%s".' % str(stage_rep))
_dag_skip_check_cached_outputs(name, record, outputs, cachers)
_dag_skip_check_cached_outputs(name, record, outputs, local_cachers)

# grab any possible previous reportables so they still end up in report.
_check_cached_reportables(name, record)
Expand All @@ -359,7 +386,9 @@ def wrapper(record: Record, *args, **kwargs):
# So the flow is: if we have a DAG, and it says to execute this stage, or if we don't have
# a DAG, check if we actually need to run this based on cached values.
if record.manager.map is None or execute_stage:
cache_valid = _check_cached_outputs(name, record, outputs, cachers)
cache_valid = _check_cached_outputs(
name, record, outputs, local_cachers
)
if cache_valid:
# get previous reportables if available
_check_cached_reportables(name, record)
Expand Down Expand Up @@ -437,7 +466,7 @@ def wrapper(record: Record, *args, **kwargs):
# handle storing outputs in record
post_cache_time_start = time.perf_counter()
record.manager.lock()
_store_outputs(name, record, outputs, cachers, function_outputs)
_store_outputs(name, record, outputs, local_cachers, function_outputs)
_store_reportables(name, record)
record.store_tracked_paths()
record.manager.unlock()
Expand Down Expand Up @@ -600,6 +629,28 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
if outputs is None:
outputs = []

# NOTE: this local_cachers is necessary because otherwise you run
# into the potential for accidental singleton cachers that all
# records running through one stage then incidentally share. I
# believe this happens because there's only technically one "stage"
# instance, so multiple records running the same stage are all using
# the cacher that's defined either directly in the stage header, or
# that is directly replaced in the stage header when given a type
# (that is then initialized below) To mitigate, we create a _copy_
# of any provided cachers, and this works both for cacher types as
# well as already initialized cachers. However, this does mean that
# if in say a notebook, if someone passes in a cacher initialized
# elsewhere, and expects it to have any state that might have been
# manipulated in the cacher's save etc. this will no longer work.
# Instead, they would have to manually get the copy of the cacher
# off of the artifact in manager in order to get the actual cacher
# instance that was used.
local_cachers = None
if cachers is not None:
local_cachers = []
for cacher in cachers:
local_cachers.append(copy.deepcopy(cacher))

record.manager.current_stage_name = name
record.set_aggregate(records)
record.stages.append(name)
Expand All @@ -619,7 +670,7 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
for output in outputs:
if (
type(output) == Lazy
and cachers is None
and local_cachers is None
and not record.manager.lazy
and not record.manager.ignore_lazy
):
Expand All @@ -638,23 +689,23 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
# NOTE: since Lazy caching doesn't work without a cacher, we need to ensure
# one if none exists. Pickle is pretty broad, but obviously there are some things
# that don't work, so we need to warn about this
if cachers is None:
if local_cachers is None:
no_cachers = True
logging.warning(
"Aggregate stage %s does not have cachers specified, a --lazy run will force caching by applying PickleCachers to anything with none specified, but this can potentially cause errors."
% name
)
cachers = []
local_cachers = []
if no_cachers:
cachers.append(PickleCacher)
local_cachers.append(PickleCacher)
elif record.manager.ignore_lazy:
for index, output in enumerate(outputs):
if type(output) == Lazy:
logging.debug("Disabling lazy cache for '%s'" % output)
outputs[index] = output.name

# check for mismatched amounts of cachers
if cachers is not None and len(cachers) != len(outputs):
if local_cachers is not None and len(local_cachers) != len(outputs):
raise CachersMismatchError(
f"Stage '{name}' - the number of cachers does not match the number of outputs to cache"
)
Expand Down Expand Up @@ -716,24 +767,27 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
)

# see note in stage
if cachers is not None:
if local_cachers is not None:
# instantiate cachers if not already
for i in range(len(cachers)):
cacher = cachers[i]
for i in range(len(local_cachers)):
cacher = local_cachers[i]
if type(cacher) == type:
cachers[i] = cacher()
local_cachers[i] = cacher()
# set the active record on the cacher as well as provide a default name
# (the name of the output)
cachers[i].set_record(record)
cachers[
local_cachers[i].set_record(record)
local_cachers[
i
].stage = name # set current stage name, so get_path is correct in later stages (particularly for lazy)
if cachers[i].name is None and cachers[i].path_override is None:
if (
local_cachers[i].name is None
and local_cachers[i].path_override is None
):
if type(outputs[i]) == Lazy:
cachers[i].name = outputs[i].name
local_cachers[i].name = outputs[i].name
else:
cachers[i].name = outputs[i]
record.stage_cachers = cachers
local_cachers[i].name = outputs[i]
record.stage_cachers = local_cachers

# at this point we've grabbed all information we would need if we're
# just mapping out the stages, so return at this point.
Expand All @@ -744,7 +798,9 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
# we need to create pseudo-state-artifact-representations for the outputs. Since
# we obviously can't add the actual artifacts (there are none without running!),
# we just add the string key and a name, so record __repr__ has something to use
_get_output_representations_for_map(record, outputs, cachers, records)
_get_output_representations_for_map(
record, outputs, local_cachers, records
)
record.stage_cachers = None
return record

Expand All @@ -762,7 +818,7 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
if stage_rep not in record.manager.map.execution_list:
logging.debug('DAG-indicated stage skip "%s".' % str(stage_rep))
_dag_skip_check_cached_outputs(
name, record, outputs, cachers, records
name, record, outputs, local_cachers, records
)

# grab any possible previous reportables so they still end up in report.
Expand All @@ -784,7 +840,7 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
# a DAG, check if we actually need to run this based on cached values.
if record.manager.map is None or execute_stage:
cache_valid = _check_cached_outputs(
name, record, outputs, cachers, records
name, record, outputs, local_cachers, records
)
if cache_valid:
# get previous reportables if available
Expand Down Expand Up @@ -859,7 +915,9 @@ def wrapper(record: Record, records: list[Record] = None, **kwargs):
# handle storing outputs in record
post_cache_time_start = time.perf_counter()
record.manager.lock()
_store_outputs(name, record, outputs, cachers, function_outputs, records)
_store_outputs(
name, record, outputs, local_cachers, function_outputs, records
)
_store_reportables(name, record, records)
record.store_tracked_paths()
record.manager.unlock()
Expand Down Expand Up @@ -1217,6 +1275,7 @@ def _store_outputs(
)
cachers[index].save(output)
artifact.file = cachers[index].get_path()
artifact.cacher = cachers[index]

# generate and save metadata
# note that if we got to this point, we actually ran the stage code, so
Expand Down
16 changes: 13 additions & 3 deletions test/examples/experiments/simple_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

sys.path.append("../")
from dataclasses import dataclass
from test.examples.stages.cache_stages import store_an_output
from test.examples.stages.cache_stages import agg_store_an_output, store_an_output

import curifactory as cf

Expand All @@ -11,13 +11,23 @@
class Params(cf.ExperimentParameters):
a: int = 5
b: int = 6
do_agg: bool = False


def get_params():
return [Params(name="thing1"), Params(name="thing2", b=10)]
return [
Params(name="thing1"),
Params(name="thing2", b=10),
Params(name="thing3", b=4, do_agg=True),
Params(name="thing4", a=2, b=5, do_agg=True),
]


def run(param_sets, manager):
for param_set in param_sets:
r = cf.Record(manager, param_set)
r = store_an_output(r)

if param_set.do_agg:
r = agg_store_an_output(r)
else:
r = store_an_output(r)
5 changes: 5 additions & 0 deletions test/examples/stages/cache_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def filerefcacher_stage(record):
@cf.stage(None, ["my_output"], [JsonCacher])
def store_an_output(record):
return record.params.a + record.params.b


@cf.aggregate(None, ["my_agg_output"], [JsonCacher])
def agg_store_an_output(record, records):
return record.params.a + record.params.b
2 changes: 1 addition & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def test_experiment_ls_output(mocker, capfd):
out, err = capfd.readouterr()
assert (
out
== "EXPERIMENTS:\n\tbasic\n\tsubexp.example\n\nPARAMS:\n\tempty\n\tnonarrayargs\n\tparams1\n\tparams2\n\tsubparams.thing\n"
== "EXPERIMENTS:\n\tbasic\n\tsimple_cache\n\tsubexp.example\n\nPARAMS:\n\tempty\n\tnonarrayargs\n\tparams1\n\tparams2\n\tsimple_cache\n\tsubparams.thing\n"
)
18 changes: 14 additions & 4 deletions test/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,12 +1042,20 @@ def do_something_else(record, output):

r0 = cf.Record(configured_test_manager, cf.ExperimentParameters(name="test"))
r0 = output_thing(r0)
assert cacher.stage == "output_thing"
first_path = cacher.get_path()

# we expect the raw cacher to not have been changed
assert cacher.stage is None

r0_cacher = configured_test_manager.artifacts[
r0.state_artifact_reps["output"]
].cacher
assert r0_cacher.stage == "output_thing"
first_path = r0_cacher.get_path()

# new things having happened shouldn't affect the previous cacher
do_something_else(r0)
assert cacher.stage == "output_thing"
second_path = cacher.get_path()
assert r0_cacher.stage == "output_thing"
second_path = r0_cacher.get_path()

assert first_path == second_path

Expand Down Expand Up @@ -1278,6 +1286,8 @@ def output_thing(record):
assert os.path.exists(f"{full_path}/thing.json")
assert os.path.exists(f"{full_path}/what.txt")

cacher = configured_test_manager.artifacts[-1].cacher

assert cacher.load()["message"] == "hello world!"


Expand Down
Loading

0 comments on commit b8a961f

Please sign in to comment.