Skip to content

Commit

Permalink
Use sgkit.distarray for count_variant_alleles and variant_stats
Browse files Browse the repository at this point in the history
Get count_hom working with explicit sum reduction over samples
  • Loading branch information
tomwhite committed Sep 11, 2024
1 parent 1591cdf commit ba39078
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cubed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
- name: Test with pytest
run: |
pytest -v sgkit/tests/test_aggregation.py -k "test_count_call_alleles" --use-cubed
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
27 changes: 16 additions & 11 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,14 @@ def count_variant_alleles(
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
n_alleles = ds.sizes["alleles"]
n_variant = ds.sizes["variants"]
G = da.asarray(ds[call_genotype]).reshape((n_variant, -1))
G = da.asarray(ds[call_genotype])
G = da.reshape(G, (n_variant, -1))
shape = (G.chunks[0], n_alleles)
# use uint64 dummy array to return uin64 counts array
N = np.empty(n_alleles, dtype=np.uint64)
AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1)
AC = da.map_blocks(
count_alleles, G, N, chunks=shape, dtype=np.uint64, drop_axis=1, new_axis=1
)
AC = xr.DataArray(AC, dims=["variants", "alleles"])
else:
options = {variables.call_genotype, variables.call_allele_count}
Expand Down Expand Up @@ -692,22 +695,23 @@ def variant_stats(
using=variables.call_genotype, # improved performance
merge=False,
)[variant_allele_count]
G = da.array(ds[call_genotype].data)
G = da.asarray(ds[call_genotype].data)
H = xr.DataArray(
da.map_blocks(
count_hom,
lambda *args: count_hom(*args)[:, np.newaxis, :],
G,
np.zeros(3, np.uint64),
drop_axis=(1, 2),
new_axis=1,
drop_axis=2,
new_axis=2,
dtype=np.int64,
chunks=(G.chunks[0], 3),
chunks=(G.chunks[0], 1, 3),
),
dims=["variants", "categories"],
dims=["variants", "samples", "categories"],
)
H = H.sum(axis=1)
_, n_sample, _ = G.shape
n_called = H.sum(axis=-1)
call_rate = n_called / n_sample
call_rate = n_called.astype(float) / float(n_sample)
n_hom_ref = H[:, 0]
n_hom_alt = H[:, 1]
n_het = H[:, 2]
Expand All @@ -723,7 +727,8 @@ def variant_stats(
variables.variant_n_non_ref: n_non_ref,
variables.variant_allele_count: AC,
variables.variant_allele_total: allele_total,
variables.variant_allele_frequency: AC / allele_total,
variables.variant_allele_frequency: AC.astype(float)
/ allele_total.astype(float),
}
)
# for backwards compatible behavior
Expand Down Expand Up @@ -798,7 +803,7 @@ def sample_stats(
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
if mixed_ploidy:
raise ValueError("Mixed-ploidy dataset")
G = da.array(ds[call_genotype].data)
G = da.asarray(ds[call_genotype].data)
H = xr.DataArray(
da.map_blocks(
count_hom,
Expand Down
11 changes: 6 additions & 5 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_count_variant_alleles__chunked(using):
chunks={"variants": 5, "samples": 5}
)
ac2 = count_variant_alleles(ds, using=using)
assert isinstance(ac2["variant_allele_count"].data, da.Array)
assert hasattr(ac2["variant_allele_count"].data, "chunks")
xr.testing.assert_equal(ac1, ac2)


Expand Down Expand Up @@ -786,13 +786,14 @@ def test_variant_stats__tetraploid():
)


@pytest.mark.parametrize(
"chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1), (100, 10, 1)]
)
def test_variant_stats__chunks(chunks):
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)])
def test_variant_stats__chunks(precompute_variant_allele_count, chunks):
ds = simulate_genotype_call_dataset(
n_variant=1000, n_sample=30, missing_pct=0.01, seed=0
)
if precompute_variant_allele_count:
ds = count_variant_alleles(ds)
expect = variant_stats(ds, merge=False).compute()
ds["call_genotype"] = ds["call_genotype"].chunk(chunks)
actual = variant_stats(ds, merge=False).compute()
Expand Down

0 comments on commit ba39078

Please sign in to comment.