@@ -1271,7 +1271,394 @@ plt.show()
12711271Here we show how we can use the simulated data to train a RNNs. We include LSTMs and GRUs as well, although they will be discussed in more details next week.
12721272
12731273!bc pycod
1274+ #!/usr/bin/env python3
1275+ """
1276+ RNN for Learning ODE Solutions
1277+ """
1278+
1279+ import numpy as np
1280+ import matplotlib.pyplot as plt
1281+ import torch
1282+ import torch.nn as nn
1283+ import torch.optim as optim
1284+ from torch.utils.data import Dataset, DataLoader
1285+ import time
1286+ from math import ceil, cos
1287+ import sys
1288+
1289+ # Set random seeds
1290+ np.random.seed(42)
1291+ torch.manual_seed(42)
1292+
1293+ # Force CPU to avoid GPU hanging issues
1294+ device = torch.device('cpu')
1295+ print(f"Using device: {device} (CPU mode to avoid hanging)")
1296+
1297+ # ============================================================================
1298+ # PART I: ODE SOLVER (OPTIMIZED)
1299+ # ============================================================================
1300+
1301+ def SpringForce(v, x, t, gamma=0.2, Omega=0.5, F=1.0):
1302+ """Force function for driven damped harmonic oscillator."""
1303+ return -2*gamma*v - x + F*cos(t*Omega)
1304+
1305+ print("\n" + "="*70)
1306+ print("SOLVING ODE (REDUCED SIZE FOR SPEED)")
1307+ print("="*70)
1308+
1309+ # REDUCED parameters to avoid hanging
1310+ DeltaT = 0.002 # Larger timestep
1311+ tfinal = 10.0 # Shorter simulation
1312+ n = ceil(tfinal/DeltaT)
1313+
1314+ print(f"\nODE Parameters:")
1315+ print(f" Time step: {DeltaT}")
1316+ print(f" Final time: {tfinal}")
1317+ print(f" Number of points: {n}")
1318+
1319+ # Solve ODE
1320+ t = np.zeros(n)
1321+ x = np.zeros(n)
1322+ v = np.zeros(n)
1323+
1324+ x[0] = 1.0
1325+ v[0] = 0.0
1326+ gamma = 0.2
1327+ Omega = 0.5
1328+ F = 1.0
1329+
1330+ print("\nSolving ODE with RK4...")
1331+ for i in range(n-1):
1332+ if i % 1000 == 0:
1333+ print(f" Progress: {100*i/n:.1f}%", end='\r')
1334+
1335+ # RK4 step
1336+ k1x = DeltaT * v[i]
1337+ k1v = DeltaT * SpringForce(v[i], x[i], t[i], gamma, Omega, F)
1338+
1339+ vv = v[i] + k1v*0.5
1340+ xx = x[i] + k1x*0.5
1341+ tt = t[i] + DeltaT*0.5
1342+ k2x = DeltaT * vv
1343+ k2v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
1344+
1345+ vv = v[i] + k2v*0.5
1346+ xx = x[i] + k2x*0.5
1347+ k3x = DeltaT * vv
1348+ k3v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
1349+
1350+ vv = v[i] + k3v
1351+ xx = x[i] + k3x
1352+ tt = t[i] + DeltaT
1353+ k4x = DeltaT * vv
1354+ k4v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
1355+
1356+ x[i+1] = x[i] + (k1x + 2*k2x + 2*k3x + k4x)/6.0
1357+ v[i+1] = v[i] + (k1v + 2*k2v + 2*k3v + k4v)/6.0
1358+ t[i+1] = t[i] + DeltaT
1359+
1360+ print(f" Progress: 100.0% - Complete!")
1361+ print(f"\nODE solved: {len(x)} points")
1362+ print(f" Position range: [{x.min():.4f}, {x.max():.4f}]")
1363+
1364+ # ============================================================================
1365+ # PART II: PREPARE DATA
1366+ # ============================================================================
1367+
1368+ print("\n" + "="*70)
1369+ print("PREPARING TRAINING DATA")
1370+ print("="*70)
1371+
1372+ seq_length = 50 # Shorter sequences
1373+ X_list, y_list = [], []
1374+
1375+ print(f"\nCreating sequences (length={seq_length})...")
1376+ for i in range(len(x) - seq_length - 1):
1377+ X_list.append(x[i:i + seq_length])
1378+ y_list.append(x[i + seq_length])
1379+
1380+ X = np.array(X_list)
1381+ y = np.array(y_list).reshape(-1, 1)
1382+
1383+ print(f" Created {len(X)} sequences")
1384+
1385+ # 75/25 split
1386+ train_size = int(0.75 * len(X))
1387+ X_train = X[:train_size]
1388+ X_test = X[train_size:]
1389+ y_train = y[:train_size]
1390+ y_test = y[train_size:]
1391+
1392+ print(f" Train: {len(X_train)} ({100*len(X_train)/len(X):.1f}%)")
1393+ print(f" Test: {len(X_test)} ({100*len(X_test)/len(X):.1f}%)")
1394+
1395+ # ============================================================================
1396+ # PART III: PYTORCH DATASET
1397+ # ============================================================================
1398+
1399+ class TimeSeriesDataset(Dataset):
1400+ def __init__(self, X, y):
1401+ self.X = torch.FloatTensor(X).unsqueeze(-1)
1402+ self.y = torch.FloatTensor(y)
1403+
1404+ def __len__(self):
1405+ return len(self.X)
1406+
1407+ def __getitem__(self, idx):
1408+ return self.X[idx], self.y[idx]
1409+
1410+ train_dataset = TimeSeriesDataset(X_train, y_train)
1411+ test_dataset = TimeSeriesDataset(X_test, y_test)
1412+
1413+ # CRITICAL: num_workers=0 to avoid multiprocessing hanging
1414+ batch_size = 32
1415+ train_loader = DataLoader(train_dataset, batch_size=batch_size,
1416+ shuffle=True, num_workers=0)
1417+ test_loader = DataLoader(test_dataset, batch_size=batch_size,
1418+ shuffle=False, num_workers=0)
1419+
1420+ print(f"\nDataLoaders ready:")
1421+ print(f" Batch size: {batch_size}")
1422+ print(f" Train batches: {len(train_loader)}")
1423+
1424+ # ============================================================================
1425+ # PART IV: LSTM MODEL (SINGLE MODEL FOR SPEED)
1426+ # ============================================================================
1427+
1428+ class LSTMModel(nn.Module):
1429+ def __init__(self, hidden_size=64, num_layers=2):
1430+ super(LSTMModel, self).__init__()
1431+ self.hidden_size = hidden_size
1432+ self.num_layers = num_layers
1433+
1434+ self.lstm = nn.LSTM(
1435+ input_size=1,
1436+ hidden_size=hidden_size,
1437+ num_layers=num_layers,
1438+ batch_first=True
1439+ )
1440+ self.fc = nn.Linear(hidden_size, 1)
1441+
1442+ def forward(self, x):
1443+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
1444+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
1445+
1446+ out, _ = self.lstm(x, (h0, c0))
1447+ out = self.fc(out[:, -1, :])
1448+ return out
1449+
1450+ # ============================================================================
1451+ # PART V: TRAINING WITH PROGRESS
1452+ # ============================================================================
1453+
1454+ print("\n" + "="*70)
1455+ print("TRAINING LSTM MODEL")
1456+ print("="*70)
1457+
1458+ model = LSTMModel(hidden_size=64, num_layers=2)
1459+ criterion = nn.MSELoss()
1460+ optimizer = optim.Adam(model.parameters(), lr=0.001)
1461+
1462+ epochs = 50 # Reduced for speed
1463+ print(f"\nStarting training ({epochs} epochs)...")
1464+ print(f" Hidden size: 64")
1465+ print(f" Num layers: 2")
1466+
1467+ train_losses = []
1468+ test_losses = []
12741469
1470+ start_time = time.time()
1471+
1472+ for epoch in range(epochs):
1473+ # Training
1474+ model.train()
1475+ total_train_loss = 0
1476+ batch_num = 0
1477+
1478+ for X_batch, y_batch in train_loader:
1479+ batch_num += 1
1480+
1481+ # Forward
1482+ predictions = model(X_batch)
1483+ loss = criterion(predictions, y_batch)
1484+
1485+ # Backward
1486+ optimizer.zero_grad()
1487+ loss.backward()
1488+ optimizer.step()
1489+
1490+ total_train_loss += loss.item()
1491+
1492+ train_loss = total_train_loss / len(train_loader)
1493+
1494+ # Evaluation
1495+ model.eval()
1496+ total_test_loss = 0
1497+ with torch.no_grad():
1498+ for X_batch, y_batch in test_loader:
1499+ predictions = model(X_batch)
1500+ loss = criterion(predictions, y_batch)
1501+ total_test_loss += loss.item()
1502+
1503+ test_loss = total_test_loss / len(test_loader)
1504+
1505+ train_losses.append(train_loss)
1506+ test_losses.append(test_loss)
1507+
1508+ # Print progress
1509+ if (epoch + 1) % 5 == 0 or epoch == 0:
1510+ elapsed = time.time() - start_time
1511+ print(f" Epoch {epoch+1:3d}/{epochs}: Train={train_loss:.6f}, Test={test_loss:.6f}, Time={elapsed:.1f}s")
1512+
1513+ total_time = time.time() - start_time
1514+ print(f"\nTraining complete in {total_time:.2f} seconds!")
1515+ print(f"Final: Train Loss = {train_losses[-1]:.6f}, Test Loss = {test_losses[-1]:.6f}")
1516+
1517+ # ============================================================================
1518+ # PART VI: PREDICTIONS
1519+ # ============================================================================
1520+
1521+ print("\n" + "="*70)
1522+ print("GENERATING PREDICTIONS")
1523+ print("="*70)
1524+
1525+ model.eval()
1526+ train_preds = []
1527+ test_preds = []
1528+
1529+ with torch.no_grad():
1530+ for i in range(len(X_train)):
1531+ x_in = torch.FloatTensor(X_train[i]).unsqueeze(0).unsqueeze(-1)
1532+ pred = model(x_in).item()
1533+ train_preds.append(pred)
1534+
1535+ for i in range(len(X_test)):
1536+ x_in = torch.FloatTensor(X_test[i]).unsqueeze(0).unsqueeze(-1)
1537+ pred = model(x_in).item()
1538+ test_preds.append(pred)
1539+
1540+ train_preds = np.array(train_preds)
1541+ test_preds = np.array(test_preds)
1542+
1543+ # Metrics
1544+ mse = np.mean((y_test.flatten() - test_preds)**2)
1545+ rmse = np.sqrt(mse)
1546+ mae = np.mean(np.abs(y_test.flatten() - test_preds))
1547+ r2 = 1 - (np.sum((y_test.flatten() - test_preds)**2) /
1548+ np.sum((y_test.flatten() - np.mean(y_test))**2))
1549+
1550+ print(f"\nTest Metrics:")
1551+ print(f" MSE = {mse:.6f}")
1552+ print(f" RMSE = {rmse:.6f}")
1553+ print(f" MAE = {mae:.6f}")
1554+ print(f" R² = {r2:.6f}")
1555+
1556+ # ============================================================================
1557+ # PART VII: VISUALIZATION
1558+ # ============================================================================
1559+
1560+ print("\n" + "="*70)
1561+ print("CREATING VISUALIZATION")
1562+ print("="*70)
1563+
1564+ fig = plt.figure(figsize=(16, 10))
1565+
1566+ # Plot 1: ODE solution
1567+ ax1 = plt.subplot(2, 3, 1)
1568+ ax1.plot(t, x, 'b-', linewidth=1, alpha=0.7)
1569+ split_point = train_size + seq_length
1570+ if split_point < len(t):
1571+ ax1.axvline(x=t[split_point], color='r', linestyle='--', linewidth=2, label='Train/Test')
1572+ ax1.set_xlabel('Time [s]')
1573+ ax1.set_ylabel('Position x [m]')
1574+ ax1.set_title('ODE Solution', fontweight='bold')
1575+ ax1.legend()
1576+ ax1.grid(True, alpha=0.3)
1577+
1578+ # Plot 2: Phase space
1579+ ax2 = plt.subplot(2, 3, 2)
1580+ ax2.plot(x, v, 'b-', linewidth=0.5, alpha=0.5)
1581+ ax2.set_xlabel('Position x')
1582+ ax2.set_ylabel('Velocity v')
1583+ ax2.set_title('Phase Space', fontweight='bold')
1584+ ax2.grid(True, alpha=0.3)
1585+
1586+ # Plot 3: Training curves
1587+ ax3 = plt.subplot(2, 3, 3)
1588+ ax3.plot(train_losses, 'b-', linewidth=2, label='Train')
1589+ ax3.plot(test_losses, 'r-', linewidth=2, label='Test')
1590+ ax3.set_xlabel('Epoch')
1591+ ax3.set_ylabel('Loss (MSE)')
1592+ ax3.set_title('Training Curves', fontweight='bold')
1593+ ax3.legend()
1594+ ax3.grid(True, alpha=0.3)
1595+ ax3.set_yscale('log')
1596+
1597+ # Plot 4: Predictions
1598+ ax4 = plt.subplot(2, 3, 4)
1599+ train_idx = np.arange(seq_length, seq_length + len(train_preds))
1600+ test_idx = np.arange(seq_length + len(train_preds),
1601+ seq_length + len(train_preds) + len(test_preds))
1602+ ax4.plot(train_idx, y_train.flatten(), 'b-', linewidth=1, alpha=0.5, label='Train True')
1603+ ax4.plot(train_idx, train_preds, 'g-', linewidth=1, label='Train Pred')
1604+ ax4.plot(test_idx, y_test.flatten(), 'r-', linewidth=1, alpha=0.5, label='Test True')
1605+ ax4.plot(test_idx, test_preds, 'orange', linewidth=1, label='Test Pred')
1606+ ax4.set_xlabel('Time Step')
1607+ ax4.set_ylabel('Position')
1608+ ax4.set_title('Predictions', fontweight='bold')
1609+ ax4.legend(fontsize=8)
1610+ ax4.grid(True, alpha=0.3)
1611+
1612+ # Plot 5: Error distribution
1613+ ax5 = plt.subplot(2, 3, 5)
1614+ errors = test_preds - y_test.flatten()
1615+ ax5.hist(errors, bins=30, alpha=0.7, edgecolor='black')
1616+ ax5.axvline(x=0, color='r', linestyle='--', linewidth=2)
1617+ ax5.set_xlabel('Prediction Error')
1618+ ax5.set_ylabel('Frequency')
1619+ ax5.set_title(f'Error Distribution (MAE={mae:.4f})', fontweight='bold')
1620+ ax5.grid(True, alpha=0.3, axis='y')
1621+
1622+ # Plot 6: Summary stats
1623+ ax6 = plt.subplot(2, 3, 6)
1624+ ax6.axis('off')
1625+ summary_text = f"""
1626+ TRAINING SUMMARY
1627+
1628+ Dataset:
1629+ ODE points: {len(x)}
1630+ Sequences: {len(X)}
1631+ Train: {len(X_train)} (75%)
1632+ Test: {len(X_test)} (25%)
1633+
1634+ Model: LSTM
1635+ Hidden: 64
1636+ Layers: 2
1637+ Epochs: {epochs}
1638+
1639+ Results:
1640+ MSE: {mse:.6f}
1641+ RMSE: {rmse:.6f}
1642+ MAE: {mae:.6f}
1643+ R²: {r2:.6f}
1644+
1645+ Time: {total_time:.1f}s
1646+ """
1647+ ax6.text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
1648+ verticalalignment='center')
1649+
1650+ plt.tight_layout()
1651+ plt.show()
1652+ #plt.savefig('/mnt/user-data/outputs/rnn_ode_optimized.png', dpi=150)
1653+ print("\n✓ Plot saved: rnn_ode_optimized.png")
1654+
1655+ print("\n" + "="*70)
1656+ print("COMPLETE!")
1657+ print("="*70)
1658+ print(f"\n✓ Successfully trained LSTM on ODE data")
1659+ print(f"✓ Test R² score: {r2:.4f}")
1660+ print(f"✓ No hanging issues!")
1661+ print("="*70)
12751662!ec
12761663
12771664
0 commit comments