-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
107 lines (88 loc) · 3.29 KB
/
infer.py
File metadata and controls
107 lines (88 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import glob
import logging
import os
import numpy as np
import tensorflow as tf
import tqdm
from PIL import Image
import docclean
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="DocClean", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-v", "--verbose", help="Be verbose", action="store_true")
parser.add_argument(
"-g",
"--gpu_id",
help="GPU ID (use -1 for CPU)",
type=int,
required=False,
default=0,
)
parser.add_argument(
"-c",
"--data_dir",
help="Directory with candidate pngs.",
required=True,
type=str,
)
parser.add_argument(
"-b", "--batch_size", help="Batch size for training data", default=32, type=int
)
parser.add_argument(
"-t",
"--type",
help="Which model to train",
choices=["cycle_gan", "autoencoder"],
required=True,
type=str,
)
parser.add_argument(
"-w", "--weights", help="Model weights", required=False, type=str, default=None
)
args = parser.parse_args()
logging_format = (
"%(asctime)s - %(funcName)s -%(name)s - %(levelname)s - %(message)s"
)
if args.verbose:
logging.basicConfig(level=logging.DEBUG, format=logging_format)
else:
logging.basicConfig(level=logging.INFO, format=logging_format)
if args.gpu_id >= 0:
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}"
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
cands_to_eval = glob.glob(f"{args.data_dir}/*png")
if len(cands_to_eval) == 0:
raise FileNotFoundError(f"No candidates to evaluate.")
if args.type == "cycle_gan":
from tensorflow_examples.models.pix2pix import pix2pix
model = pix2pix.unet_generator(3, norm_type="instancenorm")
else:
model = docclean.autoencoder.Autoencoder().autoencoder_model
if args.weights is None:
if args.type == "cycle_gan":
cg_url = "https://docclean.s3.us-east-2.amazonaws.com/cg_weights.tar.gz"
tf.keras.utils.get_file("cg_weights", cg_url, untar=True)
model.load_weights(os.path.expanduser("~/.keras/datasets/weights/cg"))
else:
ae_url = "https://docclean.s3.us-east-2.amazonaws.com/ae_weights.tar.gz"
tf.keras.utils.get_file("ae_weights", ae_url, untar=True)
model.load_weights(os.path.expanduser("~/.keras/datasets/weights/ae"))
else:
model.load_weights(args.weights)
input_img_list = glob.glob(f"{args.data_dir} + '/*png")
list_ds = tf.data.Dataset.from_tensor_slices(input_img_list)
infer_images = list_ds.map(docclean.utils.get_png_data, num_parallel_calls=AUTOTUNE)
infer_images = infer_images.map(
docclean.normalize, num_parallel_calls=AUTOTUNE
).prefetch(AUTOTUNE)
inffered_images = model.predict(infer_images, verbose=1, batch_size=args.batch_size)
if not isinstance(inffered_images, np.ndarray):
inffered_images = inffered_images.numpy()
for idx, img_name in tqdm.tqdm(enumerate(input_img_list)):
outname = img_name[:-4] + "_cleaned.png"
im = Image.fromarray(inffered_images[idx], "RGB")
im.save(outname)