-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
81 lines (62 loc) · 2.28 KB
/
main.py
File metadata and controls
81 lines (62 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
import torch
import torch.utils.data
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from src.networkUtils import *
import src.Dataset as ds
import src.network
import argparse
def getArgs():
parser = argparse.ArgumentParser()
parser.add_argument('-e', type=int)
parser.add_argument('-b', action='store_true')
return parser.parse_args()
def main():
# Get the command-line arguments
args = getArgs()
# Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device}; device: {device}")
# Load data
trainingTransforms = transforms.Compose([
transforms.ToTensor()
])
dataset = ds.ImageDataset("dummyData/ColorfulOriginal", transform=trainingTransforms)
# Calculate 90%-10% for a split between train-test set
dsLen = len(dataset)
trainLen = int(dsLen * 0.9)
testLen = dsLen - trainLen
trainSet, testSet = torch.utils.data.random_split(dataset, [trainLen, testLen])
# Create DataLoader objects from the data sets
trainset = torch.utils.data.DataLoader(trainSet, batch_size=10, shuffle=True)
testingData = torch.utils.data.DataLoader(testSet, batch_size=10, shuffle=True)
model = src.network.eccv16()
if(device.type != 'cpu'):
print(f"Pushing model to CUDA")
model.cuda()
# TODO Implement pretrained weights from original author and my own model
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)
save_images = True
best_losses = 1e10
epochs = args.e
b = args.b
print(f"Epochs {epochs}")
# Train model
for epoch in range(epochs):
# Train for one epoch, then validate
train(trainset, model, criterion, optimizer, epoch, device, args.b)
with torch.no_grad():
losses = validate(testingData, model, criterion, save_images, epoch, device)
# Save checkpoint and replace old best model if current model is better
if losses < best_losses:
best_losses = losses
pathlib.Path('checkpoints').mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))
if __name__ == '__main__':
main()