Skip to content

Commit 2ebab21

Browse files
authored
Zonal functions accept vector zones directly (#999)
* Accept vector zones in zonal functions (#998) stats(), crosstab(), apply(), and crop() now accept GeoDataFrames or list-of-(geometry, value) pairs as the zones argument. When vector input is detected, rasterize() is called internally using the values raster as the template. Adds column and rasterize_kw parameters. * Add vector zones docs and notebook examples (#998) Document column and rasterize_kw parameters in stats(), crosstab(), apply(), and crop() docstrings. Add vector zones section to the 3_Zonal user guide notebook with GeoDataFrame, list-of-pairs, crosstab, and rasterize_kw examples.
1 parent 332ad5e commit 2ebab21

3 files changed

Lines changed: 373 additions & 39 deletions

File tree

examples/user_guide/3_Zonal.ipynb

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,7 @@
33
{
44
"cell_type": "markdown",
55
"metadata": {},
6-
"source": [
7-
"# Xarray-spatial\n",
8-
"### User Guide: Zonal\n",
9-
"-----\n",
10-
"\n",
11-
"Xarray-spatial's zonal functions provide an easy way to generate statistics for zones within a raster aggregate. It's set up with a default set of calculations, or you can input any set of custom calculations you'd like to perform.\n",
12-
"\n",
13-
"[Generate terrain](#Generate-Terrain-Data)\n",
14-
"[Zonal Statistics](#Zonal-Statistics)\n",
15-
"\n",
16-
"-----------\n"
17-
]
6+
"source": "# Xarray-spatial\n### User Guide: Zonal\n-----\n\nXarray-spatial's zonal functions provide an easy way to generate statistics for zones within a raster aggregate. It's set up with a default set of calculations, or you can input any set of custom calculations you'd like to perform.\n\n[Generate terrain](#Generate-Terrain-Data)\n[Zonal Statistics](#Zonal-Statistics)\n[Zonal Crosstab](#Zonal-crosstab)\n[Vector Zones](#Vector-zones)\n\n-----------"
187
},
198
{
209
"cell_type": "markdown",
@@ -239,11 +228,71 @@
239228
"source": "# Calculate crosstab: rows are trail segments, columns are tree species (1=Oak, 2=Pine, 3=Maple)\nresult = tree_species_agg.xrs.zonal_crosstab(zones_agg, zone_ids=[11, 12, 13, 14, 15, 16], cat_ids=[1, 2, 3])\nresult.columns = ['Trail Segment', 'Oak', 'Pine', 'Maple']\nresult"
240229
},
241230
{
242-
"cell_type": "code",
231+
"cell_type": "markdown",
243232
"execution_count": null,
244233
"metadata": {},
245234
"outputs": [],
246-
"source": []
235+
"source": "## Vector zones\n\nZonal functions can also accept vector geometries directly -- a GeoDataFrame or a list of `(geometry, value)` pairs. The rasterization to match the values grid happens automatically inside the function call, so there's no separate `rasterize()` step.\n\nThis works with `zonal_stats`, `zonal_crosstab`, `zonal_apply`, and `crop`."
236+
},
237+
{
238+
"cell_type": "markdown",
239+
"source": "### GeoDataFrame zones\n\nLet's define three rectangular regions as polygons in a GeoDataFrame and compute elevation statistics directly.",
240+
"metadata": {}
241+
},
242+
{
243+
"cell_type": "code",
244+
"source": "import geopandas as gpd\nfrom shapely.geometry import box\nfrom xrspatial.zonal import stats\n\n# Three rectangular regions across the terrain\nzones_gdf = gpd.GeoDataFrame(\n {'region_id': [1.0, 2.0, 3.0],\n 'name': ['West valley', 'Central ridge', 'East slope']},\n geometry=[\n box(-20, -20, -5, 20), # left third\n box(-5, -20, 5, 20), # middle third\n box(5, -20, 20, 20), # right third\n ],\n)\n\n# Pass the GeoDataFrame directly -- no manual rasterize() needed\nresult = stats(zones_gdf, terrain, column='region_id',\n stats_funcs=['mean', 'min', 'max', 'std', 'count'])\nresult",
245+
"metadata": {},
246+
"execution_count": null,
247+
"outputs": []
248+
},
249+
{
250+
"cell_type": "markdown",
251+
"source": "The same thing works with the accessor. The values raster is used as the template for rasterization, so the zones automatically align to the grid.",
252+
"metadata": {}
253+
},
254+
{
255+
"cell_type": "code",
256+
"source": "# Accessor style -- same result\nterrain.xrs.zonal_stats(zones_gdf, column='region_id',\n stats_funcs=['mean', 'min', 'max', 'std', 'count'])",
257+
"metadata": {},
258+
"execution_count": null,
259+
"outputs": []
260+
},
261+
{
262+
"cell_type": "markdown",
263+
"source": "### List-of-pairs zones\n\nIf you don't have a GeoDataFrame, you can pass a list of `(geometry, zone_id)` tuples. No `column` argument needed here since the zone ID is the second element of each pair.",
264+
"metadata": {}
265+
},
266+
{
267+
"cell_type": "code",
268+
"source": "pairs = [\n (box(-15, -15, 0, 0), 10.0), # southwest quadrant\n (box(0, 0, 15, 15), 20.0), # northeast quadrant\n]\n\nstats(pairs, terrain, stats_funcs=['mean', 'count'])",
269+
"metadata": {},
270+
"execution_count": null,
271+
"outputs": []
272+
},
273+
{
274+
"cell_type": "markdown",
275+
"source": "### Crosstab with vector zones\n\nVector zones work with `zonal_crosstab` too. Here we'll use the tree species data from earlier with our GeoDataFrame regions.",
276+
"metadata": {}
277+
},
278+
{
279+
"cell_type": "code",
280+
"source": "from xrspatial.zonal import crosstab\n\nct = crosstab(zones_gdf, tree_species_agg, column='region_id', cat_ids=[1, 2, 3])\nct.columns = ['Region', 'Oak', 'Pine', 'Maple']\nct",
281+
"metadata": {},
282+
"execution_count": null,
283+
"outputs": []
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"source": "### Forwarding rasterize options\n\nYou can pass extra options to the internal `rasterize()` call via `rasterize_kw`. For example, `all_touched=True` will include pixels that touch the geometry boundary, not just those whose centre falls inside.",
288+
"metadata": {}
289+
},
290+
{
291+
"cell_type": "code",
292+
"source": "# Small polygon -- compare default vs all_touched\nsmall_zone = gpd.GeoDataFrame(\n {'zone': [1.0]}, geometry=[box(-2, -2, 2, 2)]\n)\n\ndefault = stats(small_zone, terrain, column='zone', stats_funcs=['count'])\ntouched = stats(small_zone, terrain, column='zone', stats_funcs=['count'],\n rasterize_kw={'all_touched': True})\n\nprint(f\"Default pixel count: {int(default['count'].iloc[0])}\")\nprint(f\"all_touched pixel count: {int(touched['count'].iloc[0])}\")",
293+
"metadata": {},
294+
"execution_count": null,
295+
"outputs": []
247296
}
248297
],
249298
"metadata": {

xrspatial/tests/test_zonal.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,145 @@ def _guarded_unique(a, *args, **kwargs):
16861686
result = result.compute()
16871687
assert isinstance(result, pd.DataFrame)
16881688
assert len(result) > 0
1689+
1690+
1691+
# ---------------------------------------------------------------------------
1692+
# Vector zones (GeoDataFrame / list-of-pairs) — implicit rasterization
1693+
# ---------------------------------------------------------------------------
1694+
1695+
try:
1696+
from shapely.geometry import box as shapely_box
1697+
import geopandas as gpd
1698+
_has_vector = True
1699+
except ImportError:
1700+
_has_vector = False
1701+
1702+
skip_no_vector = pytest.mark.skipif(
1703+
not _has_vector, reason="shapely/geopandas not installed"
1704+
)
1705+
1706+
1707+
def _make_grid(width=20, height=20, bounds=(0, 0, 100, 100)):
1708+
"""Build a template DataArray with pixel-centre coords."""
1709+
xmin, ymin, xmax, ymax = bounds
1710+
px = (xmax - xmin) / width
1711+
py = (ymax - ymin) / height
1712+
x = np.linspace(xmin + px / 2, xmax - px / 2, width)
1713+
y = np.linspace(ymax - py / 2, ymin + py / 2, height)
1714+
return xr.DataArray(
1715+
np.ones((height, width), dtype=np.float64),
1716+
dims=['y', 'x'],
1717+
coords={'y': y, 'x': x},
1718+
)
1719+
1720+
1721+
@skip_no_vector
1722+
class TestVectorZones:
1723+
"""Verify that zonal functions accept vector zones directly."""
1724+
1725+
def _zones_raster_and_gdf(self):
1726+
"""Two non-overlapping boxes covering different quadrants."""
1727+
from xrspatial.rasterize import rasterize
1728+
1729+
values = _make_grid()
1730+
gdf = gpd.GeoDataFrame(
1731+
{'zone_id': [1.0, 2.0]},
1732+
geometry=[
1733+
shapely_box(0, 50, 50, 100), # top-left
1734+
shapely_box(50, 0, 100, 50), # bottom-right
1735+
],
1736+
)
1737+
zones_raster = rasterize(gdf, like=values, column='zone_id')
1738+
return values, gdf, zones_raster
1739+
1740+
# -- stats --
1741+
1742+
def test_stats_gdf_matches_raster(self):
1743+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1744+
expected = stats(zones_raster, values, stats_funcs=['mean', 'count'])
1745+
result = stats(gdf, values, stats_funcs=['mean', 'count'],
1746+
column='zone_id')
1747+
pd.testing.assert_frame_equal(result, expected)
1748+
1749+
def test_stats_list_of_pairs(self):
1750+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1751+
pairs = [
1752+
(shapely_box(0, 50, 50, 100), 1.0),
1753+
(shapely_box(50, 0, 100, 50), 2.0),
1754+
]
1755+
expected = stats(zones_raster, values, stats_funcs=['mean', 'count'])
1756+
result = stats(pairs, values, stats_funcs=['mean', 'count'])
1757+
pd.testing.assert_frame_equal(result, expected)
1758+
1759+
def test_stats_gdf_missing_column_raises(self):
1760+
values = _make_grid()
1761+
gdf = gpd.GeoDataFrame(
1762+
{'zone_id': [1.0]},
1763+
geometry=[shapely_box(0, 0, 50, 50)],
1764+
)
1765+
with pytest.raises(ValueError, match="column is required"):
1766+
stats(gdf, values)
1767+
1768+
def test_stats_pairs_with_column_raises(self):
1769+
values = _make_grid()
1770+
pairs = [(shapely_box(0, 0, 50, 50), 1.0)]
1771+
with pytest.raises(ValueError, match="column should not be set"):
1772+
stats(pairs, values, column='zone_id')
1773+
1774+
def test_stats_accessor(self):
1775+
import xrspatial # noqa: F401 — registers accessor
1776+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1777+
expected = stats(zones_raster, values, stats_funcs=['mean', 'count'])
1778+
result = values.xrs.zonal_stats(gdf, stats_funcs=['mean', 'count'],
1779+
column='zone_id')
1780+
pd.testing.assert_frame_equal(result, expected)
1781+
1782+
# -- crosstab --
1783+
1784+
def test_crosstab_gdf(self):
1785+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1786+
expected = crosstab(zones_raster, values)
1787+
result = crosstab(gdf, values, column='zone_id')
1788+
pd.testing.assert_frame_equal(result, expected)
1789+
1790+
# -- apply --
1791+
1792+
def test_apply_gdf(self):
1793+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1794+
# rasterize produces float zones; apply needs int zones
1795+
zones_int = zones_raster.copy(data=zones_raster.values.astype(int))
1796+
fn = lambda x: x * 2
1797+
expected = apply(zones_int, values, fn)
1798+
result = apply(gdf, values, fn, column='zone_id',
1799+
rasterize_kw={'dtype': int})
1800+
xr.testing.assert_identical(result, expected)
1801+
1802+
# -- crop --
1803+
1804+
def test_crop_gdf(self):
1805+
values, gdf, zones_raster = self._zones_raster_and_gdf()
1806+
expected = crop(zones_raster, values, zones_ids=[1.0])
1807+
result = crop(gdf, values, zones_ids=[1.0], column='zone_id')
1808+
xr.testing.assert_identical(result, expected)
1809+
1810+
# -- rasterize_kw forwarding --
1811+
1812+
def test_stats_rasterize_kw_all_touched(self):
1813+
values, gdf, _ = self._zones_raster_and_gdf()
1814+
from xrspatial.rasterize import rasterize
1815+
zones_at = rasterize(gdf, like=values, column='zone_id',
1816+
all_touched=True)
1817+
expected = stats(zones_at, values, stats_funcs=['count'])
1818+
result = stats(gdf, values, stats_funcs=['count'],
1819+
column='zone_id',
1820+
rasterize_kw={'all_touched': True})
1821+
pd.testing.assert_frame_equal(result, expected)
1822+
1823+
# -- raster zones still work --
1824+
1825+
def test_raster_zones_unchanged(self):
1826+
"""Passing a DataArray directly should still work as before."""
1827+
values, _, zones_raster = self._zones_raster_and_gdf()
1828+
result = stats(zones_raster, values, stats_funcs=['count'])
1829+
assert isinstance(result, pd.DataFrame)
1830+
assert len(result) > 0

0 commit comments

Comments
 (0)