-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathvisualizations.py
More file actions
141 lines (120 loc) · 5.54 KB
/
visualizations.py
File metadata and controls
141 lines (120 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import torch
from PIL import Image, ImageOps
import torchvision.transforms as tfm
# Height and width of a single image for visualization
IMG_HW = 512
TEXT_H = 175
FONTSIZE = 50
SPACE = 50 # Space between two images
def write_labels_to_image(labels=["text1", "text2"]):
"""Creates an image with text"""
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", FONTSIZE)
img = Image.new("RGB", ((IMG_HW * len(labels)) + 50 * (len(labels) - 1), TEXT_H), (1, 1, 1))
d = ImageDraw.Draw(img)
for i, text in enumerate(labels):
_, _, w, h = d.textbbox((0, 0), text, font=font)
d.text(((IMG_HW + SPACE) * i + IMG_HW // 2 - w // 2, 1), text, fill=(0, 0, 0), font=font)
return Image.fromarray(np.array(img)[:100] * 255) # Remove some empty space
def draw_box(img, c=(0, 1, 0), thickness=20):
"""Draw a colored box around an image. Image should be a PIL.Image."""
assert isinstance(img, Image.Image)
img = tfm.ToTensor()(img)
assert len(img.shape) >= 2, f"{img.shape=}"
c = torch.tensor(c).type(torch.float).reshape(3, 1, 1)
img[..., :thickness, :] = c
img[..., -thickness:, :] = c
img[..., :, -thickness:] = c
img[..., :, :thickness] = c
return tfm.ToPILImage()(img)
def build_prediction_image(images_paths, preds_correct):
"""Build a row of images, where the first is the query and the rest are predictions.
For each image, if is_correct then draw a green/red box.
"""
assert len(images_paths) == len(preds_correct)
labels = ["Query"]
for i, is_correct in enumerate(preds_correct[1:]):
if is_correct is None:
labels.append(f"Pred{i}")
else:
labels.append(f"Pred{i} - {is_correct}")
images = [Image.open(path).convert("RGB") for path in images_paths]
for img_idx, (img, is_correct) in enumerate(zip(images, preds_correct)):
if is_correct is None:
continue
color = (0, 1, 0) if is_correct else (1, 0, 0)
img = draw_box(img, color)
images[img_idx] = img
resized_images = [tfm.Resize(510, max_size=IMG_HW, antialias=True)(img) for img in images]
resized_images = [ImageOps.pad(img, (IMG_HW, IMG_HW), color='white') for img in images] # Apply padding to make them squared
total_h = len(resized_images)*IMG_HW + max(0,len(resized_images)-1)*SPACE # 2
concat_image = Image.new('RGB', (total_h, IMG_HW), (255, 255, 255))
y=0
for img in resized_images:
concat_image.paste(img, (y, 0))
y += IMG_HW + SPACE
try:
labels_image = write_labels_to_image(labels)
# Transform the images to np arrays for concatenation
final_image = Image.fromarray(np.concatenate((np.array(labels_image), np.array(concat_image)), axis=0))
except OSError: # Handle error in case of missing PIL ImageFont
final_image = concat_image
return final_image
def save_file_with_paths(query_path, preds_paths, positives_paths, output_path, use_labels=True):
file_content = []
file_content.append("Query path:")
file_content.append(query_path + "\n")
file_content.append("Predictions paths:")
file_content.append("\n".join(preds_paths) + "\n")
if use_labels:
file_content.append("Positives paths:")
file_content.append("\n".join(positives_paths) + "\n")
with open(output_path, "w") as file:
_ = file.write("\n".join(file_content))
def save_preds(predictions, eval_ds, log_dir, save_only_wrong_preds=None, use_labels=True):
"""For each query, save an image containing the query and its predictions,
and a file with the paths of the query, its predictions and its positives.
Parameters
----------
predictions : np.array of shape [num_queries x num_preds_to_viz], with the preds
for each query
eval_ds : TestDataset
log_dir : Path with the path to save the predictions
save_only_wrong_preds : bool, if True save only the wrongly predicted queries,
i.e. the ones where the first pred is uncorrect (further than 25 m)
"""
if use_labels:
positives_per_query = eval_ds.get_positives()
viz_dir = log_dir / "preds"
viz_dir.mkdir()
for query_index, preds in enumerate(tqdm(predictions, desc=f"Saving preds in {viz_dir}")):
query_path = eval_ds.queries_paths[query_index]
list_of_images_paths = [query_path]
# List of None (query), True (correct preds) or False (wrong preds)
preds_correct = [None]
for pred_index, pred in enumerate(preds):
pred_path = eval_ds.database_paths[pred]
list_of_images_paths.append(pred_path)
if use_labels:
is_correct = pred in positives_per_query[query_index]
else:
is_correct = None
preds_correct.append(is_correct)
if save_only_wrong_preds and preds_correct[1]:
continue
prediction_image = build_prediction_image(list_of_images_paths, preds_correct)
pred_image_path = viz_dir / f"{query_index:03d}.jpg"
prediction_image.save(pred_image_path)
if use_labels:
positives_paths = [eval_ds.database_paths[idx] for idx in positives_per_query[query_index]]
else:
positives_paths = None
save_file_with_paths(
query_path=list_of_images_paths[0],
preds_paths=list_of_images_paths[1:],
positives_paths=positives_paths,
output_path=viz_dir / f"{query_index:03d}.txt",
use_labels=use_labels,
)