diff --git a/src/check_jsonschema/schema_loader/main.py b/src/check_jsonschema/schema_loader/main.py index 4ce95c9e5..88ff6a0f2 100644 --- a/src/check_jsonschema/schema_loader/main.py +++ b/src/check_jsonschema/schema_loader/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import pathlib import typing as t import urllib.error @@ -130,11 +131,21 @@ def get_validator( instance_doc: dict[str, t.Any], format_opts: FormatOptions, fill_defaults: bool, + ) -> jsonschema.protocols.Validator: + return self._get_validator(format_opts, fill_defaults) + + @functools.lru_cache + def _get_validator( + self, + format_opts: FormatOptions, + fill_defaults: bool, ) -> jsonschema.protocols.Validator: retrieval_uri = self.get_schema_retrieval_uri() schema = self.get_schema() schema_dialect = schema.get("$schema") + if schema_dialect is not None and not isinstance(schema_dialect, str): + schema_dialect = None # format checker (which may be None) format_checker = make_format_checker(format_opts, schema_dialect) diff --git a/src/check_jsonschema/schema_loader/resolver.py b/src/check_jsonschema/schema_loader/resolver.py index c63b7bb4d..5084328a5 100644 --- a/src/check_jsonschema/schema_loader/resolver.py +++ b/src/check_jsonschema/schema_loader/resolver.py @@ -79,8 +79,8 @@ def retrieve_reference(uri: str) -> referencing.Resource[Schema]: else: full_uri = uri - if full_uri in cache._cache: - return cache[uri] + if full_uri in cache: + return cache[full_uri] full_uri_scheme = urllib.parse.urlsplit(full_uri).scheme if full_uri_scheme in ("http", "https"): @@ -100,8 +100,8 @@ def validation_callback(content: bytes) -> None: else: parsed_object = get_local_file(full_uri) - cache[uri] = parsed_object - return cache[uri] + cache[full_uri] = parsed_object + return cache[full_uri] return retrieve_reference diff --git a/tests/acceptance/test_remote_ref_resolution.py b/tests/acceptance/test_remote_ref_resolution.py index d95fba555..3dafc4c8a 100644 --- a/tests/acceptance/test_remote_ref_resolution.py +++ b/tests/acceptance/test_remote_ref_resolution.py @@ -244,3 +244,60 @@ def test_ref_resolution_with_custom_base_uri(run_line, tmp_path, check_passes): assert result.exit_code == 0, output else: assert result.exit_code == 1, output + + +@pytest.mark.parametrize("num_instances", (1, 2, 10)) +@pytest.mark.parametrize("check_passes", (True, False)) +def test_remote_ref_resolution_callout_count_is_scale_free_in_instancefiles( + run_line, tmp_path, num_instances, check_passes +): + """ + Test that for any N > 1, validation of a schema with a ref against N instance files + has exactly the same number of callouts as validation when N=1 + + This proves that the validator and caching are working correctly, and we aren't + repeating callouts to rebuild state. + """ + schema_uri = "https://example.org/schemas/main.json" + ref_uri = "https://example.org/schemas/title_schema.json" + + main_schema = { + "$id": schema_uri, + "$schema": "http://json-schema.org/draft-07/schema", + "properties": { + "title": {"$ref": "./title_schema.json"}, + }, + "additionalProperties": False, + } + title_schema = {"type": "string"} + responses.add("GET", schema_uri, json=main_schema) + responses.add("GET", ref_uri, json=title_schema) + + # write N documents + instance_doc = {"title": "doc one" if check_passes else 2} + instance_paths = [] + for i in range(num_instances): + instance_path = tmp_path / f"instance{i}.json" + instance_path.write_text(json.dumps(instance_doc)) + instance_paths.append(str(instance_path)) + + result = run_line( + [ + "check-jsonschema", + "--schemafile", + schema_uri, + ] + + instance_paths + ) + output = f"\nstdout:\n{result.stdout}\n\nstderr:\n{result.stderr}" + if check_passes: + assert result.exit_code == 0, output + else: + assert result.exit_code == 1, output + + # this is the moment of the "real" test run here: + # no matter how many instances there were, there should only have been two calls + # (one for the schema and one for the $ref) + assert len(responses.calls) == 2 + assert len([c for c in responses.calls if c.request.url == schema_uri]) == 1 + assert len([c for c in responses.calls if c.request.url == ref_uri]) == 1