Skip to content

Commit

Permalink
Merge pull request #345 from jaamarks/issue324-revisit-exp-rep-retain…
Browse files Browse the repository at this point in the history
…ment

bugfix: cleaner data handling and improved logic for `sample_qc_table.py` (issue #324)
  • Loading branch information
jaamarks authored Oct 9, 2024
2 parents 0df8110 + df985b5 commit c618175
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 63 deletions.
87 changes: 40 additions & 47 deletions src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _read_contam(file_name: Optional[Path], Sample_IDs: pd.Index) -> pd.DataFram
return pd.DataFrame(
index=Sample_IDs,
columns=["Contamination_Rate", "is_contaminated"],
).astype({"Contamination_Rate": pd.NA, "is_contaminated": pd.NA})
).astype({"Contamination_Rate": "float", "is_contaminated": "boolean"})

return (
agg_contamination.read(file_name)
Expand Down Expand Up @@ -464,12 +464,12 @@ def add_qc_columns(
remove_rep_discordant: bool,
) -> pd.DataFrame:
add_call_rate_flags(sample_qc)
_add_identifiler(sample_qc)
_add_analytic_exclusion(
sample_qc,
remove_contam,
remove_rep_discordant,
)
_add_identifiler(sample_qc)
_add_subject_representative(sample_qc)
_add_subject_dropped_from_study(sample_qc)

Expand Down Expand Up @@ -516,64 +516,53 @@ def reason_string(row: pd.Series) -> str:
def _retain_valid_discordant_replicates(
sample_qc: pd.DataFrame,
) -> pd.DataFrame:
"""Check and update the status of a pair of samples labeled as
discordant expected replicates.
This function verifies if the provided sample pair is labeled as
discordant expected replicates.
If they are, it checks for contamination or low call rate flags on
each sample.
If one of the samples is found to be contaminated or has a low call
rate, the function retains the non-contaminated and non-low-call-rate
sample, updating its status to remove the expected replicate label.
"""Updates the status of discordant expected replicates.
Given a pair of discordant expected replicates, it checks for contamination
or low call rate flags on each sample. If one sample is found to be
contaminated or has a low call rate, the function updates the
"is_discordant_replicate" status of the non-contaminated and
non-low-call-rate sample to "False". This way, it can be retained for
subject-level analysis.
"""

# Assuming sample_qc is your DataFrame
if "Sample_ID" in sample_qc.columns:
sample_qc = sample_qc.set_index("Sample_ID")

# Iterate through each sample in the DataFrame
for index, row in sample_qc.iterrows():
# Check if the sample is a discordant expected replicate
if row["is_discordant_replicate"]:
# Get the list of other samples it is discordant with
discordant_samples = row["replicate_ids"].split("|")

# Initialize flags for contamination and call rate issues
is_current_sample_low_call_rate = row["is_cr2_filtered"]
is_current_sample_low_call_rate = row["is_call_rate_filtered"]
is_current_sample_contaminated = (
row["is_contaminated"] if not pd.isna(row["is_contaminated"]) else False
# Treat pd.NA as not contaminated
row["is_contaminated"]
if not pd.isna(row["is_contaminated"])
else False
)

# Initialize a flag to track if all discordant samples have issues
# flag to track if all discordant samples have issues
all_other_samples_issue = True
discordant_samples = row["replicate_ids"].split("|")

# Check each discordant sample
for sample_id in discordant_samples:
if sample_id == row.name:
continue
if sample_id == row["Sample_ID"]:
continue # only look at other samples
else:
# Get the row for the discordant sample
discordant_row = sample_qc.loc[sample_id]
if not discordant_row.empty:
contaminated = (
discordant_row["is_contaminated"]
if not pd.isna(discordant_row["is_contaminated"])
else False
)
low_call_rate = discordant_row["is_cr2_filtered"]

# Check if the discordant sample is contaminated or has a low call rate
if not contaminated and not low_call_rate:
all_other_samples_issue = False
break

# If the current sample is not contaminated or low call rate
discordant_row = sample_qc[sample_qc["Sample_ID"] == sample_id].iloc[0]
low_call_rate = discordant_row["is_call_rate_filtered"]
contaminated = (
discordant_row["is_contaminated"]
if not pd.isna(discordant_row["is_contaminated"])
else False
)

# Check if the discordant sample is contaminated and/or has a low call rate
if not contaminated and not low_call_rate:
all_other_samples_issue = False
break

# Retain the current sample if not contaminated nor low call rate...
if not is_current_sample_contaminated and not is_current_sample_low_call_rate:
# If all other samples have issues, update the current sample's status
# and the other samples to have issues
if all_other_samples_issue:
sample_qc.at[index, "is_discordant_replicate"] = False

return sample_qc


Expand All @@ -596,22 +585,26 @@ def _add_analytic_exclusion(
"is_cr2_filtered": "Call Rate 2 Filtered",
}

_retain_valid_discordant_replicates(sample_qc)
sample_qc = _retain_valid_discordant_replicates(sample_qc)

if remove_contam:
exclusion_criteria["is_contaminated"] = "Contamination"

if remove_rep_discordant:
exclusion_criteria["is_discordant_replicate"] = "Replicate Discordance"

# adding this new column which is a boolean. Checks for any T in a series (e.g., {F, F, F, F, T}.any())
sample_qc["analytic_exclusion"] = sample_qc.reindex(exclusion_criteria.keys(), axis=1).any(
axis=1
)

# looking across the colums and adding up the Trues
sample_qc["num_analytic_exclusion"] = (
sample_qc.reindex(exclusion_criteria.keys(), axis=1).sum(axis=1).astype(int)
)
sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria)

# get the names of the columns that are True and return a "|" delimited string of dict values whose keys were true.
sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria)
return sample_qc


Expand Down
131 changes: 115 additions & 16 deletions tests/workflow/scripts/test_sample_qc_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,95 @@ def fake_sample_qc() -> pd.DataFrame:
"Call_Rate_2",
"is_cr1_filtered",
"is_cr2_filtered",
"is_call_rate_filtered",
"is_contaminated",
"is_discordant_replicate",
]

data = [
("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False),
("SP00002", "SB00002", "", False, False, 0.82, False, True, False, False),
("SP00003", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, True, True),
("SP00004", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, False, True),
("SP00005", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True),
("SP00006", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True),
("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False),
("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False),
("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False),
("SP00010", "SB00008", "SP00010|SP00011", False, False, 0.99, False, False, False, False),
("SP00011", "SB00008", "SP00010|SP00011", False, False, 0.94, False, False, False, False),
("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False, False),
("SP00002", "SB00002", "", False, False, 0.82, False, True, True, False, False),
(
"SP00003",
"SB00003",
"SP00003|SP00004",
False,
False,
0.99,
False,
False,
False,
True,
True,
),
(
"SP00004",
"SB00003",
"SP00003|SP00004",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
(
"SP00005",
"SB00004",
"SP00005|SP00006",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
(
"SP00006",
"SB00004",
"SP00005|SP00006",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False, False),
("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False, False),
("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False, False),
(
"SP00010",
"SB00008",
"SP00010|SP00011",
False,
False,
0.99,
False,
False,
False,
False,
False,
),
(
"SP00011",
"SB00008",
"SP00010|SP00011",
False,
False,
0.94,
False,
False,
False,
False,
False,
),
(
"SP00012",
"SB00009",
Expand All @@ -281,6 +354,7 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00013",
Expand All @@ -293,6 +367,7 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00014",
Expand All @@ -305,24 +380,48 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00015",
"SB00010",
"SP00015|SP00016",
False,
False,
0.99,
False,
True,
True,
False,
True,
),
(
"SP00016",
"SB00010",
"SP00015|SP00016",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
("SP00015", "SB00010", "SP00015|SP00016", False, False, 0.99, False, True, False, True),
("SP00016", "SB00010", "SP00015|SP00016", False, False, 0.99, False, False, False, True),
]
return pd.DataFrame(data, columns=columns).set_index("Sample_ID")
return pd.DataFrame(data, columns=columns)


@pytest.mark.parametrize(
"contam,rep_discordant,num_removed",
[(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)], # call rate filtered
[(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)],
)
def test_add_analytic_exclusion(fake_sample_qc, contam, rep_discordant, num_removed):
pd.set_option("display.max_columns", None)
sample_qc_table._add_analytic_exclusion(fake_sample_qc, contam, rep_discordant)
assert num_removed == fake_sample_qc.analytic_exclusion.sum()


# change these since I updated fake_sample_qc
@pytest.mark.parametrize(
"contam,rep_discordant,num_subjects",
[(False, False, 9), (True, False, 9), (False, True, 8), (True, True, 8)],
Expand Down

0 comments on commit c618175

Please sign in to comment.