Skip to content

Commit

Permalink
[Data] Reimplement of fix memory pandas (ray-project#48970)
Browse files Browse the repository at this point in the history
Signed-off-by: hjiang <dentinyhao@gmail.com>
  • Loading branch information
Bye-legumes authored and dentiny committed Dec 7, 2024
1 parent 986c666 commit 1716681
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 3 deletions.
99 changes: 98 additions & 1 deletion python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import heapq
import logging
import sys
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -41,6 +43,10 @@
from ray.data.aggregate import AggregateFn

T = TypeVar("T")
# Max number of samples used to estimate the Pandas block size.
_PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT = 50

logger = logging.getLogger(__name__)

_pandas = None

Expand Down Expand Up @@ -294,7 +300,98 @@ def num_rows(self) -> int:
return self._table.shape[0]

def size_bytes(self) -> int:
return int(self._table.memory_usage(index=True, deep=True).sum())
from pandas.api.types import is_object_dtype

from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.data.extensions import TensorArrayElement, TensorDtype

pd = lazy_import_pandas()

def get_deep_size(obj):
"""Calculates the memory size of objects,
including nested objects using an iterative approach."""
seen = set()
total_size = 0
objects = collections.deque([obj])
while objects:
current = objects.pop()

# Skip interning-eligible immutable objects
if isinstance(current, (str, bytes, int, float)):
size = sys.getsizeof(current)
total_size += size
continue

# Check if the object has been seen before
# i.e. a = np.ndarray([1,2,3]), b = [a,a]
# The patten above will have only one memory copy
if id(current) in seen:
continue
seen.add(id(current))

try:
size = sys.getsizeof(current)
except TypeError:
size = 0
total_size += size

# Handle specific cases
if isinstance(current, np.ndarray):
total_size += current.nbytes - size # Avoid double counting
elif isinstance(current, pd.DataFrame):
total_size += (
current.memory_usage(index=True, deep=True).sum() - size
)
elif isinstance(current, (list, tuple, set)):
objects.extend(current)
elif isinstance(current, dict):
objects.extend(current.keys())
objects.extend(current.values())
elif isinstance(current, TensorArrayElement):
objects.extend(current.to_numpy())
return total_size

# Get initial memory usage including deep introspection
memory_usage = self._table.memory_usage(index=True, deep=True)

# TensorDtype for ray.air.util.tensor_extensions.pandas.TensorDtype
object_need_check = (TensorDtype,)
max_sample_count = _PANDAS_SIZE_BYTES_MAX_SAMPLE_COUNT

# Handle object columns separately
for column in self._table.columns:
# Check pandas object dtype and the extension dtype
if is_object_dtype(self._table[column].dtype) or isinstance(
self._table[column].dtype, object_need_check
):
total_size = len(self._table[column])

# Determine the sample size based on max_sample_count
sample_size = min(total_size, max_sample_count)
# Following codes can also handel case that sample_size == total_size
sampled_data = self._table[column].sample(n=sample_size).values

try:
if isinstance(sampled_data, TensorArray) and np.issubdtype(
sampled_data[0].numpy_dtype, np.number
):
column_memory_sample = sampled_data.nbytes
else:
vectorized_size_calc = np.vectorize(lambda x: get_deep_size(x))
column_memory_sample = np.sum(
vectorized_size_calc(sampled_data)
)
# Scale back to the full column size if we sampled
column_memory = column_memory_sample * (total_size / sample_size)
memory_usage[column] = int(column_memory)
except Exception as e:
# Handle or log the exception as needed
logger.warning(f"Error calculating size for column '{column}': {e}")

# Sum up total memory usage
total_memory_usage = memory_usage.sum()

return int(total_memory_usage)

def _zip(self, acc: BlockAccessor) -> "pandas.DataFrame":
r = self.to_pandas().copy(deep=False)
Expand Down
232 changes: 230 additions & 2 deletions python/ray/data/tests/test_pandas_block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import pickle
import random
import sys

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import ray
import ray.data
from ray.data._internal.pandas_block import PandasBlockAccessor
from ray.data.extensions.object_extension import _object_extension_type_allowed

# Set seed for the test for size as it related to sampling
np.random.seed(42)


def test_append_column(ray_start_regular_shared):
animals = ["Flamingo", "Centipede"]
Expand Down Expand Up @@ -48,7 +57,226 @@ def fn2(batch):
assert isinstance(block, pd.DataFrame)


if __name__ == "__main__":
import sys
class TestSizeBytes:
def test_small(ray_start_regular_shared):
animals = ["Flamingo", "Centipede"]
block = pd.DataFrame({"animals": animals})

block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

# check that memory usage is within 10% of the size_bytes
# For strings, Pandas seems to be fairly accurate, so let's use that.
memory_usage = block.memory_usage(index=True, deep=True).sum()
assert bytes_size == pytest.approx(memory_usage, rel=0.1), (
bytes_size,
memory_usage,
)

def test_large_str(ray_start_regular_shared):
animals = [
random.choice(["alligator", "crocodile", "centipede", "flamingo"])
for i in range(100_000)
]
block = pd.DataFrame({"animals": animals})
block["animals"] = block["animals"].astype("string")

block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

memory_usage = block.memory_usage(index=True, deep=True).sum()
assert bytes_size == pytest.approx(memory_usage, rel=0.1), (
bytes_size,
memory_usage,
)

def test_large_str_object(ray_start_regular_shared):
"""Note - this test breaks if you refactor/move the list of animals."""
num = 100_000
animals = [
random.choice(["alligator", "crocodile", "centipede", "flamingo"])
for i in range(num)
]
block = pd.DataFrame({"animals": animals})

block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

memory_usage = sum([sys.getsizeof(animal) for animal in animals])

assert bytes_size == pytest.approx(memory_usage, rel=0.1), (
bytes_size,
memory_usage,
)

def test_large_floats(ray_start_regular_shared):
animals = [random.random() for i in range(100_000)]
block = pd.DataFrame({"animals": animals})

block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

memory_usage = pickle.dumps(block).__sizeof__()
# check that memory usage is within 10% of the size_bytes
assert bytes_size == pytest.approx(memory_usage, rel=0.1), (
bytes_size,
memory_usage,
)

def test_bytes_object(ray_start_regular_shared):
def generate_data(batch):
for _ in range(8):
yield {"data": [[b"\x00" * 128 * 1024 * 128]]}

ds = (
ray.data.range(1, override_num_blocks=1)
.map_batches(generate_data, batch_size=1)
.map_batches(lambda batch: batch, batch_format="pandas")
)

true_value = 128 * 1024 * 128 * 8
for bundle in ds.iter_internal_ref_bundles():
size = bundle.size_bytes()
# assert that true_value is within 10% of bundle.size_bytes()
assert size == pytest.approx(true_value, rel=0.1), (
size,
true_value,
)

def test_nested_numpy(ray_start_regular_shared):
size = 1024
rows = 1_000
data = [
np.random.randint(size=size, low=0, high=100, dtype=np.int8)
for _ in range(rows)
]
df = pd.DataFrame({"data": data})

block_accessor = PandasBlockAccessor.for_block(df)
block_size = block_accessor.size_bytes()
true_value = rows * size
assert block_size == pytest.approx(true_value, rel=0.1), (
block_size,
true_value,
)

def test_nested_objects(ray_start_regular_shared):
size = 10
rows = 10_000
lists = [[random.randint(0, 100) for _ in range(size)] for _ in range(rows)]
data = {"lists": lists}
block = pd.DataFrame(data)

block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

# List overhead + 10 integers per list
true_size = rows * (
sys.getsizeof([random.randint(0, 100) for _ in range(size)]) + size * 28
)

assert bytes_size == pytest.approx(true_size, rel=0.1), (
bytes_size,
true_size,
)

def test_mixed_types(ray_start_regular_shared):
rows = 10_000

data = {
"integers": [random.randint(0, 100) for _ in range(rows)],
"floats": [random.random() for _ in range(rows)],
"strings": [
random.choice(["apple", "banana", "cherry"]) for _ in range(rows)
],
"object": [b"\x00" * 128 for _ in range(rows)],
}
block = pd.DataFrame(data)
block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

# Manually calculate the size
int_size = rows * 8
float_size = rows * 8
str_size = sum(sys.getsizeof(string) for string in data["strings"])
object_size = rows * sys.getsizeof(b"\x00" * 128)

true_size = int_size + float_size + str_size + object_size
assert bytes_size == pytest.approx(true_size, rel=0.1), (bytes_size, true_size)

def test_nested_lists_strings(ray_start_regular_shared):
rows = 5_000
nested_lists = ["a"] * 3 + ["bb"] * 4 + ["ccc"] * 3
data = {
"nested_lists": [nested_lists for _ in range(rows)],
}
block = pd.DataFrame(data)
block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

# Manually calculate the size
list_overhead = sys.getsizeof(block["nested_lists"].iloc[0]) + sum(
[sys.getsizeof(x) for x in nested_lists]
)
true_size = rows * list_overhead
assert bytes_size == pytest.approx(true_size, rel=0.1), (bytes_size, true_size)

@pytest.mark.parametrize("size", [10, 1024])
def test_multi_level_nesting(ray_start_regular_shared, size):
rows = 1_000
data = {
"complex": [
{"list": [np.random.rand(size)], "value": {"key": "val"}}
for _ in range(rows)
],
}
block = pd.DataFrame(data)
block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

numpy_size = np.random.rand(size).nbytes

values = ["list", "value", "key", "val"]
str_size = sum([sys.getsizeof(v) for v in values])

list_ref_overhead = sys.getsizeof([np.random.rand(size)])

dict_overhead1 = sys.getsizeof({"key": "val"})

dict_overhead3 = sys.getsizeof(
{"list": [np.random.rand(size)], "value": {"key": "val"}}
)

true_size = (
numpy_size + str_size + list_ref_overhead + dict_overhead1 + dict_overhead3
) * rows
assert bytes_size == pytest.approx(true_size, rel=0.15), (
bytes_size,
true_size,
)

def test_boolean(ray_start_regular_shared):
data = [random.choice([True, False, None]) for _ in range(100_000)]
block = pd.DataFrame({"flags": pd.Series(data, dtype="boolean")})
block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

# No object case
true_size = block.memory_usage(index=True, deep=True).sum()
assert bytes_size == pytest.approx(true_size, rel=0.1), (bytes_size, true_size)

def test_arrow(ray_start_regular_shared):
data = [
random.choice(["alligator", "crocodile", "flamingo"]) for _ in range(50_000)
]
arrow_dtype = pd.ArrowDtype(pa.string())
block = pd.DataFrame({"animals": pd.Series(data, dtype=arrow_dtype)})
block_accessor = PandasBlockAccessor.for_block(block)
bytes_size = block_accessor.size_bytes()

true_size = block.memory_usage(index=True, deep=True).sum()
assert bytes_size == pytest.approx(true_size, rel=0.1), (bytes_size, true_size)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 1716681

Please sign in to comment.