Skip to content

Commit cb6ffeb

Browse files
Add advanced numerical embeddings with periodic, PLE, and combined layers
Co-authored-by: piotr.laczkowski <piotr.laczkowski@gmail.com>
1 parent 933d411 commit cb6ffeb

11 files changed

Lines changed: 2739 additions & 43 deletions

docs/advanced/numerical-embeddings.md

Lines changed: 255 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,63 @@ The `NumericalEmbedding` layer processes each numerical feature through two para
171171
- Adaptively weights continuous vs. discrete representations
172172
- Learns optimal combination per feature and dimension
173173

174+
### Periodic Embeddings (`PeriodicEmbedding`)
175+
176+
The `PeriodicEmbedding` layer uses trigonometric functions to capture cyclical patterns:
177+
178+
1. **Frequency Learning**:
179+
- Learns optimal frequencies for each feature
180+
- Supports multiple initialization strategies (uniform, log-uniform, constant)
181+
- Frequencies are constrained to be positive
182+
183+
2. **Periodic Transformation**:
184+
- Applies sin/cos transformations: `sin(freq * x)` and `cos(freq * x)`
185+
- Captures cyclical patterns and smooth, differentiable representations
186+
- Particularly effective for features with natural periodicity
187+
188+
3. **Post-Processing**:
189+
- Optional MLP for further feature transformation
190+
- Residual connections for stability
191+
- Batch normalization and dropout for regularization
192+
193+
### PLE Embeddings (`PLEEmbedding`)
194+
195+
The `Parameterized Linear Expansion` layer provides learnable piecewise linear transformations:
196+
197+
1. **Segment Learning**:
198+
- Learns optimal segment boundaries for each feature
199+
- Supports uniform and quantile-based initialization
200+
- Each segment has learnable slope and intercept
201+
202+
2. **Piecewise Linear Transformation**:
203+
- Applies different linear transformations to different input ranges
204+
- Captures complex non-linear patterns through piecewise approximation
205+
- Supports various activation functions (ReLU, Sigmoid, Tanh)
206+
207+
3. **Flexible Architecture**:
208+
- Configurable number of segments for precision vs. efficiency trade-off
209+
- Optional MLP and residual connections
210+
- Batch normalization and dropout for regularization
211+
212+
### Advanced Combined Embeddings (`AdvancedNumericalEmbedding`)
213+
214+
The `AdvancedNumericalEmbedding` layer combines multiple embedding approaches:
215+
216+
1. **Multi-Modal Processing**:
217+
- Supports any combination of periodic, PLE, and dual-branch embeddings
218+
- Learnable gates to combine different embedding types
219+
- Adaptive weighting per feature and dimension
220+
221+
2. **Flexible Configuration**:
222+
- Choose from `['periodic', 'ple', 'dual_branch']` embedding types
223+
- Configure each embedding type independently
224+
- Enable/disable gating mechanism
225+
226+
3. **Optimal Performance**:
227+
- Empirically closes the gap between MLPs/Transformers and tree-based baselines
228+
- Particularly effective on tabular tasks
229+
- Maintains interpretability while improving performance
230+
174231
```
175232
Input value
176233
┌────────┐ ┌────────┐
@@ -220,6 +277,35 @@ This approach is ideal for:
220277
| `numerical_dropout_rate` | float | 0.1 | Dropout rate for regularization |
221278
| `numerical_use_batch_norm` | bool | True | Apply batch normalization |
222279

280+
### Periodic Embeddings
281+
282+
| Parameter | Type | Default | Description |
283+
|-----------|------|---------|-------------|
284+
| `use_periodic_embedding` | bool | False | Enable periodic embeddings |
285+
| `num_frequencies` | int | 4 | Number of frequency components |
286+
| `frequency_init` | str | "log_uniform" | Frequency initialization method |
287+
| `min_frequency` | float | 1e-4 | Minimum frequency for initialization |
288+
| `max_frequency` | float | 1e2 | Maximum frequency for initialization |
289+
| `use_residual` | bool | True | Use residual connections |
290+
291+
### PLE Embeddings
292+
293+
| Parameter | Type | Default | Description |
294+
|-----------|------|---------|-------------|
295+
| `use_ple_embedding` | bool | False | Enable PLE embeddings |
296+
| `num_segments` | int | 8 | Number of linear segments |
297+
| `segment_init` | str | "uniform" | Segment initialization method |
298+
| `ple_activation` | str | "relu" | Activation function for PLE |
299+
| `use_residual` | bool | True | Use residual connections |
300+
301+
### Advanced Combined Embeddings
302+
303+
| Parameter | Type | Default | Description |
304+
|-----------|------|---------|-------------|
305+
| `use_advanced_combined_embedding` | bool | False | Enable combined embeddings |
306+
| `embedding_types` | list | ["dual_branch"] | List of embedding types to use |
307+
| `use_gating` | bool | True | Use learnable gates to combine embeddings |
308+
223309
### Global Embeddings
224310

225311
| Parameter | Type | Default | Description |
@@ -267,24 +353,23 @@ features_specs = {
267353
name="income",
268354
feature_type=FeatureType.FLOAT_RESCALED,
269355
use_embedding=True,
356+
embedding_type="periodic", # Use periodic embedding for income
270357
embedding_dim=8,
271-
num_bins=15,
272-
init_min=0,
273-
init_max=1000000
358+
num_frequencies=4
274359
),
275360
"debt_ratio": NumericalFeature(
276361
name="debt_ratio",
277362
feature_type=FeatureType.FLOAT_NORMALIZED,
278363
use_embedding=True,
364+
embedding_type="ple", # Use PLE for debt ratio
279365
embedding_dim=4,
280-
num_bins=8,
281-
init_min=0,
282-
init_max=1 # Ratio typically between 0-1
366+
num_segments=8
283367
),
284368
"credit_score": NumericalFeature(
285369
name="credit_score",
286370
feature_type=FeatureType.FLOAT_NORMALIZED,
287371
use_embedding=True,
372+
embedding_type="dual_branch", # Traditional dual-branch
288373
embedding_dim=6,
289374
num_bins=10,
290375
init_min=300,
@@ -294,21 +379,160 @@ features_specs = {
294379
name="payment_history",
295380
feature_type=FeatureType.FLOAT_NORMALIZED,
296381
use_embedding=True,
382+
embedding_type="combined", # Combined approach
297383
embedding_dim=8,
298-
num_bins=5,
299-
init_min=0,
300-
init_max=1 # Simplified score between 0-1
384+
num_frequencies=4,
385+
num_segments=8
301386
)
302387
}
303388

304-
# Create preprocessing model
389+
# Create preprocessing model with advanced embeddings
305390
preprocessor = PreprocessingModel(
306391
path_data="data/financial_data.csv",
307392
features_specs=features_specs,
308-
use_numerical_embedding=True,
309-
numerical_mlp_hidden_units=16,
310-
numerical_dropout_rate=0.2, # Higher dropout for financial data
311-
numerical_use_batch_norm=True
393+
use_advanced_numerical_embedding=True,
394+
use_periodic_embedding=True,
395+
use_ple_embedding=True,
396+
use_advanced_combined_embedding=True,
397+
embedding_dim=8,
398+
num_frequencies=4,
399+
num_segments=8,
400+
dropout_rate=0.2, # Higher dropout for financial data
401+
use_batch_norm=True
402+
)
403+
```
404+
405+
### Healthcare Patient Analysis with Periodic Embeddings
406+
407+
```python
408+
from kdp import PreprocessingModel
409+
from kdp.features import NumericalFeature
410+
from kdp.enums import FeatureType
411+
412+
# Define patient features with periodic embeddings for cyclical patterns
413+
features_specs = {
414+
"age": NumericalFeature(
415+
name="age",
416+
feature_type=FeatureType.FLOAT_NORMALIZED,
417+
use_embedding=True,
418+
embedding_type="periodic",
419+
embedding_dim=8,
420+
num_frequencies=6, # More frequencies for age patterns
421+
kwargs={
422+
"frequency_init": "constant",
423+
"min_frequency": 1e-3,
424+
"max_frequency": 1e2
425+
}
426+
),
427+
"bmi": NumericalFeature(
428+
name="bmi",
429+
feature_type=FeatureType.FLOAT_NORMALIZED,
430+
use_embedding=True,
431+
embedding_type="ple",
432+
embedding_dim=6,
433+
num_segments=12, # More segments for BMI precision
434+
kwargs={
435+
"segment_init": "uniform",
436+
"ple_activation": "relu"
437+
}
438+
),
439+
"blood_pressure": NumericalFeature(
440+
name="blood_pressure",
441+
feature_type=FeatureType.FLOAT_NORMALIZED,
442+
use_embedding=True,
443+
embedding_type="combined",
444+
embedding_dim=10,
445+
num_frequencies=4,
446+
num_segments=8,
447+
kwargs={
448+
"embedding_types": ["periodic", "ple"],
449+
"use_gating": True
450+
}
451+
)
452+
}
453+
454+
# Create preprocessing model
455+
preprocessor = PreprocessingModel(
456+
path_data="data/patient_data.csv",
457+
features_specs=features_specs,
458+
use_advanced_numerical_embedding=True,
459+
use_periodic_embedding=True,
460+
use_ple_embedding=True,
461+
use_advanced_combined_embedding=True,
462+
embedding_dim=8,
463+
num_frequencies=6,
464+
num_segments=12,
465+
frequency_init="constant",
466+
segment_init="uniform",
467+
ple_activation="relu",
468+
use_gating=True
469+
)
470+
```
471+
472+
### Time Series Forecasting with PLE Embeddings
473+
474+
```python
475+
from kdp import PreprocessingModel
476+
from kdp.features import NumericalFeature
477+
from kdp.enums import FeatureType
478+
479+
# Define time series features with PLE embeddings for trend capture
480+
features_specs = {
481+
"temperature": NumericalFeature(
482+
name="temperature",
483+
feature_type=FeatureType.FLOAT_NORMALIZED,
484+
use_embedding=True,
485+
embedding_type="periodic", # Periodic for seasonal patterns
486+
embedding_dim=12,
487+
num_frequencies=8,
488+
kwargs={
489+
"frequency_init": "log_uniform",
490+
"min_frequency": 1e-4,
491+
"max_frequency": 1e3
492+
}
493+
),
494+
"humidity": NumericalFeature(
495+
name="humidity",
496+
feature_type=FeatureType.FLOAT_NORMALIZED,
497+
use_embedding=True,
498+
embedding_type="ple", # PLE for humidity trends
499+
embedding_dim=8,
500+
num_segments=16,
501+
kwargs={
502+
"segment_init": "quantile",
503+
"ple_activation": "sigmoid"
504+
}
505+
),
506+
"pressure": NumericalFeature(
507+
name="pressure",
508+
feature_type=FeatureType.FLOAT_NORMALIZED,
509+
use_embedding=True,
510+
embedding_type="combined", # Combined for complex patterns
511+
embedding_dim=10,
512+
num_frequencies=6,
513+
num_segments=12,
514+
kwargs={
515+
"embedding_types": ["periodic", "ple", "dual_branch"],
516+
"use_gating": True
517+
}
518+
)
519+
}
520+
521+
# Create preprocessing model
522+
preprocessor = PreprocessingModel(
523+
path_data="data/weather_data.csv",
524+
features_specs=features_specs,
525+
use_advanced_numerical_embedding=True,
526+
use_periodic_embedding=True,
527+
use_ple_embedding=True,
528+
use_advanced_combined_embedding=True,
529+
embedding_dim=10,
530+
num_frequencies=8,
531+
num_segments=16,
532+
frequency_init="log_uniform",
533+
segment_init="quantile",
534+
ple_activation="sigmoid",
535+
use_gating=True
312536
)
313537
```
314538

@@ -349,27 +573,43 @@ preprocessor = PreprocessingModel(
349573
1. **Choose the Right Embedding Type**
350574
- Use individual embeddings for interpretability and precise control
351575
- Use global embeddings for efficiency with many numerical features
576+
- Use periodic embeddings for features with cyclical patterns (time, angles, seasons)
577+
- Use PLE embeddings for features with complex non-linear relationships
578+
- Use combined embeddings for maximum performance on challenging datasets
352579

353580
2. **Distribution-Aware Initialization**
354581
- Set `init_min` and `init_max` based on your data's actual distribution
355582
- Use domain knowledge to set meaningful boundary points
356583
- Initialize closer to anticipated feature range for faster convergence
584+
- For periodic embeddings, use log-uniform initialization for better frequency distribution
585+
- For PLE embeddings, use quantile-based initialization for data-driven segment boundaries
357586

358587
3. **Dimensionality Guidelines**
359588
- Start with `embedding_dim` = 4-8 for simple features
360589
- Use 8-16 for complex features with non-linear patterns
361590
- For global embeddings, scale with the number of features (16-64)
591+
- For periodic embeddings, use 4-8 frequencies for most features
592+
- For PLE embeddings, use 8-16 segments for smooth approximations
362593

363594
4. **Performance Tuning**
364595
- Increase `num_bins` for more granular discrete representations
365596
- Adjust `mlp_hidden_units` to 2-4x the embedding dimension
366597
- Use batch normalization for faster, more stable training
367598
- Adjust dropout based on dataset size (higher for small datasets)
599+
- For periodic embeddings, experiment with different frequency ranges
600+
- For PLE embeddings, try different activation functions (relu, sigmoid, tanh)
601+
602+
5. **Advanced Embedding Strategies**
603+
- **Periodic Embeddings**: Best for time-based features, angles, cyclical patterns
604+
- **PLE Embeddings**: Best for features with piecewise linear relationships
605+
- **Combined Embeddings**: Best for maximum performance, especially on tabular tasks
606+
- **Mixed Strategies**: Use different embedding types for different features based on their characteristics
368607

369-
5. **Combine with Other KDP Features**
608+
6. **Combine with Other KDP Features**
370609
- Pair with distribution-aware encoding for optimal numerical handling
371610
- Use with tabular attention to learn cross-feature interactions
372611
- Combine with feature selection for automatic dimensionality reduction
612+
- Use with transformer blocks for advanced feature interactions
373613

374614
## 🔗 Related Topics
375615

0 commit comments

Comments
 (0)