Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
poodarchu committed Oct 24, 2023
1 parent 74e4991 commit b4baf00
Show file tree
Hide file tree
Showing 40 changed files with 2,982 additions and 278 deletions.
96 changes: 96 additions & 0 deletions efg/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,21 @@


def load_yaml(file_path):
"""
Load a YAML configuration file. This is a recursive function that loads a YAML configuration file and merges it with the includes defined in the file.
Args:
file_path: Path to the YAML configuration file.
Returns:
A dictionary of configuration values. The keys are the names of the configuration values the values are the values
"""
mapping = OmegaConf.load(file_path)

includes = mapping.get("includes", [])
include_mapping = OmegaConf.create()

# Merge the include mapping from the include file to the current include mapping.
for include in includes:
include = os.path.join("./", include)
current_include_mapping = load_yaml(include)
Expand All @@ -27,6 +37,12 @@ def load_yaml(file_path):

class Configuration:
def __init__(self, args):
"""
Initialize Omega. This is called by the __init__ method of the object. In this case we have a config object that can be used to set options and resolve resolvers
Args:
args: Command line arguments parsed by
"""
self.config = {}
self.args = args
self._register_resolvers()
Expand All @@ -41,78 +57,125 @@ def __init__(self, args):
self.config = self._merge_with_dotlist(self.config, args.opts)

def _build_default_config(self):
"""
Build and return the default configuration. This is a wrapper around _get_default_config_path to allow overriding the path to the default configuration file.
Returns:
A dictionary of config values to be used in the test case's config file or None if there is no default
"""
self.default_config_path = self._get_default_config_path()
default_config = load_yaml(self.default_config_path)
return default_config

def _build_user_config(self, config_path):
"""
Build and return user_config. This is called by __init__ to build the user_config from the config_path passed to the constructor.
Args:
config_path: Path to the config file. If None is passed an empty dict is returned.
Returns:
Dictionary of options to pass to the command line or None if no options were passed. Note that the options will be merged with the options passed
"""
user_config = {}

# Update user_config with opts if passed
self.config_path = config_path
# Load user configuration file.
if self.config_path is not None:
user_config = load_yaml(self.config_path)

return user_config

def get_config(self):
"""
Get the configuration for this service. This is a convenience method that calls : meth : ` register_resolvers ` to register all resolvers in the order they were registered.
Returns:
A dictionary of configuration values keyed by service name. The value may be None if there is no configuration
"""
self._register_resolvers()
return self.config

def _register_resolvers(self):
"""
Clear and register resolvers for device count. This is used to avoid having to re - register every time
"""
OmegaConf.clear_resolvers()
# Device count resolver
device_count = max(1, torch.cuda.device_count())
OmegaConf.register_new_resolver("device_count", lambda: device_count)

def _merge_with_dotlist(self, config, opts):
"""
Merge config with dotlist. This is a helper function for merge_and_dotlist. It will take a config and merge it with a list of options that can be used to set the value of the config.
Args:
config: The OmegaConf node to merge with the opts
opts: A list of options to set
Returns:
The config with the opts merged with the config. If opts is None or empty it will return config
"""
# TODO: To remove technical debt, a possible solution is to use
# struct mode to update with dotlist OmegaConf node. Look into this
# in next iteration
# Set the options to the default values.
if opts is None:
opts = []

# Return the config object if opts is empty.
if len(opts) == 0:
return config

# Support equal e.g. model=visual_bert for better future hydra support
has_equal = opts[0].find("=") != -1

# Returns a list of strings containing the options.
if has_equal:
opt_values = [opt.split("=") for opt in opts]
else:
assert len(opts) % 2 == 0, "Number of opts should be multiple of 2"
opt_values = zip(opts[0::2], opts[1::2])

# Update the configuration for the given option.
for opt, value in opt_values:
splits = opt.split(".")
current = config
# Update the configuration options. The value is updated from the configuration.
for idx, field in enumerate(splits):
array_index = -1
# Find the index of the field in the field.
if field.find("[") != -1 and field.find("]") != -1:
stripped_field = field[: field.find("[")]
array_index = int(field[field.find("[") + 1 : field.find("]")])
else:
stripped_field = field
# If the field is missing from current configuration option raises AttributeError
if stripped_field not in current:
raise AttributeError(
"While updating configuration option {} is missing from configuration at field {}".format(
opt, stripped_field
)
)
# Update the configuration option. If the value is a mapping the value is updated.
if isinstance(current[stripped_field], collections.abc.Mapping):
current = current[stripped_field]
elif isinstance(current[stripped_field], collections.abc.Sequence) and array_index != -1:
current_value = current[stripped_field][array_index]

# Case where array element to be updated is last element
# If the current value is a mapping or sequence it will move the current value to the next value.
if not isinstance(current_value, (collections.abc.Mapping, collections.abc.Sequence)):
print("Overriding option {} to {}".format(opt, value))
current[stripped_field][array_index] = self._decode_value(value)
else:
# Otherwise move on down the chain
current = current_value
else:
# Update the current value of the option.
if idx == len(splits) - 1:
print("Overriding option {} to {}".format(opt, value))
current[stripped_field] = self._decode_value(value)
Expand All @@ -126,10 +189,21 @@ def _merge_with_dotlist(self, config, opts):
return config

def _decode_value(self, value):
"""
Decode a value that may be a string. This is used to determine if we are dealing with a value that has been passed to the command line or not.
Args:
value: The value to decode. If it is a string it is assumed to be a string and the function will return it.
Returns:
The decoded value or the
"""
# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L400
# Return the value if it s not a string.
if not isinstance(value, str):
return value

# Set the value to None.
if value == "None":
value = None

Expand All @@ -142,15 +216,37 @@ def _decode_value(self, value):
return value

def freeze(self):
"""
Freeze the Omega configuration to prevent further changes to the configuration. This is useful for debugging and to ensure that the configuration is in a consistent state
"""
OmegaConf.set_struct(self.config, True)

def defrost(self):
"""
Defrost the Omega object so it can be used for a new simulation. This is a destructive
"""
OmegaConf.set_struct(self.config, False)

def _convert_node_to_json(self, node):
"""
Convert a node to json. This is a helper method to convert a node to json. The node is passed as a parameter to OmegaConf. to_container and a json. dumps is returned.
Args:
node: The node to convert. Must be a dict
Returns:
A json representation of the
"""
container = OmegaConf.to_container(node, resolve=True)
return json.dumps(container, indent=4, sort_keys=True)

def _get_default_config_path(self):
"""
Get the path to the default config file. This is used to determine where the user's config file should be saved when they create a new instance of the config class.
Returns:
The path to the default config file for the user's config file or None if there is no default
"""
directory = os.path.dirname(os.path.abspath(__file__))
return os.path.join(directory, "..", "config", "default.yaml")
11 changes: 11 additions & 0 deletions efg/data/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@


def build_processors(pipelines):
"""
Build processors based on pipeline. This is a list of pipelines that can be a list of pipeline names or a dictionary of name : args
Args:
pipelines: A list of pipeline names or a dictionary of name : args
Returns:
A list of processors to be used in the pipeline
"""
transforms = []
# Add transforms to the list of pipelines.
for pipeline in pipelines:
# Add a pipeline to the transforms list.
if isinstance(pipeline, dict):
name, args = pipeline.copy().popitem()
transform = PROCESSORS.get(name)(**args)
Expand Down
Loading

0 comments on commit b4baf00

Please sign in to comment.