-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #176 from praekeltfoundation/randomisation-stratif…
…ication Initial randomisation app commit
- Loading branch information
Showing
15 changed files
with
403 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
"rp_transferto", | ||
"rp_recruit", | ||
"rp_interceptors", | ||
"randomisation", | ||
] | ||
|
||
MIDDLEWARE = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from django.contrib import admin | ||
|
||
from randomisation.models import Arm, Strata, StrataOption, Strategy | ||
|
||
|
||
class ArmInline(admin.TabularInline): | ||
model = Arm | ||
|
||
|
||
class StrataOptionInline(admin.TabularInline): | ||
model = StrataOption | ||
|
||
|
||
@admin.register(Strategy) | ||
class StrategyAdmin(admin.ModelAdmin): | ||
list_display = ("name",) | ||
|
||
inlines = [ArmInline] | ||
|
||
|
||
@admin.register(Strata) | ||
class StrataAdmin(admin.ModelAdmin): | ||
list_display = ("__str__",) | ||
|
||
inlines = [StrataOptionInline] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class RandomisationConfig(AppConfig): | ||
default_auto_field = "django.db.models.BigAutoField" | ||
name = "randomisation" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Generated by Django 4.2.11 on 2024-04-17 09:03 | ||
|
||
import django.db.models.deletion | ||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
initial = True | ||
|
||
dependencies = [] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name="Strata", | ||
fields=[ | ||
( | ||
"id", | ||
models.BigAutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
("name", models.CharField(max_length=200)), | ||
], | ||
), | ||
migrations.CreateModel( | ||
name="Strategy", | ||
fields=[ | ||
( | ||
"id", | ||
models.BigAutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
("name", models.CharField(max_length=200)), | ||
( | ||
"stratas", | ||
models.ManyToManyField( | ||
related_name="stategy_stratas", to="randomisation.strata" | ||
), | ||
), | ||
], | ||
options={ | ||
"verbose_name_plural": "Strategies", | ||
}, | ||
), | ||
migrations.CreateModel( | ||
name="StrataOption", | ||
fields=[ | ||
( | ||
"id", | ||
models.BigAutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
("description", models.CharField(max_length=200)), | ||
( | ||
"strata", | ||
models.ForeignKey( | ||
on_delete=django.db.models.deletion.CASCADE, | ||
related_name="options", | ||
to="randomisation.strata", | ||
), | ||
), | ||
], | ||
), | ||
migrations.CreateModel( | ||
name="StrataMatrix", | ||
fields=[ | ||
( | ||
"id", | ||
models.BigAutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
("strata_data", models.JSONField()), | ||
("next_index", models.IntegerField(default=0)), | ||
("arm_order", models.CharField(max_length=255)), | ||
( | ||
"strategy", | ||
models.ForeignKey( | ||
on_delete=django.db.models.deletion.CASCADE, | ||
related_name="matrix_records", | ||
to="randomisation.strategy", | ||
), | ||
), | ||
], | ||
), | ||
migrations.CreateModel( | ||
name="Arm", | ||
fields=[ | ||
( | ||
"id", | ||
models.BigAutoField( | ||
auto_created=True, | ||
primary_key=True, | ||
serialize=False, | ||
verbose_name="ID", | ||
), | ||
), | ||
("name", models.CharField(max_length=200)), | ||
( | ||
"strategy", | ||
models.ForeignKey( | ||
on_delete=django.db.models.deletion.CASCADE, | ||
related_name="arms", | ||
to="randomisation.strategy", | ||
), | ||
), | ||
], | ||
), | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from django.db import models | ||
from django.db.models import JSONField | ||
from django.utils.text import slugify | ||
|
||
|
||
class Strata(models.Model): | ||
name = models.CharField(max_length=200, null=False, blank=False) | ||
|
||
@property | ||
def slug(self): | ||
return slugify(self.name) | ||
|
||
def __str__(self): | ||
options = [option.description for option in self.options.all()] | ||
return f"{self.name} - [{', '.join(options)}]" | ||
|
||
|
||
class Strategy(models.Model): | ||
class Meta: | ||
verbose_name_plural = "Strategies" | ||
|
||
name = models.CharField(max_length=200, null=False, blank=False) | ||
stratas = models.ManyToManyField(Strata, related_name="stategy_stratas") | ||
|
||
|
||
class Arm(models.Model): | ||
strategy = models.ForeignKey( | ||
Strategy, | ||
related_name="arms", | ||
null=False, | ||
on_delete=models.CASCADE, | ||
) | ||
name = models.CharField(max_length=200, null=False, blank=False) | ||
|
||
|
||
class StrataOption(models.Model): | ||
strata = models.ForeignKey( | ||
Strata, | ||
related_name="options", | ||
null=False, | ||
on_delete=models.CASCADE, | ||
) | ||
description = models.CharField(max_length=200, null=False, blank=False) | ||
|
||
|
||
class StrataMatrix(models.Model): | ||
strategy = models.ForeignKey( | ||
Strategy, | ||
related_name="matrix_records", | ||
null=False, | ||
on_delete=models.CASCADE, | ||
) | ||
strata_data = JSONField() | ||
next_index = models.IntegerField(default=0) | ||
arm_order = models.CharField(max_length=255, null=False, blank=False) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import random | ||
from collections import defaultdict | ||
|
||
from django.test import TestCase | ||
|
||
from randomisation.utils import ( | ||
get_random_stratification_arm, | ||
validate_stratification_data, | ||
) | ||
|
||
from .utils import create_test_strategy | ||
|
||
|
||
# TODO: add docstrings to tests | ||
class TestValidateStratificationData(TestCase): | ||
def setUp(self): | ||
self.strategy = create_test_strategy() | ||
|
||
def test_stratification_validation_valid_data(self): | ||
error = validate_stratification_data( | ||
self.strategy, {"age-group": "18-29", "province": "WC"} | ||
) | ||
self.assertIsNone(error) | ||
|
||
def test_stratification_validation_missing_key(self): | ||
error = validate_stratification_data(self.strategy, {"age-group": "18-29"}) | ||
self.assertEqual(error, "'province' is a required property") | ||
|
||
def test_stratification_validation_extra_key(self): | ||
error = validate_stratification_data( | ||
self.strategy, {"age-group": "18-29", "province": "WC", "extra": "key"} | ||
) | ||
|
||
self.assertEqual( | ||
error, "Additional properties are not allowed ('extra' was unexpected)" | ||
) | ||
|
||
def test_stratification_validation_invalid_option(self): | ||
error = validate_stratification_data( | ||
self.strategy, {"age-group": "18-29", "province": "FS"} | ||
) | ||
self.assertEqual(error, "'FS' is not one of ['WC', 'GT']") | ||
|
||
|
||
class TestGetRandomStratification(TestCase): | ||
|
||
# TODO: add more tests for randomisation | ||
|
||
def test_stratification_balancing(self): | ||
strategy = create_test_strategy() | ||
|
||
totals = defaultdict(int) | ||
stratas = defaultdict(lambda: defaultdict(int)) | ||
for i in range(100): | ||
random_age = random.choice(["18-29", "29-39"]) | ||
random_province = random.choice(["WC", "GT"]) | ||
|
||
data = {"age-group": random_age, "province": random_province} | ||
|
||
random_arm = get_random_stratification_arm(strategy, data) | ||
stratas[f"{random_age}_{random_province}"][random_arm] += 1 | ||
totals[random_arm] += 1 | ||
|
||
def check_arms_balanced(arms, diff, description): | ||
values = [value for value in arms.values()] | ||
msg = f"Arms not balanced: {description} - {values}" | ||
assert max(values) - diff < values[0] < min(values) + diff, msg | ||
|
||
check_arms_balanced(totals, 3, "Totals") | ||
|
||
for key, arms in stratas.items(): | ||
check_arms_balanced(arms, 3, key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from randomisation.models import Arm, Strata, StrataOption, Strategy | ||
|
||
DEFAULT_STRATEGY = { | ||
"name": "Test Strategy", | ||
"arms": ["Arm 1", "Arm 2", "Arm 3"], | ||
"stratas": [ | ||
{"name": "Age Group", "options": ["18-29", "29-39"]}, | ||
{"name": "Province", "options": ["WC", "GT"]}, | ||
], | ||
} | ||
|
||
|
||
def create_test_strategy(data=DEFAULT_STRATEGY): | ||
strategy = Strategy.objects.create(name=data["name"]) | ||
|
||
for arm in data["arms"]: | ||
Arm.objects.create(strategy=strategy, name=arm) | ||
|
||
for strata_data in data["stratas"]: | ||
strata = Strata.objects.create(name=strata_data["name"]) | ||
|
||
for option in strata_data["options"]: | ||
StrataOption.objects.create(strata=strata, description=option) | ||
|
||
strategy.stratas.add(strata) | ||
|
||
return strategy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from django.urls import path | ||
|
||
from . import views | ||
|
||
urlpatterns = [ | ||
path( | ||
"<int:strategy_id>/get_random_arm/", | ||
views.GetRandomArmView.as_view(), | ||
name="get_random_arm", | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import random | ||
|
||
from jsonschema import validate | ||
from jsonschema.exceptions import ValidationError | ||
|
||
from randomisation.models import StrataMatrix | ||
|
||
|
||
def validate_stratification_data(strategy, data): | ||
try: | ||
schema = { | ||
"type": "object", | ||
"properties": {}, | ||
"required": [strata.slug for strata in strategy.stratas.all()], | ||
"additionalProperties": False, | ||
} | ||
|
||
for strata in strategy.stratas.all(): | ||
options = [option.description for option in strata.options.all()] | ||
schema["properties"][strata.slug] = {"type": "string", "enum": options} | ||
|
||
validate(instance=data, schema=schema) | ||
except ValidationError as e: | ||
return e.message | ||
|
||
|
||
def get_random_stratification_arm(strategy, data): | ||
matrix, created = StrataMatrix.objects.get_or_create( | ||
strategy=strategy, strata_data=data | ||
) | ||
|
||
if created: | ||
study_arms = [arm.name for arm in strategy.arms.all()] | ||
random.shuffle(study_arms) | ||
random_arms = study_arms | ||
matrix.arm_order = ",".join(study_arms) | ||
else: | ||
random_arms = matrix.arm_order.split(",") | ||
|
||
arm = random_arms[matrix.next_index] | ||
|
||
if matrix.next_index + 1 == len(random_arms): | ||
matrix.delete() | ||
else: | ||
matrix.next_index += 1 | ||
matrix.save() | ||
|
||
return arm |
Oops, something went wrong.