Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add util functions for hed strings #725

Merged
merged 2 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions hed/models/string_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from hed.models.hed_string import HedString


def split_base_tags(hed_string, base_tags, remove_group=False):
""" Splits a HedString object into two separate HedString objects based on the presence of base tags.

Args:
hed_string (HedString): The input HedString object to be split.
base_tags (list of str): A list of strings representing the base tags.
This is matching the base tag NOT all the terms above it.
remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False.

Returns:
tuple: A tuple containing two HedString objects:
- The first HedString object contains the remaining tags from hed_string.
- The second HedString object contains the tags from hed_string that match the base_tags.
"""

base_tags = [tag.lower() for tag in base_tags]
include_groups = 0
if remove_group:
include_groups = 2
found_things = hed_string.find_tags(base_tags, recursive=True, include_groups=include_groups)
if remove_group:
found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things]

if found_things:
hed_string.remove(found_things)

return hed_string, HedString("", hed_string._schema, _contents=found_things)


def split_def_tags(hed_string, def_names, remove_group=False):
""" Splits a HedString object into two separate HedString objects based on the presence of wildcard tags.

This does NOT handle def-expand tags currently.

Args:
hed_string (HedString): The input HedString object to be split.
def_names (list of str): A list of def names to search for. Can optionally include a value.
remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False.

Returns:
tuple: A tuple containing two HedString objects:
- The first HedString object contains the remaining tags from hed_string.
- The second HedString object contains the tags from hed_string that match the def_names.
"""
include_groups = 0
if remove_group:
include_groups = 2
wildcard_tags = [f"def/{def_name}".lower() for def_name in def_names]
found_things = hed_string.find_wildcard_tags(wildcard_tags, recursive=True, include_groups=include_groups)
if remove_group:
found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things]

if found_things:
hed_string.remove(found_things)

return hed_string, HedString("", hed_string._schema, _contents=found_things)
30 changes: 15 additions & 15 deletions hed/schema/hed_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def merged(self):
return not self.header_attributes.get(constants.UNMERGED_ATTRIBUTE, "")

@property
def all_tags(self):
def tags(self):
""" Return the tag schema section.

Returns:
HedSchemaTagSection: The tag section.
"""
return self._sections[HedSectionKey.AllTags]
return self._sections[HedSectionKey.Tags]

@property
def unit_classes(self):
Expand Down Expand Up @@ -354,7 +354,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non
from hed.schema import schema_compliance
return schema_compliance.check_compliance(self, check_for_warnings, name, error_handler)

def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.

Parameters:
Expand All @@ -370,20 +370,20 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
return self._sections[key_class].get_entries_with_attribute(attribute, return_name_only=True,
schema_namespace=self._namespace)

def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.

Parameters:
name (str): Any form of basic tag(or other section entry) to look up.
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
schema_namespace (str): Only used on AllTags. If incorrect, will return None.
schema_namespace (str): Only used on Tags. If incorrect, will return None.

Returns:
HedSchemaEntry: The schema entry for the given tag.
"""
if key_class == HedSectionKey.AllTags:
if key_class == HedSectionKey.Tags:
if schema_namespace != self._namespace:
return None
if name.startswith(self._namespace):
Expand Down Expand Up @@ -415,7 +415,7 @@ def find_tag_entry(self, tag, schema_namespace=""):
# ===============================================
# Private utility functions for getting/finding tags
# ===============================================
def _get_tag_entry(self, name, key_class=HedSectionKey.AllTags):
def _get_tag_entry(self, name, key_class=HedSectionKey.Tags):
""" Return the schema entry for this tag, if one exists.

Parameters:
Expand Down Expand Up @@ -524,7 +524,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl
tag,
index_in_tag=word_start_index,
index_in_tag_end=word_start_index + len(name),
expected_parent_tag=self.all_tags[name].name)
expected_parent_tag=self.tags[name].name)
raise self._TagIdentifyError(error)
word_start_index += len(name) + 1

Expand All @@ -533,7 +533,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl
# ===============================================
def finalize_dictionaries(self):
""" Call to finish loading. """
self._has_duplicate_tags = bool(self.all_tags.duplicate_names)
self._has_duplicate_tags = bool(self.tags.duplicate_names)
self._update_all_entries()

def _update_all_entries(self):
Expand Down Expand Up @@ -568,13 +568,13 @@ def get_desc_iter(self):
if tag_entry.description:
yield tag_entry.name, tag_entry.description

def get_tag_description(self, tag_name, key_class=HedSectionKey.AllTags):
def get_tag_description(self, tag_name, key_class=HedSectionKey.Tags):
""" Return the description associated with the tag.

Parameters:
tag_name (str): A hed tag name(or unit/unit modifier etc) with proper capitalization.
key_class (str): A string indicating type of description (e.g. All tags, Units, Unit modifier).
The default is HedSectionKey.AllTags.
The default is HedSectionKey.Tags.

Returns:
str: A description of the specified tag.
Expand All @@ -595,7 +595,7 @@ def get_all_schema_tags(self, return_last_term=False):

"""
final_list = []
for lower_tag, tag_entry in self.all_tags.items():
for lower_tag, tag_entry in self.tags.items():
if return_last_term:
final_list.append(tag_entry.name.split('/')[-1])
else:
Expand Down Expand Up @@ -636,7 +636,7 @@ def get_tag_attribute_names(self):
and not tag_entry.has_attribute(HedKey.UnitModifierProperty)
and not tag_entry.has_attribute(HedKey.ValueClassProperty)}

def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.AllTags):
def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.Tags):
""" Gather all attributes for a given tag name.

Parameters:
Expand Down Expand Up @@ -670,7 +670,7 @@ def _create_empty_sections():
dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units)
dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses)
dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses)
dictionaries[HedSectionKey.AllTags] = HedSchemaTagSection(HedSectionKey.AllTags, case_sensitive=False)
dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False)

return dictionaries

Expand Down Expand Up @@ -717,7 +717,7 @@ def _get_attributes_for_section(self, key_class):
dict or HedSchemaSection: A dict of all the attributes and this section.

"""
if key_class == HedSectionKey.AllTags:
if key_class == HedSectionKey.Tags:
return self.get_tag_attribute_names()
elif key_class == HedSectionKey.Attributes:
prop_added_dict = {key: value for key, value in self._sections[HedSectionKey.Properties].items()}
Expand Down
6 changes: 3 additions & 3 deletions hed/schema/hed_schema_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def valid_prefixes(self):
raise NotImplemented("This function must be implemented in the baseclass")

@abstractmethod
def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.

Parameters:
Expand All @@ -72,15 +72,15 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):

# todo: maybe tweak this API so you don't have to pass in library namespace?
@abstractmethod
def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.

Parameters:
name (str): Any form of basic tag(or other section entry) to look up.
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
schema_namespace (str): Only used on AllTags. If incorrect, will return None.
schema_namespace (str): Only used on Tags. If incorrect, will return None.

Returns:
HedSchemaEntry: The schema entry for the given tag.
Expand Down
2 changes: 1 addition & 1 deletion hed/schema/hed_schema_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class HedSectionKey(Enum):
""" Kegs designating specific sections in a HedSchema object.
"""
# overarching category listing all tags
AllTags = 'tags'
Tags = 'tags'
# Overarching category listing all unit classes
UnitClasses = 'unitClasses'
# Overarching category listing all units(not divided by type)
Expand Down
12 changes: 6 additions & 6 deletions hed/schema/hed_schema_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non
issues_list += schema.check_compliance(check_for_warnings, name, error_handler)
return issues_list

def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.

Parameters:
Expand All @@ -114,20 +114,20 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
Notes:
- The result is cached so will be fast after first call.
"""
all_tags = set()
tags = set()
for schema in self._schemas.values():
all_tags.update(schema.get_tags_with_attribute(attribute, key_class))
return list(all_tags)
tags.update(schema.get_tags_with_attribute(attribute, key_class))
return list(tags)

def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.

Parameters:
name (str): Any form of basic tag(or other section entry) to look up.
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
schema_namespace (str): Only used on AllTags. If incorrect, will return None.
schema_namespace (str): Only used on Tags. If incorrect, will return None.

Returns:
HedSchemaEntry: The schema entry for the given tag.
Expand Down
2 changes: 1 addition & 1 deletion hed/schema/hed_schema_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
HedSectionKey.Units: UnitEntry,
HedSectionKey.UnitClasses: UnitClassEntry,
HedSectionKey.ValueClasses: HedSchemaEntry,
HedSectionKey.AllTags: HedTagEntry,
HedSectionKey.Tags: HedTagEntry,
}


Expand Down
4 changes: 2 additions & 2 deletions hed/schema/schema_attribute_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def tag_exists_check(hed_schema, tag_entry, attribute_name):
possible_tags = tag_entry.attributes.get(attribute_name, "")
split_tags = possible_tags.split(",")
for org_tag in split_tags:
if org_tag and org_tag not in hed_schema.all_tags:
if org_tag and org_tag not in hed_schema.tags:
issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND,
org_tag,
index_in_tag=0,
Expand All @@ -72,7 +72,7 @@ def tag_exists_base_schema_check(hed_schema, tag_entry, attribute_name):
"""
issues = []
rooted_tag = tag_entry.attributes.get(attribute_name, "")
if rooted_tag and rooted_tag not in hed_schema.all_tags:
if rooted_tag and rooted_tag not in hed_schema.tags:
issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND,
rooted_tag,
index_in_tag=0,
Expand Down
Loading