Skip to content

Commit

Permalink
Set domain to be appropriately sized in Visium ingestion (#3249)
Browse files Browse the repository at this point in the history
This sets the domain for all Visium assets to be large enough to hold the current data.

A fix was added to allow setting a string domain to be the full domain instead of the smallest possible domain.
  • Loading branch information
jp-dark authored Oct 28, 2024
1 parent 24a7438 commit 50c5d71
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 146 deletions.
294 changes: 155 additions & 139 deletions apis/python/notebooks/tutorial_spatial.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions apis/python/src/tiledbsoma/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,10 @@ def _fill_out_slot_soma_domain(
if slot_domain is not None:
# User-specified; go with it when possible
if (
pa_type == pa.string()
or pa_type == pa.large_string()
(
(pa_type == pa.string() or pa_type == pa.large_string())
and slot_domain != ("", "")
)
or pa_type == pa.binary()
or pa_type == pa.large_binary()
):
Expand Down
19 changes: 14 additions & 5 deletions apis/python/src/tiledbsoma/experimental/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,14 @@ def from_visium(
exp.obs.read(column_names=["soma_joinid", "obs_id"]).concat().to_pandas()
)
if write_obs_spatial_presence or write_var_spatial_presence:
x_layer = exp.ms[measurement_name].X[X_layer_name].read().tables().concat()
x_layer = exp.ms[measurement_name].X[X_layer_name]
(len_obs_id, len_var_id) = x_layer.shape
x_layer_data = x_layer.read().tables().concat()
if write_obs_spatial_presence:
obs_id = pacomp.unique(x_layer["soma_dim_0"])
obs_id = pacomp.unique(x_layer_data["soma_dim_0"])

if write_var_spatial_presence:
var_id = pacomp.unique(x_layer["soma_dim_1"])
var_id = pacomp.unique(x_layer_data["soma_dim_1"])

# Add spatial information to the experiment.
with Experiment.open(experiment_uri, mode="w", context=context) as exp:
Expand Down Expand Up @@ -591,7 +594,7 @@ def from_visium(
if write_obs_spatial_presence:
obs_spatial_presence_uri = _util.uri_joinpath(uri, "obs_spatial_presence")
obs_spatial_presence = _write_scene_presence_dataframe(
obs_id, scene_name, obs_spatial_presence_uri, **ingest_ctx
obs_id, len_obs_id, scene_name, obs_spatial_presence_uri, **ingest_ctx
)
_maybe_set(
exp,
Expand All @@ -605,7 +608,7 @@ def from_visium(
"var_spatial_presence",
)
var_spatial_presence = _write_scene_presence_dataframe(
var_id, scene_name, var_spatial_presence_uri, **ingest_ctx
var_id, len_var_id, scene_name, var_spatial_presence_uri, **ingest_ctx
)
meas = exp.ms[measurement_name]
_maybe_set(
Expand All @@ -619,6 +622,7 @@ def from_visium(

def _write_scene_presence_dataframe(
joinids: pa.array,
max_joinid_len: int,
scene_id: str,
df_uri: str,
*,
Expand All @@ -639,6 +643,7 @@ def _write_scene_presence_dataframe(
("data", pa.bool_()),
]
),
domain=((0, max_joinid_len - 1), ("", "")),
index_column_names=("soma_joinid", "scene_id"),
platform_config=platform_config,
context=context,
Expand Down Expand Up @@ -708,13 +713,17 @@ def _write_visium_spots(
df = pd.merge(obs_df, df, how="inner", on=id_column_name)
df.drop(id_column_name, axis=1, inplace=True)

domain = ((df["x"].min(), df["x"].max()), (df["y"].min(), df["y"].max()))

arrow_table = df_to_arrow(df)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
soma_point_cloud = PointCloudDataFrame.create(
df_uri,
schema=arrow_table.schema,
index_column_names=("x", "y"),
domain=domain,
platform_config=platform_config,
context=context,
)
Expand Down
38 changes: 38 additions & 0 deletions apis/python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,3 +1831,41 @@ def test_fix_update_dataframe_with_var_strings(tmp_path):
with soma.DataFrame.open(uri, "r") as sdf:
results = sdf.read().concat().to_pandas()
assert results.equals(updated_sdf)


def test_presence_matrix(tmp_path):
uri = tmp_path.as_uri()

# Cerate the dataframe
soma_df = soma.DataFrame.create(
uri,
schema=pa.schema(
[
("soma_joinid", pa.int64()),
("scene_id", pa.string()),
("data", pa.bool_()),
]
),
domain=((0, 99), ("", "")),
index_column_names=("soma_joinid", "scene_id"),
)

# Create datda to write
joinid_data = pa.array(np.arange(0, 100, 5))
scene_id_data = 10 * ["scene1"] + 10 * ["scene2"]
df = pd.DataFrame(
{
"soma_joinid": joinid_data,
"scene_id": scene_id_data,
"data": 20 * [True],
}
)
arrow_table = pa.Table.from_pandas(df)
soma_df.write(arrow_table)

soma_df.close()

with soma.DataFrame.open(uri) as soma_df:
actual = soma_df.read().concat().to_pandas()

assert actual.equals(df)

0 comments on commit 50c5d71

Please sign in to comment.