-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathloss.py
More file actions
101 lines (58 loc) · 2.51 KB
/
loss.py
File metadata and controls
101 lines (58 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class CAMLoss(nn.Module):
def __init__(self, size_average=True):
super(CAMLoss,self).__init__()
self.size_average = size_average
self.epsilon = 1e-9
def forward(self, input, label):
input_0 = input[:,0,:,:].view(input.size(0), -1)
input_1 = input[:,1,:,:].view(input.size(0), -1)
loss = torch.min(input_0, input_1).mean(dim=1)
label = label.view(-1)
loss += label*input_0.mean(dim=1) + (1 - label) * input_1.mean(dim=1)
if self.size_average:
return loss.mean()
else:
return loss.sum()
class exclusLoss(nn.Module):
def __init__(self, size_average=True):
super(exclusLoss,self).__init__()
self.size_average = size_average
self.epsilon = 1e-9
self.criterion = nn.CrossEntropyLoss()
def forward(self, cam1,cam2,label1,label2):
cam_1 = cam1[:,1,:,:].unsqueeze(1)
cam_2 = cam2[:,1,:,:].unsqueeze(1)
cl2_ = nn.functional.adaptive_avg_pool2d(cam2*(1-cam_1.detach()),(1,1))
cl2_ = cl2_.view(-1, 2)
cl1_ = nn.functional.adaptive_avg_pool2d(cam1*(1-cam_2.detach()),(1,1))
cl1_ = cl1_.view(-1, 2)
loss1 = self.criterion(cl2_,label2)+0.1*self.criterion(cl1_,label1)
cam_1 = cam_1.view(cam_1.size(0), -1)
cam_2 = cam_2.view(cam_2.size(0), -1)
label2 = label2.view(-1).unsqueeze(1)
loss2 =label2*(torch.min(cam_1,cam_2).mean(dim=1)-0.1*torch.max(cam_1,cam_2))
loss = loss1+0.1*loss2
if self.size_average:
return loss.mean()
else:
return loss.sum()
class AlignLoss(torch.nn.Module):
def __init__(self,reduction='elementwise_mean'):
super().__init__()
self.reduction = reduction
def forward(self, cam,cam_co,label):
cam = cam[:,1,:,:]
cam_co = cam_co[:,1,:,:]
cam = cam.view(cam.size(0), -1)
cam_co = cam_co.view(cam_co.size(0), -1)
label = label.view(-1).unsqueeze(1)
loss = -label*torch.min(cam,cam_co)
if self.reduction == 'elementwise_mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss