From c414b45fece169b65ea803ba6ba7d35ea5d7f77a Mon Sep 17 00:00:00 2001 From: Marcel R Date: Sun, 16 Jul 2023 15:07:28 +0200 Subject: [PATCH] Add make_set to utils. --- law/util.py | 76 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/law/util.py b/law/util.py index 1f16d5b8..e20458c8 100644 --- a/law/util.py +++ b/law/util.py @@ -10,8 +10,8 @@ "is_number", "is_float", "try_int", "round_discrete", "str_to_int", "flag_to_bool", "empty_context", "common_task_params", "colored", "uncolored", "query_choice", "is_pattern", "brace_expand", "range_expand", "range_join", "multi_match", "is_iterable", "is_lazy_iterable", - "make_list", "make_tuple", "make_unique", "is_nested", "flatten", "merge_dicts", "unzip", - "which", "map_verbose", "map_struct", "mask_struct", "tmp_file", "perf_counter", + "make_list", "make_tuple", "make_set", "make_unique", "is_nested", "flatten", "merge_dicts", + "unzip", "which", "map_verbose", "map_struct", "mask_struct", "tmp_file", "perf_counter", "interruptable_popen", "readable_popen", "create_hash", "create_random_string", "copy_no_perm", "makedirs", "user_owns_file", "iter_chunks", "human_bytes", "parse_bytes", "human_duration", "parse_duration", "is_file_exists_error", "send_mail", "DotDict", "ShorthandDict", @@ -801,12 +801,11 @@ def make_list(obj, cast=True): """ if isinstance(obj, list): return list(obj) - elif is_lazy_iterable(obj): + if is_lazy_iterable(obj): return list(obj) - elif isinstance(obj, (tuple, set)) and cast: + if isinstance(obj, (tuple, set)) and cast: return list(obj) - else: - return [obj] + return [obj] def make_tuple(obj, cast=True): @@ -816,12 +815,25 @@ def make_tuple(obj, cast=True): """ if isinstance(obj, tuple): return obj - elif is_lazy_iterable(obj): + if is_lazy_iterable(obj): return tuple(obj) - elif isinstance(obj, (list, set)) and cast: + if isinstance(obj, (list, set)) and cast: return tuple(obj) - else: - return (obj,) + return (obj,) + + +def make_set(obj, cast=True): + """ + Converts an object *obj* to a set and returns it. Objects of types *list* and *tuple* are + converted if *cast* is *True*. Otherwise, and for all other types, *obj* is put in a new set. + """ + if isinstance(obj, set): + return obj + if is_lazy_iterable(obj): + return set(obj) + if isinstance(obj, (list, tuple)) and cast: + return set(obj) + return {obj} def make_unique(obj): @@ -832,10 +844,9 @@ def make_unique(obj): raised. """ if not isinstance(obj, (list, tuple)): - if is_iterable(obj) or is_lazy_iterable(obj): - obj = list(obj) - else: + if not is_iterable(obj) and not is_lazy_iterable(obj): raise TypeError("object is neither list, tuple, nor generic iterable") + obj = list(obj) ret = sorted(obj.__class__(set(obj)), key=lambda elem: obj.index(elem)) @@ -857,28 +868,29 @@ def flatten(*structs, **kwargs): """ if len(structs) == 0: return [] - elif len(structs) > 1: + + if len(structs) > 1: return flatten(structs, **kwargs) - else: - struct = structs[0] - - flatten_seq = lambda seq: sum((flatten(obj, **kwargs) for obj in seq), []) - if isinstance(struct, dict): - if kwargs.get("flatten_dict", True): - return flatten_seq(struct.values()) - elif isinstance(struct, list): - if kwargs.get("flatten_list", True): - return flatten_seq(struct) - elif isinstance(struct, tuple): - if kwargs.get("flatten_tuple", True): - return flatten_seq(struct) - elif isinstance(struct, set): - if kwargs.get("flatten_set", True): - return flatten_seq(struct) - elif is_lazy_iterable(struct): + + struct = structs[0] + + flatten_seq = lambda seq: sum((flatten(obj, **kwargs) for obj in seq), []) + if isinstance(struct, dict): + if kwargs.get("flatten_dict", True): + return flatten_seq(struct.values()) + elif isinstance(struct, list): + if kwargs.get("flatten_list", True): + return flatten_seq(struct) + elif isinstance(struct, tuple): + if kwargs.get("flatten_tuple", True): + return flatten_seq(struct) + elif isinstance(struct, set): + if kwargs.get("flatten_set", True): return flatten_seq(struct) + elif is_lazy_iterable(struct): + return flatten_seq(struct) - return [struct] + return [struct] def merge_dicts(*dicts, **kwargs):