11from __future__ import annotations
22
33import logging
4- import os
5- import pickle
64import sys
7- from typing import Optional
5+ import time
6+ from contextlib import contextmanager , nullcontext
7+ from dataclasses import asdict , dataclass , field
8+ from pathlib import Path
9+ from typing import TYPE_CHECKING , Optional , Sequence
810
911import matplotlib .pyplot as plt
1012import numpy as np
13+ import yaml
1114from skimage .registration import phase_cross_correlation
15+ from tqdm import tqdm
1216from typing_extensions import Self
1317
1418from instamatic import config
19+ from instamatic ._typing import AnyPath
1520from instamatic .calibrate .filenames import *
1621from instamatic .calibrate .fit import fit_affine_transformation
22+ from instamatic .formats import read_tiff
1723from instamatic .image_utils import autoscale , imgscale
1824from instamatic .processing .find_holes import find_holes
19- from instamatic .tools import find_beam_center , printer
25+ from instamatic .tools import find_beam_center
26+ from instamatic .utils .yaml import Numpy2DDumper
27+
28+ if TYPE_CHECKING :
29+ from instamatic .gui .videostream_processor import DeferredImageDraw , VideoStreamProcessor
30+
2031
2132logger = logging .getLogger (__name__ )
2233
2334
35+ Vector2 = np .ndarray # numpy array with two float (or int) elements
36+ VectorNx2 = np .ndarray # numpy array with N Vector2-s
37+ Matrix2x2 = np .ndarray # numpy array of shape (2, 2) with float elements
38+
39+
40+ @dataclass
2441class CalibBeamShift :
2542 """Simple class to hold the methods to perform transformations from one
26- setting to another based on calibration results."""
43+ setting to another based on calibration results.
44+
45+ Throughout this class, the following two terms are used consistently:
46+ - pixel: the (x, y) beam position in pixels as determined from camera image
47+ - shift: the unitless (x, y) value pair reported by the BeamShift deflector
48+ """
2749
28- def __init__ ( self , transform , reference_shift , reference_pixel ):
29- super (). __init__ ()
30- self . transform = transform
31- self . reference_shift = reference_shift
32- self . reference_pixel = reference_pixel
33- self . has_data = False
50+ transform : Matrix2x2
51+ reference_pixel : Vector2
52+ reference_shift : Vector2
53+ pixels : Optional [ VectorNx2 ] = field ( default = None , repr = False )
54+ shifts : Optional [ VectorNx2 ] = field ( default = None , repr = False )
55+ images : Optional [ list [ np . ndarray ]] = field ( default = None , repr = False )
3456
3557 def __repr__ (self ):
3658 return f'CalibBeamShift(transform=\n { self .transform } ,\n reference_shift=\n { self .reference_shift } ,\n reference_pixel=\n { self .reference_pixel } )'
3759
38- def beamshift_to_pixelcoord (self , beamshift ) :
60+ def beamshift_to_pixelcoord (self , beamshift : Sequence [ float , float ]) -> Vector2 :
3961 """Converts from beamshift x,y to pixel coordinates."""
4062 r_i = np .linalg .inv (self .transform )
41- pixelcoord = np .dot (self .reference_shift - beamshift , r_i ) + self .reference_pixel
42- return pixelcoord
63+ return np .dot (self .reference_shift - np .array (beamshift ), r_i ) + self .reference_pixel
4364
44- def pixelcoord_to_beamshift (self , pixelcoord ) -> np . ndarray :
65+ def pixelcoord_to_beamshift (self , pixelcoord : Sequence [ float , float ] ) -> Vector2 :
4566 """Converts from pixel coordinates to beamshift x,y."""
46- r = self .transform
47- beamshift = self .reference_shift - np .dot (pixelcoord - self .reference_pixel , r )
48- return beamshift
67+ pc = np .array (pixelcoord )
68+ return self .reference_shift - np .dot (pc - self .reference_pixel , self .transform )
4969
5070 @classmethod
51- def from_data (cls , shifts , beampos , reference_shift , reference_pixel , header = None ) -> Self :
52- fit_result = fit_affine_transformation (shifts , beampos )
53- r = fit_result .r
54- t = fit_result .t
55-
56- c = cls (transform = r , reference_shift = reference_shift , reference_pixel = reference_pixel )
57- c .data_shifts = shifts
58- c .data_beampos = beampos
59- c .has_data = True
60- c .header = header
61-
62- return c
71+ def from_data (
72+ cls ,
73+ pixels : VectorNx2 ,
74+ shifts : VectorNx2 ,
75+ reference_pixel : Vector2 ,
76+ reference_shift : Vector2 ,
77+ images : Optional [list [np .ndarray ]] = None ,
78+ ) -> Self :
79+ return cls (
80+ transform = fit_affine_transformation (pixels , shifts ).r ,
81+ reference_pixel = reference_pixel ,
82+ reference_shift = reference_shift ,
83+ pixels = pixels ,
84+ shifts = shifts ,
85+ images = images ,
86+ )
6387
6488 @classmethod
65- def from_file (cls , fn = CALIB_BEAMSHIFT ) -> Self :
89+ def from_file (cls , fn : AnyPath = CALIB_BEAMSHIFT ) -> Self :
6690 """Read calibration from file."""
67- import pickle
68-
6991 try :
70- return pickle .load (open (fn , 'rb' ))
92+ with open (Path (fn ), 'r' ) as yaml_file :
93+ return cls (** {k : np .array (v ) for k , v in yaml .safe_load (yaml_file ).items ()})
7194 except OSError as e :
7295 prog = 'instamatic.calibrate_beamshift'
7396 raise OSError (f'{ e .strerror } : { fn } . Please run { prog } first.' )
7497
7598 @classmethod
76- def live (cls , ctrl , outdir = '.' ) -> Self :
99+ def live (
100+ cls , ctrl , outdir : AnyPath = '.' , vsp : Optional [VideoStreamProcessor ] = None
101+ ) -> Self :
77102 while True :
78103 c = calibrate_beamshift (ctrl = ctrl , save_images = True , outdir = outdir )
79- if input (' >> Accept? [y/n] ' ) == 'y' :
80- return c
104+ with c .annotate_videostream (vsp ) if vsp else nullcontext ():
105+ if input (' >> Accept? [y/n] ' ) == 'y' :
106+ return c
81107
82- def to_file (self , fn = CALIB_BEAMSHIFT , outdir = '.' ):
108+ def to_file (self , fn : AnyPath = CALIB_BEAMSHIFT , outdir : AnyPath = '.' ) -> None :
83109 """Save calibration to file."""
84- fout = os .path .join (outdir , fn )
85- pickle .dump (self , open (fout , 'wb' ))
86-
87- def plot (self , to_file = None , outdir = '' ):
88- if not self .has_data :
89- return
90-
91- if to_file :
92- to_file = f'calib_{ beamshift } .png'
93-
94- beampos = self .data_beampos
95- shifts = self .data_shifts
96-
97- r_i = np .linalg .inv (self .transform )
98- beampos_ = np .dot (beampos , r_i )
99-
100- plt .scatter (* shifts .T , marker = '>' , label = 'Observed pixel shifts' )
101- plt .scatter (* beampos_ .T , marker = '<' , label = 'Positions in pixel coords' )
110+ yaml_path = Path (outdir ) / fn
111+ yaml_dict = asdict (self ) # type: ignore[arg-type]
112+ yaml_dict = {k : v .tolist () for k , v in yaml_dict .items () if k != 'images' }
113+ with open (yaml_path , 'w' ) as yaml_file :
114+ yaml .dump (yaml_dict , yaml_file , Dumper = Numpy2DDumper , default_flow_style = None )
115+
116+ def plot (self , to_file : Optional [AnyPath ] = None ):
117+ """Assuming the data is present, plot the data."""
118+ shifts = np .dot (self .shifts , np .linalg .inv (self .transform ))
119+ plt .scatter (* self .pixels .T , marker = '>' , label = 'Observed pixel shifts' )
120+ plt .scatter (* shifts .T , marker = '<' , label = 'Reconstructed pixel shifts' )
102121 plt .legend ()
103122 plt .title ('BeamShift vs. Direct beam position (Imaging)' )
104123 if to_file :
105- plt .savefig (os . path . join ( outdir , to_file ))
124+ plt .savefig (Path ( to_file ) / 'calib_beamshift.png' )
106125 plt .close ()
107126 else :
108127 plt .show ()
109128
129+ @contextmanager
130+ def annotate_videostream (self , vsp : Optional [VideoStreamProcessor ] = None ) -> None :
131+ shifts = np .dot (self .shifts , np .linalg .inv (self .transform ))
132+ ins : list [DeferredImageDraw .Instruction ] = []
133+
134+ vsp .temporary_frame = np .max (self .images , axis = 0 )
135+ print ('Determined (blue) vs calibrated (orange) beam positions:' )
136+ for p , s in zip (self .pixels , shifts ):
137+ p = (p + self .reference_pixel )[::- 1 ] # xy coords inverted for plot
138+ s = (s + self .reference_pixel )[::- 1 ] # xy coords inverted for plot
139+ ins .append (vsp .draw .circle (p , radius = 3 , fill = 'blue' ))
140+ ins .append (vsp .draw .circle (s , radius = 3 , fill = 'orange' ))
141+ ins .append (vsp .draw .circle (self .reference_pixel [::- 1 ], radius = 3 , fill = 'black' ))
142+ yield
143+ vsp .temporary_frame = None
144+ for i in ins :
145+ vsp .draw .instructions .remove (i )
146+
110147 def center (self , ctrl ) -> Optional [np .ndarray ]:
111148 """Return beamshift values to center the beam in the frame."""
112149 pixel_center = [val / 2.0 for val in ctrl .cam .get_image_dimensions ()]
@@ -119,8 +156,13 @@ def center(self, ctrl) -> Optional[np.ndarray]:
119156
120157
121158def calibrate_beamshift_live (
122- ctrl , gridsize = None , stepsize = None , save_images = False , outdir = '.' , ** kwargs
123- ):
159+ ctrl ,
160+ gridsize : Optional [int ] = None ,
161+ stepsize : Optional [float ] = None ,
162+ save_images : bool = False ,
163+ outdir : AnyPath = '.' ,
164+ ** kwargs ,
165+ ) -> CalibBeamShift :
124166 """Calibrate pixel->beamshift coordinates live on the microscope.
125167
126168 ctrl: instance of `TEMController`
@@ -141,87 +183,54 @@ def calibrate_beamshift_live(
141183 """
142184 exposure = kwargs .get ('exposure' , ctrl .cam .default_exposure )
143185 binsize = kwargs .get ('binsize' , ctrl .cam .default_binsize )
186+ gridsize = gridsize or config .camera .calib_beamshift .get ('gridsize' , 5 )
187+ stepsize = stepsize or config .camera .calib_beamshift .get ('stepsize' , 250 )
188+ outfile = Path (outdir ) / 'calib_beamshift_center' if save_images else None
189+ kwargs = {'exposure' : exposure , 'binsize' : binsize , 'out' : outfile }
144190
145- if not gridsize :
146- gridsize = config .camera .calib_beamshift .get ('gridsize' , 5 )
147- if not stepsize :
148- stepsize = config .camera .calib_beamshift .get ('stepsize' , 250 )
149-
150- img_cent , h_cent = ctrl .get_image (
151- exposure = exposure , binsize = binsize , comment = 'Beam in center of image'
152- )
191+ comment = 'Beam in the center of the image'
192+ img_cent , h_cent = ctrl .get_image (comment = comment , ** kwargs )
153193 x_cent , y_cent = beamshift_cent = np .array (h_cent ['BeamShift' ])
154194
155- magnification = h_cent ['Magnification' ]
156- stepsize = 2500.0 / magnification * stepsize
195+ stepsize = 2500.0 / h_cent ['Magnification' ] * stepsize
157196
158197 print (f'Gridsize: { gridsize } | Stepsize: { stepsize :.2f} ' )
159198
160199 img_cent , scale = autoscale (img_cent )
161-
162- outfile = os .path .join (outdir , 'calib_beamcenter' ) if save_images else None
163-
164200 pixel_cent = find_beam_center (img_cent ) * binsize / scale
165201
166202 print ('Beamshift: x={} | y={}' .format (* beamshift_cent ))
167203 print ('Pixel: x={} | y={}' .format (* pixel_cent ))
168204
169- shifts = []
170- beampos = []
171-
172- n = int ((gridsize - 1 ) / 2 ) # number of points = n*(n+1)
173- x_grid , y_grid = np .meshgrid (
174- np .arange (- n , n + 1 ) * stepsize , np .arange (- n , n + 1 ) * stepsize
175- )
176- tot = gridsize * gridsize
205+ images , pixels , shifts = [], [], []
206+ dx_dy = ((np .indices ((gridsize , gridsize )) - gridsize // 2 ) * stepsize ).reshape (2 , - 1 ).T
177207
178- i = 0
179- for dx , dy in np . stack ([ x_grid , y_grid ]). reshape ( 2 , - 1 ). T :
208+ progress_bar = tqdm ( dx_dy , desc = 'Beamshift calibration' )
209+ for i , ( dx , dy ) in enumerate ( progress_bar ) :
180210 ctrl .beamshift .set (x = float (x_cent + dx ), y = float (y_cent + dy ))
211+ progress_bar .set_postfix_str (ctrl .beamshift )
212+ time .sleep (config .camera .calib_beamshift .get ('delay' , 0.0 ))
181213
182- printer (f'Position: { i + 1 } /{ tot } : { ctrl .beamshift } ' )
183-
184- outfile = os .path .join (outdir , f'calib_beamshift_{ i :04d} ' ) if save_images else None
185-
214+ kwargs ['out' ] = Path (outdir ) / f'calib_beamshift_{ i :04d} ' if save_images else None
186215 comment = f'Calib image { i } : dx={ dx } - dy={ dy } '
187- img , h = ctrl .get_image (
188- exposure = exposure ,
189- binsize = binsize ,
190- out = outfile ,
191- comment = comment ,
192- header_keys = ('BeamShift' ,),
193- )
216+ img , h = ctrl .get_image (comment = comment , header_keys = ('BeamShift' ,), ** kwargs )
194217 img = imgscale (img , scale )
195218
196- shift , error , phasediff = phase_cross_correlation (img_cent , img , upsample_factor = 10 )
197-
198- beamshift = np .array (h ['BeamShift' ])
199- beampos .append (beamshift )
200- shifts .append (shift )
201-
202- i += 1
219+ images .append (img )
220+ pixels .append (phase_cross_correlation (img_cent , img , upsample_factor = 10 )[0 ])
221+ shifts .append (np .array (h ['BeamShift' ]))
203222
204223 print ('' )
205- # print "\nReset to center"
206-
207224 ctrl .beamshift .set (* (float (_ ) for _ in beamshift_cent ))
208225
209- # correct for binsize, store in binsize=1
210- shifts = np .array (shifts ) * binsize / scale
211- beampos = np .array (beampos ) - np .array (beamshift_cent )
212-
226+ # normalize to binsize = 1 and 512-pixel image scale before initializing
213227 c = CalibBeamShift .from_data (
214- shifts ,
215- beampos ,
216- reference_shift = beamshift_cent ,
228+ np .array (pixels ) * binsize / scale ,
229+ np .array (shifts ) - beamshift_cent ,
217230 reference_pixel = pixel_cent ,
218- header = h_cent ,
231+ reference_shift = beamshift_cent ,
232+ images = images ,
219233 )
220-
221- # Calling c.plot with videostream crashes program
222- # if not hasattr(ctrl.cam, "VideoLoop"):
223- # c.plot()
224-
225234 return c
226235
227236
@@ -239,7 +248,7 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
239248 print ()
240249 print ('Center:' , center_fn )
241250
242- img_cent , h_cent = load_img (center_fn )
251+ img_cent , h_cent = read_tiff (center_fn )
243252 beamshift_cent = np .array (h_cent ['BeamShift' ])
244253
245254 img_cent , scale = autoscale (img_cent , maxdim = 512 )
@@ -252,11 +261,12 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
252261 print ('Beamshift: x={} | y={}' .format (* beamshift_cent ))
253262 print ('Pixel: x={:.2f} | y={:.2f}' .format (* pixel_cent ))
254263
264+ images = []
255265 shifts = []
256266 beampos = []
257267
258268 for fn in other_fn :
259- img , h = load_img (fn )
269+ img , h = read_tiff (fn )
260270 img = imgscale (img , scale )
261271
262272 beamshift = np .array (h ['BeamShift' ])
@@ -266,6 +276,7 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
266276
267277 shift , error , phasediff = phase_cross_correlation (img_cent , img , upsample_factor = 10 )
268278
279+ images .append (img )
269280 beampos .append (beamshift )
270281 shifts .append (shift )
271282
@@ -276,9 +287,9 @@ def calibrate_beamshift_from_image_fn(center_fn, other_fn):
276287 c = CalibBeamShift .from_data (
277288 shifts ,
278289 beampos ,
279- reference_shift = beamshift_cent ,
280290 reference_pixel = pixel_cent ,
281- header = h_cent ,
291+ reference_shift = beamshift_cent ,
292+ images = images ,
282293 )
283294 c .plot ()
284295
0 commit comments