Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python,rust): Add seed argument to rank for random #7913

Merged
merged 4 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions polars/polars-core/src/chunked_array/ops/unique/rank.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use polars_arrow::prelude::FromData;
#[cfg(feature = "random")]
use rand::prelude::SliceRandom;
use rand::prelude::*;
#[cfg(feature = "random")]
use rand::{rngs::SmallRng, thread_rng, SeedableRng};
use rand::{rngs::SmallRng, SeedableRng};

use crate::prelude::*;

Expand Down Expand Up @@ -33,7 +34,14 @@ impl Default for RankOptions {
}
}

pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {
#[cfg(feature = "random")]
fn get_random_seed() -> u64 {
let mut rng = SmallRng::from_entropy();

rng.next_u64()
}

pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) -> Series {
match s.len() {
1 => {
return match method {
Expand Down Expand Up @@ -65,7 +73,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {
};
let s = s.fill_null(null_strategy).unwrap();

let mut out = rank(&s, method, descending);
let mut out = rank(&s, method, descending, seed);
unsafe {
let arr = &mut out.chunks_mut()[0];
*arr = arr.with_validity(Some(validity.clone()))
Expand Down Expand Up @@ -151,8 +159,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool) -> Series {

let mut sort_idx = sort_idx.to_vec();

let mut thread_rng = thread_rng();
let rng = &mut SmallRng::from_rng(&mut thread_rng).unwrap();
let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));

// Shuffle sort_idx positions which point to ties in the original series.
for i in 0..(ties_indices.len() - 1) {
Expand Down Expand Up @@ -313,15 +320,15 @@ mod test {
fn test_rank() -> PolarsResult<()> {
let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]);

let out = rank(&s, RankMethod::Ordinal, false)
let out = rank(&s, RankMethod::Ordinal, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]);

#[cfg(feature = "random")]
{
let out = rank(&s, RankMethod::Random, false)
let out = rank(&s, RankMethod::Random, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -334,25 +341,25 @@ mod test {
assert_ne!(out[3], out[4]);
}

let out = rank(&s, RankMethod::Dense, false)
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]);

let out = rank(&s, RankMethod::Max, false)
let out = rank(&s, RankMethod::Max, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]);

let out = rank(&s, RankMethod::Min, false)
let out = rank(&s, RankMethod::Min, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);

let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -363,7 +370,7 @@ mod test {
&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],
);

let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_iter()
.collect::<Vec<_>>();
Expand Down Expand Up @@ -393,7 +400,7 @@ mod test {
Some(8),
],
);
let out = rank(&s, RankMethod::Max, false)
let out = rank(&s, RankMethod::Max, false, None)
.idx()?
.into_iter()
.collect::<Vec<_>>();
Expand All @@ -417,12 +424,12 @@ mod test {
#[test]
fn test_rank_all_null() -> PolarsResult<()> {
let s = UInt32Chunked::new("", &[None, None, None]).into_series();
let out = rank(&s, RankMethod::Average, false)
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 2.0, 2.0]);
let out = rank(&s, RankMethod::Dense, false)
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
Expand All @@ -433,16 +440,16 @@ mod test {
#[test]
fn test_rank_empty() {
let s = UInt32Chunked::from_slice("", &[]).into_series();
let out = rank(&s, RankMethod::Average, false);
let out = rank(&s, RankMethod::Average, false, None);
assert_eq!(out.dtype(), &DataType::Float32);
let out = rank(&s, RankMethod::Max, false);
let out = rank(&s, RankMethod::Max, false, None);
assert_eq!(out.dtype(), &IDX_DTYPE);
}

#[test]
fn test_rank_reverse() -> PolarsResult<()> {
let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]);
let out = rank(&s, RankMethod::Dense, true)
let out = rank(&s, RankMethod::Dense, true, None)
.idx()?
.into_iter()
.collect::<Vec<_>>();
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,8 @@ impl Series {
}

#[cfg(feature = "rank")]
pub fn rank(&self, options: RankOptions) -> Series {
rank(self, options.method, options.descending)
pub fn rank(&self, options: RankOptions, seed: Option<u64>) -> Series {
rank(self, options.method, options.descending, seed)
}

/// Cast throws an error if conversion had overflows
Expand Down
22 changes: 14 additions & 8 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,20 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E
let a = a.drop_nulls();
let b = b.drop_nulls();

let a_idx = a.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let b_idx = b.rank(RankOptions {
method: RankMethod::Min,
..Default::default()
});
let a_idx = a.rank(
RankOptions {
method: RankMethod::Min,
..Default::default()
},
None,
);
let b_idx = b.rank(
RankOptions {
method: RankMethod::Min,
..Default::default()
},
None,
);
let a_idx = a_idx.idx().unwrap();
let b_idx = b_idx.idx().unwrap();

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,9 @@ impl Expr {
}

#[cfg(feature = "rank")]
pub fn rank(self, options: RankOptions) -> Expr {
pub fn rank(self, options: RankOptions, seed: Option<u64>) -> Expr {
self.apply(
move |s| Ok(Some(s.rank(options))),
move |s| Ok(Some(s.rank(options, seed))),
GetOutput::map_field(move |fld| match options.method {
RankMethod::Average => Field::new(fld.name(), DataType::Float32),
_ => Field::new(fld.name(), IDX_DTYPE),
Expand Down
22 changes: 14 additions & 8 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1565,10 +1565,13 @@ fn test_groupby_rank() -> PolarsResult<()> {
let out = df
.lazy()
.groupby_stable([col("cars")])
.agg([col("B").rank(RankOptions {
method: RankMethod::Dense,
..Default::default()
})])
.agg([col("B").rank(
RankOptions {
method: RankMethod::Dense,
..Default::default()
},
None,
)])
.collect()?;

let out = out.column("B")?;
Expand Down Expand Up @@ -1659,10 +1662,13 @@ fn test_single_ranked_group() -> PolarsResult<()> {
let out = df
.lazy()
.with_columns([col("value")
.rank(RankOptions {
method: RankMethod::Average,
..Default::default()
})
.rank(
RankOptions {
method: RankMethod::Average,
..Default::default()
},
None,
)
.list()
.over([col("group")])])
.collect()?;
Expand Down
8 changes: 4 additions & 4 deletions polars/tests/it/lazy/expressions/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.lazy()
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(col("value").rank(Default::default()))
.then(col("value").rank(Default::default(), None))
.otherwise(lit(Series::new("", &[10 as IdxSize])))])
.sort("name", Default::default())
.collect()?;
Expand All @@ -312,7 +312,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(lit(Series::new("", &[10 as IdxSize])).alias("value"))
.otherwise(col("value").rank(Default::default()))])
.otherwise(col("value").rank(Default::default(), None))])
.sort("name", Default::default())
.collect()?;

Expand All @@ -331,7 +331,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.lazy()
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(col("value").rank(Default::default()))
.then(col("value").rank(Default::default(), None))
.otherwise(Null {}.lit())])
.sort("name", Default::default())
.collect()?;
Expand All @@ -346,7 +346,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> {
.groupby([col("name")])
.agg([when(col("value").sum().eq(lit(3)))
.then(Null {}.lit().alias("value"))
.otherwise(col("value").rank(Default::default()))])
.otherwise(col("value").rank(Default::default(), None))])
.sort("name", Default::default())
.collect()?;

Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4855,7 +4855,12 @@ def argsort(self, descending: bool = False, nulls_last: bool = False) -> Self:
return self.arg_sort(descending, nulls_last)

@deprecated_alias(reverse="descending")
def rank(self, method: RankMethod = "average", descending: bool = False) -> Self:
def rank(
self,
method: RankMethod = "average",
descending: bool = False,
seed: int | None = None,
) -> Self:
"""
Assign ranks to data, dealing with ties appropriately.

Expand All @@ -4881,6 +4886,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self
on the order that the values occur in the Series.
descending
Rank in descending order.
seed
If `method="random"`, use this as seed.

Examples
--------
Expand Down Expand Up @@ -4919,7 +4926,7 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Self
└─────┘

"""
return self._from_pyexpr(self._pyexpr.rank(method, descending))
return self._from_pyexpr(self._pyexpr.rank(method, descending, seed))

def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self:
"""
Expand Down
15 changes: 13 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4816,7 +4816,12 @@ def abs(self) -> Series:

@deprecated_alias(reverse="descending")
@deprecate_nonkeyword_arguments(allowed_args=["self", "method"], stacklevel=3)
def rank(self, method: RankMethod = "average", descending: bool = False) -> Series:
def rank(
self,
method: RankMethod = "average",
descending: bool = False,
seed: int | None = None,
) -> Series:
"""
Assign ranks to data, dealing with ties appropriately.

Expand All @@ -4842,6 +4847,8 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri
on the order that the values occur in the Series.
descending
Rank in descending order.
seed
If `method="random"`, use this as seed.

Examples
--------
Expand Down Expand Up @@ -4876,7 +4883,11 @@ def rank(self, method: RankMethod = "average", descending: bool = False) -> Seri
"""
return (
self.to_frame()
.select(F.col(self._s.name()).rank(method=method, descending=descending))
.select(
F.col(self._s.name()).rank(
method=method, descending=descending, seed=seed
)
)
.to_series()
)

Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1628,12 +1628,12 @@ impl PyExpr {
.into())
}

fn rank(&self, method: Wrap<RankMethod>, descending: bool) -> Self {
fn rank(&self, method: Wrap<RankMethod>, descending: bool, seed: Option<u64>) -> Self {
let options = RankOptions {
method: method.0,
descending,
};
self.inner.clone().rank(options).into()
self.inner.clone().rank(options, seed).into()
}

fn diff(&self, n: usize, null_behavior: Wrap<NullBehavior>) -> Self {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,20 @@ def test_rank_so_4109() -> None:
}


def test_rank_random() -> None:
df = pl.from_dict(
{"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]}
)

df_ranks1 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
df_ranks2 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
assert_frame_equal(df_ranks1, df_ranks2)


def test_unique_empty() -> None:
for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]:
s = pl.Series([], dtype=dt)
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,13 @@ def test_rank() -> None:
)


def test_rank_random() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
assert_series_equal(
s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=UInt32)
)


def test_diff() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
expected = pl.Series("a", [1, 1, -1, 0, 1, -3])
Expand Down