Skip to content

Commit

Permalink
Merge pull request #176 from praekeltfoundation/randomisation-stratif…
Browse files Browse the repository at this point in the history
…ication

Initial randomisation app commit
  • Loading branch information
erikh360 authored Apr 22, 2024
2 parents c073f6b + dee1d2a commit 7c39594
Show file tree
Hide file tree
Showing 15 changed files with 403 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"rp_transferto",
"rp_recruit",
"rp_interceptors",
"randomisation",
]

MIDDLEWARE = [
Expand Down
1 change: 1 addition & 0 deletions config/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
path("recruit/", include("rp_recruit.urls"), name="rp_recruit"),
path("interceptor/", include("rp_interceptors.urls")),
path("dtone/", include("rp_dtone.urls")),
path("randomisation/", include("randomisation.urls")),
]
Empty file added randomisation/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions randomisation/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from django.contrib import admin

from randomisation.models import Arm, Strata, StrataOption, Strategy


class ArmInline(admin.TabularInline):
model = Arm


class StrataOptionInline(admin.TabularInline):
model = StrataOption


@admin.register(Strategy)
class StrategyAdmin(admin.ModelAdmin):
list_display = ("name",)

inlines = [ArmInline]


@admin.register(Strata)
class StrataAdmin(admin.ModelAdmin):
list_display = ("__str__",)

inlines = [StrataOptionInline]
6 changes: 6 additions & 0 deletions randomisation/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class RandomisationConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "randomisation"
124 changes: 124 additions & 0 deletions randomisation/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Generated by Django 4.2.11 on 2024-04-17 09:03

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

initial = True

dependencies = []

operations = [
migrations.CreateModel(
name="Strata",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=200)),
],
),
migrations.CreateModel(
name="Strategy",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=200)),
(
"stratas",
models.ManyToManyField(
related_name="stategy_stratas", to="randomisation.strata"
),
),
],
options={
"verbose_name_plural": "Strategies",
},
),
migrations.CreateModel(
name="StrataOption",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("description", models.CharField(max_length=200)),
(
"strata",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="options",
to="randomisation.strata",
),
),
],
),
migrations.CreateModel(
name="StrataMatrix",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("strata_data", models.JSONField()),
("next_index", models.IntegerField(default=0)),
("arm_order", models.CharField(max_length=255)),
(
"strategy",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="matrix_records",
to="randomisation.strategy",
),
),
],
),
migrations.CreateModel(
name="Arm",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=200)),
(
"strategy",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="arms",
to="randomisation.strategy",
),
),
],
),
]
Empty file.
55 changes: 55 additions & 0 deletions randomisation/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from django.db import models
from django.db.models import JSONField
from django.utils.text import slugify


class Strata(models.Model):
name = models.CharField(max_length=200, null=False, blank=False)

@property
def slug(self):
return slugify(self.name)

def __str__(self):
options = [option.description for option in self.options.all()]
return f"{self.name} - [{', '.join(options)}]"


class Strategy(models.Model):
class Meta:
verbose_name_plural = "Strategies"

name = models.CharField(max_length=200, null=False, blank=False)
stratas = models.ManyToManyField(Strata, related_name="stategy_stratas")


class Arm(models.Model):
strategy = models.ForeignKey(
Strategy,
related_name="arms",
null=False,
on_delete=models.CASCADE,
)
name = models.CharField(max_length=200, null=False, blank=False)


class StrataOption(models.Model):
strata = models.ForeignKey(
Strata,
related_name="options",
null=False,
on_delete=models.CASCADE,
)
description = models.CharField(max_length=200, null=False, blank=False)


class StrataMatrix(models.Model):
strategy = models.ForeignKey(
Strategy,
related_name="matrix_records",
null=False,
on_delete=models.CASCADE,
)
strata_data = JSONField()
next_index = models.IntegerField(default=0)
arm_order = models.CharField(max_length=255, null=False, blank=False)
Empty file added randomisation/tests/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions randomisation/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import random
from collections import defaultdict

from django.test import TestCase

from randomisation.utils import (
get_random_stratification_arm,
validate_stratification_data,
)

from .utils import create_test_strategy


# TODO: add docstrings to tests
class TestValidateStratificationData(TestCase):
def setUp(self):
self.strategy = create_test_strategy()

def test_stratification_validation_valid_data(self):
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "WC"}
)
self.assertIsNone(error)

def test_stratification_validation_missing_key(self):
error = validate_stratification_data(self.strategy, {"age-group": "18-29"})
self.assertEqual(error, "'province' is a required property")

def test_stratification_validation_extra_key(self):
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "WC", "extra": "key"}
)

self.assertEqual(
error, "Additional properties are not allowed ('extra' was unexpected)"
)

def test_stratification_validation_invalid_option(self):
error = validate_stratification_data(
self.strategy, {"age-group": "18-29", "province": "FS"}
)
self.assertEqual(error, "'FS' is not one of ['WC', 'GT']")


class TestGetRandomStratification(TestCase):

# TODO: add more tests for randomisation

def test_stratification_balancing(self):
strategy = create_test_strategy()

totals = defaultdict(int)
stratas = defaultdict(lambda: defaultdict(int))
for i in range(100):
random_age = random.choice(["18-29", "29-39"])
random_province = random.choice(["WC", "GT"])

data = {"age-group": random_age, "province": random_province}

random_arm = get_random_stratification_arm(strategy, data)
stratas[f"{random_age}_{random_province}"][random_arm] += 1
totals[random_arm] += 1

def check_arms_balanced(arms, diff, description):
values = [value for value in arms.values()]
msg = f"Arms not balanced: {description} - {values}"
assert max(values) - diff < values[0] < min(values) + diff, msg

check_arms_balanced(totals, 3, "Totals")

for key, arms in stratas.items():
check_arms_balanced(arms, 3, key)
27 changes: 27 additions & 0 deletions randomisation/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from randomisation.models import Arm, Strata, StrataOption, Strategy

DEFAULT_STRATEGY = {
"name": "Test Strategy",
"arms": ["Arm 1", "Arm 2", "Arm 3"],
"stratas": [
{"name": "Age Group", "options": ["18-29", "29-39"]},
{"name": "Province", "options": ["WC", "GT"]},
],
}


def create_test_strategy(data=DEFAULT_STRATEGY):
strategy = Strategy.objects.create(name=data["name"])

for arm in data["arms"]:
Arm.objects.create(strategy=strategy, name=arm)

for strata_data in data["stratas"]:
strata = Strata.objects.create(name=strata_data["name"])

for option in strata_data["options"]:
StrataOption.objects.create(strata=strata, description=option)

strategy.stratas.add(strata)

return strategy
11 changes: 11 additions & 0 deletions randomisation/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from django.urls import path

from . import views

urlpatterns = [
path(
"<int:strategy_id>/get_random_arm/",
views.GetRandomArmView.as_view(),
name="get_random_arm",
),
]
48 changes: 48 additions & 0 deletions randomisation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import random

from jsonschema import validate
from jsonschema.exceptions import ValidationError

from randomisation.models import StrataMatrix


def validate_stratification_data(strategy, data):
try:
schema = {
"type": "object",
"properties": {},
"required": [strata.slug for strata in strategy.stratas.all()],
"additionalProperties": False,
}

for strata in strategy.stratas.all():
options = [option.description for option in strata.options.all()]
schema["properties"][strata.slug] = {"type": "string", "enum": options}

validate(instance=data, schema=schema)
except ValidationError as e:
return e.message


def get_random_stratification_arm(strategy, data):
matrix, created = StrataMatrix.objects.get_or_create(
strategy=strategy, strata_data=data
)

if created:
study_arms = [arm.name for arm in strategy.arms.all()]
random.shuffle(study_arms)
random_arms = study_arms
matrix.arm_order = ",".join(study_arms)
else:
random_arms = matrix.arm_order.split(",")

arm = random_arms[matrix.next_index]

if matrix.next_index + 1 == len(random_arms):
matrix.delete()
else:
matrix.next_index += 1
matrix.save()

return arm
Loading

0 comments on commit 7c39594

Please sign in to comment.