-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLinearRegression.py
More file actions
61 lines (46 loc) · 2.19 KB
/
LinearRegression.py
File metadata and controls
61 lines (46 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
class MyLinearRegression:
"""Linear Regression: y = w0 + w1*x1 + w2*x2 + ... + wn*xn"""
def __init__(self, fit_intercept: bool = True) -> None:
"""
Args:
fit_intercept: True = có sử dụng bias trick, False = không sử dụng bias trick
"""
self.fit_intercept = fit_intercept
@property
def coef_(self) -> np.ndarray:
"""Trả về hệ số [w1, w2, ..., wn] (không bao gồm w0)"""
if not hasattr(self, 'w'): # Kiểm tra đã train chưa
raise ValueError("Model chưa được train")
if self.fit_intercept:
return self.w[1:] # Bỏ w0, chỉ lấy [w1, w2, ...]
return self.w # Trả về toàn bộ nếu không có w0
def fit(self, X, y) -> None:
"""
Train model bằng công thức: w = (X^T X)^(-1) X^T y
Args:
X: Ma trận features (n_samples, n_features)
y: Vector target (n_samples,)
"""
new_data = np.copy(X) # Copy để không thay đổi dữ liệu gốc
if self.fit_intercept: # Nếu model có w0
intercept = np.ones((new_data.shape[0], 1)) # Tạo cột 1 cho w0
new_data = np.hstack((intercept, new_data)) # Ghép cột 1 vào đầu: [1, x1, x2, ...]
self.data = new_data # Lưu data đã thêm intercept
# Sử dụng công thức tính w = (X^T X)^(-1) X^T y
# self.data.T: X^T
# self.data: X
self.w = np.linalg.pinv(self.data.T @ self.data) @ self.data.T @ y
def predict(self, x):
"""
Dự đoán: y_pred = X @ w
Args:
x: Ma trận features (n_samples, n_features)
Returns:
np.ndarray: Vector dự đoán (n_samples,)
"""
new_x = np.copy(x) # Copy để không thay đổi dữ liệu gốc
if self.fit_intercept: # Nếu model có w0
intercept = np.ones((new_x.shape[0], 1)) # Tạo cột 1 để khớp với w
new_x = np.hstack((intercept, new_x)) # Ghép cột 1 vào đầu
return new_x @ self.w # Phép nhân ma trận: y = Xw