diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index 3c00612..dd9cc52 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -147,13 +147,9 @@ impl PyTable { mapping: HashMap, schema: Option, metadata: Option, - ) -> PyResult { + ) -> PyArrowResult { let (names, arrays): (Vec<_>, Vec<_>) = mapping.into_iter().unzip(); Self::from_arrays(cls, arrays, Some(names), schema, metadata) - // TODO: Construct record batches from Vec - // Can I reuse from_pylist here? I.e. this func only unwraps the dict to a list of column anmes and a list of chunked arrays, and then that passes in to from_arrays - // I probably want a helper to rechunk as necessary - // todo!() } #[classmethod] @@ -164,27 +160,65 @@ impl PyTable { names: Option>, schema: Option, metadata: Option, - ) -> PyResult { + ) -> PyArrowResult { let columns = arrays .into_iter() .map(|array| array.into_chunked_array()) .collect::>>()?; - // let schema = schema.map(|schema| schema.into_inner()).unwrap_or_else(|| { - // let fields = columns - // .iter() - // .zip(names.iter()) - // .map(|(array, name)| { - // Field::new(name.clone(), array.field().data_type().clone(), true) - // }) - // .collect::>(); - // Arc::new( - // Schema::new(fields) - // .with_metadata(metadata.unwrap_or_default().into_string_hashmap().unwrap()), - // ) - // }); - - todo!() + let schema: SchemaRef = if let Some(schema) = schema { + schema.into_inner() + } else { + let names = names.ok_or(PyValueError::new_err( + "names must be passed if schema is not passed.", + ))?; + + let fields = columns + .iter() + .zip(names.iter()) + .map(|(array, name)| Field::new(name.clone(), array.data_type().clone(), true)) + .collect::>(); + Arc::new( + Schema::new(fields) + .with_metadata(metadata.unwrap_or_default().into_string_hashmap().unwrap()), + ) + }; + + if columns.is_empty() { + return Ok(Self::new(vec![], schema)); + } + + let column_chunk_lengths = columns + .iter() + .map(|column| { + let chunk_lengths = column + .chunks() + .iter() + .map(|chunk| chunk.len()) + .collect::>(); + chunk_lengths + }) + .collect::>(); + if !column_chunk_lengths.windows(2).all(|w| w[0] == w[1]) { + return Err( + PyValueError::new_err("All columns must have the same chunk lengths").into(), + ); + } + let num_batches = column_chunk_lengths[0].len(); + + let mut batches = vec![]; + for batch_idx in 0..num_batches { + let batch = RecordBatch::try_new( + schema.clone(), + columns + .iter() + .map(|column| column.chunks()[batch_idx].clone()) + .collect(), + )?; + batches.push(batch); + } + + Ok(Self::new(batches, schema)) } pub fn add_column(