Skip to content

Commit

Permalink
feat(python,rust): Add seed argument to rank for random (#7913)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj authored Apr 2, 2023
1 parent ae8698b commit 503f353
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 48 deletions.
43 changes: 25 additions & 18 deletions polars/polars-core/src/chunked_array/ops/unique/rank.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use polars_arrow::prelude::FromData;
#[cfg(feature = "random")]
use rand::prelude::SliceRandom;
use rand::prelude::*;
#[cfg(feature = "random")]
use rand::{rngs::SmallRng, thread_rng, SeedableRng};
use rand::{rngs::SmallRng, SeedableRng};

use crate::prelude::*;

Expand Down Expand Up @@ -33,7 +34,14 @@ impl Default for RankOptions {
}
}

pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {
#[cfg(feature = "random")]
fn get_random_seed() -> u64 {
let mut rng = SmallRng::from_entropy();

rng.next_u64()
}

pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) -> Series {
match s.len() {
1 => {
return match method {
Expand Down Expand Up @@ -65,7 +73,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {
};
let s = s.fill_null(null_strategy).unwrap();

let mut out = rank(&s, method, descending);
let mut out = rank(&s, method, descending, seed);
unsafe {
let arr = &mut out.chunks_mut()[0];
*arr = arr.with_validity(Some(validity.clone()))
Expand Down Expand Up @@ -151,8 +159,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {

let mut sort_idx = sort_idx.to_vec();

let mut thread_rng = thread_rng();
let rng = &mut SmallRng::from_rng(&mut thread_rng).unwrap();
let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));

// Shuffle sort_idx positions which point to ties in the original series.
for i in 0..(ties_indices.len() - 1) {
Expand Down Expand Up @@ -313,15 +320,15 @@ mod test {
fn test_rank() -> PolarsResult<()> {
let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]);

let out = rank(&s, RankMethod::Ordinal, false)
let out = rank(&s, RankMethod::Ordinal, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]);

#[cfg(feature = "random")]
{
let out = rank(&s, RankMethod::Random, false)
let out = rank(&s, RankMethod::Random, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -334,25 +341,25 @@ mod test {
assert_ne!(out[3], out[4]);
}

let out = rank(&s, RankMethod::Dense, false)
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]);

let out = rank(&s, RankMethod::Max, false)
let out = rank(&s, RankMethod::Max, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]);

let out = rank(&s, RankMethod::Min, false)
let out = rank(&s, RankMethod::Min, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);

let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -363,7 +370,7 @@ mod test {
&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],
);

let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_iter()
.collect::<Vec<_>>();
Expand Down Expand Up @@ -393,7 +400,7 @@ mod test {
Some(8),
],
);
let out = rank(&s, RankMethod::Max, false)
let out = rank(&s, RankMethod::Max, false, None)
.idx()?
.into_iter()
.collect::<Vec<_>>();
Expand All @@ -417,12 +424,12 @@ mod test {
#[test]
fn test_rank_all_null() -> PolarsResult<()> {
let s = UInt32Chunked::new("", &[None, None, None]).into_series();
let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 2.0, 2.0]);
let out = rank(&s, RankMethod::Dense, false)
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -433,16 +440,16 @@ mod test {
#[test]
fn test_rank_empty() {
let s = UInt32Chunked::from_slice("", &[]).into_series();
let out = rank(&s, RankMethod::Average, false);
let out = rank(&s, RankMethod::Average, false, None);
assert_eq!(out.dtype(), &DataType::Float32);
let out = rank(&s, RankMethod::Max, false);
let out = rank(&s, RankMethod::Max, false, None);
assert_eq!(out.dtype(), &IDX_DTYPE);
}

#[test]
fn test_rank_reverse() -> PolarsResult<()> {
let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]);
let out = rank(&s, RankMethod::Dense, true)
let out = rank(&s, RankMethod::Dense, true, None)
.idx()?
.into_iter()
.collect::<Vec<_>>();
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,8 @@ impl Series {
}

#[cfg(feature = "rank")]
pub fn rank(&self, options: RankOptions) -> Series {
rank(self, options.method, options.descending)
pub fn rank(&self, options: RankOptions, seed: Option<u64>) -> Series {
rank(self, options.method, options.descending, seed)
}

/// Cast throws an error if conversion had overflows
Expand Down
22 changes: 14 additions & 8 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,20 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E
let a = a.drop_nulls();
let b = b.drop_nulls();

let a_idx = a.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let b_idx = b.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let a_idx = a.rank(
RankOptions {
method: RankMethod::Min,
..Default::default()
},
None,
);
let b_idx = b.rank(
RankOptions {
method: RankMethod::Min,
..Default::default()
},
None,
);
let a_idx = a_idx.idx().unwrap();
let b_idx = b_idx.idx().unwrap();

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,9 @@ impl Expr {
}

#[cfg(feature = "rank")]
pub fn rank(self, options: RankOptions) -> Expr {
pub fn rank(self, options: RankOptions, seed: Option<u64>) -> Expr {
self.apply(
move |s| Ok(Some(s.rank(options))),
move |s| Ok(Some(s.rank(options, seed))),
GetOutput::map_field(move |fld| match options.method {
RankMethod::Average => Field::new(fld.name(), DataType::Float32),
_ => Field::new(fld.name(), IDX_DTYPE),
Expand Down
22 changes: 14 additions & 8 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1565,10 +1565,13 @@ fn test_groupby_rank() -> PolarsResult<()> {
let out = df
.lazy()
.groupby_stable([col("cars")])
.agg([col("B").rank(RankOptions {
method: RankMethod::Dense,
..Default::default()
})])
.agg([col("B").rank(
RankOptions {
method: RankMethod::Dense,
..Default::default()
},
None,
)])
.collect()?;

let out = out.column("B")?;
Expand Down Expand Up @@ -1659,10 +1662,13 @@ fn test_single_ranked_group() -> PolarsResult<()> {
let out = df
.lazy()
.with_columns([col("value")
.rank(RankOptions {
method: RankMethod::Average,
..Default::default()
})
.rank(
RankOptions {
method: RankMethod::Average,
..Default::default()
},
None,
)
.list()
.over([col("group")])])
.collect()?;
Expand Down
8 changes: 4 additions & 4 deletions polars/tests/it/lazy/expressions/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.lazy()
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(col("value").rank(Default::default()))
.then(col("value").rank(Default::default(), None))
.otherwise(lit(Series::new("", &[10 as IdxSize])))])
.sort("name", Default::default())
.collect()?;
Expand All @@ -312,7 +312,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(lit(Series::new("", &[10 as IdxSize])).alias("value"))
.otherwise(col("value").rank(Default::default()))])
.otherwise(col("value").rank(Default::default(), None))])
.sort("name", Default::default())
.collect()?;

Expand All @@ -331,7 +331,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.lazy()
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(col("value").rank(Default::default()))
.then(col("value").rank(Default::default(), None))
.otherwise(Null {}.lit())])
.sort("name", Default::default())
.collect()?;
Expand All @@ -346,7 +346,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(Null {}.lit().alias("value"))
.otherwise(col("value").rank(Default::default()))])
.otherwise(col("value").rank(Default::default(), None))])
.sort("name", Default::default())
.collect()?;

Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4859,7 +4859,12 @@ def argsort(self, descending: bool = False, nulls_last: bool = False) -> Self:
return self.arg_sort(descending, nulls_last)

@deprecated_alias(reverse="descending")
def rank(self, method: RankMethod = "average", descending: bool = False) -> Self:
def rank(
self,
method: RankMethod = "average",
descending: bool = False,
seed: int | None = None,
) -> Self:
"""
Assign ranks to data, dealing with ties appropriately.
Expand All @@ -4885,6 +4890,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self
on the order that the values occur in the Series.
descending
Rank in descending order.
seed
If `method="random"`, use this as seed.
Examples
--------
Expand Down Expand Up @@ -4923,7 +4930,7 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self
└─────┘
"""
return self._from_pyexpr(self._pyexpr.rank(method, descending))
return self._from_pyexpr(self._pyexpr.rank(method, descending, seed))

def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self:
"""
Expand Down
15 changes: 13 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4816,7 +4816,12 @@ def abs(self) -> Series:

@deprecated_alias(reverse="descending")
@deprecate_nonkeyword_arguments(allowed_args=["self", "method"], stacklevel=3)
def rank(self, method: RankMethod = "average", descending: bool = False) -> Series:
def rank(
self,
method: RankMethod = "average",
descending: bool = False,
seed: int | None = None,
) -> Series:
"""
Assign ranks to data, dealing with ties appropriately.
Expand All @@ -4842,6 +4847,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri
on the order that the values occur in the Series.
descending
Rank in descending order.
seed
If `method="random"`, use this as seed.
Examples
--------
Expand Down Expand Up @@ -4876,7 +4883,11 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri
"""
return (
self.to_frame()
.select(F.col(self._s.name()).rank(method=method, descending=descending))
.select(
F.col(self._s.name()).rank(
method=method, descending=descending, seed=seed
)
)
.to_series()
)

Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1628,12 +1628,12 @@ impl PyExpr {
.into())
}

fn rank(&self, method: Wrap<RankMethod>, descending: bool) -> Self {
fn rank(&self, method: Wrap<RankMethod>, descending: bool, seed: Option<u64>) -> Self {
let options = RankOptions {
method: method.0,
descending,
};
self.inner.clone().rank(options).into()
self.inner.clone().rank(options, seed).into()
}

fn diff(&self, n: usize, null_behavior: Wrap<NullBehavior>) -> Self {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,20 @@ def test_rank_so_4109() -> None:
}


def test_rank_random() -> None:
df = pl.from_dict(
{"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]}
)

df_ranks1 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
df_ranks2 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
assert_frame_equal(df_ranks1, df_ranks2)


def test_unique_empty() -> None:
for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]:
s = pl.Series([], dtype=dt)
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,13 @@ def test_rank() -> None:
)


def test_rank_random() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
assert_series_equal(
s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=UInt32)
)


def test_diff() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
expected = pl.Series("a", [1, 1, -1, 0, 1, -3])
Expand Down

0 comments on commit 503f353

Please sign in to comment.