diff --git a/test/test_sparqlstore.py b/test/test_sparqlstore.py index 8d720de37..8b11a4c40 100644 --- a/test/test_sparqlstore.py +++ b/test/test_sparqlstore.py @@ -3,9 +3,18 @@ from urllib.error import HTTPError import unittest from nose import SkipTest -from http.server import BaseHTTPRequestHandler, HTTPServer +from http.server import BaseHTTPRequestHandler, HTTPServer, SimpleHTTPRequestHandler import socket from threading import Thread +from contextlib import contextmanager +from unittest.mock import MagicMock, Mock, patch +import typing as t +import random +import collections +from urllib.parse import ParseResult, urlparse, parse_qs +from rdflib.namespace import RDF, XSD, XMLNS, FOAF, RDFS +from rdflib.plugins.stores.sparqlstore import SPARQLConnector +import email.message from . import helper @@ -33,9 +42,26 @@ def tearDown(self): def test_Query(self): query = "select distinct ?Concept where {[] a ?Concept} LIMIT 1" - res = helper.query_with_retry(self.graph, query, initNs={}) - for i in res: - assert type(i[0]) == URIRef, i[0].n3() + _query = SPARQLConnector.query + with patch("rdflib.plugins.stores.sparqlstore.SPARQLConnector.query") as mock: + SPARQLConnector.query.side_effect = lambda *args, **kwargs: _query( + self.graph.store, *args, **kwargs + ) + res = helper.query_with_retry(self.graph, query, initNs={}) + count = 0 + for i in res: + count += 1 + assert type(i[0]) == URIRef, i[0].n3() + assert count > 0 + mock.assert_called_once() + args, kwargs = mock.call_args + + def unpacker(query, default_graph=None, named_graph=None): + return query, default_graph, named_graph + + (mquery, _, _) = unpacker(*args, *kwargs) + for _, uri in self.graph.namespaces(): + assert mquery.count(f"<{uri}>") == 1 def test_initNs(self): query = """\ @@ -196,5 +222,134 @@ def do_GET(self): return +def get_random_ip(parts: t.List[str] = None) -> str: + if parts is None: + parts = ["127"] + for index in range(4 - len(parts)): + parts.append(f"{random.randint(0, 255)}") + return ".".join(parts) + + +@contextmanager +def ctx_http_server(handler: t.Type[BaseHTTPRequestHandler]) -> t.Iterator[HTTPServer]: + host = get_random_ip() + server = HTTPServer((host, 0), handler) + server_thread = Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + yield server + server.shutdown() + server.socket.close() + server_thread.join() + + +GenericT = t.TypeVar("GenericT", bound=t.Any) + + +def make_spypair(method: GenericT) -> t.Tuple[GenericT, Mock]: + m = MagicMock() + + def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: + m(*args, **kwargs) + return method(self, *args, **kwargs) + + setattr(wrapper, "mock", m) + return t.cast(GenericT, wrapper), m + + +HeadersT = t.Dict[str, t.List[str]] +PathQueryT = t.Dict[str, t.List[str]] + + +class MockHTTPRequests(t.NamedTuple): + path: str + parsed_path: ParseResult + path_query: PathQueryT + headers: email.message.Message + + +class MockHTTPResponse(t.NamedTuple): + status_code: int + reason_phrase: str + body: bytes + headers: HeadersT = collections.defaultdict(list) + + +class SPARQLMockTests(unittest.TestCase): + requests: t.List[MockHTTPRequests] = [] + responses: t.List[MockHTTPResponse] = [] + + def setUp(self): + _tc = self + + class Handler(SimpleHTTPRequestHandler): + tc = _tc + + def _do_GET(self): + parsed_path = urlparse(self.path) + path_query = parse_qs(parsed_path.query) + request = MockHTTPRequests( + self.path, parsed_path, path_query, self.headers + ) + self.tc.requests.append(request) + + response = self.tc.responses.pop(0) + self.send_response(response.status_code, response.reason_phrase) + for header, values in response.headers.items(): + for value in values: + self.send_header(header, value) + self.end_headers() + + self.wfile.write(response.body) + self.wfile.flush() + return + + (do_GET, do_GET_mock) = make_spypair(_do_GET) + self.Handler = Handler + self.requests.clear() + self.responses.clear() + + def test_query(self): + triples = { + (RDFS.Resource, RDF.type, RDFS.Class), + (RDFS.Resource, RDFS.isDefinedBy, URIRef(RDFS)), + (RDFS.Resource, RDFS.label, Literal("Resource")), + (RDFS.Resource, RDFS.comment, Literal("The class resource, everything.")), + } + rows = "\n".join([f'"{s}","{p}","{o}"' for s, p, o in triples]) + response_body = f"s,p,o\n{rows}".encode() + response = MockHTTPResponse(200, "OK", response_body) + response.headers["Content-Type"].append("text/csv; charset=utf-8") + self.responses.append(response) + + graph = Graph(store="SPARQLStore", identifier="http://example.com") + graph.bind("xsd", XSD) + graph.bind("xml", XMLNS) + graph.bind("foaf", FOAF) + graph.bind("rdf", RDF) + + assert len(list(graph.namespaces())) >= 4 + + with ctx_http_server(self.Handler) as server: + (host, port) = server.server_address + url = f"http://{host}:{port}/query" + graph.open(url) + query_result = graph.query("SELECT ?s ?p ?o WHERE { ?s ?p ?o }") + + rows = set(query_result) + assert len(rows) == len(triples) + for triple in triples: + assert triple in rows + + self.Handler.do_GET_mock.assert_called_once() + assert len(self.requests) == 1 + request = self.requests.pop() + assert len(request.path_query["query"]) == 1 + query = request.path_query["query"][0] + + for _, uri in graph.namespaces(): + assert query.count(f"<{uri}>") == 1 + + if __name__ == "__main__": unittest.main()