diff --git a/src/ert/config/ensemble_config.py b/src/ert/config/ensemble_config.py index 4f19e236361..02b199d0673 100644 --- a/src/ert/config/ensemble_config.py +++ b/src/ert/config/ensemble_config.py @@ -13,6 +13,7 @@ overload, ) +from ert.config.ext_param_config import ExtParamConfig from ert.field_utils import get_shape from .field import Field @@ -83,6 +84,11 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig: gen_kw_list = config_dict.get(ConfigKeys.GEN_KW, []) surface_list = config_dict.get(ConfigKeys.SURFACE, []) field_list = config_dict.get(ConfigKeys.FIELD, []) + + # TODO: The EXT_PARAM key is only used internally by Everest and + # therefore not included in the ConfigKeys enum. + ext_param_dict = config_dict.get("EXT_PARAM", {}) + dims = None if grid_file_path is not None: try: @@ -106,10 +112,24 @@ def make_field(field_list: List[str]) -> Field: ) return Field.from_config_list(grid_file_path, dims, field_list) + # TODO: EXT_PARAM is not a true ERT config key, the information in + # ext_param_dict contains a dict populated by everest_to_ert_config. If + # EXT_PARAM is used for other purposes in the future, it may need to be + # promoted to a proper ERT config key. + def make_ext_param( + control_name: str, variables: Union[List[str], Dict[str, List[str]]] + ) -> ExtParamConfig: + return ExtParamConfig( + name=control_name, + input_keys=variables, + output_file=control_name + ".json", + ) + parameter_configs = ( [GenKwConfig.from_config_list(g) for g in gen_kw_list] + [SurfaceConfig.from_config_list(s) for s in surface_list] + [make_field(f) for f in field_list] + + [make_ext_param(n, e) for n, e in ext_param_dict.items()] ) response_configs: List[ResponseConfig] = [] diff --git a/src/ert/config/ext_param_config.py b/src/ert/config/ext_param_config.py index 0e69cd02603..c68b81d3da1 100644 --- a/src/ert/config/ext_param_config.py +++ b/src/ert/config/ext_param_config.py @@ -29,9 +29,7 @@ class ExtParamConfig(ParameterConfig): If a list of strings is given, the order is preserved. """ - input_keys: Union[List[str], Dict[str, List[Tuple[str, str]]]] = field( - default_factory=list - ) + input_keys: Union[List[str], Dict[str, List[str]]] = field(default_factory=list) forward_init: bool = False output_file: str = "" forward_init_file: str = "" @@ -136,16 +134,14 @@ def __contains__(self, key: Union[Tuple[str, str], str]) -> bool: """ if isinstance(self.input_keys, dict) and isinstance(key, tuple): key, suffix = key - return ( - key in self.input_keys and suffix in self.input_keys[key] # type: ignore[comparison-overlap] - ) + return key in self.input_keys and suffix in self.input_keys[key] else: return key in self.input_keys def __repr__(self) -> str: return f"ExtParamConfig(keys={self.input_keys})" - def __getitem__(self, index: str) -> List[Tuple[str, str]]: + def __getitem__(self, index: str) -> List[str]: """Retrieve an item from the configuration If @index is a string, assumes its a key and retrieves the suffixes diff --git a/src/ert/simulator/batch_simulator.py b/src/ert/simulator/batch_simulator.py index 20593b20aeb..78d288c45bd 100644 --- a/src/ert/simulator/batch_simulator.py +++ b/src/ert/simulator/batch_simulator.py @@ -1,10 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import numpy as np -from ert.config import ErtConfig, ExtParamConfig, GenDataConfig +from ert.config import ErtConfig, ExtParamConfig from .batch_simulator_context import BatchContext @@ -16,8 +26,8 @@ class BatchSimulator: def __init__( self, ert_config: ErtConfig, - controls: Dict[str, List[str]], - results: List[str], + controls: Iterable[str], + results: Iterable[str], callback: Optional[Callable[[BatchContext], None]] = None, ): """Will create simulator which can be used to run multiple simulations. @@ -88,39 +98,10 @@ def callback(*args, **kwargs): raise ValueError("The first argument must be valid ErtConfig instance") self.ert_config = ert_config - self.control_keys = set(controls.keys()) + self.control_keys = set(controls) self.result_keys = set(results) self.callback = callback - ens_config = self.ert_config.ensemble_config - for control_name, variables in controls.items(): - ens_config.addNode( - ExtParamConfig( - name=control_name, - input_keys=variables, - output_file=control_name + ".json", - ) - ) - - if "gen_data" not in ens_config: - ens_config.addNode( - GenDataConfig( - keys=results, - input_files=[f"{k}" for k in results], - report_steps_list=[None for _ in results], - ) - ) - else: - existing_gendata = ens_config.response_configs["gen_data"] - existing_keys = existing_gendata.keys - assert isinstance(existing_gendata, GenDataConfig) - - for key in results: - if key not in existing_keys: - existing_gendata.keys.append(key) - existing_gendata.input_files.append(f"{key}") - existing_gendata.report_steps_list.append(None) - def _setup_sim( self, sim_id: int, @@ -143,7 +124,7 @@ def _check_suffix( f"these suffixes: {missingsuffixes}" ) for suffix in assignment: - if suffix not in suffixes: # type: ignore[comparison-overlap] + if suffix not in suffixes: raise KeyError( f"Key {key} has suffixes {suffixes}. " f"Can't find the requested suffix {suffix}" diff --git a/src/everest/simulator/everest_to_ert.py b/src/everest/simulator/everest_to_ert.py index b15d3c015da..514dbadcf13 100644 --- a/src/everest/simulator/everest_to_ert.py +++ b/src/everest/simulator/everest_to_ert.py @@ -3,10 +3,14 @@ import json import logging import os -from typing import Union +from typing import DefaultDict, Dict, List, Union import everest from everest.config import EverestConfig +from everest.config.control_variable_config import ( + ControlVariableConfig, + ControlVariableGuessListConfig, +) from everest.config.install_data_config import InstallDataConfig from everest.config.install_job_config import InstallJobConfig from everest.config.simulator_config import SimulatorConfig @@ -455,6 +459,51 @@ def _extract_seed(ever_config: EverestConfig, ert_config): ert_config["RANDOM_SEED"] = random_seed +def _extract_controls(ever_config: EverestConfig, ert_config): + def _get_variables( + variables: Union[ + List[ControlVariableConfig], List[ControlVariableGuessListConfig] + ], + ) -> Union[List[str], Dict[str, List[str]]]: + if ( + isinstance(variables[0], ControlVariableConfig) + and getattr(variables[0], "index", None) is None + ): + return [var.name for var in variables] + result: DefaultDict[str, list] = collections.defaultdict(list) + for variable in variables: + if isinstance(variable, ControlVariableGuessListConfig): + result[variable.name].extend( + str(index + 1) for index, _ in enumerate(variable.initial_guess) + ) + else: + result[variable.name].append(str(variable.index)) # type: ignore + return dict(result) + + controls = ever_config.controls or [] + # TODO: EXT_PARAM is only used internally by Everest to configure ExtParamConfig + # objects. It is not available to ERT configuration files. For this reason, + # here we simply dump the required information as a dict. + ert_config["EXT_PARAM"] = { + control.name: _get_variables(control.variables) for control in controls + } + + +def _extract_results(ever_config: EverestConfig, ert_config): + objectives_names = [ + objective.name + for objective in ever_config.objective_functions + if objective.alias is None + ] + constraint_names = [ + constraint.name for constraint in (ever_config.output_constraints or []) + ] + gen_data = ert_config.get("GEN_DATA", []) + for name in objectives_names + constraint_names: + gen_data.append((name, f"RESULT_FILE:{name}")) + ert_config["GEN_DATA"] = gen_data + + def everest_to_ert_config(ever_config: EverestConfig, site_config=None): """ Takes as input an Everest configuration, the site-config and converts them @@ -475,5 +524,7 @@ def everest_to_ert_config(ever_config: EverestConfig, site_config=None): _extract_model(ever_config, ert_config) _extract_queue_system(ever_config, ert_config) _extract_seed(ever_config, ert_config) + _extract_controls(ever_config, ert_config) + _extract_results(ever_config, ert_config) return ert_config diff --git a/src/everest/simulator/simulator.py b/src/everest/simulator/simulator.py index 38ec1a94230..425473688f9 100644 --- a/src/everest/simulator/simulator.py +++ b/src/everest/simulator/simulator.py @@ -2,7 +2,7 @@ from collections import defaultdict from datetime import datetime from itertools import count -from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple import numpy as np from numpy import float64 @@ -13,66 +13,50 @@ from ert.config import ErtConfig, HookRuntime from ert.storage import open_storage from everest.config import EverestConfig -from everest.config.control_variable_config import ( - ControlVariableConfig, - ControlVariableGuessListConfig, -) from everest.simulator.everest_to_ert import everest_to_ert_config class Simulator(BatchSimulator): """Everest simulator: BatchSimulator""" - def __init__(self, ever_config: EverestConfig, callback=None): + def __init__(self, ever_config: EverestConfig, callback=None) -> None: self._ert_config = ErtConfig.with_plugins().from_dict( config_dict=everest_to_ert_config( ever_config, site_config=ErtConfig.read_site_config() ) ) - controls_def = self._get_controls_def(ever_config) - results_def = self._get_results_def(ever_config) super(Simulator, self).__init__( - self._ert_config, controls_def, results_def, callback=callback + self._ert_config, + self._get_controls(ever_config), + self._get_results(ever_config), + callback=callback, ) + self._function_aliases = self._get_aliases(ever_config) self._experiment_id = None self._batch = 0 self._cache: Optional[_SimulatorCache] = None if ever_config.simulator is not None and ever_config.simulator.enable_cache: self._cache = _SimulatorCache() - @staticmethod - def _get_variables( - variables: Union[ - List[ControlVariableConfig], List[ControlVariableGuessListConfig] - ], - ) -> Union[List[str], Dict[str, List[str]]]: - if ( - isinstance(variables[0], ControlVariableConfig) - and getattr(variables[0], "index", None) is None - ): - return [var.name for var in variables] - result: DefaultDict[str, list] = defaultdict(list) - for variable in variables: - if isinstance(variable, ControlVariableGuessListConfig): - result[variable.name].extend( - str(index + 1) for index, _ in enumerate(variable.initial_guess) - ) - else: - result[variable.name].append(str(variable.index)) # type: ignore - return dict(result) # { name : [ index ] - - def _get_controls_def( - self, ever_config: EverestConfig - ) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: + def _get_controls(self, ever_config: EverestConfig) -> List[str]: controls = ever_config.controls or [] - return { - control.name: self._get_variables(control.variables) for control in controls - } + return [control.name for control in controls] + + def _get_results(self, ever_config: EverestConfig) -> List[str]: + objectives_names = [ + objective.name + for objective in ever_config.objective_functions + if objective.alias is None + ] + constraint_names = [ + constraint.name for constraint in (ever_config.output_constraints or []) + ] + return objectives_names + constraint_names - def _get_results_def(self, ever_config: EverestConfig): - self._function_aliases = { + def _get_aliases(self, ever_config: EverestConfig) -> Dict[str, str]: + aliases = { objective.name: objective.alias for objective in ever_config.objective_functions if objective.alias is not None @@ -83,19 +67,9 @@ def _get_results_def(self, ever_config: EverestConfig): constraint.upper_bound is not None and constraint.lower_bound is not None ): - self._function_aliases[f"{constraint.name}:lower"] = constraint.name - self._function_aliases[f"{constraint.name}:upper"] = constraint.name - - objectives_names = [ - objective.name - for objective in ever_config.objective_functions - if objective.name not in self._function_aliases - ] - - constraint_names = [ - constraint.name for constraint in (ever_config.output_constraints or []) - ] - return objectives_names + constraint_names + aliases[f"{constraint.name}:lower"] = constraint.name + aliases[f"{constraint.name}:upper"] = constraint.name + return aliases def __call__( self, control_values: NDArray[np.float64], metadata: EvaluatorContext diff --git a/tests/ert/unit_tests/simulator/test_batch_sim.py b/tests/ert/unit_tests/simulator/test_batch_sim.py index c58d3e93725..20f1c6501e2 100644 --- a/tests/ert/unit_tests/simulator/test_batch_sim.py +++ b/tests/ert/unit_tests/simulator/test_batch_sim.py @@ -4,7 +4,7 @@ import pytest -from ert.config import ErtConfig +from ert.config import ErtConfig, ExtParamConfig, GenDataConfig from ert.scheduler import JobState from ert.simulator import BatchContext, BatchSimulator @@ -35,6 +35,51 @@ def batch_sim_example(setup_case): return setup_case("batch_sim", "batch_sim.ert") +# TODO: The batch simulator was recently refactored. It now requires an ERT +# config object that has been generated using the everest_to_ert_config +# function. The resulting ERT config object includes features that cannot be +# specified in an ERT configuration file. This is acceptable since the batch +# simulator is only used by Everest and slated to be replaced in the near future +# with newer ERT functionality. However, the tests in this file assume that the +# batch simulator can be configured independently from an Everest configuration. +# To make the tests work, the batch simulator class is patched here to inject +# the missing functionality. +class PatchedBatchSimulator(BatchSimulator): + def __init__(self, ert_config, controls, results, callback=None): + super().__init__(ert_config, set(controls), results, callback) + ens_config = ert_config.ensemble_config + for control_name, variables in controls.items(): + ens_config.addNode( + ExtParamConfig( + name=control_name, + input_keys=variables, + output_file=control_name + ".json", + ) + ) + + if "gen_data" not in ens_config: + ens_config.addNode( + GenDataConfig( + keys=results, + input_files=[f"{k}" for k in results], + report_steps_list=[None for _ in results], + ) + ) + else: + existing_gendata = ens_config.response_configs["gen_data"] + existing_keys = existing_gendata.keys + assert isinstance(existing_gendata, GenDataConfig) + + for key in results: + if key not in existing_keys: + existing_gendata.keys.append(key) + existing_gendata.input_files.append(f"{key}") + existing_gendata.report_steps_list.append(None) + + +BatchSimulator = PatchedBatchSimulator + + def test_that_simulator_raises_error_when_missing_ertconfig(): with pytest.raises(ValueError, match="The first argument must be valid ErtConfig"): _ = BatchSimulator( diff --git a/tests/everest/test_egg_simulation.py b/tests/everest/test_egg_simulation.py index f810322d9ff..e87992f936b 100644 --- a/tests/everest/test_egg_simulation.py +++ b/tests/everest/test_egg_simulation.py @@ -561,6 +561,23 @@ def _generate_exp_ert_config(config_path, output_dir): "ECLBASE": "eclipse/model/EGG", "RANDOM_SEED": 123456, "SUMMARY": SUM_KEYS, + "GEN_DATA": [("rf", "RESULT_FILE:rf")], + "EXT_PARAM": { + "well_rate": { + "PROD1": ["1"], + "PROD2": ["1"], + "PROD3": ["1"], + "PROD4": ["1"], + "INJECT1": ["1"], + "INJECT2": ["1"], + "INJECT3": ["1"], + "INJECT4": ["1"], + "INJECT5": ["1"], + "INJECT6": ["1"], + "INJECT7": ["1"], + "INJECT8": ["1"], + }, + }, } diff --git a/tests/everest/test_res_initialization.py b/tests/everest/test_res_initialization.py index a006feec338..ee4cdb8a72f 100644 --- a/tests/everest/test_res_initialization.py +++ b/tests/everest/test_res_initialization.py @@ -137,6 +137,24 @@ def make_queue_system(queue_system): elif queue_system == ConfigKeys.SLURM: return slurm_queue_system() + def make_ext_param(): + return { + "group": [ + "W1", + "W2", + "W3", + "W4", + ], + } + + def make_gen_data(): + return [ + ( + "snake_oil_nvp", + "RESULT_FILE:snake_oil_nvp", + ), + ] + ert_config = { "DEFINE": [("", os.path.abspath(SNAKE_CONFIG_DIR))], "RUNPATH": os.path.join( @@ -156,6 +174,8 @@ def make_queue_system(queue_system): os.path.realpath("snake_oil/everest/model"), "everest_output/simulation_results", ), + "EXT_PARAM": make_ext_param(), + "GEN_DATA": make_gen_data(), } ert_config.update(make_queue_system(queue_system)) return ert_config @@ -210,6 +230,32 @@ def build_tutorial_dict(config_dir, output_dir): os.path.realpath("mocked_test_case"), "everest_output/simulation_results", ), + "EXT_PARAM": { + "group": [ + "w00", + "w01", + "w02", + "w03", + "w04", + "w05", + "w06", + "w07", + "w08", + "w09", + "w10", + "w11", + "w12", + "w13", + "w14", + "w15", + ], + }, + "GEN_DATA": [ + ( + "npv_function", + "RESULT_FILE:npv_function", + ), + ], }