From 503f353166e17f0cc1df36fafe0dd6d52939c12a Mon Sep 17 00:00:00 2001 From: J van Zundert Date: Sun, 2 Apr 2023 19:30:18 +0100 Subject: [PATCH] feat(python,rust): Add seed argument to rank for random (#7913) --- .../src/chunked_array/ops/unique/rank.rs | 43 +++++++++++-------- polars/polars-core/src/series/mod.rs | 4 +- .../polars-plan/src/dsl/functions.rs | 22 ++++++---- polars/polars-lazy/polars-plan/src/dsl/mod.rs | 4 +- polars/polars-lazy/src/tests/queries.rs | 22 ++++++---- polars/tests/it/lazy/expressions/arity.rs | 8 ++-- py-polars/polars/expr/expr.py | 11 ++++- py-polars/polars/series/series.py | 15 ++++++- py-polars/src/lazy/dsl.rs | 4 +- py-polars/tests/unit/test_exprs.py | 14 ++++++ py-polars/tests/unit/test_series.py | 7 +++ 11 files changed, 106 insertions(+), 48 deletions(-) diff --git a/polars/polars-core/src/chunked_array/ops/unique/rank.rs b/polars/polars-core/src/chunked_array/ops/unique/rank.rs index 5ad4485d8d9e..aa0eec03d0b8 100644 --- a/polars/polars-core/src/chunked_array/ops/unique/rank.rs +++ b/polars/polars-core/src/chunked_array/ops/unique/rank.rs @@ -1,8 +1,9 @@ use polars_arrow::prelude::FromData; #[cfg(feature = "random")] use rand::prelude::SliceRandom; +use rand::prelude::*; #[cfg(feature = "random")] -use rand::{rngs::SmallRng, thread_rng, SeedableRng}; +use rand::{rngs::SmallRng, SeedableRng}; use crate::prelude::*; @@ -33,7 +34,14 @@ impl Default for RankOptions { } } -pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series { +#[cfg(feature = "random")] +fn get_random_seed() -> u64 { + let mut rng = SmallRng::from_entropy(); + + rng.next_u64() +} + +pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { match s.len() { 1 => { return match method { @@ -65,7 +73,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series { }; let s = s.fill_null(null_strategy).unwrap(); - let mut out = rank(&s, method, descending); + let mut out = rank(&s, method, descending, seed); unsafe { let arr = &mut out.chunks_mut()[0]; *arr = arr.with_validity(Some(validity.clone())) @@ -151,8 +159,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series { let mut sort_idx = sort_idx.to_vec(); - let mut thread_rng = thread_rng(); - let rng = &mut SmallRng::from_rng(&mut thread_rng).unwrap(); + let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); // Shuffle sort_idx positions which point to ties in the original series. for i in 0..(ties_indices.len() - 1) { @@ -313,7 +320,7 @@ mod test { fn test_rank() -> PolarsResult<()> { let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]); - let out = rank(&s, RankMethod::Ordinal, false) + let out = rank(&s, RankMethod::Ordinal, false, None) .idx()? .into_no_null_iter() .collect::>(); @@ -321,7 +328,7 @@ mod test { #[cfg(feature = "random")] { - let out = rank(&s, RankMethod::Random, false) + let out = rank(&s, RankMethod::Random, false, None) .idx()? .into_no_null_iter() .collect::>(); @@ -334,25 +341,25 @@ mod test { assert_ne!(out[3], out[4]); } - let out = rank(&s, RankMethod::Dense, false) + let out = rank(&s, RankMethod::Dense, false, None) .idx()? .into_no_null_iter() .collect::>(); assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]); - let out = rank(&s, RankMethod::Max, false) + let out = rank(&s, RankMethod::Max, false, None) .idx()? .into_no_null_iter() .collect::>(); assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]); - let out = rank(&s, RankMethod::Min, false) + let out = rank(&s, RankMethod::Min, false, None) .idx()? .into_no_null_iter() .collect::>(); assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]); - let out = rank(&s, RankMethod::Average, false) + let out = rank(&s, RankMethod::Average, false, None) .f32()? .into_no_null_iter() .collect::>(); @@ -363,7 +370,7 @@ mod test { &[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)], ); - let out = rank(&s, RankMethod::Average, false) + let out = rank(&s, RankMethod::Average, false, None) .f32()? .into_iter() .collect::>(); @@ -393,7 +400,7 @@ mod test { Some(8), ], ); - let out = rank(&s, RankMethod::Max, false) + let out = rank(&s, RankMethod::Max, false, None) .idx()? .into_iter() .collect::>(); @@ -417,12 +424,12 @@ mod test { #[test] fn test_rank_all_null() -> PolarsResult<()> { let s = UInt32Chunked::new("", &[None, None, None]).into_series(); - let out = rank(&s, RankMethod::Average, false) + let out = rank(&s, RankMethod::Average, false, None) .f32()? .into_no_null_iter() .collect::>(); assert_eq!(out, &[2.0f32, 2.0, 2.0]); - let out = rank(&s, RankMethod::Dense, false) + let out = rank(&s, RankMethod::Dense, false, None) .idx()? .into_no_null_iter() .collect::>(); @@ -433,16 +440,16 @@ mod test { #[test] fn test_rank_empty() { let s = UInt32Chunked::from_slice("", &[]).into_series(); - let out = rank(&s, RankMethod::Average, false); + let out = rank(&s, RankMethod::Average, false, None); assert_eq!(out.dtype(), &DataType::Float32); - let out = rank(&s, RankMethod::Max, false); + let out = rank(&s, RankMethod::Max, false, None); assert_eq!(out.dtype(), &IDX_DTYPE); } #[test] fn test_rank_reverse() -> PolarsResult<()> { let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]); - let out = rank(&s, RankMethod::Dense, true) + let out = rank(&s, RankMethod::Dense, true, None) .idx()? .into_iter() .collect::>(); diff --git a/polars/polars-core/src/series/mod.rs b/polars/polars-core/src/series/mod.rs index 232111b7c449..e84bc2b00309 100644 --- a/polars/polars-core/src/series/mod.rs +++ b/polars/polars-core/src/series/mod.rs @@ -695,8 +695,8 @@ impl Series { } #[cfg(feature = "rank")] - pub fn rank(&self, options: RankOptions) -> Series { - rank(self, options.method, options.descending) + pub fn rank(&self, options: RankOptions, seed: Option) -> Series { + rank(self, options.method, options.descending, seed) } /// Cast throws an error if conversion had overflows diff --git a/polars/polars-lazy/polars-plan/src/dsl/functions.rs b/polars/polars-lazy/polars-plan/src/dsl/functions.rs index 1f666c4bb46e..8fa169396a46 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/functions.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/functions.rs @@ -204,14 +204,20 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E let a = a.drop_nulls(); let b = b.drop_nulls(); - let a_idx = a.rank(RankOptions { - method: RankMethod::Min, - ..Default::default() - }); - let b_idx = b.rank(RankOptions { - method: RankMethod::Min, - ..Default::default() - }); + let a_idx = a.rank( + RankOptions { + method: RankMethod::Min, + ..Default::default() + }, + None, + ); + let b_idx = b.rank( + RankOptions { + method: RankMethod::Min, + ..Default::default() + }, + None, + ); let a_idx = a_idx.idx().unwrap(); let b_idx = b_idx.idx().unwrap(); diff --git a/polars/polars-lazy/polars-plan/src/dsl/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/mod.rs index 50d469d8d484..20ef7bc7d2bc 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/mod.rs @@ -1620,9 +1620,9 @@ impl Expr { } #[cfg(feature = "rank")] - pub fn rank(self, options: RankOptions) -> Expr { + pub fn rank(self, options: RankOptions, seed: Option) -> Expr { self.apply( - move |s| Ok(Some(s.rank(options))), + move |s| Ok(Some(s.rank(options, seed))), GetOutput::map_field(move |fld| match options.method { RankMethod::Average => Field::new(fld.name(), DataType::Float32), _ => Field::new(fld.name(), IDX_DTYPE), diff --git a/polars/polars-lazy/src/tests/queries.rs b/polars/polars-lazy/src/tests/queries.rs index e6f4afe351f1..ef7d34e8e69b 100644 --- a/polars/polars-lazy/src/tests/queries.rs +++ b/polars/polars-lazy/src/tests/queries.rs @@ -1565,10 +1565,13 @@ fn test_groupby_rank() -> PolarsResult<()> { let out = df .lazy() .groupby_stable([col("cars")]) - .agg([col("B").rank(RankOptions { - method: RankMethod::Dense, - ..Default::default() - })]) + .agg([col("B").rank( + RankOptions { + method: RankMethod::Dense, + ..Default::default() + }, + None, + )]) .collect()?; let out = out.column("B")?; @@ -1659,10 +1662,13 @@ fn test_single_ranked_group() -> PolarsResult<()> { let out = df .lazy() .with_columns([col("value") - .rank(RankOptions { - method: RankMethod::Average, - ..Default::default() - }) + .rank( + RankOptions { + method: RankMethod::Average, + ..Default::default() + }, + None, + ) .list() .over([col("group")])]) .collect()?; diff --git a/polars/tests/it/lazy/expressions/arity.rs b/polars/tests/it/lazy/expressions/arity.rs index c9393d2076bd..5ad20414ae1d 100644 --- a/polars/tests/it/lazy/expressions/arity.rs +++ b/polars/tests/it/lazy/expressions/arity.rs @@ -291,7 +291,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .lazy() .groupby([col("name")]) .agg([when(col("value").sum().eq(lit(3))) - .then(col("value").rank(Default::default())) + .then(col("value").rank(Default::default(), None)) .otherwise(lit(Series::new("", &[10 as IdxSize])))]) .sort("name", Default::default()) .collect()?; @@ -312,7 +312,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .groupby([col("name")]) .agg([when(col("value").sum().eq(lit(3))) .then(lit(Series::new("", &[10 as IdxSize])).alias("value")) - .otherwise(col("value").rank(Default::default()))]) + .otherwise(col("value").rank(Default::default(), None))]) .sort("name", Default::default()) .collect()?; @@ -331,7 +331,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .lazy() .groupby([col("name")]) .agg([when(col("value").sum().eq(lit(3))) - .then(col("value").rank(Default::default())) + .then(col("value").rank(Default::default(), None)) .otherwise(Null {}.lit())]) .sort("name", Default::default()) .collect()?; @@ -346,7 +346,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .groupby([col("name")]) .agg([when(col("value").sum().eq(lit(3))) .then(Null {}.lit().alias("value")) - .otherwise(col("value").rank(Default::default()))]) + .otherwise(col("value").rank(Default::default(), None))]) .sort("name", Default::default()) .collect()?; diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index a7661a0261bd..8cd67e90f19a 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4859,7 +4859,12 @@ def argsort(self, descending: bool = False, nulls_last: bool = False) -> Self: return self.arg_sort(descending, nulls_last) @deprecated_alias(reverse="descending") - def rank(self, method: RankMethod = "average", descending: bool = False) -> Self: + def rank( + self, + method: RankMethod = "average", + descending: bool = False, + seed: int | None = None, + ) -> Self: """ Assign ranks to data, dealing with ties appropriately. @@ -4885,6 +4890,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self on the order that the values occur in the Series. descending Rank in descending order. + seed + If `method="random"`, use this as seed. Examples -------- @@ -4923,7 +4930,7 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self └─────┘ """ - return self._from_pyexpr(self._pyexpr.rank(method, descending)) + return self._from_pyexpr(self._pyexpr.rank(method, descending, seed)) def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 826a9d57bac6..a2782ad630c6 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4816,7 +4816,12 @@ def abs(self) -> Series: @deprecated_alias(reverse="descending") @deprecate_nonkeyword_arguments(allowed_args=["self", "method"], stacklevel=3) - def rank(self, method: RankMethod = "average", descending: bool = False) -> Series: + def rank( + self, + method: RankMethod = "average", + descending: bool = False, + seed: int | None = None, + ) -> Series: """ Assign ranks to data, dealing with ties appropriately. @@ -4842,6 +4847,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri on the order that the values occur in the Series. descending Rank in descending order. + seed + If `method="random"`, use this as seed. Examples -------- @@ -4876,7 +4883,11 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri """ return ( self.to_frame() - .select(F.col(self._s.name()).rank(method=method, descending=descending)) + .select( + F.col(self._s.name()).rank( + method=method, descending=descending, seed=seed + ) + ) .to_series() ) diff --git a/py-polars/src/lazy/dsl.rs b/py-polars/src/lazy/dsl.rs index 4e9f99752d32..0578a075110d 100644 --- a/py-polars/src/lazy/dsl.rs +++ b/py-polars/src/lazy/dsl.rs @@ -1628,12 +1628,12 @@ impl PyExpr { .into()) } - fn rank(&self, method: Wrap, descending: bool) -> Self { + fn rank(&self, method: Wrap, descending: bool, seed: Option) -> Self { let options = RankOptions { method: method.0, descending, }; - self.inner.clone().rank(options).into() + self.inner.clone().rank(options, seed).into() } fn diff(&self, n: usize, null_behavior: Wrap) -> Self { diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index a829a9ff22a8..c740e381f69d 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -460,6 +460,20 @@ def test_rank_so_4109() -> None: } +def test_rank_random() -> None: + df = pl.from_dict( + {"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]} + ) + + df_ranks1 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + df_ranks2 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + assert_frame_equal(df_ranks1, df_ranks2) + + def test_unique_empty() -> None: for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]: s = pl.Series([], dtype=dt) diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 6de1f6f038ab..0b2f9f49619f 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -1284,6 +1284,13 @@ def test_rank() -> None: ) +def test_rank_random() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + assert_series_equal( + s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=UInt32) + ) + + def test_diff() -> None: s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) expected = pl.Series("a", [1, 1, -1, 0, 1, -3])