Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added a hook to add custom functions to SPARQL #723

Merged
merged 4 commits into from
May 29, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions rdflib/plugins/sparql/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,66 @@ def Builtin_EXISTS(e, ctx):
return Literal(not exists)


def Function(e, ctx):
_CUSTOM_FUNCTIONS = {}

def register_custom_function(uri, func, override=False, raw=False):
"""
Register a custom SPARQL function.

By default, the function will be passed the RDF terms in the argument list.
If raw is True, the function will be passed an Expression and a Context.

The function must return an RDF term, or raise a SparqlError.
"""
Custom functions (and casts!)
if not override and uri in _CUSTOM_FUNCTIONS:
raise ValueError("A function is already registered as %s" % uri.n3())
_CUSTOM_FUNCTIONS[uri] = (func, raw)

def custom_function(uri, override=False, raw=False):
"""
Decorator version of :func:`register_custom_function`.
"""
def decorator(func):
register_custom_function(uri, func, override=override, raw=raw)
return func
return decorator

if e.iri in XSD_DTs:
# a cast
def unregister_custom_function(uri, func):
if _CUSTOM_FUNCTIONS.get(uri, (None, None))[0] != func:
raise ValueError("This function is not registered as %s" % uri.n3())
del _CUSTOM_FUNCTIONS[uri]


def Function(e, ctx):
"""
Custom functions and casts
"""
pair =_CUSTOM_FUNCTIONS.get(e.iri)
if pair is None:
# no such function is registered
raise SPARQLError('Unknown function %r"%e.iri')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think what you actually want is raise SPARQLError('Unknown function %r' % e.iri)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what the... :-/ well spotted !

func, raw = pair
if raw:
# function expects expression and context
return func(e, ctx)
else:
# function expects the argument list
try:
return func(*e.expr)
except TypeError as ex:
# wrong argument number
raise SPARQLError(*ex.args)



@custom_function(XSD.string, raw=True)
@custom_function(XSD.dateTime, raw=True)
@custom_function(XSD.float, raw=True)
@custom_function(XSD.double, raw=True)
@custom_function(XSD.decimal, raw=True)
@custom_function(XSD.integer, raw=True)
@custom_function(XSD.boolean, raw=True)
def default_cast(e, ctx):
if not e.expr:
raise SPARQLError("Nothing given to cast.")
if len(e.expr) > 1:
Expand Down Expand Up @@ -654,15 +706,7 @@ def Function(e, ctx):
return Literal(True)
if x.lower() in ("0", "false"):
return Literal(False)

raise SPARQLError("Cannot interpret '%r' as bool" % x)
else:
raise Exception("I do not know how to cast to %r" % e.iri)

else:
raise SPARQLError('Unknown function %r"%e.iri')

# TODO: Custom functions!


def UnaryNot(expr, ctx):
Expand Down Expand Up @@ -768,7 +812,7 @@ def RelationalExpression(e, ctx):
try:
if x == expr:
return Literal(True ^ res)
except SPARQLError, e:
except SPARQLError as e:
error = e
if not error:
return Literal(False ^ res)
Expand Down Expand Up @@ -802,7 +846,7 @@ def RelationalExpression(e, ctx):
r = ops[op](expr, other)
if r == NotImplemented:
raise SPARQLError('Error when comparing')
except TypeError, te:
except TypeError as te:
raise SPARQLError(*te.args)
return Literal(r)

Expand Down Expand Up @@ -841,7 +885,7 @@ def ConditionalOrExpression(e, ctx):
try:
if EBV(x):
return Literal(True)
except SPARQLError, e:
except SPARQLError as e:
error = e
if error:
raise error
Expand Down
168 changes: 168 additions & 0 deletions test/test_issue274.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from nose.tools import assert_raises
from nose.tools import eq_
from unittest import TestCase

from rdflib import BNode, Graph, Literal, Namespace, RDFS, XSD
from rdflib.plugins.sparql.operators import register_custom_function, unregister_custom_function

EX = Namespace('http://example.org/')
G = Graph()
G.add((BNode(), RDFS.label, Literal("bnode")))
NS = {
'ex': EX,
'rdfs': RDFS,
'xsd': XSD,
}

def query(querystr, initNs=NS, initBindings=None):
return G.query(querystr, initNs=initNs, initBindings=initBindings)

def setup():
pass

def teardown():
pass


def test_cast_string_to_string():
res = query('''SELECT (xsd:string("hello") as ?x) {}''')
eq_(list(res)[0][0], Literal("hello", datatype=XSD.string))

def test_cast_int_to_string():
res = query('''SELECT (xsd:string(42) as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.string))

def test_cast_float_to_string():
res = query('''SELECT (xsd:string(3.14) as ?x) {}''')
eq_(list(res)[0][0], Literal("3.14", datatype=XSD.string))

def test_cast_bool_to_string():
res = query('''SELECT (xsd:string(true) as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.string))

def test_cast_iri_to_string():
res = query('''SELECT (xsd:string(<http://example.org/>) as ?x) {}''')
eq_(list(res)[0][0], Literal("http://example.org/", datatype=XSD.string))

def test_cast_datetime_to_datetime():
res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:dateTime) as ?x) {}''')
eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime))

def test_cast_string_to_datetime():
res = query('''SELECT (xsd:dateTime("1970-01-01T00:00:00Z"^^xsd:string) as ?x) {}''')
eq_(list(res)[0][0], Literal("1970-01-01T00:00:00Z", datatype=XSD.dateTime))

def test_cast_string_to_float():
res = query('''SELECT (xsd:float("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_int_to_float():
res = query('''SELECT (xsd:float(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.float))

def test_cast_float_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_double_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_decimal_to_float():
res = query('''SELECT (xsd:float("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.float))

def test_cast_string_to_double():
res = query('''SELECT (xsd:double("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_int_to_double():
res = query('''SELECT (xsd:double(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.double))

def test_cast_float_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_double_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_decimal_to_double():
res = query('''SELECT (xsd:double("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.double))

def test_cast_string_to_decimal():
res = query('''SELECT (xsd:decimal("0.5") as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_int_to_decimal():
res = query('''SELECT (xsd:decimal(1) as ?x) {}''')
eq_(list(res)[0][0], Literal("1", datatype=XSD.decimal))

def test_cast_float_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:float) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_double_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:double) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_decimal_to_decimal():
res = query('''SELECT (xsd:decimal("0.5"^^xsd:decimal) as ?x) {}''')
eq_(list(res)[0][0], Literal("0.5", datatype=XSD.decimal))

def test_cast_string_to_int():
res = query('''SELECT (xsd:integer("42") as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.integer))

def test_cast_int_to_int():
res = query('''SELECT (xsd:integer(42) as ?x) {}''')
eq_(list(res)[0][0], Literal("42", datatype=XSD.integer))

def test_cast_string_to_bool():
res = query('''SELECT (xsd:boolean("TRUE") as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean))

def test_cast_bool_to_bool():
res = query('''SELECT (xsd:boolean(true) as ?x) {}''')
eq_(list(res)[0][0], Literal("true", datatype=XSD.boolean))

def test_cast_bool_to_bool():
res = query('''SELECT (ex:f(42, "hello") as ?x) {}''')
eq_(len(list(res)), 0)

class TestCustom(TestCase):

@staticmethod
def f(x, y):
return Literal("%s %s" % (x, y), datatype=XSD.string)

def setUp(self):
register_custom_function(EX.f, self.f)

def tearDown(self):
unregister_custom_function(EX.f, self.f)

def test_register_twice_fail(self):
with assert_raises(ValueError):
register_custom_function(EX.f, self.f)

def test_register_override(self):
register_custom_function(EX.f, self.f, override=True)

def test_wrong_unregister_fails(self):
with assert_raises(ValueError):
unregister_custom_function(EX.f, lambda x, y: None)

def test_f(self):
res = query('''SELECT (ex:f(42, "hello") as ?x) {}''')
eq_(list(res)[0][0], Literal("42 hello", datatype=XSD.string))

def test_f_too_few_args(self):
res = query('''SELECT (ex:f(42) as ?x) {}''')
eq_(len(list(res)), 0)

def test_f_too_many_args(self):
res = query('''SELECT (ex:f(42, "hello", "world") as ?x) {}''')
eq_(len(list(res)), 0)