Skip to content

Commit 178034e

Browse files
authored
Merge pull request #97 from iraikov/feature/multiple_input_spike_namespaces
Support for multiple input spike namespaces
2 parents 2fa4ad9 + d6442b2 commit 178034e

8 files changed

Lines changed: 57 additions & 42 deletions

File tree

src/miv_simulator/env.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Union
1+
from typing import Dict, Optional, Union, List, Any, Optional
22

33
import logging
44
import os
@@ -84,9 +84,9 @@ def __init__(
8484
lptbal: bool = False,
8585
cell_selection_path: None = None,
8686
microcircuit_inputs: bool = False,
87-
spike_input_path: None = None,
88-
spike_input_namespace: None = None,
89-
spike_input_attr: None = None,
87+
spike_input_path: Optional[str] = None,
88+
spike_input_namespaces: List[str] = [],
89+
spike_input_attr: Optional[str] = None,
9090
coordinates_namespace: str = "Coordinates",
9191
cache_queries: bool = False,
9292
profile_memory: bool = False,
@@ -312,7 +312,7 @@ def __init__(
312312

313313
# Spike input path
314314
self.spike_input_path = spike_input_path
315-
self.spike_input_ns = spike_input_namespace
315+
self.spike_input_namespaces = spike_input_namespaces
316316
self.spike_input_attr = spike_input_attr
317317
self.spike_input_attribute_info = None
318318
if self.spike_input_path is not None:

src/miv_simulator/input_spike_trains.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ def generate_input_spike_trains(
9393
logger.info(f"{comm.size} ranks have been allocated")
9494

9595
population_name = population.name
96-
start_gid = 0
97-
if hasattr(population, "start_gid"):
98-
start_gid = population.start_gid
9996

10097
soma_positions_dict = None
10198
if coords_path is not None:
@@ -189,8 +186,6 @@ def generate_input_spike_trains(
189186
feature_items = list(population.features.items())
190187
n_iter = comm.allreduce(len(feature_items), op=MPI.MAX)
191188

192-
logger.info(f"n_iter = {n_iter} feature_items = {feature_items}")
193-
194189
if not dry_run and rank == 0:
195190
if output_path is None:
196191
raise RuntimeError("generate_input_spike_trains: missing output_path")
@@ -203,7 +198,6 @@ def generate_input_spike_trains(
203198
for iter_count in range(n_iter):
204199
if iter_count < len(feature_items):
205200
gid, input_feature = feature_items[iter_count]
206-
gid += start_gid
207201
else:
208202
gid, input_feature = None, None
209203
if gid is not None:
@@ -225,7 +219,19 @@ def generate_input_spike_trains(
225219
# Get spike response
226220
response = input_feature.get_response(processed_signal)
227221
if isinstance(response, list):
228-
response = np.concatenate(np.concatenate(response, dtype=np.float32))
222+
response_length = 0
223+
for x in response:
224+
response_length += len(x)
225+
if response_length > 0:
226+
try:
227+
response = np.concatenate(np.concatenate(response, dtype=np.float32))
228+
except Exception as e:
229+
logger.error(f"error concatenating response: {response}")
230+
raise e
231+
else:
232+
response = np.asarray([], dtype=np.float32)
233+
else:
234+
response = response.reshape((-1,)).astype(np.float32)
229235

230236
if len(response) > 0:
231237
spikes_attr_dict[gid] = {output_spike_train_attr_name: response}

src/miv_simulator/interface/legacy/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __call__(self):
106106
cell_selection_path=None,
107107
microcircuit_inputs=False,
108108
spike_input_path=self.config.spike_input_path,
109-
spike_input_namespace=self.config.spike_input_namespace,
109+
spike_input_namespaces=[self.config.spike_input_namespace],
110110
spike_input_attr=self.config.spike_input_attr,
111111
cleanup=True,
112112
cache_queries=False,

src/miv_simulator/network.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,10 +1203,10 @@ def make_input_cell_selection(env):
12031203

12041204
has_spike_train = False
12051205
if (env.spike_input_attribute_info is not None) and (
1206-
env.spike_input_ns is not None
1206+
len(env.spike_input_namespaces) > 0
12071207
):
12081208
if (pop_name in env.spike_input_attribute_info) and (
1209-
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
1209+
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
12101210
):
12111211
has_spike_train = True
12121212

@@ -1309,18 +1309,19 @@ def init_input_cells(env: Env) -> None:
13091309
has_vecstim = False
13101310
vecstim_source_loc = []
13111311
if (env.spike_input_attribute_info is not None) and (
1312-
env.spike_input_ns is not None
1312+
len(env.spike_input_namespaces) > 0
13131313
):
13141314
if (pop_name in env.spike_input_attribute_info) and (
1315-
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
1315+
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
13161316
):
13171317
has_vecstim = True
1318-
vecstim_source_loc.append(
1319-
(
1320-
env.spike_input_path,
1321-
env.spike_input_ns,
1322-
env.spike_input_attr,
1323-
)
1318+
for ns in env.spike_input_namespaces:
1319+
vecstim_source_loc.append(
1320+
(
1321+
env.spike_input_path,
1322+
ns,
1323+
env.spike_input_attr,
1324+
)
13241325
)
13251326
if (env.cell_attribute_info is not None) and (
13261327
vecstim_namespace is not None
@@ -1454,23 +1455,25 @@ def init_input_cells(env: Env) -> None:
14541455
has_spike_train = False
14551456
spike_input_source_loc = []
14561457
if (env.spike_input_attribute_info is not None) and (
1457-
env.spike_input_ns is not None
1458+
len(env.spike_input_namespaces) > 0
14581459
):
14591460
if (pop_name in env.spike_input_attribute_info) and (
1460-
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
1461+
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
14611462
):
14621463
has_spike_train = True
1463-
spike_input_source_loc.append(
1464-
(env.spike_input_path, env.spike_input_ns)
1464+
for ns in env.spike_input_namespaces:
1465+
spike_input_source_loc.append(
1466+
(env.spike_input_path, ns)
14651467
)
14661468
if (env.cell_attribute_info is not None) and (
1467-
env.spike_input_ns is not None
1469+
len(env.spike_input_namespaces) > 0
14681470
):
14691471
if (pop_name in env.cell_attribute_info) and (
1470-
env.spike_input_ns in env.cell_attribute_info[pop_name]
1472+
set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys()))
14711473
):
14721474
has_spike_train = True
1473-
spike_input_source_loc.append((input_file_path, env.spike_input_ns))
1475+
for ns in env.spike_input_namespaces:
1476+
spike_input_source_loc.append((input_file_path, ns))
14741477

14751478
if rank == 0:
14761479
logger.info(
@@ -1525,7 +1528,7 @@ def init_input_cells(env: Env) -> None:
15251528
elif len(this_gid_range) > 0:
15261529
raise RuntimeError(
15271530
f"init_input_cells: unable to determine spike train attribute for population {pop_name} in spike input file {env.spike_input_path};"
1528-
f" namespace {env.spike_input_ns}; attr keys {list(cell_spikes_attr_info.keys())}"
1531+
f" namespaces {env.spike_input_namespaces}; attr keys {list(cell_spikes_attr_info.keys())}"
15291532
)
15301533
for gid, cell_spikes_tuple in cell_spikes_iter:
15311534
if not (env.pc.gid_exists(gid)):
@@ -1566,7 +1569,7 @@ def init_input_cells(env: Env) -> None:
15661569
if rank == 0:
15671570
logger.warning(
15681571
f"No spike train data found for population {pop_name} in spike input file {env.spike_input_path}; "
1569-
f"namespace: {env.spike_input_ns}"
1572+
f"namespaces: {env.spike_input_namespaces}"
15701573
)
15711574

15721575
gc.collect()

src/miv_simulator/scripts/run_network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def mpi_excepthook(type, value, traceback):
187187
@click.option(
188188
"--spike-input-namespace",
189189
required=False,
190+
multiple=True,
190191
type=str,
191192
help="namespace for input spikes when cell selection is specified",
192193
)
@@ -290,6 +291,7 @@ def main(
290291
np.seterr(all="raise")
291292
params = dict(locals())
292293
params["config"] = params.pop("config_file")
294+
params["spike_input_namespaces"] = params.get("spike_input_namespace", [])
293295
env = Env(**params)
294296

295297
compile_and_load(directory=env.mechanisms_path)

src/miv_simulator/scripts/tools/cut_slice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def mpi_excepthook(type, value, traceback):
5252
@click.option(
5353
"--spike-input-namespace",
5454
required=False,
55+
multiple=True,
5556
type=str,
5657
help="namespace for input spikes when cell selection is specified",
5758
)
@@ -103,7 +104,7 @@ def main(
103104
dataset_prefix=dataset_prefix,
104105
results_path=output_path,
105106
spike_input_path=spike_input_path,
106-
spike_input_namespace=spike_input_namespace,
107+
spike_input_namespaces=spike_input_namespace,
107108
spike_input_attr=spike_input_attr,
108109
coordinates_namespace=coordinates_namespace,
109110
io_size=io_size,

src/miv_simulator/scripts/tools/sample_cells.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def mpi_excepthook(type, value, traceback):
5959
@click.option(
6060
"--spike-input-namespace",
6161
required=False,
62+
multiple=True,
6263
type=str,
6364
help="namespace for input spikes",
6465
)
@@ -125,7 +126,7 @@ def main(
125126
dataset_prefix=dataset_prefix,
126127
results_path=output_path,
127128
spike_input_path=spike_input_path,
128-
spike_input_namespace=spike_input_namespace,
129+
spike_input_namespaces=spike_input_namespace,
129130
spike_input_attr=spike_input_attr,
130131
arena_id=arena_id,
131132
stimulus_id=stimulus_id,

src/miv_simulator/utils/io.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,21 +1025,23 @@ def write_input_cell_selection(
10251025
has_spike_train = False
10261026
spike_input_source_loc = []
10271027
if (env.spike_input_attribute_info is not None) and (
1028-
env.spike_input_ns is not None
1028+
len(env.spike_input_namespaces) > 0
10291029
):
10301030
if (pop_name in env.spike_input_attribute_info) and (
1031-
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
1031+
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
10321032
):
10331033
has_spike_train = True
1034-
spike_input_source_loc.append(
1035-
(env.spike_input_path, env.spike_input_ns)
1034+
for ns in env.spike_input_namespaces:
1035+
spike_input_source_loc.append(
1036+
(env.spike_input_path, ns)
10361037
)
1037-
if (env.cell_attribute_info is not None) and (env.spike_input_ns is not None):
1038+
if (env.cell_attribute_info is not None) and (len(env.spike_input_namespaces) > 0):
10381039
if (pop_name in env.cell_attribute_info) and (
1039-
env.spike_input_ns in env.cell_attribute_info[pop_name]
1040+
set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys()))
10401041
):
10411042
has_spike_train = True
1042-
spike_input_source_loc.append((input_file_path, env.spike_input_ns))
1043+
for ns in env.spike_input_namespaces:
1044+
spike_input_source_loc.append((input_file_path, ns))
10431045

10441046
if rank == 0:
10451047
logger.info(
@@ -1083,7 +1085,7 @@ def write_input_cell_selection(
10831085
write_selection_file_path,
10841086
pop_name,
10851087
spikes_output_dict,
1086-
namespace=env.spike_input_ns,
1088+
namespace=env.spike_input_namespaces[0],
10871089
**write_kwds,
10881090
)
10891091

0 commit comments

Comments
 (0)