Skip to content

Commit 8004398

Browse files
authored
test: add unit tests for RNN and MultimodalRNN models (#936)
Add test_rnn.py with 12 test cases covering: TestRNN (8 tests): - Model initialization with correct attributes - Forward pass output structure and shapes - Backward pass gradient propagation - Embedding extraction via embed=True - Custom hyperparameters (embedding_dim, hidden_dim) - LSTM cell type variant - Vanilla RNN cell type variant - Bidirectional RNN variant TestMultimodalRNN (4 tests): - Initialization with correct sequential/non-sequential classification - Forward pass with mixed modalities (sequence + multi_hot + tensor) - Backward pass gradient propagation - Embedding extraction with correct mixed-modality dimensions Follows the established test pattern from test_mlp.py and test_tcn.py using create_sample_dataset with synthetic data. Ref #425
1 parent f7dd848 commit 8004398

1 file changed

Lines changed: 275 additions & 0 deletions

File tree

tests/core/test_rnn.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pyhealth.datasets import create_sample_dataset, get_dataloader
6+
from pyhealth.models import RNN
7+
from pyhealth.models.rnn import MultimodalRNN
8+
9+
10+
class TestRNN(unittest.TestCase):
11+
"""Test cases for the RNN model."""
12+
13+
def setUp(self):
14+
"""Set up test data and model."""
15+
self.samples = [
16+
{
17+
"patient_id": "patient-0",
18+
"visit_id": "visit-0",
19+
"conditions": ["cond-33", "cond-86", "cond-80", "cond-12"],
20+
"procedures": ["proc-12", "proc-45", "proc-23"],
21+
"label": 0,
22+
},
23+
{
24+
"patient_id": "patient-1",
25+
"visit_id": "visit-1",
26+
"conditions": ["cond-33", "cond-86", "cond-80"],
27+
"procedures": ["proc-12"],
28+
"label": 1,
29+
},
30+
]
31+
32+
self.input_schema = {
33+
"conditions": "sequence",
34+
"procedures": "sequence",
35+
}
36+
self.output_schema = {"label": "binary"}
37+
38+
self.dataset = create_sample_dataset(
39+
samples=self.samples,
40+
input_schema=self.input_schema,
41+
output_schema=self.output_schema,
42+
dataset_name="test",
43+
)
44+
45+
self.model = RNN(dataset=self.dataset)
46+
47+
def test_model_initialization(self):
48+
"""Test that the RNN model initializes correctly."""
49+
self.assertIsInstance(self.model, RNN)
50+
self.assertEqual(self.model.embedding_dim, 128)
51+
self.assertEqual(self.model.hidden_dim, 128)
52+
self.assertEqual(len(self.model.feature_keys), 2)
53+
self.assertIn("conditions", self.model.feature_keys)
54+
self.assertIn("procedures", self.model.feature_keys)
55+
self.assertEqual(self.model.label_key, "label")
56+
57+
def test_model_forward(self):
58+
"""Test that the RNN model forward pass works correctly."""
59+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
60+
data_batch = next(iter(train_loader))
61+
62+
with torch.no_grad():
63+
ret = self.model(**data_batch)
64+
65+
self.assertIn("loss", ret)
66+
self.assertIn("y_prob", ret)
67+
self.assertIn("y_true", ret)
68+
self.assertIn("logit", ret)
69+
70+
self.assertEqual(ret["y_prob"].shape[0], 2)
71+
self.assertEqual(ret["y_true"].shape[0], 2)
72+
self.assertEqual(ret["logit"].shape[0], 2)
73+
self.assertEqual(ret["loss"].dim(), 0)
74+
75+
def test_model_backward(self):
76+
"""Test that the RNN model backward pass works correctly."""
77+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
78+
data_batch = next(iter(train_loader))
79+
80+
ret = self.model(**data_batch)
81+
ret["loss"].backward()
82+
83+
has_gradient = False
84+
for param in self.model.parameters():
85+
if param.requires_grad and param.grad is not None:
86+
has_gradient = True
87+
break
88+
self.assertTrue(
89+
has_gradient, "No parameters have gradients after backward pass"
90+
)
91+
92+
def test_model_with_embedding(self):
93+
"""Test that the RNN model returns embeddings when requested."""
94+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
95+
data_batch = next(iter(train_loader))
96+
data_batch["embed"] = True
97+
98+
with torch.no_grad():
99+
ret = self.model(**data_batch)
100+
101+
self.assertIn("embed", ret)
102+
self.assertEqual(ret["embed"].shape[0], 2)
103+
expected_embed_dim = len(self.model.feature_keys) * self.model.hidden_dim
104+
self.assertEqual(ret["embed"].shape[1], expected_embed_dim)
105+
106+
def test_custom_hyperparameters(self):
107+
"""Test RNN model with custom hyperparameters."""
108+
model = RNN(
109+
dataset=self.dataset,
110+
embedding_dim=64,
111+
hidden_dim=32,
112+
)
113+
114+
self.assertEqual(model.embedding_dim, 64)
115+
self.assertEqual(model.hidden_dim, 32)
116+
117+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
118+
data_batch = next(iter(train_loader))
119+
120+
with torch.no_grad():
121+
ret = model(**data_batch)
122+
123+
self.assertIn("loss", ret)
124+
self.assertIn("y_prob", ret)
125+
126+
def test_rnn_type_lstm(self):
127+
"""Test RNN model with LSTM cell type."""
128+
model = RNN(
129+
dataset=self.dataset,
130+
rnn_type="LSTM",
131+
)
132+
133+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
134+
data_batch = next(iter(train_loader))
135+
136+
with torch.no_grad():
137+
ret = model(**data_batch)
138+
139+
self.assertIn("loss", ret)
140+
self.assertEqual(ret["y_prob"].shape[0], 2)
141+
142+
def test_rnn_type_vanilla(self):
143+
"""Test RNN model with vanilla RNN cell type."""
144+
model = RNN(
145+
dataset=self.dataset,
146+
rnn_type="RNN",
147+
)
148+
149+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
150+
data_batch = next(iter(train_loader))
151+
152+
with torch.no_grad():
153+
ret = model(**data_batch)
154+
155+
self.assertIn("loss", ret)
156+
self.assertEqual(ret["y_prob"].shape[0], 2)
157+
158+
def test_bidirectional(self):
159+
"""Test RNN model with bidirectional layers."""
160+
model = RNN(
161+
dataset=self.dataset,
162+
bidirectional=True,
163+
)
164+
165+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
166+
data_batch = next(iter(train_loader))
167+
168+
with torch.no_grad():
169+
ret = model(**data_batch)
170+
171+
self.assertIn("loss", ret)
172+
self.assertEqual(ret["y_prob"].shape[0], 2)
173+
174+
175+
class TestMultimodalRNN(unittest.TestCase):
176+
"""Test cases for the MultimodalRNN model with mixed input modalities."""
177+
178+
def setUp(self):
179+
"""Set up test data with both sequential and non-sequential features."""
180+
self.samples = [
181+
{
182+
"patient_id": "patient-0",
183+
"visit_id": "visit-0",
184+
"conditions": ["cond-33", "cond-86", "cond-80"],
185+
"demographics": ["asian", "male"],
186+
"vitals": [120.0, 80.0, 98.6],
187+
"label": 1,
188+
},
189+
{
190+
"patient_id": "patient-1",
191+
"visit_id": "visit-1",
192+
"conditions": ["cond-12", "cond-52"],
193+
"demographics": ["white", "female"],
194+
"vitals": [110.0, 75.0, 98.2],
195+
"label": 0,
196+
},
197+
]
198+
199+
self.input_schema = {
200+
"conditions": "sequence",
201+
"demographics": "multi_hot",
202+
"vitals": "tensor",
203+
}
204+
self.output_schema = {"label": "binary"}
205+
206+
self.dataset = create_sample_dataset(
207+
samples=self.samples,
208+
input_schema=self.input_schema,
209+
output_schema=self.output_schema,
210+
dataset_name="test",
211+
)
212+
213+
self.model = MultimodalRNN(dataset=self.dataset)
214+
215+
def test_model_initialization(self):
216+
"""Test that the MultimodalRNN model initializes correctly."""
217+
self.assertIsInstance(self.model, MultimodalRNN)
218+
self.assertEqual(len(self.model.feature_keys), 3)
219+
self.assertIn("conditions", self.model.sequential_features)
220+
self.assertIn("demographics", self.model.non_sequential_features)
221+
self.assertIn("vitals", self.model.non_sequential_features)
222+
223+
def test_model_forward(self):
224+
"""Test that the MultimodalRNN forward pass works correctly."""
225+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
226+
data_batch = next(iter(train_loader))
227+
228+
with torch.no_grad():
229+
ret = self.model(**data_batch)
230+
231+
self.assertIn("loss", ret)
232+
self.assertIn("y_prob", ret)
233+
self.assertIn("y_true", ret)
234+
self.assertIn("logit", ret)
235+
236+
self.assertEqual(ret["y_prob"].shape[0], 2)
237+
self.assertEqual(ret["loss"].dim(), 0)
238+
239+
def test_model_backward(self):
240+
"""Test that the MultimodalRNN backward pass works correctly."""
241+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
242+
data_batch = next(iter(train_loader))
243+
244+
ret = self.model(**data_batch)
245+
ret["loss"].backward()
246+
247+
has_gradient = False
248+
for param in self.model.parameters():
249+
if param.requires_grad and param.grad is not None:
250+
has_gradient = True
251+
break
252+
self.assertTrue(
253+
has_gradient, "No parameters have gradients after backward pass"
254+
)
255+
256+
def test_model_with_embedding(self):
257+
"""Test that the MultimodalRNN returns embeddings when requested."""
258+
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
259+
data_batch = next(iter(train_loader))
260+
data_batch["embed"] = True
261+
262+
with torch.no_grad():
263+
ret = self.model(**data_batch)
264+
265+
self.assertIn("embed", ret)
266+
self.assertEqual(ret["embed"].shape[0], 2)
267+
expected_embed_dim = (
268+
len(self.model.sequential_features) * self.model.hidden_dim
269+
+ len(self.model.non_sequential_features) * self.model.embedding_dim
270+
)
271+
self.assertEqual(ret["embed"].shape[1], expected_embed_dim)
272+
273+
274+
if __name__ == "__main__":
275+
unittest.main()

0 commit comments

Comments
 (0)