Skip to content

Commit e4d83ca

Browse files
Baharisstefsmeets
andauthored
Improve the readability of beam shift calibration code & results (#144)
* Streamline the `calibrate_beamshift_live` function * Streamline `CalibBeamShift.plot` * Improvements to `CalibBeamShift` readability (WIP) * Add option to calibrate beamshift with vsp * Fix errors, add beam center * Switch calibration output format from pickle to yaml * Allow delay during calibrate beamshift * Add necessary reflections to fix plotting * Final tweaks * Make the yaml produced by CalibBeamShift human-readable * Update src/instamatic/calibrate/calibrate_beamshift.py Co-authored-by: Stef Smeets <stefsmeets@users.noreply.github.com> * Update src/instamatic/calibrate/calibrate_beamshift.py Co-authored-by: Stef Smeets <stefsmeets@users.noreply.github.com> * Minor post-review type-hint improvements + ruff --------- Co-authored-by: Stef Smeets <stefsmeets@users.noreply.github.com>
1 parent 3cf45a1 commit e4d83ca

5 files changed

Lines changed: 166 additions & 123 deletions

File tree

docs/config.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ This file holds the specifications of the camera. This file is must be located t
243243
: Set the correction ratio for the cross pixels in the Timepix detector, default: 3.
244244

245245
**calib_beamshift**
246-
: Set up the grid and stepsize for the calibration of the beam shift in SerialED. The calibration will run a grid of `stepsize` by `stepsize` points, with steps of `stepsize`. The stepsize must be given corresponding to 2500x, and instamatic will then adjust the stepsize depending on the actual magnification, if needed. For example:
246+
: Set up the grid and stepsize for the calibration of the beam shift in SerialED. The calibration will run a grid of `stepsize` by `stepsize` points, with steps of `stepsize`. The stepsize must be given corresponding to 2500x, and instamatic will then adjust the stepsize depending on the actual magnification, if needed. If the beam moves too slow, a `delay` between setting beam shift and getting image can be introduced. For example:
247247
```yaml
248-
{gridsize: 5, stepsize: 500}
248+
{gridsize: 5, stepsize: 500, delay: 0.5}
249249
```
250250

251251
**calib_directbeam**

src/instamatic/calibrate/calibrate_beamshift.py

Lines changed: 130 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,149 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
5-
import pickle
64
import sys
7-
from typing import Optional
5+
import time
6+
from contextlib import contextmanager, nullcontext
7+
from dataclasses import asdict, dataclass, field
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, Optional, Sequence
810

911
import matplotlib.pyplot as plt
1012
import numpy as np
13+
import yaml
1114
from skimage.registration import phase_cross_correlation
15+
from tqdm import tqdm
1216
from typing_extensions import Self
1317

1418
from instamatic import config
19+
from instamatic._typing import AnyPath
1520
from instamatic.calibrate.filenames import *
1621
from instamatic.calibrate.fit import fit_affine_transformation
22+
from instamatic.formats import read_tiff
1723
from instamatic.image_utils import autoscale, imgscale
1824
from instamatic.processing.find_holes import find_holes
19-
from instamatic.tools import find_beam_center, printer
25+
from instamatic.tools import find_beam_center
26+
from instamatic.utils.yaml import Numpy2DDumper
27+
28+
if TYPE_CHECKING:
29+
from instamatic.gui.videostream_processor import DeferredImageDraw, VideoStreamProcessor
30+
2031

2132
logger = logging.getLogger(__name__)
2233

2334

35+
Vector2 = np.ndarray # numpy array with two float (or int) elements
36+
VectorNx2 = np.ndarray # numpy array with N Vector2-s
37+
Matrix2x2 = np.ndarray # numpy array of shape (2, 2) with float elements
38+
39+
40+
@dataclass
2441
class CalibBeamShift:
2542
"""Simple class to hold the methods to perform transformations from one
26-
setting to another based on calibration results."""
43+
setting to another based on calibration results.
44+
45+
Throughout this class, the following two terms are used consistently:
46+
- pixel: the (x, y) beam position in pixels as determined from camera image
47+
- shift: the unitless (x, y) value pair reported by the BeamShift deflector
48+
"""
2749

28-
def __init__(self, transform, reference_shift, reference_pixel):
29-
super().__init__()
30-
self.transform = transform
31-
self.reference_shift = reference_shift
32-
self.reference_pixel = reference_pixel
33-
self.has_data = False
50+
transform: Matrix2x2
51+
reference_pixel: Vector2
52+
reference_shift: Vector2
53+
pixels: Optional[VectorNx2] = field(default=None, repr=False)
54+
shifts: Optional[VectorNx2] = field(default=None, repr=False)
55+
images: Optional[list[np.ndarray]] = field(default=None, repr=False)
3456

3557
def __repr__(self):
3658
return f'CalibBeamShift(transform=\n{self.transform},\n reference_shift=\n{self.reference_shift},\n reference_pixel=\n{self.reference_pixel})'
3759

38-
def beamshift_to_pixelcoord(self, beamshift):
60+
def beamshift_to_pixelcoord(self, beamshift: Sequence[float, float]) -> Vector2:
3961
"""Converts from beamshift x,y to pixel coordinates."""
4062
r_i = np.linalg.inv(self.transform)
41-
pixelcoord = np.dot(self.reference_shift - beamshift, r_i) + self.reference_pixel
42-
return pixelcoord
63+
return np.dot(self.reference_shift - np.array(beamshift), r_i) + self.reference_pixel
4364

44-
def pixelcoord_to_beamshift(self, pixelcoord) -> np.ndarray:
65+
def pixelcoord_to_beamshift(self, pixelcoord: Sequence[float, float]) -> Vector2:
4566
"""Converts from pixel coordinates to beamshift x,y."""
46-
r = self.transform
47-
beamshift = self.reference_shift - np.dot(pixelcoord - self.reference_pixel, r)
48-
return beamshift
67+
pc = np.array(pixelcoord)
68+
return self.reference_shift - np.dot(pc - self.reference_pixel, self.transform)
4969

5070
@classmethod
51-
def from_data(cls, shifts, beampos, reference_shift, reference_pixel, header=None) -> Self:
52-
fit_result = fit_affine_transformation(shifts, beampos)
53-
r = fit_result.r
54-
t = fit_result.t
55-
56-
c = cls(transform=r, reference_shift=reference_shift, reference_pixel=reference_pixel)
57-
c.data_shifts = shifts
58-
c.data_beampos = beampos
59-
c.has_data = True
60-
c.header = header
61-
62-
return c
71+
def from_data(
72+
cls,
73+
pixels: VectorNx2,
74+
shifts: VectorNx2,
75+
reference_pixel: Vector2,
76+
reference_shift: Vector2,
77+
images: Optional[list[np.ndarray]] = None,
78+
) -> Self:
79+
return cls(
80+
transform=fit_affine_transformation(pixels, shifts).r,
81+
reference_pixel=reference_pixel,
82+
reference_shift=reference_shift,
83+
pixels=pixels,
84+
shifts=shifts,
85+
images=images,
86+
)
6387

6488
@classmethod
65-
def from_file(cls, fn=CALIB_BEAMSHIFT) -> Self:
89+
def from_file(cls, fn: AnyPath = CALIB_BEAMSHIFT) -> Self:
6690
"""Read calibration from file."""
67-
import pickle
68-
6991
try:
70-
return pickle.load(open(fn, 'rb'))
92+
with open(Path(fn), 'r') as yaml_file:
93+
return cls(**{k: np.array(v) for k, v in yaml.safe_load(yaml_file).items()})
7194
except OSError as e:
7295
prog = 'instamatic.calibrate_beamshift'
7396
raise OSError(f'{e.strerror}: {fn}. Please run {prog} first.')
7497

7598
@classmethod
76-
def live(cls, ctrl, outdir='.') -> Self:
99+
def live(
100+
cls, ctrl, outdir: AnyPath = '.', vsp: Optional[VideoStreamProcessor] = None
101+
) -> Self:
77102
while True:
78103
c = calibrate_beamshift(ctrl=ctrl, save_images=True, outdir=outdir)
79-
if input(' >> Accept? [y/n] ') == 'y':
80-
return c
104+
with c.annotate_videostream(vsp) if vsp else nullcontext():
105+
if input(' >> Accept? [y/n] ') == 'y':
106+
return c
81107

82-
def to_file(self, fn=CALIB_BEAMSHIFT, outdir='.'):
108+
def to_file(self, fn: AnyPath = CALIB_BEAMSHIFT, outdir: AnyPath = '.') -> None:
83109
"""Save calibration to file."""
84-
fout = os.path.join(outdir, fn)
85-
pickle.dump(self, open(fout, 'wb'))
86-
87-
def plot(self, to_file=None, outdir=''):
88-
if not self.has_data:
89-
return
90-
91-
if to_file:
92-
to_file = f'calib_{beamshift}.png'
93-
94-
beampos = self.data_beampos
95-
shifts = self.data_shifts
96-
97-
r_i = np.linalg.inv(self.transform)
98-
beampos_ = np.dot(beampos, r_i)
99-
100-
plt.scatter(*shifts.T, marker='>', label='Observed pixel shifts')
101-
plt.scatter(*beampos_.T, marker='<', label='Positions in pixel coords')
110+
yaml_path = Path(outdir) / fn
111+
yaml_dict = asdict(self) # type: ignore[arg-type]
112+
yaml_dict = {k: v.tolist() for k, v in yaml_dict.items() if k != 'images'}
113+
with open(yaml_path, 'w') as yaml_file:
114+
yaml.dump(yaml_dict, yaml_file, Dumper=Numpy2DDumper, default_flow_style=None)
115+
116+
def plot(self, to_file: Optional[AnyPath] = None):
117+
"""Assuming the data is present, plot the data."""
118+
shifts = np.dot(self.shifts, np.linalg.inv(self.transform))
119+
plt.scatter(*self.pixels.T, marker='>', label='Observed pixel shifts')
120+
plt.scatter(*shifts.T, marker='<', label='Reconstructed pixel shifts')
102121
plt.legend()
103122
plt.title('BeamShift vs. Direct beam position (Imaging)')
104123
if to_file:
105-
plt.savefig(os.path.join(outdir, to_file))
124+
plt.savefig(Path(to_file) / 'calib_beamshift.png')
106125
plt.close()
107126
else:
108127
plt.show()
109128

129+
@contextmanager
130+
def annotate_videostream(self, vsp: Optional[VideoStreamProcessor] = None) -> None:
131+
shifts = np.dot(self.shifts, np.linalg.inv(self.transform))
132+
ins: list[DeferredImageDraw.Instruction] = []
133+
134+
vsp.temporary_frame = np.max(self.images, axis=0)
135+
print('Determined (blue) vs calibrated (orange) beam positions:')
136+
for p, s in zip(self.pixels, shifts):
137+
p = (p + self.reference_pixel)[::-1] # xy coords inverted for plot
138+
s = (s + self.reference_pixel)[::-1] # xy coords inverted for plot
139+
ins.append(vsp.draw.circle(p, radius=3, fill='blue'))
140+
ins.append(vsp.draw.circle(s, radius=3, fill='orange'))
141+
ins.append(vsp.draw.circle(self.reference_pixel[::-1], radius=3, fill='black'))
142+
yield
143+
vsp.temporary_frame = None
144+
for i in ins:
145+
vsp.draw.instructions.remove(i)
146+
110147
def center(self, ctrl) -> Optional[np.ndarray]:
111148
"""Return beamshift values to center the beam in the frame."""
112149
pixel_center = [val / 2.0 for val in ctrl.cam.get_image_dimensions()]
@@ -119,8 +156,13 @@ def center(self, ctrl) -> Optional[np.ndarray]:
119156

120157

121158
def calibrate_beamshift_live(
122-
ctrl, gridsize=None, stepsize=None, save_images=False, outdir='.', **kwargs
123-
):
159+
ctrl,
160+
gridsize: Optional[int] = None,
161+
stepsize: Optional[float] = None,
162+
save_images: bool = False,
163+
outdir: AnyPath = '.',
164+
**kwargs,
165+
) -> CalibBeamShift:
124166
"""Calibrate pixel->beamshift coordinates live on the microscope.
125167
126168
ctrl: instance of `TEMController`
@@ -141,87 +183,54 @@ def calibrate_beamshift_live(
141183
"""
142184
exposure = kwargs.get('exposure', ctrl.cam.default_exposure)
143185
binsize = kwargs.get('binsize', ctrl.cam.default_binsize)
186+
gridsize = gridsize or config.camera.calib_beamshift.get('gridsize', 5)
187+
stepsize = stepsize or config.camera.calib_beamshift.get('stepsize', 250)
188+
outfile = Path(outdir) / 'calib_beamshift_center' if save_images else None
189+
kwargs = {'exposure': exposure, 'binsize': binsize, 'out': outfile}
144190

145-
if not gridsize:
146-
gridsize = config.camera.calib_beamshift.get('gridsize', 5)
147-
if not stepsize:
148-
stepsize = config.camera.calib_beamshift.get('stepsize', 250)
149-
150-
img_cent, h_cent = ctrl.get_image(
151-
exposure=exposure, binsize=binsize, comment='Beam in center of image'
152-
)
191+
comment = 'Beam in the center of the image'
192+
img_cent, h_cent = ctrl.get_image(comment=comment, **kwargs)
153193
x_cent, y_cent = beamshift_cent = np.array(h_cent['BeamShift'])
154194

155-
magnification = h_cent['Magnification']
156-
stepsize = 2500.0 / magnification * stepsize
195+
stepsize = 2500.0 / h_cent['Magnification'] * stepsize
157196

158197
print(f'Gridsize: {gridsize} | Stepsize: {stepsize:.2f}')
159198

160199
img_cent, scale = autoscale(img_cent)
161-
162-
outfile = os.path.join(outdir, 'calib_beamcenter') if save_images else None
163-
164200
pixel_cent = find_beam_center(img_cent) * binsize / scale
165201

166202
print('Beamshift: x={} | y={}'.format(*beamshift_cent))
167203
print('Pixel: x={} | y={}'.format(*pixel_cent))
168204

169-
shifts = []
170-
beampos = []
171-
172-
n = int((gridsize - 1) / 2) # number of points = n*(n+1)
173-
x_grid, y_grid = np.meshgrid(
174-
np.arange(-n, n + 1) * stepsize, np.arange(-n, n + 1) * stepsize
175-
)
176-
tot = gridsize * gridsize
205+
images, pixels, shifts = [], [], []
206+
dx_dy = ((np.indices((gridsize, gridsize)) - gridsize // 2) * stepsize).reshape(2, -1).T
177207

178-
i = 0
179-
for dx, dy in np.stack([x_grid, y_grid]).reshape(2, -1).T:
208+
progress_bar = tqdm(dx_dy, desc='Beamshift calibration')
209+
for i, (dx, dy) in enumerate(progress_bar):
180210
ctrl.beamshift.set(x=float(x_cent + dx), y=float(y_cent + dy))
211+
progress_bar.set_postfix_str(ctrl.beamshift)
212+
time.sleep(config.camera.calib_beamshift.get('delay', 0.0))
181213

182-
printer(f'Position: {i + 1}/{tot}: {ctrl.beamshift}')
183-
184-
outfile = os.path.join(outdir, f'calib_beamshift_{i:04d}') if save_images else None
185-
214+
kwargs['out'] = Path(outdir) / f'calib_beamshift_{i:04d}' if save_images else None
186215
comment = f'Calib image {i}: dx={dx} - dy={dy}'
187-
img, h = ctrl.get_image(
188-
exposure=exposure,
189-
binsize=binsize,
190-
out=outfile,
191-
comment=comment,
192-
header_keys=('BeamShift',),
193-
)
216+
img, h = ctrl.get_image(comment=comment, header_keys=('BeamShift',), **kwargs)
194217
img = imgscale(img, scale)
195218

196-
shift, error, phasediff = phase_cross_correlation(img_cent, img, upsample_factor=10)
197-
198-
beamshift = np.array(h['BeamShift'])
199-
beampos.append(beamshift)
200-
shifts.append(shift)
201-
202-
i += 1
219+
images.append(img)
220+
pixels.append(phase_cross_correlation(img_cent, img, upsample_factor=10)[0])
221+
shifts.append(np.array(h['BeamShift']))
203222

204223
print('')
205-
# print "\nReset to center"
206-
207224
ctrl.beamshift.set(*(float(_) for _ in beamshift_cent))
208225

209-
# correct for binsize, store in binsize=1
210-
shifts = np.array(shifts) * binsize / scale
211-
beampos = np.array(beampos) - np.array(beamshift_cent)
212-
226+
# normalize to binsize = 1 and 512-pixel image scale before initializing
213227
c = CalibBeamShift.from_data(
214-
shifts,
215-
beampos,
216-
reference_shift=beamshift_cent,
228+
np.array(pixels) * binsize / scale,
229+
np.array(shifts) - beamshift_cent,
217230
reference_pixel=pixel_cent,
218-
header=h_cent,
231+
reference_shift=beamshift_cent,
232+
images=images,
219233
)
220-
221-
# Calling c.plot with videostream crashes program
222-
# if not hasattr(ctrl.cam, "VideoLoop"):
223-
# c.plot()
224-
225234
return c
226235

227236

@@ -239,7 +248,7 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
239248
print()
240249
print('Center:', center_fn)
241250

242-
img_cent, h_cent = load_img(center_fn)
251+
img_cent, h_cent = read_tiff(center_fn)
243252
beamshift_cent = np.array(h_cent['BeamShift'])
244253

245254
img_cent, scale = autoscale(img_cent, maxdim=512)
@@ -252,11 +261,12 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
252261
print('Beamshift: x={} | y={}'.format(*beamshift_cent))
253262
print('Pixel: x={:.2f} | y={:.2f}'.format(*pixel_cent))
254263

264+
images = []
255265
shifts = []
256266
beampos = []
257267

258268
for fn in other_fn:
259-
img, h = load_img(fn)
269+
img, h = read_tiff(fn)
260270
img = imgscale(img, scale)
261271

262272
beamshift = np.array(h['BeamShift'])
@@ -266,6 +276,7 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
266276

267277
shift, error, phasediff = phase_cross_correlation(img_cent, img, upsample_factor=10)
268278

279+
images.append(img)
269280
beampos.append(beamshift)
270281
shifts.append(shift)
271282

@@ -276,9 +287,9 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
276287
c = CalibBeamShift.from_data(
277288
shifts,
278289
beampos,
279-
reference_shift=beamshift_cent,
280290
reference_pixel=pixel_cent,
281-
header=h_cent,
291+
reference_shift=beamshift_cent,
292+
images=images,
282293
)
283294
c.plot()
284295

src/instamatic/calibrate/filenames.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
CALIB_STAGE_LOWMAG = 'calib_stage_lowmag.pickle'
4-
CALIB_BEAMSHIFT = 'calib_beamshift.pickle'
4+
CALIB_BEAMSHIFT = 'calib_beamshift.yaml'
55
CALIB_BRIGHTNESS = 'calib_brightness.pickle'
66
CALIB_DIFFSHIFT = 'calib_diffshift.pickle'
77
CALIB_DIRECTBEAM = 'calib_directbeam.pickle'

0 commit comments

Comments
 (0)