-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_node_base.py
More file actions
74 lines (62 loc) · 2.25 KB
/
inference_node_base.py
File metadata and controls
74 lines (62 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# ============================================
__author__ = "ShigemichiMatsuzaki"
__maintainer__ = "ShigemichiMatsuzaki"
# ============================================
# ROS related
import rospy
from sensor_msgs.msg import Image
# PyTorch related
import torch
from torchvision import transforms
from typing import Union
# from util.util import import_model
class InferenceNodeBase(object):
def __init__(
self,
model_name: str = "",
version: str = "",
import_model_func=None,
size: Union[int, tuple] = 224,
):
self.model_name = model_name
self.version = version
# Import model
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.size = size
rospy.loginfo(self.device)
if model_name:
try:
model_etc = import_model_func(model_name, version)
self.model = model_etc["model"]
self.transforms = transforms.Compose(
[
transforms.Resize(self.size),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]) if model_etc["transform"] is None else model_etc["transform"]
except ValueError:
print("Invalid input ('{}', '{}')".format(model_name, version))
exit(1)
self.model.to(self.device)
self.model.eval()
else:
self.model = None
self.transforms = transforms.Compose(
[
transforms.Resize(self.size),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
self.image_sub = rospy.Subscriber(
'~image', Image, self.image_callback, queue_size=1)
self.image_pub = rospy.Publisher('~visualize', Image, queue_size=10)
def image_callback(self, img_msg):
""" Take an image message and process it through the model
Args:
img_msg: Image message
"""
raise NotImplementedError