Skip to content

Commit e440eaa

Browse files
authored
Merge pull request #100 from iraikov/fix/optimization_memory
Reduce memory used by network optimization objectives
2 parents 09106c1 + 9636276 commit e440eaa

4 files changed

Lines changed: 29 additions & 20 deletions

File tree

src/miv_simulator/optimization.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,15 +370,20 @@ def update_run_params(env, param_tuples):
370370
def network_features(env, t_start, t_stop, target_populations):
371371
features_dict = dict()
372372

373-
temporal_resolution = float(env.stimulus_config["Temporal Resolution"])
374-
time_bins = np.arange(t_start, t_stop, temporal_resolution)
373+
analysis_config = env.analysis_config
374+
if analysis_config is None:
375+
analysis_config = {}
376+
fr_inference_config = analysis_config.get("Firing Rate Inference", {})
377+
378+
temporal_resolution = float(fr_inference_config.get("Temporal Resolution", 2.0))
379+
time_bins = np.arange(t_start, t_stop, temporal_resolution).astype(np.float32)
375380

376381
pop_spike_dict = spikedata.get_env_spike_dict(env, include_artificial=False)
377382

378383
for pop_name in target_populations:
379384
n_active = 0
380385
spike_density_dict = spikedata.spike_density_estimate(
381-
pop_name, pop_spike_dict[pop_name], time_bins
386+
pop_name, pop_spike_dict[pop_name], time_bins, return_time_bins=False
382387
)
383388
for gid, dens_dict in spike_density_dict.items():
384389
mean_rate = np.mean(dens_dict["rate"])

src/miv_simulator/optimize_network.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Network model optimization script for optimization with dmosopt
44
"""
5-
5+
import gc
66
import os
77
import sys
88
import datetime
@@ -410,9 +410,9 @@ def compute_objectives(local_features, operational_config, opt_targets):
410410
sum_mean_rate_local += mean_rate
411411
ip_rate = np.interp(
412412
fr_time_centers,
413-
dens_dict["time"].astype(np.float32),
413+
time_bins_ref,
414414
dens_dict["rate"].astype(np.float32),
415-
)
415+
).astype(np.float32)
416416
active_per_bin = ip_rate > active_threshold
417417
sum_active_per_bin += active_per_bin
418418

@@ -439,17 +439,19 @@ def compute_objectives(local_features, operational_config, opt_targets):
439439
)
440440

441441
all_features_dict[f"{pop_name} mean fraction active per time bin"] = (
442-
mean_fraction_active_per_bin
442+
float(mean_fraction_active_per_bin)
443443
)
444444
all_features_dict[f"{pop_name} std fraction active per time bin"] = (
445-
std_fraction_active_per_bin
445+
float(std_fraction_active_per_bin)
446446
)
447-
all_features_dict[f"{pop_name} fraction active"] = fraction_active
448-
all_features_dict[f"{pop_name} firing rate"] = mean_rate
447+
all_features_dict[f"{pop_name} fraction active"] = float(fraction_active)
448+
all_features_dict[f"{pop_name} firing rate"] = float(mean_rate)
449449

450450
rate_constr = mean_rate if mean_rate > 0.0 else -1.0
451451
constraints.append(rate_constr)
452452

453+
gc.collect()
454+
453455
objective_names = operational_config["objective_names"]
454456
feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names]
455457

@@ -470,7 +472,7 @@ def compute_objectives(local_features, operational_config, opt_targets):
470472
features.append(feature_val)
471473

472474
result = (
473-
np.asarray(objectives),
475+
np.asarray(objectives, dtype=np.float32),
474476
np.array([tuple(features)], dtype=np.dtype(feature_dtypes)),
475477
np.asarray(constraints, dtype=np.float32),
476478
)

src/miv_simulator/spikedata.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def spike_density_estimate(
295295
trajectory_id=None,
296296
output_file_path=None,
297297
progress=False,
298+
return_time_bins=True,
298299
inferred_rate_attr_name="Inferred Rate Map",
299300
**kwargs,
300301
):
@@ -339,7 +340,7 @@ def make_spktrain(lst, t_start, t_stop):
339340
spk_rate_dict = {
340341
ind: baks(spkts / 1000.0, time_bins / 1000.0, **baks_args)[0].reshape((-1,))
341342
if len(spkts) > 1
342-
else np.zeros(time_bins.shape)
343+
else np.zeros(time_bins.shape, dtype=np.float32)
343344
for ind, spkts in seq
344345
}
345346

@@ -360,13 +361,14 @@ def make_spktrain(lst, t_start, t_stop):
360361
output_file_path, population, attr_dict, namespace=namespace
361362
)
362363

363-
result = {
364-
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
365-
}
366-
367-
result = {
368-
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
369-
}
364+
if return_time_bins:
365+
result = {
366+
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
367+
}
368+
else:
369+
result = {
370+
ind: {"rate": rate} for ind, rate in spk_rate_dict.items()
371+
}
370372

371373
return result
372374

src/miv_simulator/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def baks(spktimes, time, a=1.5, b=None):
984984

985985
threshold = 6.0
986986
h = (gamma(a) / gamma(a + 0.5)) * (sumnum / sumdenom)
987-
rate = np.zeros((len(time),))
987+
rate = np.zeros((len(time),), dtype=np.float32)
988988
for j in range(n):
989989
time_diff = time - spktimes[j]
990990
abs_diff = np.abs(time_diff)

0 commit comments

Comments
 (0)