Skip to content

Commit

Permalink
Fix int csr_matrix (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Aug 29, 2024
1 parent d236ed4 commit ec7de24
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
Expand Down
11 changes: 3 additions & 8 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ name: Unit Tests

on:
push:
branches:
- main
branches: [main]
pull_request:
branches:
- "*"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down Expand Up @@ -47,14 +44,12 @@ jobs:
packages: pandoc gfortran libblas-dev liblapack-dev libedit-dev llvm-dev libcurl4-openssl-dev ffmpeg libhdf5-dev
version: 1.0

- name: Set up Python
uses: actions/setup-python@v5
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.config.python }}
cache: pip

- name: Set up R
id: setup-r
- id: setup-r
uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
Expand Down
5 changes: 1 addition & 4 deletions src/anndata2ri/_r2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import _conv_name
from ._conv import converter, mat_rpy2py
from ._rpy2_ext import importr
from ._rpy2_ext import R_INT_BYTES, importr
from .scipy2ri import supported_r_matrix_classes
from .scipy2ri._r2py import rmat_to_spmat

Expand All @@ -23,9 +23,6 @@
from scipy.sparse import spmatrix


R_INT_BYTES = 4


@converter.rpy2py.register(SexpS4)
def rpy2py_s4(obj: SexpS4) -> pd.DataFrame | AnnData | None:
"""Convert known S4 class instance to Python object.
Expand Down
9 changes: 6 additions & 3 deletions src/anndata2ri/_rpy2_ext.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from __future__ import annotations

from functools import lru_cache
from functools import cache

from rpy2.robjects import Environment, packages


@lru_cache
R_INT_BYTES = 4


@cache
def importr(name: str) -> packages.Package:
return packages.importr(name)


@lru_cache
@cache
def data(package: str, name: str | None = None) -> packages.PackageData | Environment:
if name is None:
return packages.data(importr(package))
Expand Down
53 changes: 22 additions & 31 deletions src/anndata2ri/scipy2ri/_py2r.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from functools import lru_cache, wraps
from functools import cache
from importlib.resources import files
from typing import TYPE_CHECKING

import numpy as np
from rpy2.robjects import default_converter, numpy2ri
from rpy2.robjects.conversion import localconverter
from rpy2.robjects.packages import Package, SignatureTranslatedAnonymousPackage
from rpy2.robjects.packages import InstalledSTPackage, SignatureTranslatedAnonymousPackage
from scipy import sparse

from anndata2ri._rpy2_ext import importr
from anndata2ri._rpy2_ext import R_INT_BYTES, importr

from ._conv import converter

Expand All @@ -21,55 +21,46 @@
from rpy2.rinterface import Sexp


matrix: SignatureTranslatedAnonymousPackage | None = None
base: Package | None = None
@cache
def baseenv() -> InstalledSTPackage:
with localconverter(default_converter):
return importr('base')


@lru_cache
def get_r_code() -> str:
return files('anndata2ri').joinpath('scipy2ri', '_py2r_helpers.r').read_text()
@cache
def matrixenv() -> SignatureTranslatedAnonymousPackage:
with localconverter(default_converter):
importr('Matrix') # make class available
r_code = files('anndata2ri').joinpath('scipy2ri', '_py2r_helpers.r').read_text()
return SignatureTranslatedAnonymousPackage(r_code, 'matrix')


def get_type_conv(dtype: np.dtype) -> Callable[[np.ndarray], Sexp]:
global base # noqa: PLW0603
if base is None:
base = importr('base')
base = baseenv()
if np.issubdtype(dtype, np.floating):
return base.as_double
if np.issubdtype(dtype, np.integer):
if dtype.itemsize <= R_INT_BYTES:
return base.as_integer
return base.as_numeric # maybe uses R_xlen_t?
if np.issubdtype(dtype, np.bool_):
return base.as_logical
msg = f'Unknown dtype {dtype!r} cannot be converted to ?gRMatrix.'
raise ValueError(msg)


def py2r_context(f: Callable[[sparse.spmatrix], Sexp]) -> Callable[[sparse.spmatrix], Sexp]:
"""R globalenv context with some helper functions."""

@wraps(f)
def wrapper(obj: sparse.spmatrix) -> Sexp:
global matrix # noqa: PLW0603
if matrix is None:
importr('Matrix') # make class available
r_code = get_r_code()
matrix = SignatureTranslatedAnonymousPackage(r_code, 'matrix')

return f(obj)

return wrapper


@converter.py2rpy.register(sparse.csc_matrix)
@py2r_context
def csc_to_rmat(csc: sparse.csc_matrix) -> Sexp:
matrix = matrixenv()
csc.sort_indices()
conv_data = get_type_conv(csc.dtype)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_csc(i=csc.indices, p=csc.indptr, x=csc.data, dims=list(csc.shape), conv_data=conv_data)


@converter.py2rpy.register(sparse.csr_matrix)
@py2r_context
def csr_to_rmat(csr: sparse.csr_matrix) -> Sexp:
matrix = matrixenv()
csr.sort_indices()
conv_data = get_type_conv(csr.dtype)
with localconverter(default_converter + numpy2ri.converter):
Expand All @@ -83,8 +74,8 @@ def csr_to_rmat(csr: sparse.csr_matrix) -> Sexp:


@converter.py2rpy.register(sparse.coo_matrix)
@py2r_context
def coo_to_rmat(coo: sparse.coo_matrix) -> Sexp:
matrix = matrixenv()
conv_data = get_type_conv(coo.dtype)
with localconverter(default_converter + numpy2ri.converter):
return matrix.from_coo(
Expand All @@ -97,8 +88,8 @@ def coo_to_rmat(coo: sparse.coo_matrix) -> Sexp:


@converter.py2rpy.register(sparse.dia_matrix)
@py2r_context
def dia_to_rmat(dia: sparse.dia_matrix) -> Sexp:
matrix = matrixenv()
conv_data = get_type_conv(dia.dtype)
if len(dia.offsets) > 1:
msg = (
Expand Down
43 changes: 31 additions & 12 deletions tests/test_py2rpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import catch_warnings, simplefilter
from warnings import catch_warnings, filterwarnings, simplefilter

import numpy as np
import pytest
Expand All @@ -10,6 +10,7 @@
from pandas import DataFrame
from rpy2.robjects import baseenv, globalenv
from rpy2.robjects.conversion import localconverter
from scipy import sparse

import anndata2ri
from anndata2ri._rpy2_ext import importr
Expand All @@ -31,6 +32,26 @@ def mk_ad_simple() -> AnnData:
)


@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
@pytest.mark.parametrize('mat_type', [np.asarray, sparse.csr_matrix])
def test_simple(
py2r: Py2R,
dtype: np.dtype,
mat_type: Callable[[np.ndarray, np.dtype], np.ndarray | sparse.spmatrix],
) -> None:
data = mk_ad_simple()
if data.X is not None:
data.X = mat_type(data.X, dtype=dtype)
ex = py2r(anndata2ri, data)
assert tuple(baseenv['dim'](ex)[::-1]) == data.shape


def krumsiek() -> AnnData:
adata = sc.datasets.krumsiek11()
adata.obs_names_make_unique()
return adata


def check_empty(_: Sexp) -> None:
pass

Expand All @@ -45,33 +66,31 @@ def check_pca(ex: Sexp) -> None:
datasets = [
pytest.param(check_empty, (0, 0), AnnData, id='empty'),
pytest.param(check_pca, (2, 3), mk_ad_simple, id='simple'),
pytest.param(check_empty, (640, 11), sc.datasets.krumsiek11, id='krumsiek'),
pytest.param(check_empty, (640, 11), krumsiek, id='krumsiek'),
]


@pytest.mark.parametrize(('check', 'shape', 'dataset'), datasets)
def test_py2rpy(
def test_datasets(
py2r: Py2R,
check: Callable[[Sexp], None],
shape: tuple[int, ...],
dataset: Callable[[], AnnData],
) -> None:
if dataset is sc.datasets.krumsiek11:
with (
pytest.warns(UserWarning, match=r'Duplicated obs_names'),
pytest.warns(UserWarning, match=r'Observation names are not unique'),
# TODO(flying-sheep): Adapt to rpy2 changes instead
# https://github.com/theislab/anndata2ri/issues/109
pytest.warns(DeprecationWarning, match=r'rpy2\.robjects\.conversion is deprecated'),
):
if dataset is krumsiek:
# TODO(flying-sheep): Adapt to rpy2 changes instead
# https://github.com/theislab/anndata2ri/issues/109
with pytest.warns(DeprecationWarning, match=r'rpy2\.robjects\.conversion is deprecated'):
filterwarnings('ignore', r'Duplicated obs_names', UserWarning)
filterwarnings('ignore', r'Observation names are not unique', UserWarning)
ex = py2r(anndata2ri, dataset())
else:
ex = py2r(anndata2ri, dataset())
assert tuple(baseenv['dim'](ex)[::-1]) == shape
check(ex)


def test_py2rpy2_numpy_pbmc68k() -> None:
def test_numpy_pbmc68k() -> None:
"""Not tested above as the pbmc68k dataset has some weird metadata."""
from scanpy.datasets import pbmc68k_reduced

Expand Down
6 changes: 3 additions & 3 deletions tests/test_rpy2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def check_example(adata: AnnData) -> None:


@pytest.mark.parametrize(('check', 'shape', 'dataset'), expression_sets)
def test_convert_manual(
def test_manual(
r2py: R2Py,
check: Callable[[AnnData], None],
shape: tuple[int, ...],
Expand All @@ -81,15 +81,15 @@ def test_convert_manual(
check(ad)


def test_convert_empty_df_with_rows(r2py: R2Py) -> None:
def test_empty_df_with_rows(r2py: R2Py) -> None:
df = r('S4Vectors::DataFrame(a=1:10)[, -1]')
assert df.slots['nrows'][0] == 10

df_py = r2py(anndata2ri, lambda: conversion.get_conversion().rpy2py(df))
assert isinstance(df_py, pd.DataFrame)


def test_convert_factor(r2py: R2Py) -> None:
def test_factor(r2py: R2Py) -> None:
code = """
SingleCellExperiment::SingleCellExperiment(
assays = list(counts = matrix(rpois(6*4, 5), ncol=4)),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scipy_py2rpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@pytest.mark.parametrize('typ', ['l', 'd'])
@pytest.mark.parametrize(('shape', 'dataset', 'cls'), mats)
def test_py2rpy(
def test_mats(
py2r: Py2R,
typ: Literal['l', 'd'],
shape: tuple[int, ...],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scipy_rpy2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@


@pytest.mark.parametrize(('shape', 'cls', 'dtype', 'arr', 'dataset'), mats)
def test_py2rpy(
def test_mats(
r2py: R2Py,
shape: tuple[int, int],
cls: type[sparse.spmatrix],
Expand Down

0 comments on commit ec7de24

Please sign in to comment.