Skip to content

Commit

Permalink
fix: Column selection wasn't applied when reading CSV with no rows (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored and Wouittone committed Jun 22, 2024
1 parent 2115b2f commit 9e9fb71
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
38 changes: 26 additions & 12 deletions crates/polars-io/src/csv/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@ pub(crate) fn cast_columns(
};

if parallel {
let cols = df
.get_columns()
.iter()
.map(|s| {
if let Some(fld) = to_cast.iter().find(|fld| fld.name().as_str() == s.name()) {
cast_fn(s, fld)
} else {
Ok(s.clone())
}
})
.collect::<PolarsResult<Vec<_>>>()?;
let cols = POOL.install(|| {
df.get_columns()
.into_par_iter()
.map(|s| {
if let Some(fld) = to_cast.iter().find(|fld| fld.name().as_str() == s.name()) {
cast_fn(s, fld)
} else {
Ok(s.clone())
}
})
.collect::<PolarsResult<Vec<_>>>()
})?;
*df = unsafe { DataFrame::new_no_checks(cols) }
} else {
// cast to the original dtypes in the schema
Expand Down Expand Up @@ -473,7 +474,20 @@ impl<'a> CoreReader<'a> {

// An empty file with a schema should return an empty DataFrame with that schema
if bytes.is_empty() {
let mut df = DataFrame::from(self.schema.as_ref());
let mut df = if projection.len() == self.schema.len() {
DataFrame::from(self.schema.as_ref())
} else {
DataFrame::from(
&projection
.iter()
.map(|&i| self.schema.get_at_index(i).unwrap())
.map(|(name, dtype)| Field {
name: name.clone(),
dtype: dtype.clone(),
})
.collect::<Schema>(),
)
};
if let Some(ref row_index) = self.row_index {
df.insert_column(0, Series::new_empty(&row_index.name, &IDX_DTYPE))?;
}
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2160,3 +2160,24 @@ def test_read_csv_dtypes_deprecated() -> None:
schema={"a": pl.Int8, "b": pl.Int8, "c": pl.Int8},
)
assert_frame_equal(df, expected)


def test_projection_applied_on_file_with_no_rows_16606(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

path = tmp_path / "data.csv"

data = """\
a,b,c,d
"""

with path.open("w") as f:
f.write(data)

columns = ["a", "b"]

out = pl.read_csv(path, columns=columns).columns
assert out == columns

out = pl.scan_csv(path).select(columns).collect().columns
assert out == columns

0 comments on commit 9e9fb71

Please sign in to comment.