Skip to content

Commit

Permalink
Refactor BatchSimulator
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Oct 9, 2024
1 parent 45c8364 commit 5168146
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 95 deletions.
20 changes: 20 additions & 0 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
overload,
)

from ert.config.ext_param_config import ExtParamConfig
from ert.field_utils import get_shape

from .field import Field
Expand Down Expand Up @@ -83,6 +84,11 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
gen_kw_list = config_dict.get(ConfigKeys.GEN_KW, [])
surface_list = config_dict.get(ConfigKeys.SURFACE, [])
field_list = config_dict.get(ConfigKeys.FIELD, [])

# TODO: The EXT_PARAM key is only used internally by Everest and
# therefore not included in the ConfigKeys enum.
ext_param_dict = config_dict.get("EXT_PARAM", {})

dims = None
if grid_file_path is not None:
try:
Expand All @@ -106,10 +112,24 @@ def make_field(field_list: List[str]) -> Field:
)
return Field.from_config_list(grid_file_path, dims, field_list)

# TODO: EXT_PARAM is not a true ERT config key, the information in
# ext_param_dict contains a dict populated by everest_to_ert_config. If
# EXT_PARAM is used for other purposes in the future, it may need to be
# promoted to a proper ERT config key.
def make_ext_param(
control_name: str, variables: Union[List[str], Dict[str, List[str]]]
) -> ExtParamConfig:
return ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
+ [SurfaceConfig.from_config_list(s) for s in surface_list]
+ [make_field(f) for f in field_list]
+ [make_ext_param(n, e) for n, e in ext_param_dict.items()]
)

response_configs: List[ResponseConfig] = []
Expand Down
10 changes: 3 additions & 7 deletions src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class ExtParamConfig(ParameterConfig):
If a list of strings is given, the order is preserved.
"""

input_keys: Union[List[str], Dict[str, List[Tuple[str, str]]]] = field(
default_factory=list
)
input_keys: Union[List[str], Dict[str, List[str]]] = field(default_factory=list)
forward_init: bool = False
output_file: str = ""
forward_init_file: str = ""
Expand Down Expand Up @@ -136,16 +134,14 @@ def __contains__(self, key: Union[Tuple[str, str], str]) -> bool:
"""
if isinstance(self.input_keys, dict) and isinstance(key, tuple):
key, suffix = key
return (
key in self.input_keys and suffix in self.input_keys[key] # type: ignore[comparison-overlap]
)
return key in self.input_keys and suffix in self.input_keys[key]
else:
return key in self.input_keys

def __repr__(self) -> str:
return f"ExtParamConfig(keys={self.input_keys})"

def __getitem__(self, index: str) -> List[Tuple[str, str]]:
def __getitem__(self, index: str) -> List[str]:
"""Retrieve an item from the configuration
If @index is a string, assumes its a key and retrieves the suffixes
Expand Down
51 changes: 16 additions & 35 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from ert.config import ErtConfig, ExtParamConfig, GenDataConfig
from ert.config import ErtConfig, ExtParamConfig

from .batch_simulator_context import BatchContext

Expand All @@ -16,8 +26,8 @@ class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
controls: Dict[str, List[str]],
results: List[str],
controls: Iterable[str],
results: Iterable[str],
callback: Optional[Callable[[BatchContext], None]] = None,
):
"""Will create simulator which can be used to run multiple simulations.
Expand Down Expand Up @@ -88,39 +98,10 @@ def callback(*args, **kwargs):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.control_keys = set(controls.keys())
self.control_keys = set(controls)
self.result_keys = set(results)
self.callback = callback

ens_config = self.ert_config.ensemble_config
for control_name, variables in controls.items():
ens_config.addNode(
ExtParamConfig(
name=control_name,
input_keys=variables,
output_file=control_name + ".json",
)
)

if "gen_data" not in ens_config:
ens_config.addNode(
GenDataConfig(
keys=results,
input_files=[f"{k}" for k in results],
report_steps_list=[None for _ in results],
)
)
else:
existing_gendata = ens_config.response_configs["gen_data"]
existing_keys = existing_gendata.keys
assert isinstance(existing_gendata, GenDataConfig)

for key in results:
if key not in existing_keys:
existing_gendata.keys.append(key)
existing_gendata.input_files.append(f"{key}")
existing_gendata.report_steps_list.append(None)

def _setup_sim(
self,
sim_id: int,
Expand All @@ -143,7 +124,7 @@ def _check_suffix(
f"these suffixes: {missingsuffixes}"
)
for suffix in assignment:
if suffix not in suffixes: # type: ignore[comparison-overlap]
if suffix not in suffixes:
raise KeyError(
f"Key {key} has suffixes {suffixes}. "
f"Can't find the requested suffix {suffix}"
Expand Down
53 changes: 52 additions & 1 deletion src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import json
import logging
import os
from typing import Union
from typing import DefaultDict, Dict, List, Union

import everest
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.config.install_data_config import InstallDataConfig
from everest.config.install_job_config import InstallJobConfig
from everest.config.simulator_config import SimulatorConfig
Expand Down Expand Up @@ -455,6 +459,51 @@ def _extract_seed(ever_config: EverestConfig, ert_config):
ert_config["RANDOM_SEED"] = random_seed


def _extract_controls(ever_config: EverestConfig, ert_config):
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = collections.defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result)

controls = ever_config.controls or []
# TODO: EXT_PARAM is only used internally by Everest to configure ExtParamConfig
# objects. It is not available to ERT configuration files. For this reason,
# here we simply dump the required information as a dict.
ert_config["EXT_PARAM"] = {
control.name: _get_variables(control.variables) for control in controls
}


def _extract_results(ever_config: EverestConfig, ert_config):
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
gen_data = ert_config.get("GEN_DATA", [])
for name in objectives_names + constraint_names:
gen_data.append((name, f"RESULT_FILE:{name}"))
ert_config["GEN_DATA"] = gen_data


def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
"""
Takes as input an Everest configuration, the site-config and converts them
Expand All @@ -475,5 +524,7 @@ def everest_to_ert_config(ever_config: EverestConfig, site_config=None):
_extract_model(ever_config, ert_config)
_extract_queue_system(ever_config, ert_config)
_extract_seed(ever_config, ert_config)
_extract_controls(ever_config, ert_config)
_extract_results(ever_config, ert_config)

return ert_config
76 changes: 25 additions & 51 deletions src/everest/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from datetime import datetime
from itertools import count
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple

import numpy as np
from numpy import float64
Expand All @@ -13,66 +13,50 @@
from ert.config import ErtConfig, HookRuntime
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config.control_variable_config import (
ControlVariableConfig,
ControlVariableGuessListConfig,
)
from everest.simulator.everest_to_ert import everest_to_ert_config


class Simulator(BatchSimulator):
"""Everest simulator: BatchSimulator"""

def __init__(self, ever_config: EverestConfig, callback=None):
def __init__(self, ever_config: EverestConfig, callback=None) -> None:
self._ert_config = ErtConfig.with_plugins().from_dict(
config_dict=everest_to_ert_config(
ever_config, site_config=ErtConfig.read_site_config()
)
)
controls_def = self._get_controls_def(ever_config)
results_def = self._get_results_def(ever_config)

super(Simulator, self).__init__(
self._ert_config, controls_def, results_def, callback=callback
self._ert_config,
self._get_controls(ever_config),
self._get_results(ever_config),
callback=callback,
)

self._function_aliases = self._get_aliases(ever_config)
self._experiment_id = None
self._batch = 0
self._cache: Optional[_SimulatorCache] = None
if ever_config.simulator is not None and ever_config.simulator.enable_cache:
self._cache = _SimulatorCache()

@staticmethod
def _get_variables(
variables: Union[
List[ControlVariableConfig], List[ControlVariableGuessListConfig]
],
) -> Union[List[str], Dict[str, List[str]]]:
if (
isinstance(variables[0], ControlVariableConfig)
and getattr(variables[0], "index", None) is None
):
return [var.name for var in variables]
result: DefaultDict[str, list] = defaultdict(list)
for variable in variables:
if isinstance(variable, ControlVariableGuessListConfig):
result[variable.name].extend(
str(index + 1) for index, _ in enumerate(variable.initial_guess)
)
else:
result[variable.name].append(str(variable.index)) # type: ignore
return dict(result) # { name : [ index ]

def _get_controls_def(
self, ever_config: EverestConfig
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
def _get_controls(self, ever_config: EverestConfig) -> List[str]:
controls = ever_config.controls or []
return {
control.name: self._get_variables(control.variables) for control in controls
}
return [control.name for control in controls]

def _get_results(self, ever_config: EverestConfig) -> List[str]:
objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.alias is None
]
constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names

def _get_results_def(self, ever_config: EverestConfig):
self._function_aliases = {
def _get_aliases(self, ever_config: EverestConfig) -> Dict[str, str]:
aliases = {
objective.name: objective.alias
for objective in ever_config.objective_functions
if objective.alias is not None
Expand All @@ -83,19 +67,9 @@ def _get_results_def(self, ever_config: EverestConfig):
constraint.upper_bound is not None
and constraint.lower_bound is not None
):
self._function_aliases[f"{constraint.name}:lower"] = constraint.name
self._function_aliases[f"{constraint.name}:upper"] = constraint.name

objectives_names = [
objective.name
for objective in ever_config.objective_functions
if objective.name not in self._function_aliases
]

constraint_names = [
constraint.name for constraint in (ever_config.output_constraints or [])
]
return objectives_names + constraint_names
aliases[f"{constraint.name}:lower"] = constraint.name
aliases[f"{constraint.name}:upper"] = constraint.name
return aliases

def __call__(
self, control_values: NDArray[np.float64], metadata: EvaluatorContext
Expand Down
Loading

0 comments on commit 5168146

Please sign in to comment.