Skip to content

Commit 6747bf0

Browse files
committed
2 parents 9a49914 + c5d9fec commit 6747bf0

10 files changed

Lines changed: 325 additions & 31 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,4 @@ tests/plots/generated
117117

118118
# Local temp script for testing user bugs (Luca)
119119
temp/
120+
uv.lock

docs/contributing.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
11
# Contributing guide
22

33
Please refer to the [contribution guide from the `spatialdata` repository](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md).
4+
5+
## Debugging napari GUI tests
6+
7+
To visually inspect what a test is rendering in napari:
8+
9+
1. Change `make_napari_viewer()` to `make_napari_viewer(show=True)`
10+
2. Add `napari.run()` before the end of the test (before the assertions)
11+
12+
Example:
13+
14+
```python
15+
import napari
16+
17+
18+
def test_my_visualization(make_napari_viewer):
19+
viewer = make_napari_viewer(show=True)
20+
# ... setup code ...
21+
napari.run()
22+
# assertions...
23+
```
24+
25+
Remember to revert these changes before committing.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ install_requires =
5454
scipy
5555
shapely
5656
scikit-learn
57-
spatialdata>=0.7.0dev0
57+
spatialdata>=0.7.0dev1
5858
superqt
5959
typing_extensions>=4.8.0
6060
vispy

src/napari_spatialdata/_viewer.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from spatialdata import get_element_annotators, get_element_instances
1919
from spatialdata._core.query.relational_query import _left_join_spatialelement_table
2020
from spatialdata._types import ArrayLike
21-
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names
21+
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names
2222
from spatialdata.transformations import Affine, Identity
2323

2424
from napari_spatialdata._model import DataModel
2525
from napari_spatialdata.constants import config
26-
from napari_spatialdata.constants.config import CIRCLES_AS_POINTS
2726
from napari_spatialdata.utils._utils import (
2827
_adjust_channels_order,
2928
_get_ellipses_from_circles,
@@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
470469
if multi:
471470
original_name = original_name[: original_name.rfind("_")]
472471

473-
affine = _get_transform(sdata.images[original_name], selected_cs)
472+
affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True)
474473
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])
475474

476475
channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name])
@@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
517516
df = sdata.shapes[original_name]
518517
affine = _get_transform(sdata.shapes[original_name], selected_cs)
519518

519+
# 2.5D circles not supported yet
520520
xy = np.array([df.geometry.x, df.geometry.y]).T
521521
yx = np.fliplr(xy)
522522
radii = df.radius.to_numpy()
@@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
541541
version = get_napari_version()
542542
kwargs: dict[str, Any] = (
543543
{"edge_width": 0.0}
544-
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS
544+
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS
545545
else {"border_width": 0.0}
546546
)
547-
if CIRCLES_AS_POINTS:
547+
if config.CIRCLES_AS_POINTS:
548548
layer = Points(
549549
yx,
550550
name=key,
@@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
556556
assert affine is not None
557557
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
558558
else:
559-
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS:
559+
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS:
560560
kwargs |= {"edge_color": "white"}
561561
else:
562562
kwargs |= {"border_color": "white"}
@@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
597597
original_name = original_name[: original_name.rfind("_")]
598598

599599
df = sdata.shapes[original_name]
600-
affine = _get_transform(sdata.shapes[original_name], selected_cs)
600+
include_z = not config.PROJECT_2_5D_SHAPES_TO_2D
601+
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)
601602

602603
# when mulitpolygons are present, we select the largest ones
603604
if "MultiPolygon" in np.unique(df.geometry.type):
@@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
609610
df = df.sort_index() # reset the index to the first order
610611

611612
simplify = len(df) > config.POLYGON_THRESHOLD
612-
polygons, indices = _get_polygons_properties(df, simplify)
613+
polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z)
613614

614615
# this will only work for polygons and not for multipolygons
615616
polygons = _transform_coordinates(polygons, f=lambda x: x[::-1])
@@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi
662663
original_name = original_name[: original_name.rfind("_")]
663664

664665
indices = get_element_instances(sdata.labels[original_name])
665-
affine = _get_transform(sdata.labels[original_name], selected_cs)
666+
affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True)
666667
rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name])
667668

668669
adata, table_name, table_names = self._get_table_data(sdata, original_name)
@@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
706707
if multi:
707708
original_name = original_name[: original_name.rfind("_")]
708709

710+
axes = get_axes_names(sdata.points[original_name])
709711
points = sdata.points[original_name].compute()
710-
affine = _get_transform(sdata.points[original_name], selected_cs)
712+
include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D
713+
affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z)
711714
adata, table_name, table_names = self._get_table_data(sdata, original_name)
712715

713716
if len(points) < config.POINT_THRESHOLD:
@@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
727730
_, adata = _left_join_spatialelement_table(
728731
{"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left"
729732
)
730-
xy = subsample_points[["y", "x"]].values
731-
np.fliplr(xy)
733+
axes = sorted(axes, reverse=True)
734+
if not include_z and "z" in axes:
735+
axes.remove("z")
736+
coords = subsample_points[axes].values
732737
# radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
733738
radii_size = 3
734739
version = get_napari_version()
735740
kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
736741
layer = Points(
737-
xy,
742+
coords,
738743
name=key,
739744
size=radii_size * 2,
740745
affine=affine,

src/napari_spatialdata/_widgets.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232
except ImportError:
3333
from scanpy.plotting._utils import set_default_colors_for_categorical_obs
3434

35+
# See https://github.com/scverse/squidpy/issues/1061 for more details.
36+
# Scanpy 0.11.x-0.12.x renamed set_default_colors_for_categorical_obs to _set_default_colors_for_categorical_obs
37+
# and then changed it back. Try underscore version first, fall back to non-underscore.
38+
try:
39+
from scanpy.plotting._utils import _set_colors_for_categorical_obs as set_colors_for_categorical_obs
40+
except ImportError:
41+
from scanpy.plotting._utils import set_colors_for_categorical_obs
42+
3543
from napari_spatialdata._model import DataModel
3644
from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version
3745

@@ -219,7 +227,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
219227
if self._attr != "columns_df":
220228
if vec_color_name not in self.model.adata.uns:
221229
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
222-
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
230+
set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
223231
colors = colorer.uns["vec_colors"]
224232
color_dict = dict(zip(vec.cat.categories, colors, strict=False))
225233
color_dict.update({np.nan: "#808080ff"})
@@ -230,7 +238,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
230238
df = layer.metadata["_columns_df"]
231239
if vec_color_name not in df.columns:
232240
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
233-
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
241+
set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
234242
colors = colorer.uns["vec_colors"]
235243
color_dict = dict(zip(vec.cat.categories, colors, strict=False))
236244
color_dict.update({np.nan: "#808080ff"})

src/napari_spatialdata/constants/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
N_SHAPES_WARNING_THRESHOLD = 10000
55
POINT_SIZE_SCATTERPLOT_WIDGET = 6
66
CIRCLES_AS_POINTS = True
7+
PROJECT_3D_POINTS_TO_2D = True
8+
PROJECT_2_5D_SHAPES_TO_2D = True

src/napari_spatialdata/utils/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,20 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
181181
return [[f(xy) for xy in sublist] for sublist in data]
182182

183183

184-
def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike:
184+
def _get_transform(
185+
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
186+
) -> None | ArrayLike:
185187
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
186188
raise RuntimeError("Cannot get transform for {type(element)}")
187189

188190
transformations = get_transformation(element, get_all=True)
189191
cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name
190192
ct = transformations.get(cs)
191193
if ct:
192-
return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x"))
194+
axes_element = get_axes_names(element)
195+
include_z = include_z and "z" in axes_element
196+
axes_transformation = ("z", "y", "x") if include_z else ("y", "x")
197+
return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation)
193198
return None
194199

195200

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,37 @@
11
from geopandas import GeoDataFrame
2+
from spatialdata.models import get_axes_names
23

4+
# type aliases, only used in this module
5+
Coord2D = tuple[float, float]
6+
Coord3D = tuple[float, float, float]
7+
Polygon2D = list[Coord2D]
8+
Polygon3D = list[Coord3D]
9+
Polygon = Polygon2D | Polygon3D
310

4-
def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]:
5-
indices = []
6-
polygons = []
711

8-
if simplify:
9-
for i in range(0, len(df)):
10-
indices.append(df.iloc[i].name)
11-
# This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly
12-
polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords))
13-
else:
14-
for i in range(0, len(df)):
15-
indices.append(df.iloc[i].name)
16-
polygons.append(list(df.geometry.iloc[i].exterior.coords))
12+
def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]:
13+
# assumes no "Polygon Z": z is in separate column if present
14+
indices: list[int] = []
15+
polygons: list[Polygon] = []
16+
17+
axes = get_axes_names(df)
18+
add_z = include_z and "z" in axes
19+
20+
for i in range(len(df)):
21+
indices.append(int(df.index[i]))
22+
23+
if simplify:
24+
xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)
25+
else:
26+
xy = list(df.geometry.iloc[i].exterior.coords)
27+
28+
coords: Polygon2D | Polygon3D
29+
if add_z:
30+
z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z)
31+
coords = [(x, y, z_val) for x, y in xy]
32+
else:
33+
coords = xy
34+
35+
polygons.append(coords)
1736

1837
return polygons, indices

tests/conftest.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
22

3+
# MUST set environment variables BEFORE any Qt/napari/vispy imports
4+
# to enable headless mode in CI environments (Ubuntu/Linux without display)
35
import os
6+
import sys
7+
8+
# Only use offscreen on Linux - macOS doesn't support the offscreen Qt platform plugin
9+
if sys.platform == "linux":
10+
os.environ.setdefault("QT_QPA_PLATFORM", "offscreen")
11+
12+
os.environ.setdefault("NAPARI_HEADLESS", "1")
413
import random
514
import string
615
from abc import ABC, ABCMeta
@@ -9,19 +18,23 @@
918
from pathlib import Path
1019
from typing import Any
1120

21+
import geopandas as gpd
1222
import napari
1323
import numpy as np
1424
import pandas as pd
1525
import pytest
1626
from anndata import AnnData
27+
from dask.dataframe import from_pandas
1728
from loguru import logger
1829
from matplotlib.testing.compare import compare_images
1930
from scipy import ndimage as ndi
31+
from shapely import MultiPolygon, Polygon
2032
from skimage import data
2133
from spatialdata import SpatialData
2234
from spatialdata._types import ArrayLike
2335
from spatialdata.datasets import blobs
24-
from spatialdata.models import TableModel
36+
from spatialdata.models import PointsModel, ShapesModel, TableModel
37+
from spatialdata.transformations import Identity, set_transformation
2538

2639
from napari_spatialdata.utils._test_utils import export_figure, save_image
2740

@@ -259,3 +272,61 @@ def caplog(caplog):
259272
def always_sync(monkeypatch, request):
260273
if request.node.get_closest_marker("use_thread_loader") is None:
261274
monkeypatch.setattr("napari_spatialdata._sdata_widgets.PROBLEMATIC_NUMPY_MACOS", True)
275+
276+
277+
@pytest.fixture
278+
def sdata_3d_points() -> SpatialData:
279+
"""Create a SpatialData object with 3D points (x, y, z coordinates)."""
280+
n_points = 10
281+
rng = np.random.default_rng(SEED)
282+
df = pd.DataFrame(
283+
{
284+
"x": rng.uniform(0, 100, n_points),
285+
"y": rng.uniform(0, 100, n_points),
286+
"z": rng.uniform(0, 50, n_points),
287+
}
288+
)
289+
dask_df = from_pandas(df, npartitions=1)
290+
points = PointsModel.parse(dask_df)
291+
set_transformation(points, {"global": Identity()}, set_all=True)
292+
293+
return SpatialData(points={"points_3d": points})
294+
295+
296+
@pytest.fixture
297+
def sdata_2_5d_shapes() -> SpatialData:
298+
"""Create a SpatialData object with 2.5D shapes (3 layers at different z, polygons + multipolygons)."""
299+
shapes = {}
300+
301+
geometries = []
302+
z_values = []
303+
indices = []
304+
for i, z_val in enumerate([0.0, 10.0, 20.0]):
305+
# Add simple polygons (triangles and quadrilaterals)
306+
poly1 = Polygon([(10 + i * 5, 10), (20 + i * 5, 10), (15 + i * 5, 20)])
307+
poly2 = Polygon([(30 + i * 5, 30), (40 + i * 5, 30), (40 + i * 5, 40), (30 + i * 5, 40)])
308+
geometries.extend([poly1, poly2])
309+
indices.extend([0, 1])
310+
z_values.extend([z_val] * 2)
311+
312+
# Add a multipolygon (two separate polygon parts)
313+
multi_poly = MultiPolygon(
314+
[
315+
Polygon([(50 + i * 5, 10), (60 + i * 5, 10), (55 + i * 5, 20)]),
316+
Polygon([(50 + i * 5, 30), (60 + i * 5, 30), (60 + i * 5, 40), (50 + i * 5, 40)]),
317+
]
318+
)
319+
geometries.append(multi_poly)
320+
indices.append(2)
321+
z_values.append(z_val)
322+
323+
gdf = gpd.GeoDataFrame(
324+
{"z": z_values, "geometry": geometries},
325+
index=indices,
326+
)
327+
328+
shape_element = ShapesModel.parse(gdf)
329+
set_transformation(shape_element, {"global": Identity()}, set_all=True)
330+
shapes["shapes_2.5d"] = shape_element
331+
332+
return SpatialData(shapes=shapes)

0 commit comments

Comments
 (0)