Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support hf:// in read_(csv|ipc|ndjson) functions #17785

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions crates/polars-io/src/path_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ pub fn expand_paths_hive(
if is_cloud || { cfg!(not(target_family = "windows")) && config::force_async() } {
#[cfg(feature = "cloud")]
{
use polars_utils::_limit_path_len_io_err;

use crate::cloud::object_path_from_string;

if first_path.starts_with("hf://") {
Expand Down Expand Up @@ -199,14 +201,8 @@ pub fn expand_paths_hive(
// indistinguishable from an empty directory.
let path = PathBuf::from(path);
if !path.is_dir() {
path.metadata().map_err(|err| {
let msg =
Some(format!("{}: {}", err, path.to_str().unwrap()).into());
PolarsError::IO {
error: err.into(),
msg,
}
})?;
path.metadata()
.map_err(|err| _limit_path_len_io_err(&path, err))?;
}
}

Expand Down
5 changes: 1 addition & 4 deletions crates/polars-io/src/utils/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], off
}

#[cfg(feature = "json")]
pub(crate) fn overwrite_schema(
schema: &mut Schema,
overwriting_schema: &Schema,
) -> PolarsResult<()> {
pub fn overwrite_schema(schema: &mut Schema, overwriting_schema: &Schema) -> PolarsResult<()> {
for (k, value) in overwriting_schema.iter() {
*schema.try_get_mut(k)? = value.clone();
}
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-lazy/src/scan/ndjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct LazyJsonLineReader {
pub(crate) low_memory: bool,
pub(crate) rechunk: bool,
pub(crate) schema: Option<SchemaRef>,
pub(crate) schema_overwrite: Option<SchemaRef>,
pub(crate) row_index: Option<RowIndex>,
pub(crate) infer_schema_length: Option<NonZeroUsize>,
pub(crate) n_rows: Option<usize>,
Expand All @@ -38,6 +39,7 @@ impl LazyJsonLineReader {
low_memory: false,
rechunk: false,
schema: None,
schema_overwrite: None,
row_index: None,
infer_schema_length: NonZeroUsize::new(100),
ignore_errors: false,
Expand Down Expand Up @@ -82,6 +84,13 @@ impl LazyJsonLineReader {
self
}

/// Set the JSON file's schema
#[must_use]
pub fn with_schema_overwrite(mut self, schema_overwrite: Option<SchemaRef>) -> Self {
self.schema_overwrite = schema_overwrite;
self
}

/// Reduce memory usage at the expense of performance
#[must_use]
pub fn low_memory(mut self, toggle: bool) -> Self {
Expand Down Expand Up @@ -129,6 +138,7 @@ impl LazyFileListReader for LazyJsonLineReader {
low_memory: self.low_memory,
ignore_errors: self.ignore_errors,
schema: self.schema,
schema_overwrite: self.schema_overwrite,
};

let scan_type = FileScan::NDJson {
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-plan/src/plans/conversion/scans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(super) fn ndjson_file_info(
};
let mut reader = std::io::BufReader::new(f);

let (reader_schema, schema) = if let Some(schema) = ndjson_options.schema.take() {
let (mut reader_schema, schema) = if let Some(schema) = ndjson_options.schema.take() {
if file_options.row_index.is_none() {
(schema.clone(), schema.clone())
} else {
Expand All @@ -340,6 +340,11 @@ pub(super) fn ndjson_file_info(
prepare_schemas(schema, file_options.row_index.as_ref())
};

if let Some(overwriting_schema) = &ndjson_options.schema_overwrite {
let schema = Arc::make_mut(&mut reader_schema);
overwrite_schema(schema, overwriting_schema)?;
}

Ok(FileInfo::new(
schema,
Some(Either::Right(reader_schema)),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,5 @@ pub struct NDJsonReadOptions {
pub low_memory: bool,
pub ignore_errors: bool,
pub schema: Option<SchemaRef>,
pub schema_overwrite: Option<SchemaRef>,
}
6 changes: 3 additions & 3 deletions crates/polars-utils/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::path::Path;

use polars_error::*;

fn map_err(path: &Path, err: io::Error) -> PolarsError {
pub fn _limit_path_len_io_err(path: &Path, err: io::Error) -> PolarsError {
let path = path.to_string_lossy();
let msg = if path.len() > 88 {
let truncated_path: String = path.chars().skip(path.len() - 88).collect();
Expand All @@ -19,12 +19,12 @@ pub fn open_file<P>(path: P) -> PolarsResult<File>
where
P: AsRef<Path>,
{
File::open(&path).map_err(|err| map_err(path.as_ref(), err))
File::open(&path).map_err(|err| _limit_path_len_io_err(path.as_ref(), err))
}

pub fn create_file<P>(path: P) -> PolarsResult<File>
where
P: AsRef<Path>,
{
File::create(&path).map_err(|err| map_err(path.as_ref(), err))
File::create(&path).map_err(|err| _limit_path_len_io_err(path.as_ref(), err))
}
97 changes: 82 additions & 15 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import contextlib
import os
from io import BytesIO, StringIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Mapping, Sequence

import polars._reexport as pl
import polars.functions as F
from polars._utils.deprecation import deprecate_renamed_parameter
from polars._utils.various import (
_process_null_values,
Expand Down Expand Up @@ -419,45 +421,110 @@ def read_csv(
if not infer_schema:
infer_schema_length = 0

with prepare_file_arg(
source,
encoding=encoding,
use_pyarrow=False,
raise_if_empty=raise_if_empty,
storage_options=storage_options,
) as data:
df = _read_csv_impl(
data,
# TODO: scan_csv doesn't support a "dtype slice" (i.e. list[DataType])
schema_overrides_is_list = isinstance(schema_overrides, Sequence)
encoding_supported_in_lazy = encoding in {"utf8", "utf8-lossy"}

if (
# Check that it is not a BytesIO object
isinstance(v := source, (str, Path))
) and (
# HuggingFace only for now ⊂( ◜◒◝ )⊃
str(v).startswith("hf://")
# Also dispatch on FORCE_ASYNC, so that this codepath gets run
# through by our test suite during CI.
or (
os.getenv("POLARS_FORCE_ASYNC") == "1"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magical test coverage for the hf:// dispatch 😉

and not schema_overrides_is_list
and encoding_supported_in_lazy
)
# TODO: We can't dispatch this for all paths due to a few reasons:
# * `scan_csv` does not support compressed files
# * The `storage_options` configuration keys are different between
# fsspec and object_store (would require a breaking change)
):
if schema_overrides_is_list:
msg = "passing a list to `schema_overrides` is unsupported for hf:// paths"
raise ValueError(msg)
if not encoding_supported_in_lazy:
msg = f"unsupported encoding {encoding} for hf:// paths"
raise ValueError(msg)

lf = _scan_csv_impl(
source, # type: ignore[arg-type]
has_header=has_header,
columns=columns if columns else projection,
separator=separator,
comment_prefix=comment_prefix,
quote_char=quote_char,
skip_rows=skip_rows,
schema_overrides=schema_overrides,
schema_overrides=schema_overrides, # type: ignore[arg-type]
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
try_parse_dates=try_parse_dates,
n_threads=n_threads,
infer_schema_length=infer_schema_length,
batch_size=batch_size,
n_rows=n_rows,
encoding=encoding if encoding == "utf8-lossy" else "utf8",
encoding=encoding, # type: ignore[arg-type]
low_memory=low_memory,
rechunk=rechunk,
skip_rows_after_header=skip_rows_after_header,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
sample_size=sample_size,
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
)

if columns:
lf = lf.select(columns)
elif projection:
lf = lf.select(F.nth(projection))

df = lf.collect()

else:
with prepare_file_arg(
source,
encoding=encoding,
use_pyarrow=False,
raise_if_empty=raise_if_empty,
storage_options=storage_options,
) as data:
df = _read_csv_impl(
data,
has_header=has_header,
columns=columns if columns else projection,
separator=separator,
comment_prefix=comment_prefix,
quote_char=quote_char,
skip_rows=skip_rows,
schema_overrides=schema_overrides,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
try_parse_dates=try_parse_dates,
n_threads=n_threads,
infer_schema_length=infer_schema_length,
batch_size=batch_size,
n_rows=n_rows,
encoding=encoding if encoding == "utf8-lossy" else "utf8",
low_memory=low_memory,
rechunk=rechunk,
skip_rows_after_header=skip_rows_after_header,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
sample_size=sample_size,
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
)

if new_columns:
return _update_columns(df, new_columns)
return df
Expand Down
42 changes: 38 additions & 4 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import contextlib
import os
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Sequence

import polars._reexport as pl
import polars.functions as F
from polars._utils.deprecation import deprecate_renamed_parameter
from polars._utils.various import (
is_str_sequence,
Expand All @@ -29,8 +31,6 @@
from polars._typing import SchemaDict


@deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4")
@deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4")
stinodego marked this conversation as resolved.
Show resolved Hide resolved
def read_ipc(
source: str | Path | IO[bytes] | bytes,
*,
Expand Down Expand Up @@ -92,6 +92,42 @@ def read_ipc(
That means that you cannot write to the same filename.
E.g. `pl.read_ipc("my_file.arrow").write_ipc("my_file.arrow")` will fail.
"""
if (
# Check that it is not a BytesIO object
isinstance(v := source, (str, Path))
) and (
# HuggingFace only for now ⊂( ◜◒◝ )⊃
(is_hf := str(v).startswith("hf://"))
# Also dispatch on FORCE_ASYNC, so that this codepath gets run
# through by our test suite during CI.
or os.getenv("POLARS_FORCE_ASYNC") == "1"
# TODO: Dispatch all paths to `scan_ipc` - this will need a breaking
# change to the `storage_options` parameter.
):
if is_hf and use_pyarrow:
msg = "`use_pyarrow=True` is not supported for Hugging Face"
raise ValueError(msg)

lf = scan_ipc(
source, # type: ignore[arg-type]
n_rows=n_rows,
memory_map=memory_map,
storage_options=storage_options,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
rechunk=rechunk,
)

if columns:
if isinstance(columns[0], int):
lf = lf.select(F.nth(columns)) # type: ignore[arg-type]
else:
lf = lf.select(columns)

df = lf.collect()

return df

if use_pyarrow and n_rows and not memory_map:
msg = "`n_rows` cannot be used with `use_pyarrow=True` and `memory_map=False`"
raise ValueError(msg)
Expand Down Expand Up @@ -305,8 +341,6 @@ def read_ipc_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataTyp
return _read_ipc_schema(source)


@deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4")
@deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4")
def scan_ipc(
source: str | Path | list[str] | list[Path],
*,
Expand Down
Loading