From a6a9bc8f5f296bef0cd466c658c094d5d8bad0e1 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 7 Jun 2024 16:12:51 +0200 Subject: [PATCH] fix: Raise on non-positive json schema inference (#16770) --- crates/polars-io/src/json/infer.rs | 12 +++++++----- crates/polars-io/src/json/mod.rs | 11 ++++++----- crates/polars-io/src/ndjson/core.rs | 8 ++++---- crates/polars-io/src/ndjson/mod.rs | 4 +++- crates/polars-json/src/ndjson/file.rs | 5 +++-- .../src/physical_plan/executors/scan/ndjson.rs | 5 ++++- crates/polars/tests/it/io/json.rs | 8 ++++---- py-polars/polars/io/ndjson.py | 3 +++ py-polars/src/dataframe/io.rs | 7 +++---- py-polars/tests/unit/io/test_lazy_json.py | 5 +++++ 10 files changed, 42 insertions(+), 26 deletions(-) diff --git a/crates/polars-io/src/json/infer.rs b/crates/polars-io/src/json/infer.rs index 83026b6111c99..9cd82721d156c 100644 --- a/crates/polars-io/src/json/infer.rs +++ b/crates/polars-io/src/json/infer.rs @@ -1,23 +1,25 @@ +use std::num::NonZeroUsize; + use polars_core::prelude::DataType; use polars_core::utils::try_get_supertype; -use polars_error::PolarsResult; +use polars_error::{polars_bail, PolarsResult}; use simd_json::BorrowedValue; pub(crate) fn json_values_to_supertype( values: &[BorrowedValue], - infer_schema_len: usize, + infer_schema_len: NonZeroUsize, ) -> PolarsResult { // struct types may have missing fields so find supertype values .iter() - .take(infer_schema_len) + .take(infer_schema_len.into()) .map(|value| polars_json::json::infer(value).map(|dt| DataType::from(&dt))) .reduce(|l, r| { let l = l?; let r = r?; try_get_supertype(&l, &r) }) - .unwrap() + .unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type")) } pub(crate) fn data_types_to_supertype>( @@ -30,5 +32,5 @@ pub(crate) fn data_types_to_supertype>( let r = r?; try_get_supertype(&l, &r) }) - .unwrap() + .unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type")) } diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index 2e7726d7614a9..4653e4f692f83 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -25,7 +25,7 @@ //! let file = Cursor::new(basic_json); //! let df = JsonReader::new(file) //! .with_json_format(JsonFormat::JsonLines) -//! .infer_schema_len(Some(3)) +//! .infer_schema_len(NonZeroUsize::new(3)) //! .with_batch_size(NonZeroUsize::new(3).unwrap()) //! .finish() //! .unwrap(); @@ -206,7 +206,7 @@ where reader: R, rechunk: bool, ignore_errors: bool, - infer_schema_len: Option, + infer_schema_len: Option, batch_size: NonZeroUsize, projection: Option>, schema: Option, @@ -223,7 +223,7 @@ where reader, rechunk: true, ignore_errors: false, - infer_schema_len: Some(100), + infer_schema_len: Some(NonZeroUsize::new(100).unwrap()), batch_size: NonZeroUsize::new(8192).unwrap(), projection: None, schema: None, @@ -265,7 +265,8 @@ where let inner_dtype = if let BorrowedValue::Array(values) = &json_value { infer::json_values_to_supertype( values, - self.infer_schema_len.unwrap_or(usize::MAX), + self.infer_schema_len + .unwrap_or(NonZeroUsize::new(usize::MAX).unwrap()), )? .to_arrow(true) } else { @@ -360,7 +361,7 @@ where /// /// It is an error to pass `max_records = Some(0)`, as a schema cannot be inferred from 0 records when deserializing /// from JSON (unlike CSVs, there is no header row to inspect for column names). - pub fn infer_schema_len(mut self, max_records: Option) -> Self { + pub fn infer_schema_len(mut self, max_records: Option) -> Self { self.infer_schema_len = max_records; self } diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index afc1c79d6295e..30eaeeab517c3 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -26,7 +26,7 @@ where rechunk: bool, n_rows: Option, n_threads: Option, - infer_schema_len: Option, + infer_schema_len: Option, chunk_size: NonZeroUsize, schema: Option, schema_overwrite: Option<&'a Schema>, @@ -58,7 +58,7 @@ where self } - pub fn infer_schema_len(mut self, infer_schema_len: Option) -> Self { + pub fn infer_schema_len(mut self, infer_schema_len: Option) -> Self { self.infer_schema_len = infer_schema_len; self } @@ -112,7 +112,7 @@ where rechunk: true, n_rows: None, n_threads: None, - infer_schema_len: Some(128), + infer_schema_len: Some(NonZeroUsize::new(100).unwrap()), schema: None, schema_overwrite: None, path: None, @@ -166,7 +166,7 @@ impl<'a> CoreJsonReader<'a> { sample_size: usize, chunk_size: NonZeroUsize, low_memory: bool, - infer_schema_len: Option, + infer_schema_len: Option, ignore_errors: bool, ) -> PolarsResult> { let reader_bytes = reader_bytes; diff --git a/crates/polars-io/src/ndjson/mod.rs b/crates/polars-io/src/ndjson/mod.rs index 3fb432929f457..8d6ad6c2d6802 100644 --- a/crates/polars-io/src/ndjson/mod.rs +++ b/crates/polars-io/src/ndjson/mod.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroUsize; + use arrow::array::StructArray; use polars_core::prelude::*; @@ -6,7 +8,7 @@ pub mod core; pub fn infer_schema( reader: &mut R, - infer_schema_len: Option, + infer_schema_len: Option, ) -> PolarsResult { let data_types = polars_json::ndjson::iter_unique_dtypes(reader, infer_schema_len)?; let data_type = diff --git a/crates/polars-json/src/ndjson/file.rs b/crates/polars-json/src/ndjson/file.rs index 35700c1a6001d..3bc2e126fb850 100644 --- a/crates/polars-json/src/ndjson/file.rs +++ b/crates/polars-json/src/ndjson/file.rs @@ -1,4 +1,5 @@ use std::io::BufRead; +use std::num::NonZeroUsize; use arrow::datatypes::ArrowDataType; use fallible_streaming_iterator::FallibleStreamingIterator; @@ -100,7 +101,7 @@ fn parse_value<'a>(scratch: &'a mut Vec, val: &[u8]) -> PolarsResult( reader: &mut R, - number_of_rows: Option, + number_of_rows: Option, ) -> PolarsResult> { if reader.fill_buf().map(|b| b.is_empty())? { return Err(PolarsError::ComputeError( @@ -109,7 +110,7 @@ pub fn iter_unique_dtypes( } let rows = vec!["".to_string(); 1]; // 1 <=> read row by row - let mut reader = FileReader::new(reader, rows, number_of_rows); + let mut reader = FileReader::new(reader, rows, number_of_rows.map(|v| v.into())); let mut data_types = PlIndexSet::default(); let mut buf = vec![]; diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs index 40b66fef3ae64..a5a6257ba8ba8 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroUsize; + use super::*; impl AnonymousScan for LazyJsonLineReader { @@ -17,6 +19,7 @@ impl AnonymousScan for LazyJsonLineReader { } fn schema(&self, infer_schema_length: Option) -> PolarsResult { + polars_ensure!(infer_schema_length != Some(0), InvalidOperation: "JSON requires positive 'infer_schema_length'"); // Short-circuit schema inference if the schema has been explicitly provided, // or already inferred if let Some(schema) = &(*self.schema.read().unwrap()) { @@ -28,7 +31,7 @@ impl AnonymousScan for LazyJsonLineReader { let schema = Arc::new(polars_io::ndjson::infer_schema( &mut reader, - infer_schema_length, + infer_schema_length.and_then(NonZeroUsize::new), )?); let mut guard = self.schema.write().unwrap(); *guard = Some(schema.clone()); diff --git a/crates/polars/tests/it/io/json.rs b/crates/polars/tests/it/io/json.rs index 92f7aa233df73..faf17d71d07e4 100644 --- a/crates/polars/tests/it/io/json.rs +++ b/crates/polars/tests/it/io/json.rs @@ -20,7 +20,7 @@ fn read_json() { "#; let file = Cursor::new(basic_json); let df = JsonReader::new(file) - .infer_schema_len(Some(3)) + .infer_schema_len(NonZeroUsize::new(3)) .with_json_format(JsonFormat::JsonLines) .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() @@ -48,7 +48,7 @@ fn read_json_with_whitespace() { { "a":100000000000000, "b":0.6, "c":false, "d":"text"}"#; let file = Cursor::new(basic_json); let df = JsonReader::new(file) - .infer_schema_len(Some(3)) + .infer_schema_len(NonZeroUsize::new(3)) .with_json_format(JsonFormat::JsonLines) .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() @@ -73,7 +73,7 @@ fn read_json_with_escapes() { "#; let file = Cursor::new(escaped_json); let df = JsonLineReader::new(file) - .infer_schema_len(Some(6)) + .infer_schema_len(NonZeroUsize::new(6)) .finish() .unwrap(); assert_eq!("id", df.get_columns()[0].name()); @@ -102,7 +102,7 @@ fn read_unordered_json() { "#; let file = Cursor::new(unordered_json); let df = JsonReader::new(file) - .infer_schema_len(Some(3)) + .infer_schema_len(NonZeroUsize::new(3)) .with_json_format(JsonFormat::JsonLines) .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 2f08d5d119fbd..d129a9da8e52e 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -143,6 +143,9 @@ def scan_ndjson( else: sources = [normalize_filepath(source) for source in source] source = None # type: ignore[assignment] + if infer_schema_length == 0: + msg = "'infer_schema_length' should be positive" + raise ValueError(msg) pylf = PyLazyFrame.new_from_ndjson( source, diff --git a/py-polars/src/dataframe/io.rs b/py-polars/src/dataframe/io.rs index 58a839caa2f26..783274880bcbc 100644 --- a/py-polars/src/dataframe/io.rs +++ b/py-polars/src/dataframe/io.rs @@ -215,6 +215,7 @@ impl PyDataFrame { schema: Option>, schema_overrides: Option>, ) -> PyResult { + assert!(infer_schema_length != Some(0)); use crate::file::read_if_bytesio; py_f = read_if_bytesio(py_f); let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; @@ -222,7 +223,7 @@ impl PyDataFrame { py.allow_threads(move || { let mut builder = JsonReader::new(mmap_bytes_r) .with_json_format(JsonFormat::Json) - .infer_schema_len(infer_schema_length); + .infer_schema_len(infer_schema_length.and_then(NonZeroUsize::new)); if let Some(schema) = schema { builder = builder.with_schema(Arc::new(schema.0)); @@ -232,9 +233,7 @@ impl PyDataFrame { builder = builder.with_schema_overwrite(&schema.0); } - let out = builder - .finish() - .map_err(|e| PyPolarsErr::Other(format!("{e}")))?; + let out = builder.finish().map_err(PyPolarsErr::from)?; Ok(out.into()) }) } diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py index 97e32f3eaee64..0d2dbf6be0e54 100644 --- a/py-polars/tests/unit/io/test_lazy_json.py +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -56,6 +56,11 @@ def test_scan_ndjson_with_schema(foods_ndjson_path: Path) -> None: assert df["sugars_g"].dtype == pl.Float64 +def test_scan_ndjson_infer_0(foods_ndjson_path: Path) -> None: + with pytest.raises(ValueError): + pl.scan_ndjson(foods_ndjson_path, infer_schema_length=0) + + def test_scan_ndjson_batch_size_zero() -> None: with pytest.raises(ValueError, match="invalid zero value"): pl.scan_ndjson("test.ndjson", batch_size=0)