-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathMPFNet-C.py
More file actions
198 lines (144 loc) · 8.4 KB
/
MPFNet-C.py
File metadata and controls
198 lines (144 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics.pairwise import cosine_similarity
import os
import glob
import torch.nn as nn
import torch.optim as optim
from CA_I3D import CAI3D
# Dataset class to load and preprocess the data
class MEFeaturesDataset(Dataset):
def __init__(self, dataset_name, root_dir, transform=None):
self.dataset_name = dataset_name
self.root_dir = root_dir
self.transform = transform
self.data = []
self.labels = []
self._load_data()
def _load_data(self):
if self.dataset_name == 'SMIC':
smic_dir = os.path.join(self.root_dir, 'ME_features', 'SMIC')
flow_feature_dir = os.path.join(smic_dir, 'flow_feature')
frame_diff_feature_dir = os.path.join(smic_dir, 'frame_diff_feature')
flow_features = glob.glob(os.path.join(flow_feature_dir, '*.npy'))
frame_diff_features = glob.glob(os.path.join(frame_diff_feature_dir, '*.npy'))
for flow_file, frame_diff_file in zip(flow_features, frame_diff_features):
flow_feature = np.load(flow_file)
frame_diff_feature = np.load(frame_diff_file)
# Concatenate along the channel dimension
combined_feature = np.concatenate((flow_feature, frame_diff_feature), axis=-1)
label = int(flow_file.split('_')[0]) # Extract label from filename
self.data.append(combined_feature)
self.labels.append(label)
# For other datasets like 'CASME II', 'SAMM', 'MEGC2019-CD'
else:
dataset_dir = os.path.join(self.root_dir, 'ME_features', self.dataset_name)
feature_files = glob.glob(os.path.join(dataset_dir, '*.npy'))
for feature_file in feature_files:
feature = np.load(feature_file)
label = int(feature_file.split('_')[0]) # Extract label from filename (adjust as needed)
self.data.append(feature)
self.labels.append(label)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
feature = self.data[idx]
label = self.labels[idx]
if self.transform:
feature = self.transform(feature)
return torch.tensor(feature, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
# Meta-Learning Module with Cascade and Residual Architecture
class MetaLearningModule(nn.Module):
def __init__(self, gfe_model_path, afe_model_path, input_dim, output_dim):
super(MetaLearningModule, self).__init__()
# Load GFE and AFE models
self.gfe = CAI3D(input_dim=input_dim, output_dim=output_dim) # GFE feature extractor
self.afe = CAI3D(input_dim=input_dim + output_dim, output_dim=output_dim) # AFE feature extractor
# Load pre-trained model weights for GFE and AFE
self.gfe.load_state_dict(torch.load(gfe_model_path))
self.afe.load_state_dict(torch.load(afe_model_path))
def forward(self, support_set, query_set):
# GFE feature extraction
support_gfe_features = self.gfe(support_set) # Extract features from support set using GFE
query_gfe_features = self.gfe(query_set) # Extract features from query set using GFE
# Residual connection: Concatenate GFE features with original features (along channel dimension)
support_gfe_residual = torch.cat((support_set, support_gfe_features), dim=1) # Concatenate along the channel dimension
query_gfe_residual = torch.cat((query_set, query_gfe_features), dim=1) # Concatenate along the channel dimension
# AFE feature extraction using the residual features
support_afe_features = self.afe(support_gfe_residual) # AFE processing with residual connection
query_afe_features = self.afe(query_gfe_residual) # AFE processing with residual connection
# Compute class centroids for the support set
wc = support_afe_features.mean(dim=0) # Compute average feature vector (centroid) for each class
# Compute similarity (cosine distance) between query features and class centroids
d_FE = cosine_similarity(query_afe_features.detach().cpu().numpy(), wc.detach().cpu().numpy()) # Compute cosine similarity
return d_FE
# Function to create 3-way 5-shot or 5-way 5-shot tasks
def create_task(dataset, num_classes=3, num_shots=5):
# Create task for 3-way 5-shot or 5-way 5-shot
classes = list(set(dataset.labels)) # Get all unique classes
selected_classes = random.sample(classes, num_classes) # Randomly select classes for the task
support_set = []
query_set = []
support_labels = []
query_labels = []
# For each selected class, select `num_shots` for the support set and 1 for the query set
for class_label in selected_classes:
class_samples = [i for i, label in enumerate(dataset.labels) if label == class_label]
selected_samples = random.sample(class_samples, num_shots + 1) # Select 5 shots + 1 query sample
# First `num_shots` samples go to support set, last one to query set
support_set.extend([dataset.data[i] for i in selected_samples[:-1]])
query_set.extend([dataset.data[i] for i in selected_samples[-1:]])
support_labels.extend([dataset.labels[i] for i in selected_samples[:-1]])
query_labels.extend([dataset.labels[i] for i in selected_samples[-1:]])
support_set = torch.stack(support_set)
query_set = torch.stack(query_set)
return support_set, query_set, torch.tensor(support_labels), torch.tensor(query_labels)
# Classification Module: Takes fused similarity features and classifies them
class ClassificationModule(nn.Module):
def __init__(self, input_dim, num_classes):
super(ClassificationModule, self).__init__()
self.fc = nn.Linear(input_dim, num_classes) # Fully connected layer for classification
def forward(self, x):
return torch.softmax(self.fc(x), dim=1) # Softmax to get probability distribution over classes
# Loss function (Cross-Entropy Loss) for training
criterion = nn.CrossEntropyLoss()
# Optimizer (Adam optimizer for model training)
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
# Training loop function
def train(model, train_loader, num_epochs=10):
model.train()
for epoch in range(num_epochs):
total_loss = 0
for support_set, query_set, support_labels, query_labels in train_loader:
optimizer.zero_grad()
# Get the fused similarity vector from the MetaLearningModule
d_FE = model(support_set, query_set)
# Classify using the ClassificationModule
output = ClassificationModule(d_FE)
# Compute loss (Cross-Entropy)
loss = criterion(output, query_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}")
# Input feature dimensions (128x128x5x10)
input_dim = 128 * 128 * 5 * 10 # Flattened feature size
output_dim = 128 # Output dimensionality of the feature space
num_classes = 3 # Number of emotion classes (can be adjusted to 5 classes)
# Define paths to pre-trained models for GFE and AFE
gfe_model_path = "models/PLTN.pth"
afe_model_path = "models/PLSM.pth"
# Define the dataset to use
dataset_name = 'SMIC' # Change to 'CASME II', 'SAMM', or 'MEGC2019-CD' as needed
root_dir = 'ME_features'
# Initialize the dataset and DataLoader
dataset = MEFeaturesDataset(dataset_name=dataset_name, root_dir=root_dir)
# Create the train_loader using the custom task creation function
support_set, query_set, support_labels, query_labels = create_task(dataset, num_classes=3, num_shots=5)
# Initialize the model
model = MetaLearningModule(gfe_model_path, afe_model_path, input_dim, output_dim)
# Start training
train(model, [(support_set, query_set, support_labels, query_labels)], num_epochs=10)
# predicted_class = predict(model, support_set, query_set)