Skip to content

Commit

Permalink
perf(rust, python): first check rev-map on categorical equality check (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and zundertj committed Jan 7, 2023
1 parent 9bf4397 commit ae31c5b
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ impl RevMapping {
}
}

pub fn get_optional(&self, idx: u32) -> Option<&str> {
match self {
Self::Global(map, a, _) => {
let idx = *map.get(&idx)?;
a.get(idx as usize)
}
Self::Local(a) => a.get(idx as usize),
}
}

/// Categorical to str
///
/// # Safety
Expand Down
30 changes: 22 additions & 8 deletions polars/polars-core/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,17 @@ impl ChunkCompare<&Series> for Series {
#[cfg(feature = "dtype-categorical")]
(Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => {
if rev_map_l.same_src(rev_map_r) {
self.categorical()
.unwrap()
.logical()
.equal(rhs.categorical().unwrap().logical())
let rhs = rhs.categorical().unwrap().logical();

// first check the rev-map
if rhs.len() == 1 && rhs.null_count() == 0 {
let rhs = rhs.get(0).unwrap();
if rev_map_l.get_optional(rhs).is_none() {
return Ok(BooleanChunked::full(self.name(), false, self.len()));
}
}

self.categorical().unwrap().logical().equal(rhs)
} else {
return Err(PolarsError::ComputeError("Cannot compare categoricals originating from different sources. Consider setting a global string cache.".into()));
}
Expand Down Expand Up @@ -203,10 +210,17 @@ impl ChunkCompare<&Series> for Series {
#[cfg(feature = "dtype-categorical")]
(Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => {
if rev_map_l.same_src(rev_map_r) {
self.categorical()
.unwrap()
.logical()
.not_equal(rhs.categorical().unwrap().logical())
let rhs = rhs.categorical().unwrap().logical();

// first check the rev-map
if rhs.len() == 1 && rhs.null_count() == 0 {
let rhs = rhs.get(0).unwrap();
if rev_map_l.get_optional(rhs).is_none() {
return Ok(BooleanChunked::full(self.name(), true, self.len()));
}
}

self.categorical().unwrap().logical().not_equal(rhs)
} else {
return Err(PolarsError::ComputeError("Cannot compare categoricals originating from different sources. Consider setting a global string cache.".into()));
}
Expand Down
79 changes: 76 additions & 3 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ae31c5b

Please sign in to comment.