-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathdemo_metauas.py
More file actions
90 lines (70 loc) · 2.22 KB
/
demo_metauas.py
File metadata and controls
90 lines (70 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : demo_metauas.py
@Time : 2025/03/26 23:49:14
@Author : Bin-Bin Gao
@Email : csgaobb@gmail.com
@Homepage: https://csgaobb.github.io/
@Version : 1.0
@Desc : MetaUAS Demo
'''
import os
import cv2
import torch
import json
import shutil
import kornia as K
import numpy as np
from easydict import EasyDict
from argparse import ArgumentParser
from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, read_image_as_tensor, safely_load_state_dict
if __name__ == "__main__":
random_seed = 1
set_random_seed(random_seed)
ckt_path = 'weights/metauas-256.ckpt'
img_size = 256
#ckt_path = "weights/metauas-512.ckpt"
#img_size = 512
# load model
encoder = 'efficientnet-b4'
decoder = 'unet'
encoder_depth = 5
decoder_depth = 5
num_crossfa_layers = 3
alignment_type = 'sa'
fusion_policy = 'cat'
model = MetaUAS(encoder,
decoder,
encoder_depth,
decoder_depth,
num_crossfa_layers,
alignment_type,
fusion_policy
)
model = safely_load_state_dict(model, ckt_path)
model.cuda()
model.eval()
# load test images
path_root = "./images/"
path_to_prompt = path_root + "036.png"
path_to_query = path_root + "024.png"
query = read_image_as_tensor(path_to_query)
prompt = read_image_as_tensor(path_to_prompt)
if query.shape[1] != img_size:
resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True)
query = resize_trans(query)[0]
prompt = resize_trans(prompt)[0]
test_data = {
"query_image": query.cuda(),
"prompt_image": prompt.cuda(),
}
# forward
predicted_masks = model(test_data)
# visualization
query_img = test_data["query_image"][0] * 255
query_img = query_img.permute(1,2,0)
pred = (1-predicted_masks.squeeze().detach())[:, :, None].cpu().numpy().repeat(3, 2)
# normalize just for analysis
scoremap_self = apply_ad_scoremap(query_img.cpu(), normalize(pred))
cv2.imwrite('./anomaly_map.jpg', scoremap_self)