-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_models.py
More file actions
57 lines (47 loc) · 1.74 KB
/
train_models.py
File metadata and controls
57 lines (47 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pandas as pd
import re
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
import os
def parse_parameters(param_str):
params = {}
parts = param_str.split("|")
for part in parts:
part = part.strip()
if "=" in part:
key, val = part.split("=", 1)
val = re.sub(r"Mbps|ms|%|devices/km2|km/h", "", val).strip()
try:
params[key.strip()] = float(val)
except ValueError:
params[key.strip()] = val.strip()
return params
print("Loading dataset...")
df_raw = pd.read_csv(
"/Users/Agentic RAN/network_slicing_300.csv"
)
parsed_data = df_raw["Parameters"].apply(parse_parameters)
df = pd.DataFrame(list(parsed_data))
X = df[["Throughput", "Latency", "Reliability", "Density", "Mobility"]]
print("Training Random Forest Classifier...")
label_enc = LabelEncoder()
y = label_enc.fit_transform(df["Error"])
rf_model = RandomForestClassifier(
n_estimators=100, random_state=42, class_weight="balanced"
)
rf_model.fit(X, y)
print("Training KMeans Clusterer...")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
kmeans = KMeans(n_clusters=4, random_state=42, n_init="auto")
kmeans.fit(X_scaled)
# Save the models and preprocessors
models_dir = "/Users/Agentic RAN/models"
os.makedirs(models_dir, exist_ok=True)
joblib.dump(rf_model, os.path.join(models_dir, "rf_model.joblib"))
joblib.dump(kmeans, os.path.join(models_dir, "kmeans_model.joblib"))
joblib.dump(scaler, os.path.join(models_dir, "scaler.joblib"))
joblib.dump(label_enc, os.path.join(models_dir, "label_enc.joblib"))
print(f"Models successfully trained and saved to {models_dir}")