Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Jul 11, 2024
1 parent daf2e49 commit 2c70546
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 83 deletions.
215 changes: 133 additions & 82 deletions crates/polars-io/src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ where
path
}

/// Write a partitioned parquet dataset. This functionality is unstable.
pub fn write_partitioned_dataset<S>(
df: &DataFrame,
path: &Path,
Expand All @@ -139,100 +140,150 @@ pub fn write_partitioned_dataset<S>(
where
S: AsRef<str>,
{
let base_path = path;
// Note: When adding support for formats other than Parquet, avoid writing the partitioned
// columns into the file. We write them for parquet because they are encoded efficiently with
// RLE and also gives us a way to get the hive schema from the parquet file for free.
let get_hive_path_part = {
let schema = &df.schema();

for (path_part, part_df) in get_hive_partitions_iter(df, partition_by)? {
let dir = base_path.join(path_part);
std::fs::create_dir_all(&dir)?;
let partition_by_col_idx = partition_by
.iter()
.map(|x| {
let Some(i) = schema.index_of(x.as_ref()) else {
polars_bail!(ColumnNotFound: "{}", x.as_ref())
};
Ok(i)
})
.collect::<PolarsResult<Vec<_>>>()?;

let n_files = (part_df.estimated_size() / chunk_size).clamp(1, 0xf_ffff_ffff_ffff);
let rows_per_file = (df.height() / n_files).saturating_add(1);
const CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
.add(b'/')
.add(b'=')
.add(b':')
.add(b' ');

fn get_path_for_index(i: usize) -> String {
// Use a fixed-width file name so that it sorts properly.
format!("{:013x}.parquet", i)
move |df: &DataFrame| {
let cols = df.get_columns();

partition_by_col_idx
.iter()
.map(|&i| {
let s = &cols[i].slice(0, 1).cast(&DataType::String).unwrap();

format!(
"{}={}",
s.name(),
percent_encoding::percent_encode(
s.str()
.unwrap()
.get(0)
.unwrap_or("__HIVE_DEFAULT_PARTITION__")
.as_bytes(),
CHAR_SET
)
)
})
.collect::<Vec<_>>()
.join("/")
}
};

for (i, slice_start) in (0..part_df.height()).step_by(rows_per_file).enumerate() {
let f = std::fs::File::create(dir.join(get_path_for_index(i)))?;
let base_path = path;
let groups = df.group_by(partition_by)?.take_groups();

file_write_options
.to_writer(f)
.finish(&mut part_df.slice(slice_start as i64, rows_per_file))?;
}
}
let init_part_base_dir = |part_df: &DataFrame| {
let path_part = get_hive_path_part(part_df);
let dir = base_path.join(path_part);
std::fs::create_dir_all(&dir)?;

Ok(())
}
PolarsResult::Ok(dir)
};

/// Creates an iterator of (hive partition path, DataFrame) pairs, e.g.:
/// ("a=1/b=1", DataFrame)
fn get_hive_partitions_iter<'a, S>(
df: &'a DataFrame,
partition_by: &'a [S],
) -> PolarsResult<Box<dyn Iterator<Item = (String, DataFrame)> + 'a>>
where
S: AsRef<str>,
{
let schema = df.schema();

let partition_by_col_idx = partition_by
.iter()
.map(|x| {
let Some(i) = schema.index_of(x.as_ref()) else {
polars_bail!(ColumnNotFound: "{}", x.as_ref())
};
Ok(i)
})
.collect::<PolarsResult<Vec<_>>>()?;

let get_hive_path_part = move |df: &DataFrame| {
const CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
.add(b'/')
.add(b'=')
.add(b':')
.add(b' ');
fn get_path_for_index(i: usize) -> String {
// Use a fixed-width file name so that it sorts properly.
format!("{:08x}.parquet", i)
}

let cols = df.get_columns();
let get_n_files_and_rows_per_file = |part_df: &DataFrame| {
let n_files = (part_df.estimated_size() / chunk_size).clamp(1, 0xffff_ffff);
let rows_per_file = (df.height() / n_files).saturating_add(1);
(n_files, rows_per_file)
};

partition_by_col_idx
.iter()
.map(|&i| {
let s = &cols[i].slice(0, 1).cast(&DataType::String).unwrap();

format!(
"{}={}",
s.name(),
percent_encoding::percent_encode(
s.str()
.unwrap()
.get(0)
.unwrap_or("__HIVE_DEFAULT_PARTITION__")
.as_bytes(),
CHAR_SET
)
)
})
.collect::<Vec<_>>()
.join("/")
let write_part = |mut df: DataFrame, path: &Path| {
let f = std::fs::File::create(path)?;
file_write_options.to_writer(f).finish(&mut df)?;
PolarsResult::Ok(())
};

let groups = df.group_by(partition_by)?;
let groups = groups.take_groups();

let out: Box<dyn Iterator<Item = (String, DataFrame)>> = match groups {
GroupsProxy::Idx(idx) => Box::new(idx.into_iter().map(move |(_, group)| {
let part_df =
unsafe { df._take_unchecked_slice_sorted(&group, false, IsSorted::Ascending) };
(get_hive_path_part(&part_df), part_df)
})),
GroupsProxy::Slice { groups, .. } => {
Box::new(groups.into_iter().map(move |[offset, len]| {
let part_df = df.slice(offset as i64, len as usize);
(get_hive_path_part(&part_df), part_df)
}))
},
// This is sqrt(N) of the actual limit - we chunk the input both at the groups
// proxy level and within every group.
const MAX_OPEN_FILES: usize = 8;

let finish_part_df = |df: DataFrame| {
let dir_path = init_part_base_dir(&df)?;
let (n_files, rows_per_file) = get_n_files_and_rows_per_file(&df);

if n_files == 1 {
write_part(df.clone(), &dir_path.join(get_path_for_index(0)))
} else {
(0..df.height())
.step_by(rows_per_file)
.enumerate()
.collect::<Vec<_>>()
.chunks(MAX_OPEN_FILES)
.map(|chunk| {
chunk
.into_par_iter()
.map(|&(idx, slice_start)| {
let df = df.slice(slice_start as i64, rows_per_file);
write_part(df.clone(), &dir_path.join(get_path_for_index(idx)))
})
.reduce(
|| PolarsResult::Ok(()),
|a, b| if a.is_err() { a } else { b },
)
})
.collect::<PolarsResult<Vec<()>>>()?;
Ok(())
}
};

Ok(out)
POOL.install(|| match groups {
GroupsProxy::Idx(idx) => idx
.all()
.chunks(MAX_OPEN_FILES)
.map(|chunk| {
chunk
.par_iter()
.map(|group| {
let df = unsafe {
df._take_unchecked_slice_sorted(group, false, IsSorted::Ascending)
};
finish_part_df(df)
})
.reduce(
|| PolarsResult::Ok(()),
|a, b| if a.is_err() { a } else { b },
)
})
.collect::<PolarsResult<Vec<()>>>(),
GroupsProxy::Slice { groups, .. } => groups
.chunks(MAX_OPEN_FILES)
.map(|chunk| {
chunk
.into_par_iter()
.map(|&[offset, len]| {
let df = df.slice(offset as i64, len as usize);
finish_part_df(df)
})
.reduce(
|| PolarsResult::Ok(()),
|a, b| if a.is_err() { a } else { b },
)
})
.collect::<PolarsResult<Vec<()>>>(),
})?;

Ok(())
}
2 changes: 2 additions & 0 deletions crates/polars-io/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ pub use crate::json::*;
pub use crate::ndjson::core::*;
#[cfg(feature = "parquet")]
pub use crate::parquet::{metadata::*, read::*, write::*};
#[cfg(feature = "parquet")]
pub use crate::partition::write_partitioned_dataset;
pub use crate::shared::{SerReader, SerWriter};
pub use crate::utils::*;
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None:

@pytest.mark.slow()
@pytest.mark.write_disk()
def test_hive_write_multiple_files(tmp_path: Path, monkeypatch: Any) -> None:
def test_hive_write_multiple_files(tmp_path: Path) -> None:
chunk_size = 262_144
n_rows = 100_000
df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))
Expand Down

0 comments on commit 2c70546

Please sign in to comment.