-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #929 from IanCa/develop
Move schema scripts over to hed-python
- Loading branch information
Showing
9 changed files
with
474 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from hed.schema import load_schema_version | ||
from hed.scripts.script_util import sort_base_schemas, validate_all_schemas, add_extension | ||
from hed.schema.schema_io.df2schema import load_dataframes | ||
from hed.schema.schema_io.ontology_util import update_dataframes_from_schema, save_dataframes | ||
from hed.schema.hed_schema_io import load_schema, from_dataframes | ||
import argparse | ||
|
||
|
||
def convert_and_update(filenames, set_ids): | ||
""" Validate, convert, and update as needed all schemas listed in filenames | ||
If any schema fails to validate, no schemas will be updated. | ||
Parameters: | ||
filenames(list of str): A list of filenames that have been updated | ||
set_ids(bool): If True, assign missing hedIds | ||
""" | ||
# Find and group the changed files | ||
schema_files = sort_base_schemas(filenames) | ||
all_issues = validate_all_schemas(schema_files) | ||
|
||
if all_issues or not schema_files: | ||
print("Did not attempt to update schemas due to validation failures") | ||
return 1 | ||
|
||
updated = [] | ||
# If we are here, we have validated the schemas(and if there's more than one version changed, that they're the same) | ||
for basename, extensions in schema_files.items(): | ||
# Skip any with multiple extensions or not in pre-release | ||
if "prerelease" not in basename: | ||
print(f"Skipping updates on {basename}, not in a prerelease folder.") | ||
continue | ||
source_filename = add_extension(basename, | ||
list(extensions)[0]) # Load any changed schema version, they're all the same | ||
|
||
# todo: more properly decide how we want to handle non lowercase extensions. | ||
tsv_extension = ".tsv" | ||
for extension in extensions: | ||
if extension.lower() == ".tsv": | ||
tsv_extension = extension | ||
|
||
source_df_filename = add_extension(basename, tsv_extension) | ||
schema = load_schema(source_filename) | ||
print(f"Trying to convert/update file {source_filename}") | ||
source_dataframes = load_dataframes(source_df_filename) | ||
# todo: We need a more robust system for if some files are missing | ||
# (especially for library schemas which will probably lack some) | ||
if any(value is None for value in source_dataframes.values()): | ||
source_dataframes = schema.get_as_dataframes() | ||
|
||
result = update_dataframes_from_schema(source_dataframes, schema, assign_missing_ids=set_ids) | ||
|
||
schema_reloaded = from_dataframes(result) | ||
schema_reloaded.save_as_mediawiki(basename + ".mediawiki") | ||
schema_reloaded.save_as_xml(basename + ".xml") | ||
|
||
save_dataframes(source_df_filename, result) | ||
updated.append(basename) | ||
|
||
for basename in updated: | ||
print(f"Schema {basename} updated.") | ||
|
||
if not updated: | ||
print("Did not update any schemas") | ||
return 0 | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Update other schema formats based on the changed one.') | ||
parser.add_argument('filenames', nargs='*', help='List of files to process') | ||
parser.add_argument('--set-ids', action='store_true', help='Set IDs for each file') | ||
|
||
args = parser.parse_args() | ||
|
||
filenames = args.filenames | ||
set_ids = args.set_ids | ||
|
||
# Trigger a local cache hit (this ensures trying to load withStandard schemas will work properly) | ||
_ = load_schema_version("8.2.0") | ||
|
||
return convert_and_update(filenames, set_ids) | ||
|
||
|
||
if __name__ == "__main__": | ||
exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import os.path | ||
from collections import defaultdict | ||
from hed.schema import from_string, load_schema | ||
from hed.errors import get_printable_issue_string, HedFileError, SchemaWarnings | ||
|
||
all_extensions = [".tsv", ".mediawiki", ".xml"] | ||
|
||
|
||
def validate_schema(file_path): | ||
""" Validates the given schema, ensuring it can save/load as well as validates. | ||
This is probably overkill... | ||
""" | ||
validation_issues = [] | ||
try: | ||
base_schema = load_schema(file_path) | ||
issues = base_schema.check_compliance() | ||
issues = [issue for issue in issues if issue["code"] != SchemaWarnings.SCHEMA_PRERELEASE_VERSION_USED] | ||
if issues: | ||
error_message = get_printable_issue_string(issues, title=file_path) | ||
validation_issues.append(error_message) | ||
|
||
mediawiki_string = base_schema.get_as_mediawiki_string() | ||
reloaded_schema = from_string(mediawiki_string, schema_format=".mediawiki") | ||
|
||
if reloaded_schema != base_schema: | ||
error_text = f"Failed to reload {file_path} as mediawiki. " \ | ||
f"There is either a problem with the source file, or the saving/loading code." | ||
validation_issues.append(error_text) | ||
|
||
xml_string = base_schema.get_as_xml_string() | ||
reloaded_schema = from_string(xml_string, schema_format=".xml") | ||
|
||
if reloaded_schema != base_schema: | ||
error_text = f"Failed to reload {file_path} as xml. " \ | ||
f"There is either a problem with the source file, or the saving/loading code." | ||
validation_issues.append(error_text) | ||
except HedFileError as e: | ||
print(f"Saving/loading error: {file_path} {e.message}") | ||
error_text = e.message | ||
if e.issues: | ||
error_text = get_printable_issue_string(e.issues, title=file_path) | ||
validation_issues.append(error_text) | ||
|
||
return validation_issues | ||
|
||
|
||
def add_extension(basename, extension): | ||
"""Generate the final name for a given extension. Only .tsv varies notably.""" | ||
if extension.lower() == ".tsv": | ||
parent_path, basename = os.path.split(basename) | ||
return os.path.join(parent_path, "hedtsv", basename) | ||
return basename + extension | ||
|
||
|
||
def sort_base_schemas(filenames): | ||
""" Sort and group the changed files based on basename | ||
Example input: ["test_schema.mediawiki", "hedtsv/test_schema/test_schema_Tag.tsv", "other_schema.xml"] | ||
Example output: | ||
{ | ||
"test_schema": {".mediawiki", ".tsv"}, | ||
other_schema": {".xml"} | ||
} | ||
Parameters: | ||
filenames(list or container): The changed filenames | ||
Returns: | ||
sorted_files(dict): A dictionary where keys are the basename, and the values are a set of extensions modified | ||
Can include tsv, mediawiki, and xml. | ||
""" | ||
schema_files = defaultdict(set) | ||
for file_path in filenames: | ||
basename, extension = os.path.splitext(file_path) | ||
if extension.lower() == ".xml" or extension.lower() == ".mediawiki": | ||
schema_files[basename].add(extension) | ||
continue | ||
elif extension.lower() == ".tsv": | ||
tsv_basename = basename.rpartition("_")[0] | ||
full_parent_path, real_basename = os.path.split(tsv_basename) | ||
full_parent_path, real_basename2 = os.path.split(full_parent_path) | ||
real_parent_path, hedtsv_folder = os.path.split(full_parent_path) | ||
if hedtsv_folder != "hedtsv": | ||
print(f"Ignoring file {file_path}. .tsv files must be in an 'hedtsv' subfolder.") | ||
continue | ||
if real_basename != real_basename2: | ||
print(f"Ignoring file {file_path}. .tsv files must be in a subfolder with the same name.") | ||
continue | ||
real_name = os.path.join(real_parent_path, real_basename) | ||
schema_files[real_name].add(extension) | ||
else: | ||
print(f"Ignoring file {file_path}") | ||
|
||
return schema_files | ||
|
||
|
||
def validate_all_schema_formats(basename): | ||
""" Validate all 3 versions of the given schema. | ||
Parameters: | ||
basename(str): a schema to check all 3 formats are identical of. | ||
Returns: | ||
issue_list(list): A non-empty list if there are any issues. | ||
""" | ||
# Note if more than one is changed, it intentionally checks all 3 even if one wasn't changed. | ||
# todo: this needs to be updated to handle capital letters in the extension. | ||
paths = [add_extension(basename, extension) for extension in all_extensions] | ||
try: | ||
schemas = [load_schema(path) for path in paths] | ||
all_equal = all(obj == schemas[0] for obj in schemas[1:]) | ||
if not all_equal: | ||
return [ | ||
f"Multiple schemas of type {basename} were modified, and are not equal.\n" | ||
f"Only modify one source schema type at a time(mediawiki, xml, tsv), or modify all 3 at once."] | ||
except HedFileError as e: | ||
error_message = f"Error loading schema: {e.message}" | ||
return [error_message] | ||
|
||
return [] | ||
|
||
|
||
def validate_all_schemas(schema_files): | ||
"""Validates all the schema files/formats in the schema dict | ||
If multiple formats were edited, ensures all 3 formats exist and match. | ||
Parameters: | ||
schema_files(dict of sets): basename:[extensions] dictionary for all files changed | ||
Returns: | ||
issues(list of str): Any issues found validating or loading schemas. | ||
""" | ||
all_issues = [] | ||
for basename, extensions in schema_files.items(): | ||
single_schema_issues = [] | ||
for extension in extensions: | ||
full_path = add_extension(basename, extension) | ||
single_schema_issues += validate_schema(full_path) | ||
|
||
if len(extensions) > 1 and not single_schema_issues and "prerelease" in basename: | ||
single_schema_issues += validate_all_schema_formats(basename) | ||
|
||
print(f"Validating: {basename}...") | ||
print(f"Extensions: {extensions}") | ||
if single_schema_issues: | ||
for issue in single_schema_issues: | ||
print(issue) | ||
|
||
all_issues += single_schema_issues | ||
return all_issues |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import sys | ||
from hed.schema import load_schema_version | ||
from hed.scripts.script_util import validate_all_schemas, sort_base_schemas | ||
|
||
|
||
def main(arg_list=None): | ||
# Trigger a local cache hit | ||
_ = load_schema_version("8.2.0") | ||
|
||
if not arg_list: | ||
arg_list = sys.argv[1:] | ||
|
||
schema_files = sort_base_schemas(arg_list) | ||
issues = validate_all_schemas(schema_files) | ||
|
||
if issues: | ||
return 1 | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import unittest | ||
import os | ||
import shutil | ||
import copy | ||
from hed import load_schema, load_schema_version | ||
from hed.schema import HedSectionKey, HedKey | ||
from hed.scripts.script_util import add_extension | ||
from hed.scripts.convert_and_update_schema import convert_and_update | ||
|
||
|
||
class TestConvertAndUpdate(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Create a temporary directory for schema files | ||
cls.base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'schemas_update', 'prerelease') | ||
if not os.path.exists(cls.base_path): | ||
os.makedirs(cls.base_path) | ||
|
||
def test_schema_conversion_and_update(self): | ||
# Load a known schema, modify it if necessary, and save it | ||
schema = load_schema_version("8.3.0") | ||
original_name = os.path.join(self.base_path, "test_schema.mediawiki") | ||
schema.save_as_mediawiki(original_name) | ||
|
||
# Assume filenames updated includes just the original schema file for simplicity | ||
filenames = [original_name] | ||
result = convert_and_update(filenames, set_ids=False) | ||
|
||
# Verify no error from convert_and_update and the correct schema version was saved | ||
self.assertEqual(result, 0) | ||
|
||
tsv_filename = add_extension(os.path.join(self.base_path, "test_schema"), ".tsv") | ||
schema_reload1 = load_schema(tsv_filename) | ||
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml")) | ||
|
||
self.assertEqual(schema, schema_reload1) | ||
self.assertEqual(schema, schema_reload2) | ||
|
||
# Now verify after doing this again with a new schema, they're still the same. | ||
schema = load_schema_version("8.2.0") | ||
schema.save_as_dataframes(tsv_filename) | ||
|
||
filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")] | ||
result = convert_and_update(filenames, set_ids=False) | ||
|
||
# Verify no error from convert_and_update and the correct schema version was saved | ||
self.assertEqual(result, 0) | ||
|
||
schema_reload1 = load_schema(os.path.join(self.base_path, "test_schema.mediawiki")) | ||
schema_reload2 = load_schema(os.path.join(self.base_path, "test_schema.xml")) | ||
|
||
self.assertEqual(schema, schema_reload1) | ||
self.assertEqual(schema, schema_reload2) | ||
|
||
def test_schema_adding_tag(self): | ||
schema = load_schema_version("8.3.0") | ||
basename = os.path.join(self.base_path, "test_schema_edited") | ||
schema.save_as_mediawiki(add_extension(basename, ".mediawiki")) | ||
schema.save_as_xml(add_extension(basename, ".xml")) | ||
schema.save_as_dataframes(add_extension(basename, ".tsv")) | ||
|
||
schema_edited = copy.deepcopy(schema) | ||
test_tag_name = "NewTagWithoutID" | ||
new_entry = schema_edited._create_tag_entry(test_tag_name, HedSectionKey.Tags) | ||
schema_edited._add_tag_to_dict(test_tag_name, new_entry, HedSectionKey.Tags) | ||
|
||
schema_edited.save_as_mediawiki(add_extension(basename, ".mediawiki")) | ||
|
||
# Assume filenames updated includes just the original schema file for simplicity | ||
filenames = [add_extension(basename, ".mediawiki")] | ||
result = convert_and_update(filenames, set_ids=False) | ||
self.assertEqual(result, 0) | ||
|
||
schema_reloaded = load_schema(add_extension(basename, ".xml")) | ||
|
||
self.assertEqual(schema_reloaded, schema_edited) | ||
|
||
result = convert_and_update(filenames, set_ids=True) | ||
self.assertEqual(result, 0) | ||
|
||
schema_reloaded = load_schema(add_extension(basename, ".xml")) | ||
|
||
reloaded_entry = schema_reloaded.tags[test_tag_name] | ||
self.assertTrue(reloaded_entry.has_attribute(HedKey.HedID)) | ||
|
||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
# Clean up the directory created for testing | ||
shutil.rmtree(cls.base_path) |
Oops, something went wrong.