Skip to content

Commit

Permalink
indicate non-keyword arguments in docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Apr 5, 2022
1 parent a16f02a commit a22ca65
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
5 changes: 5 additions & 0 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,12 @@ static void nb_func_render_signature(const func_record *f) noexcept {
PyErr_Clear();
}
}

arg_index++;

if (arg_index == f->nargs - has_var_args - has_var_kwargs && !has_args)
buf.put(", /");

break;

case '%':
Expand Down
12 changes: 6 additions & 6 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def assert_stats(**kwargs):
def test01_signature():
assert t.Struct.__init__.__doc__ == (
"__init__(self) -> None\n"
"__init__(self, arg: int) -> None"
"__init__(self, arg: int, /) -> None"
)

assert t.Struct.value.__doc__ == "value(self) -> int"
assert t.Struct.create_move.__doc__ == "create_move() -> test_classes_ext.Struct"
assert t.Struct.set_value.__doc__ == "set_value(self, value: int) -> None"
assert t.Struct.__doc__ == 'Some documentation'
assert t.Struct.static_test.__doc__ == (
"static_test(arg: int) -> int\n"
"static_test(arg: float) -> int")
"static_test(arg: int, /) -> int\n"
"static_test(arg: float, /) -> int")


def test02_static_overload():
Expand Down Expand Up @@ -244,7 +244,7 @@ def test12_large_pointers():


def test13_implicitly_convertible():
assert t.get_d.__doc__ == "get_d(arg: test_classes_ext.D) -> int"
assert t.get_d.__doc__ == "get_d(arg: test_classes_ext.D, /) -> int"
a = t.A(1)
b = t.B(2)
b2 = t.B2(3)
Expand All @@ -254,7 +254,7 @@ def test13_implicitly_convertible():
t.get_d(c)
assert str(excinfo.value) == (
"get_d(): incompatible function arguments. The following argument types are supported:\n"
" 1. get_d(arg: test_classes_ext.D) -> int\n"
" 1. get_d(arg: test_classes_ext.D, /) -> int\n"
"\n"
"Invoked with types: C")
assert t.get_d(a) == 11
Expand Down Expand Up @@ -424,7 +424,7 @@ def test21_low_level(clean):


def test22_handle_of(clean):
assert t.test_handle_of.__doc__ == 'test_handle_of(arg: test_classes_ext.Struct) -> object'
assert t.test_handle_of.__doc__ == 'test_handle_of(arg: test_classes_ext.Struct, /) -> object'
s = t.test_handle_of(t.Struct(5))
assert s.value() == 5
del s
Expand Down
12 changes: 7 additions & 5 deletions tests/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def test01_metadata():


def test02_docstr():
assert t.get_shape.__doc__ == "get_shape(arg: tensor[]) -> list"
assert t.pass_uint32.__doc__ == "pass_uint32(arg: tensor[dtype=uint32]) -> None"
assert t.pass_float32.__doc__ == "pass_float32(arg: tensor[dtype=float32]) -> None"
assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(arg: tensor[dtype=float32, shape=(3, *, 4)]) -> None"
assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(arg: tensor[dtype=float32, order='C', shape=(*, *, 4)]) -> None"
assert t.get_shape.__doc__ == "get_shape(arg: tensor[], /) -> list"
assert t.pass_uint32.__doc__ == "pass_uint32(arg: tensor[dtype=uint32], /) -> None"
assert t.pass_float32.__doc__ == "pass_float32(arg: tensor[dtype=float32], /) -> None"
assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(arg: tensor[dtype=float32, shape=(3, *, 4)], /) -> None"
assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(arg: tensor[dtype=float32, order='C', shape=(*, *, 4)], /) -> None"
assert t.check_device.__doc__ == ("check_device(arg: tensor[device='cpu'], /) -> str\n"
"check_device(arg: tensor[device='cuda'], /) -> str")


def test03_constrain_dtype():
Expand Down
16 changes: 8 additions & 8 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ def test05_signature():
assert t.test_01.__doc__ == 'test_01() -> None'
assert t.test_02.__doc__ == 'test_02(j: int = 8, k: int = 1) -> int'
assert t.test_05.__doc__ == (
"test_05(arg: int) -> int\n"
"test_05(arg: float) -> int\n"
"test_05(arg: int, /) -> int\n"
"test_05(arg: float, /) -> int\n"
"\n"
"Overloaded function.\n"
"\n"
"1. ``test_05(arg: int) -> int``\n"
"1. ``test_05(arg: int, /) -> int``\n"
"\n"
"doc_1\n"
"\n"
"2. ``test_05(arg: float) -> int``\n"
"2. ``test_05(arg: float, /) -> int``\n"
"\n"
"doc_2")

assert t.test_07.__doc__ == (
"test_07(arg0: int, arg1: int, *args, **kwargs) -> tuple[int, int]\n"
"test_07(arg0: int, arg1: int, /, *args, **kwargs) -> tuple[int, int]\n"
"test_07(a: int, b: int, *myargs, **mykwargs) -> tuple[int, int]")

def test06_signature_error():
Expand All @@ -54,8 +54,8 @@ def test06_signature_error():
assert str(excinfo.value) == (
"test_05(): incompatible function arguments. The "
"following argument types are supported:\n"
" 1. test_05(arg: int) -> int\n"
" 2. test_05(arg: float) -> int\n\n"
" 1. test_05(arg: int, /) -> int\n"
" 2. test_05(arg: float, /) -> int\n\n"
"Invoked with types: str, kwargs = { y: int }")


Expand Down Expand Up @@ -143,7 +143,7 @@ def test16_raw_doc():
assert t.test_08.__doc__ == 'raw'

def test17_type_check_manual():
assert t.test_09.__doc__ == 'test_09(arg: type) -> bool'
assert t.test_09.__doc__ == 'test_09(arg: type, /) -> bool'

assert t.test_09(bool) is True
assert t.test_09(int) is False
Expand Down

0 comments on commit a22ca65

Please sign in to comment.