Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Change Log
All notable changes to this project will be documented in this file.

## 2.1.2 - 2026-01
### Maker
- Feature type normalization to Double and Long

## 2.1.1 - 2026-01
- Replaces 2.1.0 release

## 2.1.0 - 2025-10
### General
- Updated python version to 3.12 and pyspark to 4.0
- Migrated from poetry to UV
### Runner
- Added merge_schema manual override option
### Maker
- Added another feature decorator _@template_ to support feature to text conversion

## 2.0.11 - 2025-08-12
### Loader
Expand Down
36 changes: 32 additions & 4 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@

import inspect
import os
from typing import Any, List
from typing import Any

import pyspark.sql.functions as F
import yaml
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType
from pyspark.sql.types import (
ByteType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
LongType,
ShortType,
)

from rialto.common.env_yaml import EnvLoader

Expand Down Expand Up @@ -62,12 +70,32 @@ def get_caller_module() -> Any:
0th entry is this function
1st entry is the function which needs to know who called it
2nd entry is the calling function

Therefore, we'll return a module which contains the function at the 2nd place on the stack.

:return: Python Module containing the calling function.
"""

stack = inspect.stack()
last_stack = stack[2]
return inspect.getmodule(last_stack[0])


def normalize_types(df: DataFrame) -> DataFrame:
"""
Normalize data types in the DataFrame

Converts all decimal columns to FloatType and
all integer columns to LongType.
"""
float_types = (FloatType, DecimalType)
int_types = (ByteType, ShortType, IntegerType)

return df.select(
[
F.col(f.name).cast(DoubleType())
if isinstance(f.dataType, float_types)
else F.col(f.name).cast(LongType())
if isinstance(f.dataType, int_types)
else F.col(f.name)
for f in df.schema.fields
]
)
13 changes: 9 additions & 4 deletions rialto/maker/feature_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from loguru import logger
from pyspark.sql import DataFrame

from rialto.common.utils import normalize_types
from rialto.maker.containers import FeatureFunction, FeatureHolder


Expand Down Expand Up @@ -140,7 +141,8 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame:
if not keep_preexisting:
logger.info("Dropping non-selected columns")
self.data_frame = self.data_frame.select(*self.key, *feature_names)
return self._filter_null_keys(self.data_frame)
df = self._filter_null_keys(self.data_frame)
return normalize_types(df)

def _make_aggregated(self) -> DataFrame:
"""
Expand All @@ -154,7 +156,8 @@ def _make_aggregated(self) -> DataFrame:
aggregates.append(feature_function.callable().alias(feature_function.get_feature_name()))

self.data_frame = self.data_frame.groupBy(self.key).agg(*aggregates)
return self._filter_null_keys(self.data_frame)
df = self._filter_null_keys(self.data_frame)
return normalize_types(df)

def make(
self,
Expand Down Expand Up @@ -237,7 +240,8 @@ def make_single_feature(
self.make_date = make_date
feature_functions = self._register_module(features_module)
feature = self._find_feature(name, feature_functions)
return df.withColumn(feature.get_feature_name(), feature.callable()).select(feature.get_feature_name())
df = df.withColumn(feature.get_feature_name(), feature.callable()).select(feature.get_feature_name())
return normalize_types(df)

def make_single_agg_feature(
self,
Expand All @@ -261,7 +265,8 @@ def make_single_agg_feature(
self.make_date = make_date
feature_functions = self._register_module(features_module)
feature = self._find_feature(name, feature_functions)
return df.groupBy(key).agg(feature.callable().alias(feature.get_feature_name()))
df = df.groupBy(key).agg(feature.callable().alias(feature.get_feature_name()))
return normalize_types(df)


FeatureMaker = _FeatureMaker()
24 changes: 23 additions & 1 deletion tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import pyspark.sql.functions as F
import pytest
from numpy import dtype
from pyspark.sql.types import DoubleType, LongType, StringType

from rialto.common.utils import cast_decimals_to_floats
from rialto.common.utils import cast_decimals_to_floats, normalize_types


@pytest.fixture
Expand All @@ -29,6 +30,16 @@ def sample_df(spark):
return df.select("a", "b", "c", F.col("d").cast("decimal"), F.col("e").cast("decimal(18,5)"))


@pytest.fixture
def sample_df2(spark):
df = spark.createDataFrame(
[(1, 2.33, "str", 4.55, 5.66, 4), (1, 2.33, "str", 4.55, 5.66, 5), (1, 2.33, "str", 4.55, 5.66, 6)],
schema="a long, b float, c string, d float, e float, f int",
)

return df.select("a", "b", "c", F.col("d").cast("decimal"), F.col("e").cast("double"), "f")


def test_cast_decimals_to_floats(sample_df):
df_fixed = cast_decimals_to_floats(sample_df)

Expand All @@ -42,3 +53,14 @@ def test_cast_decimals_to_floats_topandas_works(sample_df):

assert df_pd.dtypes.iloc[3] == dtype("float32")
assert df_pd.dtypes.iloc[4] == dtype("float32")


def test_normalize_types(sample_df2):
df_fixed = normalize_types(sample_df2)

assert isinstance(df_fixed.schema["a"].dataType, LongType)
assert isinstance(df_fixed.schema["b"].dataType, DoubleType)
assert isinstance(df_fixed.schema["c"].dataType, StringType)
assert isinstance(df_fixed.schema["d"].dataType, DoubleType)
assert isinstance(df_fixed.schema["e"].dataType, DoubleType)
assert isinstance(df_fixed.schema["f"].dataType, LongType)