diff --git a/rdflib/plugins/sparql/operators.py b/rdflib/plugins/sparql/operators.py index 0d12a6a26..2e9fb66a0 100644 --- a/rdflib/plugins/sparql/operators.py +++ b/rdflib/plugins/sparql/operators.py @@ -582,14 +582,65 @@ def Builtin_EXISTS(e, ctx): return Literal(not exists) -def Function(e, ctx): +_CUSTOM_FUNCTIONS = {} + +def register_custom_function(uri, func, override=False, raw=False): """ - Custom functions (and casts!) + Register a custom SPARQL function. + + By default, the function will be passed the RDF terms in the argument list. + If raw is True, the function will be passed an Expression and a Context. + + The function must return an RDF term, or raise a SparqlError. """ + if not override and uri in _CUSTOM_FUNCTIONS: + raise ValueError("A function is already registered as %s" % uri.n3()) + _CUSTOM_FUNCTIONS[uri] = (func, raw) - if e.iri in XSD_DTs: - # a cast +def custom_function(uri, override=False, raw=False): + """ + Decorator version of :func:`register_custom_function`. + """ + def decorator(func): + register_custom_function(uri, func, override=override, raw=raw) + return func + return decorator + +def unregister_custom_function(uri, func): + if _CUSTOM_FUNCTIONS.get(uri, (None, None))[0] != func: + raise ValueError("This function is not registered as %s" % uri.n3()) + del _CUSTOM_FUNCTIONS[uri] + +def Function(e, ctx): + """ + Custom functions and casts + """ + pair =_CUSTOM_FUNCTIONS.get(e.iri) + if pair is None: + # no such function is registered + raise SPARQLError('Unknown function %r' % e.iri) + func, raw = pair + if raw: + # function expects expression and context + return func(e, ctx) + else: + # function expects the argument list + try: + return func(*e.expr) + except TypeError as ex: + # wrong argument number + raise SPARQLError(*ex.args) + + +@custom_function(XSD.string, raw=True) +@custom_function(XSD.dateTime, raw=True) +@custom_function(XSD.float, raw=True) +@custom_function(XSD.double, raw=True) +@custom_function(XSD.decimal, raw=True) +@custom_function(XSD.integer, raw=True) +@custom_function(XSD.boolean, raw=True) +def default_cast(e, ctx): if not e.expr: raise SPARQLError("Nothing given to cast.") if len(e.expr) > 1: @@ -654,15 +705,7 @@ def Function(e, ctx): return Literal(True) if x.lower() in ("0", "false"): return Literal(False) - raise SPARQLError("Cannot interpret '%r' as bool" % x) - else: - raise Exception("I do not know how to cast to %r" % e.iri) - - else: - raise SPARQLError('Unknown function %r"%e.iri') - - # TODO: Custom functions! def UnaryNot(expr, ctx): @@ -768,7 +811,7 @@ def RelationalExpression(e, ctx): try: if x == expr: return Literal(True ^ res) - except SPARQLError, e: + except SPARQLError as e: error = e if not error: return Literal(False ^ res) @@ -802,7 +845,7 @@ def RelationalExpression(e, ctx): r = ops[op](expr, other) if r == NotImplemented: raise SPARQLError('Error when comparing') - except TypeError, te: + except TypeError as te: raise SPARQLError(*te.args) return Literal(r) @@ -841,7 +884,7 @@ def ConditionalOrExpression(e, ctx): try: if EBV(x): return Literal(True) - except SPARQLError, e: + except SPARQLError as e: error = e if error: raise error diff --git a/test/test_issue274.py b/test/test_issue274.py new file mode 100644 index 000000000..3affa33e3 --- /dev/null +++ b/test/test_issue274.py @@ -0,0 +1,168 @@ +from nose.tools import assert_raises +from nose.tools import eq_ +from unittest import TestCase + +from rdflib import BNode, Graph, Literal, Namespace, RDFS, XSD +from rdflib.plugins.sparql.operators import register_custom_function, unregister_custom_function + +EX = Namespace('http://example.org/') +G = Graph() +G.add((BNode(), RDFS.label, Literal("bnode"))) +NS = { + 'ex': EX, + 'rdfs': RDFS, + 'xsd': XSD, +} + +def query(querystr, initNs=NS, initBindings=None): + return G.query(querystr, initNs=initNs, initBindings=initBindings) + +def setup(): + pass + +def teardown(): + pass + + +def test_cast_string_to_string(): + res = query('''SELECT (xsd:string("hello") as ?x) {}''') + eq_(list(res)[0][0], Literal("hello", datatype=XSD.string)) + +def test_cast_int_to_string(): + res = query('''SELECT (xsd:string(42) as ?x) {}''') + eq_(list(res)[0][0], Literal("42", datatype=XSD.string)) + +def test_cast_float_to_string(): + res = query('''SELECT (xsd:string(3.14) as ?x) {}''') + eq_(list(res)[0][0], Literal("3.14", datatype=XSD.string)) + +def test_cast_bool_to_string(): + res = query('''SELECT (xsd:string(true) as ?x) {}''') + eq_(list(res)[0][0], Literal("true", datatype=XSD.string)) + +def test_cast_iri_to_string(): + res = query('''SELECT (xsd:string() as ?x) {}''') + eq_(list(res)[0][0], Literal("http://example.org/", datatype=XSD.string)) + +def test_cast_datetime_to_datetime(): + res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:dateTime) as ?x) {}''') + eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime)) + +def test_cast_string_to_datetime(): + res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:string) as ?x) {}''') + eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime)) + +def test_cast_string_to_float(): + res = query('''SELECT (xsd:float("0.5") as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float)) + +def test_cast_int_to_float(): + res = query('''SELECT (xsd:float(1) as ?x) {}''') + eq_(list(res)[0][0], Literal("1", datatype=XSD.float)) + +def test_cast_float_to_float(): + res = query('''SELECT (xsd:float("0.5"^^xsd:float) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float)) + +def test_cast_double_to_float(): + res = query('''SELECT (xsd:float("0.5"^^xsd:double) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float)) + +def test_cast_decimal_to_float(): + res = query('''SELECT (xsd:float("0.5"^^xsd:decimal) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float)) + +def test_cast_string_to_double(): + res = query('''SELECT (xsd:double("0.5") as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double)) + +def test_cast_int_to_double(): + res = query('''SELECT (xsd:double(1) as ?x) {}''') + eq_(list(res)[0][0], Literal("1", datatype=XSD.double)) + +def test_cast_float_to_double(): + res = query('''SELECT (xsd:double("0.5"^^xsd:float) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double)) + +def test_cast_double_to_double(): + res = query('''SELECT (xsd:double("0.5"^^xsd:double) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double)) + +def test_cast_decimal_to_double(): + res = query('''SELECT (xsd:double("0.5"^^xsd:decimal) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double)) + +def test_cast_string_to_decimal(): + res = query('''SELECT (xsd:decimal("0.5") as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal)) + +def test_cast_int_to_decimal(): + res = query('''SELECT (xsd:decimal(1) as ?x) {}''') + eq_(list(res)[0][0], Literal("1", datatype=XSD.decimal)) + +def test_cast_float_to_decimal(): + res = query('''SELECT (xsd:decimal("0.5"^^xsd:float) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal)) + +def test_cast_double_to_decimal(): + res = query('''SELECT (xsd:decimal("0.5"^^xsd:double) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal)) + +def test_cast_decimal_to_decimal(): + res = query('''SELECT (xsd:decimal("0.5"^^xsd:decimal) as ?x) {}''') + eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal)) + +def test_cast_string_to_int(): + res = query('''SELECT (xsd:integer("42") as ?x) {}''') + eq_(list(res)[0][0], Literal("42", datatype=XSD.integer)) + +def test_cast_int_to_int(): + res = query('''SELECT (xsd:integer(42) as ?x) {}''') + eq_(list(res)[0][0], Literal("42", datatype=XSD.integer)) + +def test_cast_string_to_bool(): + res = query('''SELECT (xsd:boolean("TRUE") as ?x) {}''') + eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean)) + +def test_cast_bool_to_bool(): + res = query('''SELECT (xsd:boolean(true) as ?x) {}''') + eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean)) + +def test_cast_bool_to_bool(): + res = query('''SELECT (ex:f(42, "hello") as ?x) {}''') + eq_(len(list(res)), 0) + +class TestCustom(TestCase): + + @staticmethod + def f(x, y): + return Literal("%s %s" % (x, y), datatype=XSD.string) + + def setUp(self): + register_custom_function(EX.f, self.f) + + def tearDown(self): + unregister_custom_function(EX.f, self.f) + + def test_register_twice_fail(self): + with assert_raises(ValueError): + register_custom_function(EX.f, self.f) + + def test_register_override(self): + register_custom_function(EX.f, self.f, override=True) + + def test_wrong_unregister_fails(self): + with assert_raises(ValueError): + unregister_custom_function(EX.f, lambda x, y: None) + + def test_f(self): + res = query('''SELECT (ex:f(42, "hello") as ?x) {}''') + eq_(list(res)[0][0], Literal("42 hello", datatype=XSD.string)) + + def test_f_too_few_args(self): + res = query('''SELECT (ex:f(42) as ?x) {}''') + eq_(len(list(res)), 0) + + def test_f_too_many_args(self): + res = query('''SELECT (ex:f(42, "hello", "world") as ?x) {}''') + eq_(len(list(res)), 0)