Skip to content

Commit

Permalink
move_to_x: Fix name of non-implemented argument "copy" to "copy_x", i…
Browse files Browse the repository at this point in the history
…mplement & test (#832)

* fix name of & implement missing arg in move_to_x

* fix argument description
  • Loading branch information
eroell authored Dec 2, 2024
1 parent 03cd180 commit ee84d9e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
9 changes: 6 additions & 3 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData:
return adata


def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:
def move_to_x(adata: AnnData, to_x: list[str] | str, copy_x: bool = False) -> AnnData:
"""Move features from obs to X inplace.
Args:
adata: The AnnData object
to_x: The columns to move to X
copy: Whether to return a copy or not
copy_x: The values are copied to X (and therefore kept in obs) instead of moved completely
Returns:
A new AnnData object with moved columns from obs to X. This should not be used for datetime columns currently.
Expand Down Expand Up @@ -292,7 +292,10 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:

if cols_not_in_x:
new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1)
new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]]
if copy_x:
new_adata.obs = adata.obs
else:
new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]]

# AnnData's concat discards var if they don't match in their keys, so we need to create a new var
created_var = pd.DataFrame(index=cols_not_in_x)
Expand Down
7 changes: 7 additions & 0 deletions tests/anndata/test_anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def test_move_to_x(adata_move_obs_mix):
)


def test_move_to_x_copy_x(adata_move_obs_mix):
move_to_obs(adata_move_obs_mix, ["name"], copy_obs=False)
obs_df = adata_move_obs_mix.obs.copy()
new_adata = move_to_x(adata_move_obs_mix, ["name"], copy_x=True)
assert_frame_equal(new_adata.obs, obs_df)


def test_move_to_x_invalid_column_names(adata_move_obs_mix):
move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True)
move_to_obs(adata_move_obs_mix, ["clinic_id"], copy_obs=False)
Expand Down

0 comments on commit ee84d9e

Please sign in to comment.