Skip to content

Commit 0a72dc3

Browse files
jan-matthiscopybara-github
authored andcommitted
Add decorators.
PiperOrigin-RevId: 548633295
1 parent a4fa5c4 commit 0a72dc3

10 files changed

Lines changed: 1668 additions & 0 deletions

File tree

decorators/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2023 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.

decorators/affine.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# coding=utf-8
2+
# Copyright 2023 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Affine transform decorators."""
16+
17+
from typing import Any, Mapping, MutableMapping, Optional, Sequence
18+
19+
from connectomics.common import opencv_utils
20+
from connectomics.volume.decorators import Decorator # pylint: disable=g-importing-member
21+
import gin
22+
import numpy as np
23+
import skimage.feature
24+
import tensorstore as ts
25+
26+
JsonSpec = Mapping[str, Any]
27+
MutableJsonSpec = MutableMapping[str, Any]
28+
29+
30+
@gin.register
31+
class OptimAffineTransformSectionwise(Decorator):
32+
"""Finds 2D affine transforms sectionwise by ECC optimization."""
33+
34+
def __init__(self,
35+
fixed_spec: JsonSpec,
36+
image_dims: Sequence[str] = ('x', 'y'),
37+
batch_dim: Optional[str] = None,
38+
init_previous: bool = False,
39+
context_spec: Optional[MutableJsonSpec] = None,
40+
**optim_args):
41+
"""Optimize affine transform sectionwise.
42+
43+
Uses OpenCV's `cv.findTransformECC` to find an affine transformations that
44+
aligns moving and fixed 2D images, where moving images are taken from the
45+
input TensorStore and fixed images from the one specified by `fixed_spec`.
46+
47+
Note that optimisation is done per 2D section according to `image_dims`.
48+
The resulting TensorStore contains 2x3 transformation matrices in
49+
dimensions 'r' (row) and 'c' (column), for all non-image dimensions.
50+
Transformation matrices are stored in individual chunks.
51+
52+
Args:
53+
fixed_spec: TensorStore containing fixed images to align against.
54+
Must have same dimensions as input TS (labels and shape).
55+
image_dims: Image dimensions to transform, e.g., `x` and `y` (two).
56+
batch_dim: Optional dimension to batch reads.
57+
init_previous: If True, initializes transform for subsequent calls
58+
of the optimization function with the previous result. Requires
59+
specification of a `batch_dim`.
60+
context_spec: Spec for virtual chunked context overriding its defaults.
61+
**optim_args: Passed to `opencv_utils.optim_transform`.
62+
"""
63+
super().__init__(context_spec)
64+
self._fixed_spec = fixed_spec
65+
self._image_dims = image_dims
66+
self._batch_dim = batch_dim
67+
self._init_previous = init_previous
68+
if init_previous and not batch_dim:
69+
raise ValueError('`batch_dim` must be specified to use `init_previous`.')
70+
if 'transform_initial' in optim_args:
71+
self._transform_initial = optim_args['transform_initial']
72+
optim_args.pop('transform_initial')
73+
else:
74+
self._transform_initial = None
75+
self._optim_args = optim_args
76+
77+
def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore:
78+
"""Wraps input TensorStore with a virtual_chunked for optim_transform."""
79+
80+
fixed_ts = ts.open(self._fixed_spec).result()
81+
if input_ts.domain.labels != fixed_ts.domain.labels:
82+
raise ValueError(
83+
'Input TS and fixed TS must have same labels, but they are ' +
84+
f'{input_ts.domain.labels} and {fixed_ts.domain.labels}.')
85+
if input_ts.shape != fixed_ts.shape:
86+
raise ValueError(
87+
'Input TS and fixed TS must have same shape, but they are ' +
88+
f'{input_ts.shape} and {fixed_ts.shape}.')
89+
90+
if len(self._image_dims) != 2:
91+
raise ValueError(
92+
f'2 image dimensions are required, but got {len(self._image_dims)}.')
93+
for d in self._image_dims:
94+
if d not in input_ts.domain.labels:
95+
raise ValueError(
96+
f'image dimension {d} not among labels {input_ts.domain.labels}.')
97+
elif input_ts.domain[d].size < 2:
98+
raise ValueError(
99+
'image dimension {d} must at least have size 2 but has size: ' +
100+
f'{input_ts.domain[d].size}.')
101+
102+
non_image_dims = [
103+
l for l in input_ts.domain.labels if l not in self._image_dims]
104+
input_domain_dict = {dim.label: dim for dim in list(input_ts.domain)}
105+
batch_idx = (input_ts.domain.labels.index(self._batch_dim)
106+
if self._batch_dim else None)
107+
108+
def read_fn(domain: ts.IndexDomain, array: np.ndarray,
109+
unused_read_params: ts.VirtualChunkedReadParameters):
110+
domain_dict = {dim.label: dim for dim in list(domain)}
111+
112+
if self._transform_initial:
113+
transform_initial = self._transform_initial.copy()
114+
else:
115+
transform_initial = None
116+
117+
if not self._batch_dim:
118+
read_domain = []
119+
for l in non_image_dims:
120+
read_domain.append(domain_dict[l])
121+
for l in self._image_dims:
122+
read_domain.append(input_domain_dict[l])
123+
read_domain = ts.IndexDomain(read_domain)
124+
125+
# Images are transposed since OpenCV uses [x,y]-convention.
126+
# See `opencv_utils` for more details.
127+
_, transform = opencv_utils.optim_transform(
128+
fix=np.array(fixed_ts[read_domain], dtype=np.float32).squeeze().T,
129+
mov=np.array(input_ts[read_domain], dtype=np.float32).squeeze().T,
130+
transform_initial=transform_initial,
131+
**self._optim_args)
132+
133+
array[...] = transform.reshape(array.shape)
134+
else:
135+
for i, j in enumerate(domain_dict[self._batch_dim]):
136+
read_domain = []
137+
for l in non_image_dims:
138+
if l != self._batch_dim:
139+
read_domain.append(domain_dict[l])
140+
else:
141+
read_domain.append(
142+
ts.Dim(inclusive_min=j, exclusive_max=j+1,
143+
label=self._batch_dim))
144+
for l in self._image_dims:
145+
read_domain.append(input_domain_dict[l])
146+
read_domain = ts.IndexDomain(read_domain)
147+
148+
# Images are transposed since OpenCV uses [x,y]-convention.
149+
# See `opencv_utils` for more details.
150+
_, transform = opencv_utils.optim_transform(
151+
fix=np.array(fixed_ts[read_domain], dtype=np.float32).squeeze().T,
152+
mov=np.array(input_ts[read_domain], dtype=np.float32).squeeze().T,
153+
transform_initial=transform_initial,
154+
**self._optim_args)
155+
if self._init_previous:
156+
transform_initial = transform
157+
158+
idx = [slice(None) for _ in range(array.ndim)]
159+
idx[batch_idx] = i
160+
array[tuple(idx)] = transform.reshape(array[tuple(idx)].shape)
161+
162+
chunksize = [2, 3]
163+
for l in non_image_dims:
164+
if l != self._batch_dim:
165+
chunksize.append(1)
166+
else:
167+
chunksize.append(input_domain_dict[l].size)
168+
schema = {
169+
'chunk_layout': {
170+
'read_chunk': {'shape': chunksize},
171+
'write_chunk': {'shape': chunksize},
172+
},
173+
'domain': {
174+
'labels': ['r', 'c',] + non_image_dims,
175+
'inclusive_min': [0, 0] + [
176+
input_domain_dict[l].inclusive_min for l in non_image_dims],
177+
'exclusive_max': [2, 3] + [
178+
input_domain_dict[l].exclusive_max for l in non_image_dims],
179+
},
180+
'dtype': 'float64',
181+
'rank': len(chunksize),
182+
}
183+
184+
return ts.virtual_chunked(
185+
read_fn, schema=ts.Schema(schema), context=self._context)
186+
187+
188+
@gin.register
189+
class OptimTranslationTransform(Decorator):
190+
"""Finds 2D/3D translations for registration via cross-correlation."""
191+
192+
def __init__(self,
193+
fixed_spec: JsonSpec,
194+
image_dims: Sequence[str] = ('x', 'y'),
195+
context_spec: Optional[MutableJsonSpec] = None,
196+
**optim_args):
197+
"""Computes cross-correlation between volumes for registration.
198+
199+
Uses skimage's `registration.phase_cross_correlation` to find translation
200+
matrices for registration of two volumes, where 2D or 3D moving images are
201+
taken from the input TensorStore and fixed images from the one specified by
202+
`fixed_spec`.
203+
204+
The resulting TensorStore contains 2x3 (2D) or 3x4 (3D) transformation
205+
matrices in dimensions 'r' (row) and 'c' (column), for all non-image
206+
dimensions. Transformation matrices are stored in individual chunks.
207+
208+
Args:
209+
fixed_spec: TensorStore containing fixed images to align against.
210+
Must have same dimensions as input TS (labels and shape).
211+
image_dims: Image dimensions to transform, e.g., `x` and `y` (two).
212+
context_spec: Spec for virtual chunked context overriding its defaults.
213+
**optim_args: Passed to `skimage.registration.phase_cross_correlation`.
214+
"""
215+
super().__init__(context_spec)
216+
self._fixed_spec = fixed_spec
217+
self._image_dims = image_dims
218+
self._optim_args = optim_args
219+
220+
def decorate(self, input_ts: ts.TensorStore) -> ts.TensorStore:
221+
"""Wraps input TensorStore with a virtual_chunked."""
222+
223+
fixed_ts = ts.open(self._fixed_spec).result()
224+
if input_ts.domain.labels != fixed_ts.domain.labels:
225+
raise ValueError(
226+
'Input TS and fixed TS must have same labels, but they are ' +
227+
f'{input_ts.domain.labels} and {fixed_ts.domain.labels}.')
228+
if input_ts.shape != fixed_ts.shape:
229+
raise ValueError(
230+
'Input TS and fixed TS must have same shape, but they are ' +
231+
f'{input_ts.shape} and {fixed_ts.shape}.')
232+
233+
num_image_dims = len(self._image_dims)
234+
if num_image_dims not in (2, 3):
235+
raise ValueError(
236+
f'2 or 3 image dimensions are required, but got {num_image_dims}.')
237+
for d in self._image_dims:
238+
if d not in input_ts.domain.labels:
239+
raise ValueError(
240+
f'image dimension {d} not among labels {input_ts.domain.labels}.')
241+
elif input_ts.domain[d].size < 2:
242+
raise ValueError(
243+
'image dimension {d} must at least have size 2 but has size: ' +
244+
f'{input_ts.domain[d].size}.')
245+
246+
non_image_dims = [
247+
l for l in input_ts.domain.labels if l not in self._image_dims]
248+
input_domain_dict = {dim.label: dim for dim in list(input_ts.domain)}
249+
250+
def read_fn(domain: ts.IndexDomain, array: np.ndarray,
251+
unused_read_params: ts.VirtualChunkedReadParameters):
252+
domain_dict = {dim.label: dim for dim in list(domain)}
253+
254+
read_domain = []
255+
for l in non_image_dims:
256+
read_domain.append(domain_dict[l])
257+
for l in self._image_dims:
258+
read_domain.append(input_domain_dict[l])
259+
read_domain = ts.IndexDomain(read_domain)
260+
261+
# Default to no normalization.
262+
if 'normalization' not in self._optim_args:
263+
self._optim_args['normalization'] = None
264+
265+
translation, _, _ = skimage.registration.phase_cross_correlation(
266+
reference_image=np.array(
267+
fixed_ts[read_domain], dtype=np.float32).squeeze(),
268+
moving_image=np.array(
269+
input_ts[read_domain], dtype=np.float32).squeeze(),
270+
**self._optim_args)
271+
transform = np.hstack(
272+
(np.eye(len(self._image_dims)), translation.reshape(-1, 1)))
273+
274+
array[...] = transform.reshape(array.shape)
275+
276+
chunksize = [num_image_dims, num_image_dims + 1]
277+
for _ in non_image_dims:
278+
chunksize.append(1)
279+
schema = {
280+
'chunk_layout': {
281+
'read_chunk': {'shape': chunksize},
282+
'write_chunk': {'shape': chunksize},
283+
},
284+
'domain': {
285+
'labels': ['r', 'c',] + non_image_dims,
286+
'inclusive_min': [0, 0] + [
287+
input_domain_dict[l].inclusive_min for l in non_image_dims],
288+
'exclusive_max': [num_image_dims, num_image_dims + 1] + [
289+
input_domain_dict[l].exclusive_max for l in non_image_dims],
290+
},
291+
'dtype': 'float64',
292+
'rank': len(chunksize),
293+
}
294+
295+
return ts.virtual_chunked(
296+
read_fn, schema=ts.Schema(schema), context=self._context)

0 commit comments

Comments
 (0)