Skip to content

Commit

Permalink
Detach schema export as a function in a new module
Browse files Browse the repository at this point in the history
  new file:   xmlschema/exports.py

  - Fix location matching (getting locations from schema source)
  - Add a flag to define if an exported schema has been processed
  - Add option remove_residuals=True for remove location hints
    from unused import statements
  • Loading branch information
brunato committed Sep 21, 2023
1 parent 866ed43 commit eac7eb0
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 126 deletions.
13 changes: 13 additions & 0 deletions tests/test_cases/issues/issue_362/dir1/dir2/issue_362_2.xsd
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<xs:schema
xmlns:xs="http://www.w3.org/2001/XMLSchema"
targetNamespace="http://xmlschema.test/tns2"
elementFormDefault="qualified">

<xs:include schemaLocation="../../dir2/issue_362_2.xsd"/>
<xs:import namespace="http://xmlschema.test/tns1" schemaLocation="http://xmlschema.test/tns1"/>
<xs:import namespace="http://xmlschema.test/tns1" schemaLocation="../issue_362_1.xsd"/>

<xs:element name="item2" />

</xs:schema>

11 changes: 11 additions & 0 deletions tests/test_cases/issues/issue_362/dir1/issue_362_1.xsd
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<xs:schema
xmlns:xs="http://www.w3.org/2001/XMLSchema"
targetNamespace="http://xmlschema.test/tns1"
elementFormDefault="qualified">

<xs:include schemaLocation="../issue_362_1.xsd"/>
<xs:import namespace="http://xmlschema.test/tns2" schemaLocation="http://xmlschema.test/tns2"/>

<xs:element name="item1" />

</xs:schema>
12 changes: 12 additions & 0 deletions tests/test_cases/issues/issue_362/dir2/issue_362_2.xsd
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<xs:schema
xmlns:xs="http://www.w3.org/2001/XMLSchema"
targetNamespace="http://xmlschema.test/tns2"
elementFormDefault="qualified">

<xs:include schemaLocation="../dir1/dir2/issue_362_2.xsd"/>
<xs:import namespace="http://xmlschema.test/tns1" schemaLocation="http://xmlschema.test/tns1"/>

<xs:element name="item3" />

</xs:schema>

25 changes: 25 additions & 0 deletions tests/test_cases/issues/issue_362/issue_362_1.xsd
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<!--
A test for export schemas with crossed imports/includes and additional failing remote imports.
-->
<xs:schema
xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:tns1="http://xmlschema.test/tns1"
xmlns:tns2="http://xmlschema.test/tns2"
targetNamespace="http://xmlschema.test/tns1">

<xs:include schemaLocation="./dir1/../dir1/issue_362_1.xsd"/>
<xs:import namespace="http://xmlschema.test/tns2" schemaLocation="http://xmlschema.test/tns2"/>
<xs:import namespace="http://xmlschema.test/tns2" schemaLocation="dir1/dir2/issue_362_2.xsd"/>

<xs:element name="root">
<xs:complexType>
<xs:sequence>
<xs:element ref="tns1:item1" />
<xs:element ref="tns2:item2" />
<xs:element ref="tns2:item3" />
</xs:sequence>
</xs:complexType>
</xs:element>

</xs:schema>

32 changes: 32 additions & 0 deletions tests/validators/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,38 @@ def test_export_other_encoding(self):
self.assertFalse(filecmp.cmp(schema_ascii_file, exported_schema))
self.assertTrue(filecmp.cmp(schema_cp1252_file, exported_schema))

def test_export_more_remote_imports__issue_362(self):
schema_file = self.casepath('issues/issue_362/issue_362_1.xsd')
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
schema = self.schema_class(schema_file)

self.assertIn('{http://xmlschema.test/tns1}root', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns1}item1', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns2}item2', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns2}item3', schema.maps.elements)

with tempfile.TemporaryDirectory() as dirname:
schema.export(target=dirname)

exported_files = set(
str(x.relative_to(dirname)).replace('\\', '/')
for x in pathlib.Path(dirname).glob('**/*.xsd')
)
self.assertSetEqual(
exported_files,
{'issue_362_1.xsd', 'dir2/issue_362_2.xsd', 'dir1/issue_362_1.xsd',
'dir1/dir2/issue_362_2.xsd', 'issue_362_1.xsd', 'dir2/issue_362_2.xsd',
'dir1/issue_362_1.xsd', 'dir1/dir2/issue_362_2.xsd'}
)

schema_file = pathlib.Path(dirname).joinpath('issue_362_1.xsd')
schema = self.schema_class(schema_file)
self.assertIn('{http://xmlschema.test/tns1}root', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns1}item1', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns2}item2', schema.maps.elements)
self.assertIn('{http://xmlschema.test/tns2}item3', schema.maps.elements)

def test_pickling_subclassed_schema__issue_263(self):
cases_dir = pathlib.Path(__file__).parent.parent
schema_file = cases_dir.joinpath('test_cases/examples/vehicles/vehicles.xsd')
Expand Down
184 changes: 184 additions & 0 deletions xmlschema/exports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#
# Copyright (c), 2016-2023, SISSA (International School for Advanced Studies).
# All rights reserved.
# This file is distributed under the terms of the MIT License.
# See the file 'LICENSE' in the root directory of the present
# distribution, or http://opensource.org/licenses/MIT.
#
# @author Davide Brunato <brunato@sissa.it>
#
import re
import pathlib
from itertools import chain
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote, urlsplit

from .exceptions import XMLSchemaValueError
from .resources import _PurePath, is_remote_url
from .translation import gettext as _

if TYPE_CHECKING:
from .validators import XMLSchemaBase


def replace_location(text: str, location: str, repl_location: str) -> str:
repl = 'schemaLocation="{}"'.format(repl_location)
pattern = r'\bschemaLocation\s*=\s*[\'\"].*%s.*[\'"]' % re.escape(location)
return re.sub(pattern, repl, text)


def export_schema(obj: 'XMLSchemaBase', target_dir: str,
save_remote: bool = False, remove_residuals: bool = True) -> None:

target_path = pathlib.Path(target_dir)
if target_path.is_dir():
if list(target_path.iterdir()):
msg = _("target directory {} is not empty")
raise XMLSchemaValueError(msg.format(target_dir))
elif target_path.exists():
msg = _("target {} is not a directory")
raise XMLSchemaValueError(msg.format(target_path.parent))
elif not target_path.parent.exists():
msg = _("target parent directory {} does not exist")
raise XMLSchemaValueError(msg.format(target_path.parent))
elif not target_path.parent.is_dir():
msg = _("target parent {} is not a directory")
raise XMLSchemaValueError(msg.format(target_path.parent))

name = obj.name or 'schema.xsd'
exports: Any = {obj: [_PurePath(unquote(name)), obj.get_text(), False]}
path: Any

while True:
current_length = len(exports)

for schema in list(exports):
if exports[schema][2]:
continue # Skip already processed schemas
exports[schema][2] = True

dir_path = exports[schema][0].parent
imports_items = [(x.url, x) for x in schema.imports.values()
if x is not None]

pattern = r'\bschemaLocation\s*=\s*[\'\"](.*)[\'"]'
schema_locations = set(
x.strip() for x in re.findall(pattern, exports[schema][1])
)

for location, ref_schema in chain(schema.includes.items(), imports_items):

# Find matching schema location
if location in schema_locations:
schema_locations.remove(location)
else:
name = ref_schema.name
assert isinstance(name, str)

matching_items = [x for x in schema_locations if x.endswith(name)]
if len(matching_items) == 1:
location = matching_items[0]
schema_locations.remove(location)
elif not matching_items:
continue
else:
for item in matching_items:
item_path = _PurePath.from_uri(item)
if location.endswith(str(item_path).lstrip('.')):
location = item
schema_locations.remove(location)
break
else:
location = matching_items[0]
schema_locations.remove(location)

if is_remote_url(location):
if not save_remote:
continue

parts = urlsplit(unquote(location))
path = _PurePath(parts.scheme). \
joinpath(parts.netloc). \
joinpath(parts.path.lstrip('/'))
else:
if location.startswith('file:/'):
path = _PurePath(unquote(urlsplit(location).path))
else:
path = _PurePath(unquote(location))

if not path.is_absolute():
path = dir_path.joinpath(path).normalize()
if not str(path).startswith('..'):
# A relative path that doesn't exceed the loading schema dir
if ref_schema not in exports:
exports[ref_schema] = [path, ref_schema.get_text(), False]
continue

# Use the absolute schema path
schema_path = ref_schema.filepath
assert schema_path is not None
path = _PurePath(schema_path)

if path.drive:
drive = path.drive.split(':')[0]
path = _PurePath(drive).joinpath('/'.join(path.parts[1:]))

path = _PurePath('file').joinpath(path.as_posix().lstrip('/'))

parts = path.parent.parts
dir_parts = dir_path.parts

k = 0
for item1, item2 in zip(parts, dir_parts):
if item1 != item2:
break
k += 1

if not k:
prefix = '/'.join(['..'] * len(dir_parts))
repl_path = _PurePath(prefix).joinpath(path)
else:
repl_path = _PurePath('/'.join(parts[k:])).joinpath(path.name)
if k < len(dir_parts):
prefix = '/'.join(['..'] * (len(dir_parts) - k))
repl_path = _PurePath(prefix).joinpath(repl_path)

repl = repl_path.as_posix()
exports[schema][1] = replace_location(exports[schema][1], location, repl)
if ref_schema not in exports:
exports[ref_schema] = [path, ref_schema.get_text(), False]

if remove_residuals:
# Deactivate residual redundant imports
for location in filter(lambda x: x not in schema.includes, schema_locations):
exports[schema][1] = replace_location(exports[schema][1], location, '')

if current_length == len(exports):
break

for schema, (path, text, processed) in exports.items():
assert processed

filepath = target_path.joinpath(path)

# Safety check: raise error if filepath is not inside the target path
try:
filepath.resolve(strict=False).relative_to(target_path.resolve(strict=False))
except ValueError:
msg = _("target directory {} violation for exported path {}, {}")
raise XMLSchemaValueError(msg.format(target_dir, str(path), str(filepath)))

if not filepath.parent.exists():
filepath.parent.mkdir(parents=True)

encoding = 'utf-8' # default encoding for XML 1.0

if text.startswith('<?'):
# Get the encoding from XML declaration
xml_declaration = text.split('\n', maxsplit=1)[0]
re_match = re.search('(?<=encoding=["\'])[^"\']+', xml_declaration)
if re_match is not None:
encoding = re_match.group(0).lower()

with filepath.open(mode='w', encoding=encoding) as fp:
fp.write(text)
Loading

0 comments on commit eac7eb0

Please sign in to comment.