Skip to content

Commit

Permalink
- Add rand_range (#269)
Browse files Browse the repository at this point in the history
- Add rand_range + range_laplace test
  • Loading branch information
zeotuan authored Oct 10, 2024
1 parent 7c1332a commit f538672
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
4 changes: 4 additions & 0 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
week_end_date,
week_start_date,
)
from quinn.math import (
rand_laplace,
rand_range,
)
from quinn.schema_helpers import print_schema_as_code
from quinn.split_columns import split_col
from quinn.transformations import (
Expand Down
23 changes: 23 additions & 0 deletions quinn/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,29 @@ def rand_laplace(
)


def rand_range(
minimum: Union[int, Column],
maximum: Union[int, Column],
seed: Optional[int] = None,
) -> Column:
"""Generate random numbers uniformly distributed in [`minimum`, `maximum`).
:param minimum: minimum value of the random numbers
:param maximum: maximum value of the random numbers
:param seed: random seed value (optional, default None)
:returns: column with random numbers
"""
if not isinstance(minimum, Column):
minimum = F.lit(minimum)

if not isinstance(maximum, Column):
maximum = F.lit(maximum)

u = F.rand(seed)

return minimum + (maximum - minimum) * u


def div_or_else(
cola: Column,
colb: Column,
Expand Down
40 changes: 40 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pyspark.sql.functions as F

import quinn
import math
from .spark import spark


def test_rand_laplace():
stats = (
spark.range(100000)
.select(quinn.rand_laplace(0.0, 1.0, 42))
.agg(
F.mean("laplace_random").alias("mean"),
F.stddev("laplace_random").alias("std_dev"),
)
.first()
)

laplace_mean = stats["mean"]
laplace_stddev = stats["std_dev"]

# Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0)
assert abs(laplace_mean) <= 0.1
assert abs(laplace_stddev - math.sqrt(2.0)) < 0.5


def test_rand_range():
lower_bound = 5
upper_bound = 10
stats = (
spark.range(1000)
.select(quinn.rand_range(lower_bound, upper_bound).alias("rand_uniform"))
.agg(F.min("rand_uniform").alias("min"), F.min("rand_uniform").alias("max"))
.first()
)

uniform_min = stats["min"]
uniform_max = stats["max"]

assert lower_bound <= uniform_min <= uniform_max <= upper_bound

0 comments on commit f538672

Please sign in to comment.