Skip to content

Commit

Permalink
Merge pull request #723 from pchampin/custom-functions
Browse files Browse the repository at this point in the history
added a hook to add custom functions to SPARQL
  • Loading branch information
gromgull authored May 29, 2017
2 parents 7c65b34 + 5cbfe91 commit 2182c0e
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 12 deletions.
67 changes: 55 additions & 12 deletions rdflib/plugins/sparql/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,65 @@ def Builtin_EXISTS(e, ctx):
return Literal(not exists)


def Function(e, ctx):
_CUSTOM_FUNCTIONS = {}

def register_custom_function(uri, func, override=False, raw=False):
"""
Custom functions (and casts!)
Register a custom SPARQL function.
By default, the function will be passed the RDF terms in the argument list.
If raw is True, the function will be passed an Expression and a Context.
The function must return an RDF term, or raise a SparqlError.
"""
if not override and uri in _CUSTOM_FUNCTIONS:
raise ValueError("A function is already registered as %s" % uri.n3())
_CUSTOM_FUNCTIONS[uri] = (func, raw)

if e.iri in XSD_DTs:
# a cast
def custom_function(uri, override=False, raw=False):
"""
Decorator version of :func:`register_custom_function`.
"""
def decorator(func):
register_custom_function(uri, func, override=override, raw=raw)
return func
return decorator

def unregister_custom_function(uri, func):
if _CUSTOM_FUNCTIONS.get(uri, (None, None))[0] != func:
raise ValueError("This function is not registered as %s" % uri.n3())
del _CUSTOM_FUNCTIONS[uri]


def Function(e, ctx):
"""
Custom functions and casts
"""
pair =_CUSTOM_FUNCTIONS.get(e.iri)
if pair is None:
# no such function is registered
raise SPARQLError('Unknown function %r' % e.iri)
func, raw = pair
if raw:
# function expects expression and context
return func(e, ctx)
else:
# function expects the argument list
try:
return func(*e.expr)
except TypeError as ex:
# wrong argument number
raise SPARQLError(*ex.args)


@custom_function(XSD.string, raw=True)
@custom_function(XSD.dateTime, raw=True)
@custom_function(XSD.float, raw=True)
@custom_function(XSD.double, raw=True)
@custom_function(XSD.decimal, raw=True)
@custom_function(XSD.integer, raw=True)
@custom_function(XSD.boolean, raw=True)
def default_cast(e, ctx):
if not e.expr:
raise SPARQLError("Nothing given to cast.")
if len(e.expr) > 1:
Expand Down Expand Up @@ -657,15 +708,7 @@ def Function(e, ctx):
return Literal(True)
if x.lower() in ("0", "false"):
return Literal(False)

raise SPARQLError("Cannot interpret '%r' as bool" % x)
else:
raise Exception("I do not know how to cast to %r" % e.iri)

else:
raise SPARQLError('Unknown function %r"%e.iri')

# TODO: Custom functions!


def UnaryNot(expr, ctx):
Expand Down
168 changes: 168 additions & 0 deletions test/test_issue274.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from nose.tools import assert_raises
from nose.tools import eq_
from unittest import TestCase

from rdflib import BNode, Graph, Literal, Namespace, RDFS, XSD
from rdflib.plugins.sparql.operators import register_custom_function, unregister_custom_function

EX = Namespace('http://example.org/')
G = Graph()
G.add((BNode(), RDFS.label, Literal("bnode")))
NS = {
'ex': EX,
'rdfs': RDFS,
'xsd': XSD,
}

def query(querystr, initNs=NS, initBindings=None):
return G.query(querystr, initNs=initNs, initBindings=initBindings)

def setup():
pass

def teardown():
pass


def test_cast_string_to_string():
res = query('''SELECT (xsd:string("hello") as ?x) {}''')
eq_(list(res)[0][0], Literal("hello", datatype=XSD.string))

def test_cast_int_to_string():
res = query('''SELECT (xsd:string(42) as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.string))

def test_cast_float_to_string():
res = query('''SELECT (xsd:string(3.14) as ?x) {}''')
eq_(list(res)[0][0], Literal("3.14", datatype=XSD.string))

def test_cast_bool_to_string():
res = query('''SELECT (xsd:string(true) as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.string))

def test_cast_iri_to_string():
res = query('''SELECT (xsd:string(<http://example.org/>) as ?x) {}''')
eq_(list(res)[0][0], Literal("http://example.org/", datatype=XSD.string))

def test_cast_datetime_to_datetime():
res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:dateTime) as ?x) {}''')
eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime))

def test_cast_string_to_datetime():
res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:string) as ?x) {}''')
eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime))

def test_cast_string_to_float():
res = query('''SELECT (xsd:float("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_int_to_float():
res = query('''SELECT (xsd:float(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.float))

def test_cast_float_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_double_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_decimal_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_string_to_double():
res = query('''SELECT (xsd:double("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_int_to_double():
res = query('''SELECT (xsd:double(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.double))

def test_cast_float_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_double_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_decimal_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_string_to_decimal():
res = query('''SELECT (xsd:decimal("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_int_to_decimal():
res = query('''SELECT (xsd:decimal(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.decimal))

def test_cast_float_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_double_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_decimal_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_string_to_int():
res = query('''SELECT (xsd:integer("42") as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.integer))

def test_cast_int_to_int():
res = query('''SELECT (xsd:integer(42) as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.integer))

def test_cast_string_to_bool():
res = query('''SELECT (xsd:boolean("TRUE") as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean))

def test_cast_bool_to_bool():
res = query('''SELECT (xsd:boolean(true) as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean))

def test_cast_bool_to_bool():
res = query('''SELECT (ex:f(42, "hello") as ?x) {}''')
eq_(len(list(res)), 0)

class TestCustom(TestCase):

@staticmethod
def f(x, y):
return Literal("%s %s" % (x, y), datatype=XSD.string)

def setUp(self):
register_custom_function(EX.f, self.f)

def tearDown(self):
unregister_custom_function(EX.f, self.f)

def test_register_twice_fail(self):
with assert_raises(ValueError):
register_custom_function(EX.f, self.f)

def test_register_override(self):
register_custom_function(EX.f, self.f, override=True)

def test_wrong_unregister_fails(self):
with assert_raises(ValueError):
unregister_custom_function(EX.f, lambda x, y: None)

def test_f(self):
res = query('''SELECT (ex:f(42, "hello") as ?x) {}''')
eq_(list(res)[0][0], Literal("42 hello", datatype=XSD.string))

def test_f_too_few_args(self):
res = query('''SELECT (ex:f(42) as ?x) {}''')
eq_(len(list(res)), 0)

def test_f_too_many_args(self):
res = query('''SELECT (ex:f(42, "hello", "world") as ?x) {}''')
eq_(len(list(res)), 0)

0 comments on commit 2182c0e

Please sign in to comment.