diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index fb420202..38d60224 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -252,13 +252,13 @@ def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData: return adata -def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: +def move_to_x(adata: AnnData, to_x: list[str] | str, copy_x: bool = False) -> AnnData: """Move features from obs to X inplace. Args: adata: The AnnData object to_x: The columns to move to X - copy: Whether to return a copy or not + copy_x: The values are copied to X (and therefore kept in obs) instead of moved completely Returns: A new AnnData object with moved columns from obs to X. This should not be used for datetime columns currently. @@ -292,7 +292,10 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: if cols_not_in_x: new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1) - new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] + if copy_x: + new_adata.obs = adata.obs + else: + new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] # AnnData's concat discards var if they don't match in their keys, so we need to create a new var created_var = pd.DataFrame(index=cols_not_in_x) diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index aca4ca77..6e5bbf83 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -164,6 +164,13 @@ def test_move_to_x(adata_move_obs_mix): ) +def test_move_to_x_copy_x(adata_move_obs_mix): + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=False) + obs_df = adata_move_obs_mix.obs.copy() + new_adata = move_to_x(adata_move_obs_mix, ["name"], copy_x=True) + assert_frame_equal(new_adata.obs, obs_df) + + def test_move_to_x_invalid_column_names(adata_move_obs_mix): move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True) move_to_obs(adata_move_obs_mix, ["clinic_id"], copy_obs=False)