Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Change Log
All notable changes to this project will be documented in this file.

## 2.1.3 - 2026-02
### Runner
- Separate writer from runner, sorting schema to align to written table
- Added a debug run option that return dataframe without writing anything

## 2.1.2 - 2026-01
### Maker
- Feature type normalization to Double and Long
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rialto"
version = "2.1.2"
version = "2.1.3"
description = "Rialto is a framework for building and deploying machine learning features in a scalable and reusable way. It provides a set of tools that make it easy to define and deploy features and models, and it provides a way to orchestrate the execution of these features and models."
authors = [
{ name = "Marek Dobransky", email = "marekdobr@gmail.com" },
Expand Down
6 changes: 3 additions & 3 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def normalize_types(df: DataFrame) -> DataFrame:

return df.select(
[
F.col(f.name).cast(DoubleType())
F.when(F.col(f.name).isNotNull(), F.col(f.name).cast(DoubleType())).otherwise(F.lit(None)).alias(f.name)
if isinstance(f.dataType, float_types)
else F.col(f.name).cast(LongType())
else F.when(F.col(f.name).isNotNull(), F.col(f.name).cast(LongType())).otherwise(F.lit(None)).alias(f.name)
if isinstance(f.dataType, int_types)
else F.col(f.name)
else F.when(F.col(f.name).isNotNull(), F.col(f.name)).otherwise(F.lit(None)).alias(f.name)
for f in df.schema.fields
]
)
53 changes: 23 additions & 30 deletions rialto/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from rialto.runner.reporting.tracker import Tracker
from rialto.runner.table import Table
from rialto.runner.transformation import Transformation
from rialto.runner.writer import Writer


class Runner:
Expand All @@ -52,7 +53,7 @@ def __init__(
self.rerun = rerun
self.skip_dependencies = skip_dependencies
self.op = op
self.merge_schema = merge_schema
self.writer = Writer(spark, merge_schema=merge_schema)
self.tracker = Tracker(
mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark
)
Expand Down Expand Up @@ -96,34 +97,6 @@ def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineC

return df

def _create_schema(self, table: Table):
"""
Create schema if it doesn't exist

:param schema_path: path to schema
"""
self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {table.get_schema_path()}")

def _write(self, df: DataFrame, info_date: date, table: Table) -> None:
"""
Write dataframe to storage

:param df: dataframe to write
:param info_date: date to partition
:param table: path to write to
:return: None
"""
self._create_schema(table)

df = df.withColumn(table.partition, F.lit(info_date))
if self.merge_schema is True:
df.write.partitionBy(table.partition).mode("overwrite").option("mergeSchema", "true").saveAsTable(
table.get_table_path()
)
else:
df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path())
logger.info(f"Results writen to {table.get_table_path()}")

def _check_written(self, info_date: date, table: Table) -> int:
"""
Check if there are records written for given date
Expand Down Expand Up @@ -238,7 +211,7 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat

feature_group = utils.load_module(pipeline.module)
df = self._execute(feature_group, run_date, pipeline)
self._write(df, info_date, target)
self.writer.write(df, info_date, target)
records = self._check_written(info_date, target)
logger.info(f"Generated {records} records")
if records == 0:
Expand Down Expand Up @@ -331,3 +304,23 @@ def __call__(self):
print(self.tracker.records)
self.tracker.report_by_mail()
logger.info("Execution finished")

def debug(self) -> DataFrame:
"""Debug mode - run only first op for one date and return the resulting dataframe"""
logger.info("Running in debug mode")
if self.op:
pipeline = [p for p in self.config.pipelines if p.name == self.op][0]
else:
pipeline = self.config.pipelines[0]

target = Table(
schema_path=pipeline.target.target_schema,
class_name=pipeline.module.python_class,
partition=pipeline.target.target_partition_column,
)
selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target)
if len(selected_run_dates) > 0:
df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline)
return self.writer._process(df, selected_info_dates[0], target)
else:
logger.info("No dates to run in debug mode")
96 changes: 96 additions & 0 deletions rialto/runner/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2022 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["Writer"]

from datetime import date
from typing import List

import pyspark.sql.functions as F
from loguru import logger
from pyspark.sql import DataFrame, SparkSession

from rialto.runner.table import Table


class Writer:
"""Supporting class for runner"""

def __init__(self, spark: SparkSession, merge_schema=False):
self.spark = spark
self.merge_schema = merge_schema

def _create_schema(self, table: Table):
"""
Create schema if it doesn't exist

:param schema_path: path to schema
"""
self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {table.get_schema_path()}")

def _get_existing_columns(self, table: Table):
"""
Get existing schema of table if it exists

:param table: table to check for
:return: existing columns or None
"""
try:
return self.spark.table(table.get_table_path()).columns
except Exception as e:
logger.warning(f"Could not get existing schema for {table.get_table_path()}: {e}")
return None

def _align_schema(self, df: DataFrame, existing_columns: List) -> DataFrame:
"""
Align schema of dataframe to existing schema of table if it exists

:param df: dataframe to align
:param table: table to check for existing schema
:return: dataframe with aligned schema
"""
if existing_columns is not None:
return df.select(
*[F.col(c) for c in existing_columns if c in df.columns],
*[F.col(c) for c in df.columns if c not in existing_columns],
)
return df

def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame:
df = df.withColumn(table.partition, F.lit(info_date))

df = self._align_schema(df, self._get_existing_columns(table))

return df

def write(self, df: DataFrame, info_date: date, table: Table) -> None:
"""
Write dataframe to storage

:param df: dataframe to write
:param info_date: date to partition
:param table: path to write to
:return: None
"""
self._create_schema(table)

df = self._process(df, info_date, table)

if self.merge_schema is True:
df.write.partitionBy(table.partition).mode("overwrite").option("mergeSchema", "true").saveAsTable(
table.get_table_path()
)
else:
df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path())
logger.info(f"Results writen to {table.get_table_path()}")
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.