Skip to content

Commit

Permalink
Average Baskets per period
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiekt committed Jan 13, 2023
1 parent e5b03be commit c155abf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
2 changes: 1 addition & 1 deletion jstark/feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class FeatureGenerator(metaclass=ABCMeta):
def __init__(
self,
as_at: date,
feature_periods: List[Union[FeaturePeriod, str]] = [
feature_periods: Union[List[FeaturePeriod], List[str]] = [
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 52, 0),
],
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion jstark/purchasing_feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PurchasingFeatureGenerator(FeatureGenerator):
def __init__(
self,
as_at: date,
feature_periods: List[Union[FeaturePeriod, str]] = [
feature_periods: Union[List[FeaturePeriod], List[str]] = [
FeaturePeriod(PeriodUnitOfMeasure.DAY, 2, 0),
FeaturePeriod(PeriodUnitOfMeasure.DAY, 4, 3),
],
Expand Down
61 changes: 29 additions & 32 deletions tests/test_purchasing_feature_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date, datetime, timedelta
import pytest
from math import pow
from typing import List
from pyspark.sql import DataFrame, Row
import pyspark.sql.functions as f

Expand Down Expand Up @@ -428,23 +429,7 @@ def test_recencyweightedbasketweeks_luke_and_leia(
"""
fg = PurchasingFeatureGenerator(
as_at=as_at_timestamp.date(),
feature_periods=[
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 13, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 0, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 1, 1),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 2, 2),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 3, 3),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 4, 4),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 5, 5),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 6, 6),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 7, 7),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 8, 8),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 9, 9),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 10, 10),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 11, 11),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 12, 12),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 13, 13),
],
feature_periods=feature_periods_for_perweek_features(),
)
df = luke_and_leia_purchases.groupBy().agg(*fg.features)
df2 = df.select(
Expand Down Expand Up @@ -493,25 +478,33 @@ def test_recencyweightedbasketweeks_luke_and_leia(
)


def feature_periods_for_perweek_features() -> List[FeaturePeriod]:
return [
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 13, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 0, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 1, 1),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 2, 2),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 3, 3),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 4, 4),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 5, 5),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 6, 6),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 7, 7),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 8, 8),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 9, 9),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 10, 10),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 11, 11),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 12, 12),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 13, 13),
]


def test_averagebasketsperweek_luke_and_leia(
as_at_timestamp: datetime, luke_and_leia_purchases: DataFrame
):
"""Test AverageBasketsPerWeek"""
fg = PurchasingFeatureGenerator(
as_at=as_at_timestamp.date(),
feature_periods=[
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 9, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 0, 0),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 1, 1),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 2, 2),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 3, 3),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 4, 4),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 5, 5),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 6, 6),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 7, 7),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 8, 8),
FeaturePeriod(PeriodUnitOfMeasure.WEEK, 9, 9),
],
feature_periods=feature_periods_for_perweek_features(),
)
df = luke_and_leia_purchases.groupBy().agg(*fg.features)
df2 = df.select(
Expand All @@ -525,12 +518,16 @@ def test_averagebasketsperweek_luke_and_leia(
"BasketCount_7w7",
"BasketCount_8w8",
"BasketCount_9w9",
"BasketCount_10w10",
"BasketCount_11w11",
"BasketCount_12w12",
"BasketCount_13w13",
)
df_first = df.first()
df2_first = df2.first()
assert df_first is not None
assert df2_first is not None
n = 10
n = 14
total_baskets = sum(df2_first[f"BasketCount_{i}w{i}"] for i in range(n))
average_baskets_per_week = total_baskets / n
assert average_baskets_per_week == df_first["AverageBasketsPerWeek_9w0"]
assert average_baskets_per_week == df_first["AverageBasketsPerWeek_13w0"]

0 comments on commit c155abf

Please sign in to comment.