-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vrnn_mech3d.py
More file actions
97 lines (83 loc) · 3.8 KB
/
train_vrnn_mech3d.py
File metadata and controls
97 lines (83 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# %%
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt
from vrnn.normalization import NormalizedDataset, NormalizationModule, SpectralNormalization
from vrnn.data_mechanical import Dataset3DMechanical
from vrnn.models import VanillaModule, MixedActivationMLP
from vrnn import utils
import numpy as np
from datetime import datetime
from vrnn.losses import VoigtReussNormalizedLoss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
dtypes = {'features': torch.float64, 'targets': torch.float64, 'images': torch.float64}
# %%
# Load hdf5 files
data_dir = utils.get_data_dir()
h5_file = data_dir / 'feature_engineering_mechanical_3D.h5'
csv_file = data_dir /'metadata_mechanical_3D.csv'
train_data = Dataset3DMechanical(
csv_file_path= csv_file,
h5_file_path= h5_file,
group="train_set",
num_samples=751089, # Maximum number of samples: 751089
random_seed=42,
input_mode='descriptors',
feature_idx= None,
feature_key="feature_vector",
device=device,
dtypes=dtypes,
)
val_data = Dataset3DMechanical(
csv_file_path= csv_file,
h5_file_path= h5_file,
group="val_set",
num_samples=263188, # Maximum number of samples: 263188
random_seed=42,
input_mode='descriptors',
feature_idx= None,
feature_key="feature_vector",
device=device,
dtypes=dtypes,
)
# Create dataloaders
batch_size = 75000
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
in_dim, out_dim = train_data.features.shape[-1], train_data.targets.shape[-1]
# %%
# Define normalization
features_max = torch.cat([train_data.features, val_data.features], dim=0).max(dim=0)[0]
features_min = torch.cat([train_data.features, val_data.features], dim=0).min(dim=0)[0]
features_min[0],features_max[0] = 0.0, 1.0 # Dont normalize the first feature (volume fraction)
normalization = SpectralNormalization(dim=6, features_min=features_min, features_max=features_max, bounds_fn=train_data.calc_bounds)
# Normalize data
train_data_norm = NormalizedDataset(train_data, normalization)
val_data_norm = NormalizedDataset(val_data, normalization)
train_loader_norm = DataLoader(train_data_norm, batch_size=batch_size, shuffle=False)
val_loader_norm = DataLoader(val_data_norm, batch_size=batch_size, shuffle=False)
# %%
ann_model = MixedActivationMLP(input_dim=in_dim, hidden_dims=[1024, 512, 256, 128, 128], output_dim=out_dim,
activation_fns=[nn.SELU(), nn.Tanh(), nn.Sigmoid(), nn.Identity()],
output_activation=nn.Sigmoid(),
use_batch_norm=True)
model_norm = VanillaModule(ann_model).to(device=device, dtype=dtypes['features'])
print(summary(model_norm, input_size=(batch_size, in_dim), dtypes=[dtypes['features']], device=device))
loss_fn = VoigtReussNormalizedLoss(dim=6)
optimizer = torch.optim.AdamW(model_norm.parameters(), lr=1e-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, min_lr=5e-5)
# %%
epochs = 1000
train_losses, val_losses, best_epoch = \
utils.model_training(model_norm, loss_fn, optimizer, train_loader_norm, val_loader_norm, epochs,
verbose=True, scheduler=scheduler)
fig, ax = plt.subplots()
utils.plot_training_history(ax, train_losses, val_losses, best_epoch)
# %%
model = NormalizationModule(normalized_module=model_norm, normalization=normalization).to(device=device)
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(model, data_dir / f'Mechanical3D_models/vrnn_mech3D_{current_time}.pt')
fig.savefig(data_dir / f'Mechanical3D_models/vrnn_mech3D_training_history_{current_time}.png', dpi=300)