Skip to content

Commit d7c9ba1

Browse files
committed
2 parents ed60620 + d5496e8 commit d7c9ba1

9 files changed

Lines changed: 2071 additions & 1282 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v5.0.0
3+
rev: v6.0.0
44
hooks:
55
- id: trailing-whitespace
66
- id: check-json
77
- id: end-of-file-fixer
88
- repo: https://github.com/astral-sh/ruff-pre-commit
9-
rev: v0.11.7
9+
rev: v0.12.9
1010
hooks:
1111
- id: ruff
1212
args: [--fix]

src/miv_simulator/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
config: Optional[str] = None,
5959
template_paths: str = "templates",
6060
hoc_lib_path: Optional[str] = None,
61-
mechanisms_path: Optional[str] = None,
61+
mechanisms_path: Optional[str] = "mechanisms",
6262
dataset_prefix: Optional[str] = None,
6363
results_path: Optional[str] = None,
6464
results_file_id: Optional[str] = None,

src/miv_simulator/interface/synapses.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __call__(self):
7070
def on_write_meta_data(self):
7171
return MPI.COMM_WORLD.Get_rank() == 0
7272

73+
def on_after_dispatch(self, success):
74+
# ensure shutdown
75+
MPI.COMM_WORLD.Abort(int(not success))
76+
7377
def compute_context(self):
7478
context = super().compute_context()
7579
del context["config"]["forest_filepath"]

src/miv_simulator/optimization.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -377,27 +377,21 @@ def network_features(env, t_start, t_stop, target_populations):
377377

378378
for pop_name in target_populations:
379379
n_active = 0
380-
sum_mean_rate = 0.0
381380
spike_density_dict = spikedata.spike_density_estimate(
382381
pop_name, pop_spike_dict[pop_name], time_bins
383382
)
384383
for gid, dens_dict in spike_density_dict.items():
385384
mean_rate = np.mean(dens_dict["rate"])
386-
sum_mean_rate += mean_rate
387385
if mean_rate > 0.0:
388386
n_active += 1
389387

390388
n_total = len(env.cells[pop_name]) - len(env.artificial_cells[pop_name])
391389

392-
n_target_rate_map = 0
393-
sum_snr = None
394-
395390
pop_features_dict = {}
396391
pop_features_dict["n_total"] = n_total
397392
pop_features_dict["n_active"] = n_active
398-
pop_features_dict["n_target_rate_map"] = n_target_rate_map
399-
pop_features_dict["sum_mean_rate"] = sum_mean_rate
400-
pop_features_dict["sum_snr"] = sum_snr
393+
pop_features_dict["time_bins"] = time_bins
394+
pop_features_dict["spike_density_dict"] = spike_density_dict
401395

402396
features_dict[pop_name] = pop_features_dict
403397

src/miv_simulator/optimize_network.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -374,57 +374,80 @@ def compute_objectives(local_features, operational_config, opt_targets):
374374
all_features_dict = {}
375375
constraints = []
376376

377+
active_threshold = 0.01
377378
target_populations = operational_config["target_populations"]
379+
temporal_resolution = operational_config["temporal_resolution"]
378380
for pop_name in target_populations:
379381
pop_features_dicts = [
380382
features_dict[0][pop_name] for features_dict in local_features
381383
]
382384

383385
sum_mean_rate = 0.0
384-
sum_snr = 0.0
385386
n_total = 0
386387
n_active = 0
387-
n_target_rate_map = 0
388+
time_bins_ref = None
389+
sum_active_per_bin = None
388390
for pop_feature_dict in pop_features_dicts:
389391
n_active_local = pop_feature_dict["n_active"]
390392
n_total_local = pop_feature_dict["n_total"]
391-
n_target_rate_map_local = pop_feature_dict["n_target_rate_map"]
392-
sum_mean_rate_local = pop_feature_dict["sum_mean_rate"]
393-
sum_snr_local = pop_feature_dict["sum_snr"]
393+
time_bins = pop_feature_dict["time_bins"]
394+
if time_bins_ref is None:
395+
time_bins_ref = time_bins
396+
spike_density_dict = pop_feature_dict["spike_density_dict"]
397+
sum_mean_rate_local = 0.0
398+
t_start = time_bins_ref[0]
399+
t_end = time_bins_ref[-1] + (time_bins_ref[1] - time_bins_ref[0])
400+
# time bins for fraction active per time bin calculation
401+
fr_time_bins = np.arange(t_start, t_end, temporal_resolution)
402+
bin_width = time_bins_ref[1] - time_bins_ref[0]
403+
time_centers = time_bins_ref + bin_width / 2
404+
fr_time_centers = fr_time_bins + temporal_resolution / 2
405+
if sum_active_per_bin is None:
406+
sum_active_per_bin = np.zeros_like(time_centers)
407+
for gid, dens_dict in spike_density_dict.items():
408+
mean_rate = np.mean(dens_dict["rate"])
409+
if mean_rate > 0.0:
410+
sum_mean_rate_local += mean_rate
411+
ip_rate = np.interp1d(
412+
fr_time_centers,
413+
dens_dict["rate"],
414+
kind="linear",
415+
bounds_error=False,
416+
fill_value=0.0,
417+
)
418+
active_per_bin = ip_rate > active_threshold
419+
sum_active_per_bin += active_per_bin
394420

395421
n_total += n_total_local
396422
n_active += n_active_local
397-
n_target_rate_map += n_target_rate_map_local
398423
sum_mean_rate += sum_mean_rate_local
399424

400-
if sum_snr_local is not None:
401-
sum_snr += sum_snr_local
402-
403425
if n_active > 0:
404426
mean_rate = sum_mean_rate / n_active
405427
else:
406428
mean_rate = 0.0
407429

408430
if n_total > 0:
409431
fraction_active = n_active / n_total
432+
mean_fraction_active_per_bin = np.mean(sum_active_per_bin / float(n_total))
433+
std_fraction_active_per_bin = np.std(sum_active_per_bin / float(n_total))
410434
else:
411435
fraction_active = 0.0
412-
413-
mean_snr = None
414-
if n_target_rate_map > 0:
415-
mean_snr = sum_snr / n_target_rate_map
436+
mean_fraction_active_per_bin = 0.0
437+
std_fraction_active_per_bin = 0.0
416438

417439
logger.info(
418440
f"population {pop_name}: n_active = {n_active} n_total = {n_total} mean rate = {mean_rate}"
419441
)
420-
logger.info(
421-
f"population {pop_name}: n_target_rate_map = {n_target_rate_map} snr: sum = {sum_snr} mean = {mean_snr}"
422-
)
423442

443+
all_features_dict[f"{pop_name} mean fraction active per time bin"] = (
444+
mean_fraction_active_per_bin
445+
)
446+
all_features_dict[f"{pop_name} std fraction active per time bin"] = (
447+
std_fraction_active_per_bin
448+
)
424449
all_features_dict[f"{pop_name} fraction active"] = fraction_active
425450
all_features_dict[f"{pop_name} firing rate"] = mean_rate
426-
if mean_snr is not None:
427-
all_features_dict[f"{pop_name} snr"] = mean_snr
428451

429452
rate_constr = mean_rate if mean_rate > 0.0 else -1.0
430453
constraints.append(rate_constr)

src/miv_simulator/scripts/distribute_synapse_locs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import click
5+
from mpi4py import MPI
56
from miv_simulator.simulator.distribute_synapses import (
67
distribute_synapse_locations,
78
)
@@ -75,3 +76,5 @@ def main(
7576
config_prefix=config_prefix,
7677
mechanisms_path=mechanisms_path,
7778
)
79+
80+
MPI.COMM_WORLD.Abort(0)

src/miv_simulator/simulator/distribute_synapses.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def distribute_synapses(
264264
dry_run: bool,
265265
):
266266
logger = utils.get_script_logger(os.path.basename(__file__))
267-
267+
write_size = 1
268268
comm = MPI.COMM_WORLD
269269
rank = comm.rank
270270

@@ -413,7 +413,15 @@ def distribute_synapses(
413413
else:
414414
logger.info(f"Rank {rank} gid is None")
415415
gc.collect()
416-
if (not dry_run) and (write_size > 0) and (gid_count % write_size == 0):
416+
417+
# Check if any rank has reached the write_size threshold and
418+
# ensure that all ranks in that group call collectively
419+
local_should_write = (
420+
(not dry_run) and (write_size > 0) and (gid_count % write_size == 0)
421+
)
422+
global_should_write = comm.allreduce(local_should_write, op=MPI.LOR)
423+
424+
if global_should_write:
417425
append_cell_attributes(
418426
output_filepath,
419427
population,
@@ -428,7 +436,13 @@ def distribute_synapses(
428436
syn_stats[population] = syn_stats_dict
429437
count += 1
430438

431-
if not dry_run:
439+
# Final write for any remaining synapse data - allreduce to ensure all ranks participate
440+
local_should_write_final = not dry_run
441+
global_should_write_final = comm.allreduce(
442+
local_should_write_final, op=MPI.LAND
443+
)
444+
445+
if global_should_write_final:
432446
append_cell_attributes(
433447
output_filepath,
434448
population,
@@ -447,7 +461,3 @@ def distribute_synapses(
447461
f"to compute synapse locations for {np.sum(global_count)} cells"
448462
)
449463
logger.info(summary)
450-
451-
comm.barrier()
452-
453-
MPI.Finalize()

src/miv_simulator/synapses.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,10 +2425,11 @@ def modify_mechanism_parameters(
24252425

24262426
assert new_val is not None
24272427

2428-
value_result, source = (
2429-
self.syn_store.param_store.get_parameter_value_hierarchy(
2430-
gid, array_index, mech_name, k, syn_id
2431-
)
2428+
(
2429+
value_result,
2430+
source,
2431+
) = self.syn_store.param_store.get_parameter_value_hierarchy(
2432+
gid, array_index, mech_name, k, syn_id
24322433
)
24332434
old_val = value_result if value_result is not None else mech_param
24342435

@@ -3133,13 +3134,12 @@ def insert_cell_syns(
31333134
list(sec_list),
31343135
)
31353136
)
3136-
if len(reduced_section_dict) == 0:
3137-
if hasattr(cell, "soma"):
3138-
cell_soma = cell.soma
3139-
if isinstance(cell_soma, list):
3140-
cell_soma = cell_soma[0]
3141-
if hasattr(cell, "dend"):
3142-
cell_dendrite = cell.dend
3137+
if hasattr(cell, "soma"):
3138+
cell_soma = cell.soma
3139+
if isinstance(cell_soma, list):
3140+
cell_soma = cell_soma[0]
3141+
if hasattr(cell, "dend"):
3142+
cell_dendrite = cell.dend
31433143

31443144
syn_manager = env.synapse_manager
31453145

0 commit comments

Comments
 (0)