Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move schema scripts over to hed-python #929

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added hed/scripts/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions hed/scripts/convert_and_update_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from hed.schema import load_schema_version
from hed.scripts.script_util import sort_base_schemas, validate_all_schemas, add_extension
from hed.schema.schema_io.df2schema import load_dataframes
from hed.schema.schema_io.ontology_util import update_dataframes_from_schema, save_dataframes
from hed.schema.hed_schema_io import load_schema, from_dataframes
import argparse


def convert_and_update(filenames, set_ids):
""" Validate, convert, and update as needed all schemas listed in filenames
If any schema fails to validate, no schemas will be updated.
Parameters:
filenames(list of str): A list of filenames that have been updated
set_ids(bool): If True, assign missing hedIds
"""
# Find and group the changed files
schema_files = sort_base_schemas(filenames)
all_issues = validate_all_schemas(schema_files)

if all_issues or not schema_files:
print("Did not attempt to update schemas due to validation failures")
return 1

updated = []
# If we are here, we have validated the schemas(and if there's more than one version changed, that they're the same)
for basename, extensions in schema_files.items():
# Skip any with multiple extensions or not in pre-release
if "prerelease" not in basename:
print(f"Skipping updates on {basename}, not in a prerelease folder.")
continue
source_filename = add_extension(basename,
list(extensions)[0]) # Load any changed schema version, they're all the same

# todo: more properly decide how we want to handle non lowercase extensions.
tsv_extension = ".tsv"
for extension in extensions:
if extension.lower() == ".tsv":
tsv_extension = extension

source_df_filename = add_extension(basename, tsv_extension)
schema = load_schema(source_filename)
print(f"Trying to convert/update file {source_filename}")
source_dataframes = load_dataframes(source_df_filename)
# todo: We need a more robust system for if some files are missing
# (especially for library schemas which will probably lack some)
if any(value is None for value in source_dataframes.values()):
source_dataframes = schema.get_as_dataframes()

result = update_dataframes_from_schema(source_dataframes, schema, assign_missing_ids=set_ids)

schema_reloaded = from_dataframes(result)
schema_reloaded.save_as_mediawiki(basename + ".mediawiki")
schema_reloaded.save_as_xml(basename + ".xml")

save_dataframes(source_df_filename, result)
updated.append(basename)

for basename in updated:
print(f"Schema {basename} updated.")

if not updated:
print("Did not update any schemas")
return 0


def main():
parser = argparse.ArgumentParser(description='Update other schema formats based on the changed one.')
parser.add_argument('filenames', nargs='*', help='List of files to process')
parser.add_argument('--set-ids', action='store_true', help='Set IDs for each file')

args = parser.parse_args()

filenames = args.filenames
set_ids = args.set_ids

# Trigger a local cache hit (this ensures trying to load withStandard schemas will work properly)
_ = load_schema_version("8.2.0")

return convert_and_update(filenames, set_ids)


if __name__ == "__main__":
exit(main())
153 changes: 153 additions & 0 deletions hed/scripts/script_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os.path
from collections import defaultdict
from hed.schema import from_string, load_schema
from hed.errors import get_printable_issue_string, HedFileError, SchemaWarnings

all_extensions = [".tsv", ".mediawiki", ".xml"]


def validate_schema(file_path):
""" Validates the given schema, ensuring it can save/load as well as validates.
This is probably overkill...
"""
validation_issues = []
try:
base_schema = load_schema(file_path)
issues = base_schema.check_compliance()
issues = [issue for issue in issues if issue["code"] != SchemaWarnings.SCHEMA_PRERELEASE_VERSION_USED]
if issues:
error_message = get_printable_issue_string(issues, title=file_path)
validation_issues.append(error_message)

mediawiki_string = base_schema.get_as_mediawiki_string()
reloaded_schema = from_string(mediawiki_string, schema_format=".mediawiki")

if reloaded_schema != base_schema:
error_text = f"Failed to reload {file_path} as mediawiki. " \
f"There is either a problem with the source file, or the saving/loading code."
validation_issues.append(error_text)

xml_string = base_schema.get_as_xml_string()
reloaded_schema = from_string(xml_string, schema_format=".xml")

if reloaded_schema != base_schema:
error_text = f"Failed to reload {file_path} as xml. " \
f"There is either a problem with the source file, or the saving/loading code."
validation_issues.append(error_text)
except HedFileError as e:
print(f"Saving/loading error: {file_path} {e.message}")
error_text = e.message
if e.issues:
error_text = get_printable_issue_string(e.issues, title=file_path)
validation_issues.append(error_text)

return validation_issues


def add_extension(basename, extension):
"""Generate the final name for a given extension. Only .tsv varies notably."""
if extension.lower() == ".tsv":
parent_path, basename = os.path.split(basename)
return os.path.join(parent_path, "hedtsv", basename)
return basename + extension


def sort_base_schemas(filenames):
""" Sort and group the changed files based on basename
Example input: ["test_schema.mediawiki", "hedtsv/test_schema/test_schema_Tag.tsv", "other_schema.xml"]
Example output:
{
"test_schema": {".mediawiki", ".tsv"},
other_schema": {".xml"}
}
Parameters:
filenames(list or container): The changed filenames
Returns:
sorted_files(dict): A dictionary where keys are the basename, and the values are a set of extensions modified
Can include tsv, mediawiki, and xml.
"""
schema_files = defaultdict(set)
for file_path in filenames:
basename, extension = os.path.splitext(file_path)
if extension.lower() == ".xml" or extension.lower() == ".mediawiki":
schema_files[basename].add(extension)
continue
elif extension.lower() == ".tsv":
tsv_basename = basename.rpartition("_")[0]
full_parent_path, real_basename = os.path.split(tsv_basename)
full_parent_path, real_basename2 = os.path.split(full_parent_path)
real_parent_path, hedtsv_folder = os.path.split(full_parent_path)
if hedtsv_folder != "hedtsv":
print(f"Ignoring file {file_path}. .tsv files must be in an 'hedtsv' subfolder.")
continue
if real_basename != real_basename2:
print(f"Ignoring file {file_path}. .tsv files must be in a subfolder with the same name.")
continue
real_name = os.path.join(real_parent_path, real_basename)
schema_files[real_name].add(extension)
else:
print(f"Ignoring file {file_path}")

return schema_files


def validate_all_schema_formats(basename):
""" Validate all 3 versions of the given schema.
Parameters:
basename(str): a schema to check all 3 formats are identical of.
Returns:
issue_list(list): A non-empty list if there are any issues.
"""
# Note if more than one is changed, it intentionally checks all 3 even if one wasn't changed.
# todo: this needs to be updated to handle capital letters in the extension.
paths = [add_extension(basename, extension) for extension in all_extensions]
try:
schemas = [load_schema(path) for path in paths]
all_equal = all(obj == schemas[0] for obj in schemas[1:])
if not all_equal:
return [
f"Multiple schemas of type {basename} were modified, and are not equal.\n"
f"Only modify one source schema type at a time(mediawiki, xml, tsv), or modify all 3 at once."]
except HedFileError as e:
error_message = f"Error loading schema: {e.message}"
return [error_message]

return []


def validate_all_schemas(schema_files):
"""Validates all the schema files/formats in the schema dict
If multiple formats were edited, ensures all 3 formats exist and match.
Parameters:
schema_files(dict of sets): basename:[extensions] dictionary for all files changed
Returns:
issues(list of str): Any issues found validating or loading schemas.
"""
all_issues = []
for basename, extensions in schema_files.items():
single_schema_issues = []
for extension in extensions:
full_path = add_extension(basename, extension)
single_schema_issues += validate_schema(full_path)

if len(extensions) > 1 and not single_schema_issues and "prerelease" in basename:
single_schema_issues += validate_all_schema_formats(basename)

print(f"Validating: {basename}...")
print(f"Extensions: {extensions}")
if single_schema_issues:
for issue in single_schema_issues:
print(issue)

all_issues += single_schema_issues
return all_issues
22 changes: 22 additions & 0 deletions hed/scripts/validate_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
from hed.schema import load_schema_version
from hed.scripts.script_util import validate_all_schemas, sort_base_schemas


def main(arg_list=None):
# Trigger a local cache hit
_ = load_schema_version("8.2.0")

if not arg_list:
arg_list = sys.argv[1:]

schema_files = sort_base_schemas(arg_list)
issues = validate_all_schemas(schema_files)

if issues:
return 1
return 0


if __name__ == "__main__":
exit(main())
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dependencies = [
run_remodel = "hed.tools.remodeling.cli.run_remodel:main"
run_remodel_backup = "hed.tools.remodeling.cli.run_remodel_backup:main"
run_remodel_restore = "hed.tools.remodeling.cli.run_remodel_restore:main"
hed_validate_schemas = "hed.scripts.validate_schemas:main"
hed_update_schemas = "hed.scripts.convert_and_update_schema:main"

[tool.versioneer]
VCS = "git"
Expand Down
2 changes: 1 addition & 1 deletion tests/schema/test_schema_attribute_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_deprecatedFrom(self):
self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, tag_entry, attribute_name))
del tag_entry.attributes["deprecatedFrom"]

unit_class_entry = self.hed_schema.unit_classes["temperatureUnits"]
unit_class_entry = copy.deepcopy(self.hed_schema.unit_classes["temperatureUnits"])
# This should raise an issue because it assumes the attribute is set
self.assertTrue(schema_attribute_validators.tag_is_deprecated_check(self.hed_schema, unit_class_entry, attribute_name))
unit_class_entry.attributes["deprecatedFrom"] = "8.1.0"
Expand Down
Empty file added tests/scripts/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions tests/scripts/test_convert_and_update_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import unittest
import os
import shutil
import copy
from hed import load_schema, load_schema_version
from hed.schema import HedSectionKey, HedKey
from hed.scripts.script_util import add_extension
from hed.scripts.convert_and_update_schema import convert_and_update


class TestConvertAndUpdate(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create a temporary directory for schema files
cls.base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'schemas_update', 'prerelease')
if not os.path.exists(cls.base_path):
os.makedirs(cls.base_path)

def test_schema_conversion_and_update(self):
# Load a known schema, modify it if necessary, and save it
schema = load_schema_version("8.3.0")
original_name = os.path.join(self.base_path, "test_schema.mediawiki")
schema.save_as_mediawiki(original_name)

# Assume filenames updated includes just the original schema file for simplicity
filenames = [original_name]
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)

tsv_filename = add_extension(os.path.join(self.base_path, "test_schema"), ".tsv")
schema_reload1 = load_schema(tsv_filename)
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml"))

self.assertEqual(schema, schema_reload1)
self.assertEqual(schema, schema_reload2)

# Now verify after doing this again with a new schema, they're still the same.
schema = load_schema_version("8.2.0")
schema.save_as_dataframes(tsv_filename)

filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")]
result = convert_and_update(filenames, set_ids=False)

# Verify no error from convert_and_update and the correct schema version was saved
self.assertEqual(result, 0)

schema_reload1 = load_schema(os.path.join(self.base_path, "test_schema.mediawiki"))
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml"))

self.assertEqual(schema, schema_reload1)
self.assertEqual(schema, schema_reload2)

def test_schema_adding_tag(self):
schema = load_schema_version("8.3.0")
basename = os.path.join(self.base_path, "test_schema_edited")
schema.save_as_mediawiki(add_extension(basename, ".mediawiki"))
schema.save_as_xml(add_extension(basename, ".xml"))
schema.save_as_dataframes(add_extension(basename, ".tsv"))

schema_edited = copy.deepcopy(schema)
test_tag_name = "NewTagWithoutID"
new_entry = schema_edited._create_tag_entry(test_tag_name, HedSectionKey.Tags)
schema_edited._add_tag_to_dict(test_tag_name, new_entry, HedSectionKey.Tags)

schema_edited.save_as_mediawiki(add_extension(basename, ".mediawiki"))

# Assume filenames updated includes just the original schema file for simplicity
filenames = [add_extension(basename, ".mediawiki")]
result = convert_and_update(filenames, set_ids=False)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))

self.assertEqual(schema_reloaded, schema_edited)

result = convert_and_update(filenames, set_ids=True)
self.assertEqual(result, 0)

schema_reloaded = load_schema(add_extension(basename, ".xml"))

reloaded_entry = schema_reloaded.tags[test_tag_name]
self.assertTrue(reloaded_entry.has_attribute(HedKey.HedID))


@classmethod
def tearDownClass(cls):
# Clean up the directory created for testing
shutil.rmtree(cls.base_path)
Loading