Skip to content

Commit

Permalink
Merge branch 'main' into scheduler_annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 11, 2022
2 parents 8a158a3 + e390609 commit e51734f
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 71 deletions.
3 changes: 3 additions & 0 deletions continuous_integration/scripts/parse_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import html
import re
import sys
from collections import Counter, defaultdict
Expand Down Expand Up @@ -70,6 +71,8 @@ def build_xml(rows: list[tuple[str, str, set[str | None]]]) -> None:
)

for clsname, tname, outcomes in rows:
clsname = html.escape(clsname)
tname = html.escape(tname)
print(f'<testcase classname="{clsname}" name="{tname}" time="0.0"', end="")
if outcomes == {"PASSED"}:
print(" />")
Expand Down
17 changes: 12 additions & 5 deletions distributed/comm/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
get_tcp_server_address,
to_frames,
)
from distributed.utils import ensure_bytes, nbytes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,14 +136,18 @@ async def write(self, msg, serializers=None, on_error=None):
frame_split_size=BIG_BYTES_SHARD_SIZE,
)
n = struct.pack("Q", len(frames))
nbytes_frames = 0
try:
await self.handler.write_message(n, binary=True)
for frame in frames:
await self.handler.write_message(ensure_bytes(frame), binary=True)
if type(frame) is not bytes:
frame = bytes(frame)
await self.handler.write_message(frame, binary=True)
nbytes_frames += len(frame)
except WebSocketClosedError as e:
raise CommClosedError(str(e))

return sum(map(nbytes, frames))
return nbytes_frames

def abort(self):
self.handler.close()
Expand Down Expand Up @@ -226,14 +229,18 @@ async def write(self, msg, serializers=None, on_error=None):
frame_split_size=BIG_BYTES_SHARD_SIZE,
)
n = struct.pack("Q", len(frames))
nbytes_frames = 0
try:
await self.sock.write_message(n, binary=True)
for frame in frames:
await self.sock.write_message(ensure_bytes(frame), binary=True)
if type(frame) is not bytes:
frame = bytes(frame)
await self.sock.write_message(frame, binary=True)
nbytes_frames += len(frame)
except WebSocketClosedError as e:
raise CommClosedError(e)

return sum(map(nbytes, frames))
return nbytes_frames

async def close(self):
if not self.sock.close_code:
Expand Down
15 changes: 7 additions & 8 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
pack_frames_prelude,
unpack_frames,
)
from distributed.utils import has_keyword
from distributed.utils import ensure_memoryview, has_keyword

dask_serialize = dask.utils.Dispatch("dask_serialize")
dask_deserialize = dask.utils.Dispatch("dask_deserialize")
Expand Down Expand Up @@ -91,7 +91,7 @@ def pickle_loads(header, frames):
memoryviews = map(memoryview, buffers)
for w, mv in zip(writeable, memoryviews):
if w == mv.readonly:
if mv.readonly:
if w:
mv = memoryview(bytearray(mv))
else:
mv = memoryview(bytes(mv))
Expand Down Expand Up @@ -765,12 +765,11 @@ def _serialize_array(obj):
@dask_deserialize.register(array)
def _deserialize_array(header, frames):
a = array(header["typecode"])
for f in map(memoryview, frames):
try:
f = f.cast("B")
except TypeError:
f = f.tobytes()
a.frombytes(f)
nframes = len(frames)
if nframes == 1:
a.frombytes(ensure_memoryview(frames[0]))
elif nframes > 1:
a.frombytes(b"".join(map(ensure_memoryview, frames)))
return a


Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

np = pytest.importorskip("numpy")

from dask.utils import tmpfile
from dask.utils import ensure_bytes, tmpfile

from distributed.protocol import (
decompress,
Expand All @@ -20,7 +20,7 @@
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_bytes, nbytes
from distributed.utils import nbytes
from distributed.utils_test import gen_cluster


Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
np = pytest.importorskip("numpy")

from dask.dataframe.utils import assert_eq
from dask.utils import ensure_bytes

from distributed.protocol import (
decompress,
Expand All @@ -13,7 +14,6 @@
serialize,
to_serialize,
)
from distributed.utils import ensure_bytes

dfs = [
pd.DataFrame({}),
Expand Down
32 changes: 29 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
to_serialize,
)
from distributed.protocol.serialize import check_dask_serializable
from distributed.utils import nbytes
from distributed.utils import ensure_memoryview, nbytes
from distributed.utils_test import gen_test, inc


Expand Down Expand Up @@ -88,19 +88,45 @@ def test_serialize_bytestrings():
assert bb == b


def test_serialize_empty_array():
a = array("I")

# serialize array
header, frames = serialize(a)
assert frames[0] == memoryview(a)
# drop empty frame
del frames[:]
# deserialize with no frames
a2 = deserialize(header, frames)
assert type(a2) == type(a)
assert a2.typecode == a.typecode
assert a2 == a


@pytest.mark.parametrize(
"typecode", ["b", "B", "h", "H", "i", "I", "l", "L", "q", "Q", "f", "d"]
)
def test_serialize_arrays(typecode):
a = array(typecode)
a.extend(range(5))
a = array(typecode, range(5))

# handle normal round trip through serialization
header, frames = serialize(a)
assert frames[0] == memoryview(a)
a2 = deserialize(header, frames)
assert type(a2) == type(a)
assert a2.typecode == a.typecode
assert a2 == a

# split up frames to test joining them back together
header, frames = serialize(a)
(f,) = frames
f = ensure_memoryview(f)
frames = [f[:1], f[1:2], f[2:-1], f[-1:]]
a3 = deserialize(header, frames)
assert type(a3) == type(a)
assert a3.typecode == a.typecode
assert a3 == a


def test_Serialize():
s = Serialize(123)
Expand Down
7 changes: 5 additions & 2 deletions distributed/tests/test_parse_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
distributed/tests/test1.py::test_flaky PASSED [ 70%]
distributed/tests/test1.py::test_leaking PASSED [ 72%]
distributed/tests/test1.py::test_leaking LEAKED [ 72%]
distributed/tests/test1.py::test_pass [32mPASSED[0m[31m [ 80%][0m
distributed/tests/test1.py::test_pass [32mPASSED[0m[31m [ 75%][0m
distributed/tests/test1.py::test_params[a a] PASSED [ 78%]
distributed/tests/test1.py::test_escape_chars[<lambda>] PASSED [ 80%]
distributed/tests/test1.py::MyTest::test_unittest PASSED [ 86%]
distributed/tests/test1.py::test_timeout
"""
Expand Down Expand Up @@ -57,6 +58,7 @@ def test_parse_rows():
("distributed.tests.test1", "test_leaking", {"PASSED"}),
("distributed.tests.test1", "test_pass", {"PASSED"}),
("distributed.tests.test1", "test_params[a a]", {"PASSED"}),
("distributed.tests.test1", "test_escape_chars[<lambda>]", {"PASSED"}),
("distributed.tests.test1.MyTest", "test_unittest", {"PASSED"}),
("distributed.tests.test1", "test_timeout", {None}),
]
Expand All @@ -74,7 +76,7 @@ def test_build_xml(capsys):
expect = """
<?xml version="1.0" encoding="utf-8"?>
<testsuites>
<testsuite name="distributed" errors="3" failures="3" skipped="2" tests="14" time="0.0" timestamp="snip" hostname="">
<testsuite name="distributed" errors="3" failures="3" skipped="2" tests="15" time="0.0" timestamp="snip" hostname="">
<testcase classname="distributed.tests.test1" name="test_fail" time="0.0"><failure message=""></failure></testcase>
<testcase classname="distributed.tests.test1" name="test_error_in_setup" time="0.0"><error message="failed on setup"></error></testcase>
<testcase classname="distributed.tests.test1" name="test_pass_and_then_error_in_teardown" time="0.0"><error message="failed on teardown"></error></testcase>
Expand All @@ -86,6 +88,7 @@ def test_build_xml(capsys):
<testcase classname="distributed.tests.test1" name="test_leaking" time="0.0" />
<testcase classname="distributed.tests.test1" name="test_pass" time="0.0" />
<testcase classname="distributed.tests.test1" name="test_params[a a]" time="0.0" />
<testcase classname="distributed.tests.test1" name="test_escape_chars[&lt;lambda&gt;]" time="0.0" />
<testcase classname="distributed.tests.test1.MyTest" name="test_unittest" time="0.0" />
<testcase classname="distributed.tests.test1" name="test_timeout" time="0.0"><failure message="pytest-timeout exceeded"></failure></testcase>
</testsuite>
Expand Down
22 changes: 0 additions & 22 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
LoopRunner,
TimeoutError,
_maybe_complex,
ensure_bytes,
ensure_ip,
ensure_memoryview,
format_dashboard_link,
Expand Down Expand Up @@ -249,27 +248,6 @@ def test_seek_delimiter_endline():
assert f.tell() == 7


def test_ensure_bytes():
data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])]
for d in data:
result = ensure_bytes(d)
assert isinstance(result, bytes)
assert result == b"1"


def test_ensure_bytes_ndarray():
np = pytest.importorskip("numpy")
result = ensure_bytes(np.arange(12))
assert isinstance(result, bytes)


def test_ensure_bytes_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_bytes(buf)
assert isinstance(result, bytes)


def test_ensure_memoryview_empty():
result = ensure_memoryview(b"")
assert isinstance(result, memoryview)
Expand Down
Loading

0 comments on commit e51734f

Please sign in to comment.