Skip to content

Commit 9150d9f

Browse files
authored
Fix time lag issue for inversion (#370)
* Fix time lag issue for inversion - t passed to cost_fn(t) was out of sync because update_forcings was being used so the optimal result was not correct for channel inversion - Now need to pass the cost function through the export_func rather than through update_forcings - Typo fixed in Tohoku Makefile as well * Update inversion tools for parallelisation - Previously would hang because you could have a processor with no Station coordinates within its partition of the mesh, causing issues
1 parent 44fc60b commit 9150d9f

4 files changed

Lines changed: 26 additions & 24 deletions

File tree

examples/channel_inversion/inverse_problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
cost_function = inv_manager.get_cost_function(solver_obj)
114114

115115
# Solve and setup reduced functional
116-
solver_obj.iterate(update_forcings=cost_function)
116+
solver_obj.iterate(export_func=cost_function) # note that export time should equal dt if not using a custom callback
117117
inv_manager.stop_annotating()
118118

119119
# Run inversion

examples/tohoku_inversion/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ invert:
99

1010
plot:
1111
python3 plot_elevation_progress.py
12-
python3 plot_elevation_optimised.py
12+
python3 plot_elevation_optimized.py
1313
python3 plot_convergence.py
1414

1515
clean:

examples/tohoku_inversion/inverse_problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
cost_function = inv_manager.get_cost_function(solver_obj, weight_by_variance=True)
102102

103103
# Solve and setup the reduced functional
104-
solver_obj.iterate(update_forcings=cost_function)
104+
solver_obj.iterate(export_func=cost_function)
105105
inv_manager.stop_annotating()
106106

107107
# Run inversion

thetis/inversion_tools.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ def get_cost_function(self, solver_obj, weight_by_variance=False):
252252
var.dat.data[i] = numpy.var(self.sta_manager.observation_values[j])
253253
self.sta_manager.station_weight_0d.interpolate(1/var)
254254

255-
def cost_fn(t):
255+
def cost_fn():
256+
t = solver_obj.simulation_time
256257
misfit = self.sta_manager.eval_cost_function(t)
257258
self.J_misfit += misfit
258259
self.J += misfit
@@ -487,26 +488,27 @@ def construct_evaluator(self):
487488
# Construct timeseries interpolator
488489
self.station_interpolators = []
489490
self.local_station_index = []
490-
for i in range(self.fs_points_0d.dof_dset.size):
491-
# loop over local DOFs and match coordinates to observations
492-
# NOTE this must be done manually as VertexOnlyMesh reorders points
493-
x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :]
494-
xy_diff = xy - numpy.array([x_mesh, y_mesh])
495-
xy_dist = numpy.sqrt(xy_diff[:, 0]**2 + xy_diff[:, 1]**2)
496-
j = numpy.argmin(xy_dist)
497-
self.local_station_index.append(j)
498-
499-
x, y = xy[j, :]
500-
t = self.observation_time[j]
501-
v = self.observation_values[j]
502-
x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :]
503-
504-
msg = 'bad station location ' \
505-
f'{j} {i} {x} {x_mesh} {y} {y_mesh} {x-x_mesh} {y-y_mesh}'
506-
assert numpy.allclose([x, y], [x_mesh, y_mesh]), msg
507-
# create temporal interpolator
508-
ip = interp1d(t, v, **interp_kw)
509-
self.station_interpolators.append(ip)
491+
if len(mesh0d.coordinates.dat.data[:]) > 0:
492+
for i in range(self.fs_points_0d.dof_dset.size):
493+
# loop over local DOFs and match coordinates to observations
494+
# NOTE this must be done manually as VertexOnlyMesh reorders points
495+
x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :]
496+
xy_diff = xy - numpy.array([x_mesh, y_mesh])
497+
xy_dist = numpy.sqrt(xy_diff[:, 0]**2 + xy_diff[:, 1]**2)
498+
j = numpy.argmin(xy_dist)
499+
self.local_station_index.append(j)
500+
501+
x, y = xy[j, :]
502+
t = self.observation_time[j]
503+
v = self.observation_values[j]
504+
x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :]
505+
506+
msg = 'bad station location ' \
507+
f'{j} {i} {x} {x_mesh} {y} {y_mesh} {x-x_mesh} {y-y_mesh}'
508+
assert numpy.allclose([x, y], [x_mesh, y_mesh]), msg
509+
# create temporal interpolator
510+
ip = interp1d(t, v, **interp_kw)
511+
self.station_interpolators.append(ip)
510512

511513
# Process start and end times for observations
512514
self.obs_start_times = numpy.array([

0 commit comments

Comments
 (0)