-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Fix actor pool initialization in ray client mode #3028
Changes from all commits
2234066
72c556d
617bb53
d809564
79f5776
4c8df04
93c838a
adb5061
43f8920
1795942
a937931
bbc2149
ba384b4
d2b50b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,14 +8,14 @@ | |
from uuid import uuid4 | ||
|
||
from daft.datatype import TimeUnit | ||
from daft.table import MicroPartition | ||
|
||
if TYPE_CHECKING: | ||
import pandas as pd | ||
import pyarrow as pa | ||
|
||
from daft.expressions.expressions import Expression | ||
from daft.logical.schema import Schema | ||
from daft.table import MicroPartition | ||
|
||
PartID = int | ||
|
||
|
@@ -271,6 +271,92 @@ def wait(self) -> None: | |
raise NotImplementedError() | ||
|
||
|
||
class LocalPartitionSet(PartitionSet[MicroPartition]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a pure movement of the class from pyrunner.py? Why did we decide to move it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was being imported by dataframe.py and ray_runner.py which was causing some circular import issues. I figured since it was being used outside of the pyrunner I would move it out to simplify the dependency tree. |
||
_partitions: dict[PartID, MaterializedResult[MicroPartition]] | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self._partitions = {} | ||
|
||
def items(self) -> list[tuple[PartID, MaterializedResult[MicroPartition]]]: | ||
return sorted(self._partitions.items()) | ||
|
||
def _get_merged_micropartition(self) -> MicroPartition: | ||
ids_and_partitions = self.items() | ||
assert ids_and_partitions[0][0] == 0 | ||
assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions) | ||
return MicroPartition.concat([part.partition() for id, part in ids_and_partitions]) | ||
|
||
def _get_preview_micropartitions(self, num_rows: int) -> list[MicroPartition]: | ||
ids_and_partitions = self.items() | ||
preview_parts = [] | ||
for _, mat_result in ids_and_partitions: | ||
part: MicroPartition = mat_result.partition() | ||
part_len = len(part) | ||
if part_len >= num_rows: # if this part has enough rows, take what we need and break | ||
preview_parts.append(part.slice(0, num_rows)) | ||
break | ||
else: # otherwise, take the whole part and keep going | ||
num_rows -= part_len | ||
preview_parts.append(part) | ||
return preview_parts | ||
|
||
def get_partition(self, idx: PartID) -> MaterializedResult[MicroPartition]: | ||
return self._partitions[idx] | ||
|
||
def set_partition(self, idx: PartID, part: MaterializedResult[MicroPartition]) -> None: | ||
self._partitions[idx] = part | ||
|
||
def set_partition_from_table(self, idx: PartID, part: MicroPartition) -> None: | ||
self._partitions[idx] = LocalMaterializedResult(part, PartitionMetadata.from_table(part)) | ||
|
||
def delete_partition(self, idx: PartID) -> None: | ||
del self._partitions[idx] | ||
|
||
def has_partition(self, idx: PartID) -> bool: | ||
return idx in self._partitions | ||
|
||
def __len__(self) -> int: | ||
return sum(len(partition.partition()) for partition in self._partitions.values()) | ||
|
||
def size_bytes(self) -> int | None: | ||
size_bytes_ = [partition.partition().size_bytes() for partition in self._partitions.values()] | ||
size_bytes: list[int] = [size for size in size_bytes_ if size is not None] | ||
if len(size_bytes) != len(size_bytes_): | ||
return None | ||
else: | ||
return sum(size_bytes) | ||
|
||
def num_partitions(self) -> int: | ||
return len(self._partitions) | ||
|
||
def wait(self) -> None: | ||
pass | ||
|
||
|
||
@dataclass | ||
class LocalMaterializedResult(MaterializedResult[MicroPartition]): | ||
_partition: MicroPartition | ||
_metadata: PartitionMetadata | None = None | ||
|
||
def partition(self) -> MicroPartition: | ||
return self._partition | ||
|
||
def micropartition(self) -> MicroPartition: | ||
return self._partition | ||
|
||
def metadata(self) -> PartitionMetadata: | ||
if self._metadata is None: | ||
self._metadata = PartitionMetadata.from_table(self._partition) | ||
return self._metadata | ||
|
||
def cancel(self) -> None: | ||
return None | ||
|
||
def _noop(self, _: MicroPartition) -> None: | ||
return None | ||
|
||
|
||
@dataclass(eq=False, repr=False) | ||
class PartitionCacheEntry: | ||
key: str | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice