diff --git a/src/dsl/functions/image/decode.rs b/src/dsl/functions/image/decode.rs index 873e7cf1f5..b3e1d7b960 100644 --- a/src/dsl/functions/image/decode.rs +++ b/src/dsl/functions/image/decode.rs @@ -1,4 +1,10 @@ -use crate::{datatypes::Field, dsl::Expr, error::DaftResult, schema::Schema, series::Series}; +use crate::{ + datatypes::{DataType, Field}, + dsl::Expr, + error::{DaftError, DaftResult}, + schema::Schema, + series::Series, +}; use super::super::FunctionEvaluator; @@ -9,11 +15,35 @@ impl FunctionEvaluator for DecodeEvaluator { "decode" } - fn to_field(&self, _: &[Expr], _: &Schema) -> DaftResult { - todo!("not implemented"); + fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?; + if !matches!(field.dtype, DataType::Binary) { + return Err(DaftError::TypeError(format!( + "ImageDecode can only decode BinaryArrays, got {}", + field + ))); + } + Ok(Field::new( + field.name, + DataType::Image(Box::new(DataType::UInt8), None), + )) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } } - fn evaluate(&self, _: &[Series], _: &Expr) -> DaftResult { - todo!("not implemented"); + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => input.image_decode(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } } } diff --git a/src/dsl/functions/image/resize.rs b/src/dsl/functions/image/resize.rs index c2f40de73d..c9717928c5 100644 --- a/src/dsl/functions/image/resize.rs +++ b/src/dsl/functions/image/resize.rs @@ -1,3 +1,4 @@ +use crate::datatypes::DataType; use crate::dsl::functions::image::ImageExpr; use crate::error::DaftError; use crate::{datatypes::Field, dsl::Expr, error::DaftResult, schema::Schema, series::Series}; @@ -13,8 +14,24 @@ impl FunctionEvaluator for ResizeEvaluator { "resize" } - fn to_field(&self, _: &[Expr], _: &Schema) -> DaftResult { - todo!("not implemented"); + fn to_field(&self, inputs: &[Expr], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?; + + match &field.dtype { + DataType::Image(_, _) => Ok(field.clone()), + _ => Err(DaftError::TypeError(format!( + "ImageResize can only resize ImageArrays, got {}", + field + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } } fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult { diff --git a/tests/cookbook/assets/images/0000.jpg b/tests/cookbook/assets/images/0000.jpg new file mode 100644 index 0000000000..3a7fff270f Binary files /dev/null and b/tests/cookbook/assets/images/0000.jpg differ diff --git a/tests/cookbook/assets/images/0007.jpg b/tests/cookbook/assets/images/0007.jpg new file mode 100644 index 0000000000..2f917a56fb Binary files /dev/null and b/tests/cookbook/assets/images/0007.jpg differ diff --git a/tests/cookbook/assets/images/0018.png b/tests/cookbook/assets/images/0018.png new file mode 100644 index 0000000000..f2eb8c46b8 Binary files /dev/null and b/tests/cookbook/assets/images/0018.png differ diff --git a/tests/cookbook/assets/images/0025.tiff b/tests/cookbook/assets/images/0025.tiff new file mode 100644 index 0000000000..27c322f784 Binary files /dev/null and b/tests/cookbook/assets/images/0025.tiff differ diff --git a/tests/cookbook/test_image.py b/tests/cookbook/test_image.py new file mode 100644 index 0000000000..c96c4f3733 --- /dev/null +++ b/tests/cookbook/test_image.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import numpy as np +from PIL import Image + +import daft +from daft import col +from daft.datatype import DataType +from daft.series import Series +from tests.cookbook.assets import ASSET_FOLDER + + +def test_image_resize_mixed_modes(): + rgba = np.ones((2, 2, 4), dtype=np.uint8) + rgba[..., 1] = 2 + rgba[..., 2] = 3 + rgba[..., 3] = 4 + + data = [ + rgba[..., :3], # rgb + rgba, # RGBA + np.arange(12, dtype=np.uint8).reshape((1, 4, 3)), # RGB + np.arange(12, dtype=np.uint8).reshape((3, 4)) * 10, # L + np.ones(24, dtype=np.uint8).reshape((3, 4, 2)) * 10, # LA + None, + ] + + s = Series.from_pylist(data, pyobj="force") + df = daft.from_pydict({"img": s}) + + target_dtype = DataType.image() + df = df.select(df["img"].cast(target_dtype)) + + assert df.schema()["img"].dtype == target_dtype + + df = df.with_column("resized", df["img"].image.resize(5, 5)) + + assert df.schema()["resized"].dtype == target_dtype + + as_py = df.to_pydict()["resized"] + + first_resized = np.array(as_py[0]["data"]).reshape(5, 5, 3) + assert np.all(first_resized[..., 0] == 1) + assert np.all(first_resized[..., 1] == 2) + assert np.all(first_resized[..., 2] == 3) + + second_resized = np.array(as_py[1]["data"]).reshape(5, 5, 4) + assert np.all(second_resized[..., 0] == 1) + assert np.all(second_resized[..., 1] == 2) + assert np.all(second_resized[..., 2] == 3) + assert np.all(second_resized[..., 3] == 4) + + for i in range(2, 4): + resized_i = np.array(as_py[i]["data"]).reshape(5, 5, -1) + resized_i_gt = np.asarray(Image.fromarray(data[i]).resize((5, 5), resample=Image.BILINEAR)).reshape(5, 5, -1) + assert np.all(resized_i == resized_i_gt), f"{i} does not match" + + # LA sampling doesn't work for some reason in PIL + resized_i = np.array(as_py[4]["data"]).reshape(5, 5, -1) + assert np.all(resized_i == 10) + + assert as_py[-1] == None + + +def test_image_decode() -> None: + df = ( + daft.from_glob_path(f"{ASSET_FOLDER}/images/**") + .into_partitions(2) + .with_column("image", col("path").url.download().image.decode().image.resize(10, 10)) + ) + target_dtype = DataType.image() + assert df.schema()["image"].dtype == target_dtype + df.collect()