-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_setup.py
More file actions
93 lines (77 loc) · 2.36 KB
/
model_setup.py
File metadata and controls
93 lines (77 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import mediapipe as mp
mp_pose = mp.solutions.pose
class Model_Setup:
"""
This class hasall the CONSTANTS and Hyperparameter
of the models.
Args:
Self explainable
Returns:
Raises:
"""
def __init__(
self,
_presence_threshold=0.5,
_rgb_channels=3,
_visibility_threshold=0.5,
batch_size=256,
black_color=(0, 0, 0),
blue_color=(255, 0, 0),
buffer_size=150,
epochs=150,
green_color=(0, 128, 0),
hist_windows=24 * 3,
horizon=1,
model_name="autoregression",
model = None,
num_coords=33,
red_color=(0, 0, 255),
show_window=False,
steps_per_epoch=100,
train_split=0.2,
validation_steps=50,
verbose=1,
white_color=(224, 224, 224),
transformer_layers = 8,
num_heads = 4,
projection_dim = 99,
transformer_units = [99 * 2,99],
mlp_head_units = [1024, 512],
hdf_name = 'data.h5',
create_hdf = True,
keep_frame = 1 ,
):
# Train.py
self.KEEP_FRAME = keep_frame
self.HIST_WINDOW = hist_windows
self.HORIZON = horizon
self.TRAIN_SPLIT = train_split
self.MODEL_NAME = model_name
self.MODEL = model
self.BATCH_SIZE = batch_size
self.BUFFER_SIZE = buffer_size
self.EPOCHS = epochs
self.STEPS_PER_EPOCH = steps_per_epoch
self.VALIDATION_STEPS = validation_steps
self.VERBOSE = verbose
self.HDF = hdf_name
self.CREATE_HDF= create_hdf
# Multihead attention
self.TRANSFORMER_LAYERS = transformer_layers
self.NUM_HEADS = num_heads
self.PROJECTION_DIM =projection_dim
self.TRANSFORMER_UNITS =transformer_units
self.MLP_HEAD_UNITS =mlp_head_units
# PoseDataGenerator
self.SHOW_WINDOW = show_window
self.NUM_COORDS = num_coords
# Plot3D
self.WHITE_COLOR = white_color
self.BLACK_COLOR = black_color
self.RED_COLOR = red_color
self.GREEN_COLOR = green_color
self.BLUE_COLOR = blue_color
self._PRESENCE_THRESHOLD = _presence_threshold
self._VISIBILITY_THRESHOLD = _visibility_threshold
self._RGB_CHANNELS = _rgb_channels
self.POSE_CONNECTIONS = mp_pose.POSE_CONNECTIONS