diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 10d73d7..1bb69fe 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -11,6 +11,7 @@ from .callable import * from .context import * from .enums import Enum +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..54f4a9d 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -12,10 +12,11 @@ """ import abc +import inspect import logging from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -27,8 +28,12 @@ ResultBase, ResultType, ) +from .local_persistence import create_ccflow_model from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -60,6 +65,25 @@ def _cached_signature(fn): return signature(fn) +def _callable_qualname(fn: Callable[..., Any]) -> str: + return getattr(fn, "__qualname__", type(fn).__qualname__) + + +def _declared_type_matches(actual: Any, expected: Any) -> bool: + if isinstance(expected, TypeVar): + return True + if get_origin(expected) is Union: + expected_args = tuple(arg for arg in get_args(expected) if isinstance(arg, type)) + if not expected_args: + return False + if get_origin(actual) is Union: + actual_args = tuple(arg for arg in get_args(actual) if isinstance(arg, type)) + return set(actual_args) == set(expected_args) + return isinstance(actual, type) and any(issubclass(actual, arg) for arg in expected_args) + + return isinstance(actual, type) and isinstance(expected, type) and issubclass(actual, expected) + + class MetaData(BaseModel): """Class to represent metadata for all callable models""" @@ -126,7 +150,7 @@ def _check_result_type(cls, result_type): @model_validator(mode="after") def _check_signature(self): sig_call = _cached_signature(self.__class__.__call__) - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: raise ValueError("__call__ method must take a single argument, named 'context'") sig_deps = _cached_signature(self.__class__.__deps__) @@ -268,14 +292,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( - get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) - ): - raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not ( - get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type)) + + # Check if this is an auto_context decorated method + has_auto_context = hasattr(fn, "__auto_context__") + if has_auto_context: + method_context_type = fn.__auto_context__ + else: + method_context_type = model.context_type + + # Validate context type (skip for auto contexts which are always valid ContextBase subclasses) + if not has_auto_context: + if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( + get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) + ): + raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") + + # Validate result type - use __result_type__ for auto contexts if available + if has_auto_context and hasattr(fn, "__result_type__"): + method_result_type = fn.__result_type__ + else: + method_result_type = model.result_type + if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not ( + get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type)) ): - raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase") + raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase") + if self._deps and fn.__name__ != "__deps__": raise ValueError("Can only apply Flow.deps decorator to __deps__") if context is Signature.empty: @@ -285,18 +326,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = context = kwargs else: raise TypeError( - f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it" + f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it" ) elif kwargs: # Kwargs passed in as well as context. Not allowed raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'") # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message - if not isinstance(context, model.context_type): - if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type): - model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0] + if not isinstance(context, method_context_type): + if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type): + coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0] else: - model_context_type = model.context_type - context = model_context_type.model_validate(context) + coerce_context_type = method_context_type + context = coerce_context_type.model_validate(context) if fn != getattr(model.__class__, fn.__name__).__wrapped__: # This happens when super().__call__ is used when implementing a CallableModel that derives from another one. @@ -310,9 +351,17 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = return result wrap = wraps(fn)(wrapper) - wrap.get_evaluator = self.get_evaluator - wrap.get_options = self.get_options - wrap.get_evaluation_context = get_evaluation_context + wrap_any = cast(Any, wrap) + wrap_any.get_evaluator = self.get_evaluator + wrap_any.get_options = self.get_options + wrap_any.get_evaluation_context = get_evaluation_context + + # Preserve auto context attributes for introspection + if hasattr(fn, "__auto_context__"): + wrap_any.__auto_context__ = fn.__auto_context__ + if hasattr(fn, "__result_type__"): + wrap_any.__result_type__ = fn.__result_type__ + return wrap @@ -391,7 +440,59 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Args: + auto_context: Controls automatic context class generation from the function + signature. Accepts three types of values: + - False (default): No auto-generation, use traditional context parameter + - True: Auto-generate context class with no parent + - ContextBase subclass: Auto-generate context class inheriting from this parent + **kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result, + cacheable, evaluator, volatile). + + Basic Example: + class MyModel(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> MyResult: + return MyResult(value=context.x) + + Auto Context Example: + class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> MyResult: + return MyResult(value=f"{x}-{y}") + + model = MyModel() + model(x=42) # Call with kwargs directly + + With Parent Context: + class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> MyResult: + return MyResult(value=date.day + extra) + + # The generated context inherits from DateContext, so it's compatible + # with infrastructure expecting DateContext instances. + + """ + # Extract auto_context option (not part of FlowOptions) + # Can be: False, True, or a ContextBase subclass + auto_context = kwargs.pop("auto_context", False) + + # Determine if auto_context is enabled and extract parent class if provided + if auto_context is False: + auto_context_enabled = False + context_parent = None + elif auto_context is True: + auto_context_enabled = True + context_parent = None + elif isclass(auto_context) and issubclass(auto_context, ContextBase): + auto_context_enabled = True + context_parent = auto_context + else: + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -400,6 +501,14 @@ def call(*args, **kwargs): else: # Arguments to decorator, this is just returning the decorator # Note that the code below is executed only once + if auto_context_enabled: + # Return a decorator that first applies auto_context, then FlowOptions + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + wrapped = _apply_auto_context(fn, parent=context_parent) + # FlowOptions.__call__ already applies wraps, so we just return its result + return FlowOptions(**kwargs)(wrapped) + + return auto_context_decorator return FlowOptions(**kwargs) @staticmethod @@ -417,6 +526,62 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with context_args mode + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Two Context Modes: + + Mode 1 - Explicit context parameter: + Function has a 'context' parameter annotated with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + + Mode 2 - Unpacked context_args: + Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + Dependencies: + Any non-context parameter can be bound either to a literal value or + to another CallableModel. When a CallableModel is supplied, the + generated model treats it as an upstream dependency and resolves it + with the current context before calling the underlying function. + + Usage: + # Create model instances + loader = load_prices(source="prod_db") + returns = compute_returns(prices=loader) + + # Execute + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = returns(ctx) + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -451,13 +616,15 @@ class ModelEvaluationContext( # TODO: Make the instance check compatible with the generic types instead of the base type @model_validator(mode="wrap") - def _context_validator(cls, values, handler, info): + @classmethod + def _context_validator(cls, values: Any, handler: Any, info: Any): """Override _context_validator from parent""" # Validate the context with the model, if possible - model = values.get("model") - if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): - values["context"] = model.context_type.model_validate(values.get("context")) + if isinstance(values, dict): + model = values.get("model") + if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): + values["context"] = model.context_type.model_validate(values.get("context")) # Apply standard pydantic validation context = handler(values) @@ -485,9 +652,9 @@ def __call__(self) -> ResultType: raise TypeError(f"Model result_type {result_type} is not a subclass of ResultBase") result = result_type.model_validate(result) - return result + return cast(ResultType, result) else: - return fn(self.context) + return cast(ResultType, fn(self.context)) class EvaluatorBase(_CallableModel, abc.ABC): @@ -575,7 +742,7 @@ def context_type(self) -> Type[ContextType]: if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase): raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.") - return typ + return cast(Type[ContextType], typ) @property def result_type(self) -> Type[ResultType]: @@ -618,7 +785,7 @@ def result_type(self) -> Type[ResultType]: # Ensure subclass of ResultBase if not isclass(typ) or not issubclass(typ, ResultBase): raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.") - return typ + return cast(Type[ResultType], typ) @Flow.deps def __deps__( @@ -634,6 +801,19 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this model into a downstream generated ``@Flow.model`` stage.""" + from .flow_model import pipe_model + + return pipe_model(self, stage, param=param, **bindings) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -646,12 +826,12 @@ class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): @property def context_type(self) -> Type[ContextType]: """Return the context type of the underlying model.""" - return self.model.context_type + return cast(CallableModel, self.model).context_type @property def result_type(self) -> Type[ResultType]: """Return the result type of the underlying model.""" - return self.model.result_type + return cast(CallableModel, self.model).result_type class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]): @@ -723,34 +903,110 @@ def _determine_context_result(cls): if new_context_type is not None: # Set on class - cls._context_generic_type = new_context_type + setattr(cls, "_context_generic_type", new_context_type) if new_result_type is not None: # Set on class - cls._result_generic_type = new_result_type + setattr(cls, "_result_generic_type", new_result_type) @model_validator(mode="wrap") - def _validate_callable_model_generic_type(cls, m, handler, info): + @classmethod + def _validate_callable_model_generic_type(cls, m: Any, handler: Any, info: Any): from ccflow.base import resolve_str if isinstance(m, str): m = resolve_str(m) - if isinstance(m, dict): - m = handler(m) - elif isinstance(m, cls): - m = handler(m) + validated_cls = cast(Any, cls) + if isinstance(m, (dict, CallableModel)): + if isinstance(m, dict): + m = handler(m) + elif isinstance(m, validated_cls): + m = handler(m) # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/ if not isinstance(m, CallableModel): raise ValueError(f"{m} is not a CallableModel: {type(m)}") subtypes = cls.__pydantic_generic_metadata__["args"] - if subtypes: - TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type) - TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type) + if len(subtypes) >= 1 and not _declared_type_matches(m.context_type, subtypes[0]): + raise ValueError(f"{m} context_type {m.context_type} does not match {subtypes[0]}") + if len(subtypes) >= 2 and not _declared_type_matches(m.result_type, subtypes[1]): + raise ValueError(f"{m} result_type {m.result_type} does not match {subtypes[1]}") return m CallableModelGenericType = CallableModelGeneric + + +# ***************************************************************************** +# Auto Context (internal helper for Flow.call(auto_context=True)) +# ***************************************************************************** + + +def _apply_auto_context(func: Callable[..., Any], *, parent: Optional[Type[ContextBase]] = None) -> Callable[..., Any]: + """Internal function that creates an auto context class from function parameters. + + This function extracts the parameters from a function signature and creates + a ContextBase subclass whose fields correspond to those parameters. + The decorated function is then wrapped to accept the context object and + unpack it into keyword arguments. + + Used internally by Flow.call(auto_context=...). + + Example: + class MyCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = MyCallable() + model(x=42, y="hello") # Works with kwargs + """ + sig = signature(func) + base_class = parent or ContextBase + + if sig.return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True") + + # Validate parent fields are in function signature + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters.keys()) - {"self"} + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + # Build fields from parameters (skip 'self'), pydantic validates types + fields = {} + for name, param in sig.parameters.items(): + if name == "self": + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True") + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True") + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Create auto context class + auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in fields} + return func(self, **fn_kwargs) + + # Must set __signature__ so CallableModel validation sees 'context' parameter + wrapper_any = cast(Any, wrapper) + wrapper_any.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), + ], + return_annotation=sig.return_annotation, + ) + wrapper_any.__auto_context__ = auto_context_class + wrapper_any.__result_type__ = sig.return_annotation + return wrapper diff --git a/ccflow/context.py b/ccflow/context.py index 9a04fad..ae69e22 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,16 +1,18 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency from .validators import normalize_date, normalize_datetime __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -89,6 +91,49 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FlowContext): + return False + return self.model_dump(mode="python") == other.model_dump(mode="python") + + def __hash__(self) -> int: + return hash(_freeze_for_hash(self.model_dump(mode="python"))) + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted((key, _freeze_for_hash(item)) for key, item in value.items())) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + if hasattr(value, "__dict__"): + return (type(value), _freeze_for_hash(vars(value))) + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}") from exc + return value + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..da05d8e --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,1334 @@ +"""Flow.model decorator implementation. + +This module provides the Flow.model decorator that generates CallableModel classes +from plain Python functions, reducing boilerplate while maintaining full compatibility +with existing ccflow infrastructure. + +Key design: Uses TypedDict + TypeAdapter for context schema validation instead of +generating dynamic ContextBase subclasses. This avoids class registration overhead +and enables clean pickling for distributed computing (e.g., Ray). +""" + +import inspect +import logging +import threading +from functools import wraps +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints + +from pydantic import Field, PrivateAttr, TypeAdapter, model_serializer, model_validator +from typing_extensions import NotRequired, TypedDict + +from .base import ContextBase, ResultBase +from .callable import CallableModel, Flow, GraphDepList, WrapperModel +from .context import FlowContext +from .local_persistence import register_ccflow_import_path +from .result import GenericResult + +__all__ = ("FlowAPI", "BoundModel", "Lazy") + +_AnyCallable = Callable[..., Any] + + +class _DeferredInput: + """Sentinel for dynamic @Flow.model inputs left for runtime context.""" + + def __repr__(self) -> str: + return "" + + +_DEFERRED_INPUT = _DeferredInput() + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_module(func: _AnyCallable) -> str: + return getattr(func, "__module__", __name__) + + +class _LazyMarker: + """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" + + pass + + +def _extract_lazy(annotation) -> Tuple[Any, bool]: + """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). + + Handles nested Annotated types, so we need to check the outermost + Annotated layer for _LazyMarker. + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + return args[0], True + return annotation, False + + +def _make_lazy_thunk(model, context): + """Create a zero-arg callable that evaluates model(context) on demand. + + The thunk caches its result so repeated calls don't re-evaluate. + """ + _cache = {} + + def thunk(): + if "result" not in _cache: + result = model(context) + if isinstance(result, GenericResult): + result = result.value + _cache["result"] = result + return _cache["result"] + + return thunk + + +log = logging.getLogger(__name__) + + +def _context_values(context: ContextBase) -> Dict[str, Any]: + """Return a plain mapping of all context values. + + `dict(context)` uses pydantic's public iteration behavior, which includes + both declared fields and any allowed extra fields. + """ + + return dict(context) + + +def _transform_repr(transform: Any) -> str: + """Render an input transform without noisy object addresses.""" + + if callable(transform): + name = _callable_name(transform) + if name.startswith("<") and name.endswith(">"): + return name + return f"<{name}>" + return repr(transform) + + +def _is_model_dependency(value: Any) -> bool: + return isinstance(value, CallableModel) + + +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set(getattr(model, "_bound_fields", set())) + + +def _has_deferred_input(value: Any) -> bool: + return isinstance(value, _DeferredInput) + + +def _deferred_input_factory() -> _DeferredInput: + return _DEFERRED_INPUT + + +def _effective_bound_field_names(model: Any) -> set[str]: + fields = _bound_field_names(model) + defaults = getattr(model.__class__, "__flow_model_default_param_names__", set()) + return fields | set(defaults) + + +def _runtime_input_names(model: Any) -> set[str]: + all_param_names = set(getattr(model.__class__, "__flow_model_all_param_types__", {})) + if not all_param_names: + return set() + return all_param_names - _effective_bound_field_names(model) + + +def _resolve_registry_candidate(value: str) -> Any: + from .base import BaseModel as _BM + + try: + candidate = _BM.model_validate(value) + except Exception: + return None + return candidate if isinstance(candidate, _BM) else None + + +def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + try: + TypeAdapter(expected_type).validate_python(candidate) + except Exception: + return False + return True + + +def _type_accepts_str(annotation) -> bool: + """Return True when ``str`` is a valid type for *annotation*. + + Handles ``str``, ``Union[str, ...]``, ``Optional[str]``, and + ``Annotated[str, ...]``. + """ + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin is Union: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False + + +def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: + """Build a TypeAdapter for a runtime TypedDict schema.""" + + if not schema: + return TypeAdapter(dict) + return TypeAdapter(TypedDict(name, schema, total=total)) + + +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + """Extract a concrete ContextBase subclass from a context annotation.""" + + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in (Optional, Union): + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + +def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: + """Precompute validators for constructor fields.""" + + validatable_types: Dict[str, Type] = {} + for name, typ in all_param_types.items(): + try: + TypeAdapter(typ) + validatable_types[name] = typ + except Exception: + pass + + validators = {name: TypeAdapter(typ) for name, typ in validatable_types.items()} + return validatable_types, validators + + +def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapter], validatable_types: Dict[str, Type]) -> Any: + """Validate/coerce a single context-sourced value. Returns coerced value or raises TypeError.""" + if name not in validators: + return value + try: + return validators[name].validate_python(value) + except Exception as exc: + expected = validatable_types.get(name, "unknown") + raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc + + +def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: + """Validate plain config inputs while still allowing dependency objects.""" + + if not validators: + return + + from .base import ModelRegistry as _MR + + for field_name, validator in validators.items(): + if field_name not in kwargs: + continue + value = kwargs[field_name] + if value is None or _is_model_dependency(value): + continue + if isinstance(value, str) and value in _MR.root(): + expected_type = validatable_types[field_name] + if _type_accepts_str(expected_type): + continue + candidate = _resolve_registry_candidate(value) + if candidate is not None and _registry_candidate_allowed(expected_type, candidate): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + + +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + if isinstance(stage, BoundModel): + model = stage.model + else: + model = stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _generated_model_class(stage: Any) -> Optional[type["_GeneratedFlowModelBase"]]: + model = _generated_model_instance(stage) + if model is not None: + return type(model) + + generated_model = getattr(stage, "_generated_model", None) + if isinstance(generated_model, type) and issubclass(generated_model, _GeneratedFlowModelBase): + return generated_model + return None + + +def _describe_pipe_stage(stage: Any) -> str: + if isinstance(stage, BoundModel): + return repr(stage) + if isinstance(stage, _GeneratedFlowModelBase): + return repr(stage) + if callable(stage): + return _callable_name(stage) + return repr(stage) + + +def _generated_model_explicit_kwargs(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: + return cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) + + +def _infer_pipe_param( + stage_name: str, + param_names: List[str], + default_param_names: set[str], + occupied_names: set[str], +) -> str: + required_candidates = [name for name in param_names if name not in occupied_names and name not in default_param_names] + if len(required_candidates) == 1: + return required_candidates[0] + if len(required_candidates) > 1: + candidates = ", ".join(required_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + fallback_candidates = [name for name in param_names if name not in occupied_names] + if len(fallback_candidates) == 1: + return fallback_candidates[0] + if len(fallback_candidates) > 1: + candidates = ", ".join(fallback_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + raise TypeError(f"pipe() could not find an available target parameter for {stage_name}.") + + +def _resolve_pipe_param(source: Any, stage: Any, param: Optional[str], bindings: Dict[str, Any]) -> Tuple[str, type["_GeneratedFlowModelBase"]]: + del source # Source only matters when binding, not during target resolution. + + generated_model_cls = _generated_model_class(stage) + if generated_model_cls is None: + raise TypeError("pipe() only supports downstream stages created by @Flow.model or bound versions of those stages.") + + stage_name = _describe_pipe_stage(stage) + all_param_types = getattr(generated_model_cls, "__flow_model_all_param_types__", {}) + if not all_param_types: + raise TypeError(f"pipe() could not determine bindable parameters for {stage_name}.") + + param_names = list(all_param_types.keys()) + default_param_names = set(getattr(generated_model_cls, "__flow_model_default_param_names__", set())) + + generated_model = _generated_model_instance(stage) + occupied_names = set(bindings) + if generated_model is not None: + occupied_names |= _bound_field_names(generated_model) + if isinstance(stage, BoundModel): + occupied_names |= set(stage._input_transforms) + + if param is not None: + if param not in all_param_types: + valid = ", ".join(param_names) + raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available parameters: {valid}.") + if param in occupied_names: + raise TypeError(f"pipe() target parameter '{param}' is already bound for {stage_name}.") + return param, generated_model_cls + + return _infer_pipe_param(stage_name, param_names, default_param_names, occupied_names), generated_model_cls + + +def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" + + if not _is_model_dependency(source): + raise TypeError(f"pipe() source must be a CallableModel, got {type(source).__name__}.") + + target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) + build_kwargs = dict(bindings) + build_kwargs[target_param] = source + + if isinstance(stage, BoundModel): + generated_model = _generated_model_instance(stage) + if generated_model is None: + raise TypeError("pipe() only supports downstream BoundModel stages created from @Flow.model.") + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + rebound_model = generated_model_cls(**explicit_kwargs) + return BoundModel(model=rebound_model, input_transforms=dict(stage._input_transforms)) + + generated_model = _generated_model_instance(stage) + if generated_model is not None: + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + return generated_model_cls(**explicit_kwargs) + + return stage(**build_kwargs) + + +class FlowAPI: + """API namespace for deferred computation operations. + + Provides methods for executing models and transforming contexts. + Accessed via model.flow property. + """ + + def __init__(self, model: CallableModel): + self._model = model + + def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: + """Construct a runtime context for either generated or hand-written models.""" + get_validator = getattr(self._model, "_get_context_validator", None) + if get_validator is not None: + validator = get_validator() + validated = validator.validate_python(kwargs) + if isinstance(validated, ContextBase): + return validated + return FlowContext(**validated) + + validator = TypeAdapter(self._model.context_type) + return validator.validate_python(kwargs) + + def compute(self, **kwargs) -> Any: + """Execute the model with the provided context arguments. + + Validates kwargs against the model's context schema using TypeAdapter, + then wraps in FlowContext and calls the model. + + Args: + **kwargs: Context arguments (e.g., start_date, end_date) + + Returns: + The model's result, using the same return contract as ``model(context)``. + """ + ctx = self._build_context(kwargs) + return self._model(ctx) + + @property + def unbound_inputs(self) -> Dict[str, Type]: + """Return the context schema (field name -> type). + + In deferred mode, this is everything that must still come from runtime context. + """ + all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + model_cls = self._model.__class__ + + # If explicit context_args was provided, use _context_schema minus + # fields that have function defaults (they aren't truly required). + explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) + if explicit_args is not None: + context_schema = getattr(model_cls, "_context_schema", None) + if context_schema is None: + return {} + ctx_arg_defaults = getattr(model_cls, "__flow_model_context_arg_defaults__", {}) + return {name: typ for name, typ in context_schema.items() if name not in ctx_arg_defaults} + + # Dynamic @Flow.model: unbound = params with no explicit value and no declared default + if all_param_types: + runtime_inputs = _runtime_input_names(self._model) + return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} + + # Generic CallableModel / Mode 1: runtime inputs are the required + # context fields (fields with defaults are not required). + context_cls = _concrete_context_type(self._model.context_type) + if context_cls is None or not hasattr(context_cls, "model_fields"): + return {} + return {name: info.annotation for name, info in context_cls.model_fields.items() if info.is_required()} + + @property + def bound_inputs(self) -> Dict[str, Any]: + """Return the effective config values for this model.""" + result: Dict[str, Any] = {} + flow_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + if flow_param_types: + for name in flow_param_types: + value = getattr(self._model, name, _DEFERRED_INPUT) + if _has_deferred_input(value): + continue + result[name] = value + return result + + # Generic CallableModel: configured model fields are the bound inputs. + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Create a version of this model with transformed context inputs. + + Args: + **transforms: Mapping of field name to either: + - A callable (ctx) -> value for dynamic transforms + - A static value to bind + + Returns: + A BoundModel that applies the transforms before calling. + """ + return BoundModel(model=self._model, input_transforms=transforms) + + +_bound_model_restore = threading.local() + + +def _fingerprint_transforms(transforms: Dict[str, Any]) -> Dict[str, str]: + """Create a stable, hashable fingerprint of input transforms for cache key differentiation. + + Callable transforms are identified by their id() (unique per object), which is + stable within a process lifetime. Static values are repr'd directly. + """ + result = {} + for name, transform in sorted(transforms.items()): + if callable(transform): + result[name] = f"callable:{id(transform)}" + else: + result[name] = repr(transform) + return result + + +class BoundModel(WrapperModel): + """A model with context transforms applied. + + Created by model.flow.with_inputs(). Applies transforms to context + before delegating to the underlying model. + + Context propagation across dependencies: + Each BoundModel transforms the context locally — only for the model it + wraps. When used as a dependency inside another model, the FlowContext + flows through the chain unchanged until it reaches this BoundModel, + which intercepts it, applies its transforms, and passes the modified + context to the wrapped model. Upstream models never see the transform. + + Chaining with_inputs: + Calling ``bound.flow.with_inputs(...)`` merges the new transforms with + the existing ones (new overrides old for the same key). All transforms + are applied to the incoming context in one pass — they don't compose + sequentially (each transform sees the original context, not the output + of a previous transform). + """ + + _input_transforms: Dict[str, Any] = PrivateAttr(default_factory=dict) + + @model_validator(mode="wrap") + @classmethod + def _restore_serialized_transforms(cls, values, handler): + """Strip serialization-injected keys, restore static transforms, guarantee cleanup. + + Uses thread-local storage to pass static transforms to __init__ because + pydantic rejects unknown keys in the input dict. The wrap validator's + try/finally ensures the thread-local is always cleaned up, even if + validation fails before __init__ runs. + """ + if isinstance(values, dict): + values = dict(values) # Don't mutate the caller's dict + values.pop("_input_transforms_token", None) + static = values.pop("_static_transforms", None) + else: + static = None + + if static is not None: + _bound_model_restore.pending = static + try: + return handler(values) + except Exception: + _bound_model_restore.pending = None + raise + + def __init__(self, *, model: CallableModel, input_transforms: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__(model=model, **kwargs) + restore = getattr(_bound_model_restore, "pending", None) + if restore is not None: + _bound_model_restore.pending = None + if input_transforms is not None: + self._input_transforms = input_transforms + elif restore is not None: + self._input_transforms = restore + else: + self._input_transforms = {} + + def _transform_context(self, context: ContextBase) -> ContextBase: + """Return this model's preferred context type with input transforms applied.""" + ctx_dict = _context_values(context) + for name, transform in self._input_transforms.items(): + if callable(transform): + ctx_dict[name] = transform(context) + else: + ctx_dict[name] = transform + context_type = _concrete_context_type(self.model.context_type) + if context_type is not None and context_type is not FlowContext: + return context_type.model_validate(ctx_dict) + return FlowContext(**ctx_dict) + + @Flow.call + def __call__(self, context: ContextBase) -> ResultBase: + """Call the model with transformed context.""" + return self.model(self._transform_context(context)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + """Declare the wrapped model as an upstream dependency with transformed context.""" + return [(self.model, [self._transform_context(context)])] + + @model_serializer(mode="wrap") + def _serialize_with_transforms(self, handler): + """Include transforms in serialization for cache keys and faithful roundtrips. + + Static (non-callable) transforms are serialized in _static_transforms for + faithful restoration. A fingerprint token covers all transforms (including + callables) for cache key differentiation. + """ + data = handler(self) + static = {k: v for k, v in self._input_transforms.items() if not callable(v)} + if static: + data["_static_transforms"] = static + data["_input_transforms_token"] = _fingerprint_transforms(self._input_transforms) + return data + + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" + return pipe_model(self, stage, param=param, **bindings) + + def __repr__(self) -> str: + transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) + return f"{self.model!r}.flow.with_inputs({transforms})" + + @property + def flow(self) -> "FlowAPI": + """Access the flow API.""" + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + """FlowAPI that delegates to a BoundModel, honoring transforms.""" + + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model.model) + + def compute(self, **kwargs) -> Any: + ctx = self._build_context(kwargs) + return self._bound(ctx) # Call through BoundModel, not inner model + + def with_inputs(self, **transforms) -> "BoundModel": + """Chain transforms: merge new transforms with existing ones. + + New transforms override existing ones for the same key. + """ + merged = {**self._bound._input_transforms, **transforms} + return BoundModel(model=self._bound.model, input_transforms=merged) + + +class _GeneratedFlowModelBase(CallableModel): + """Shared behavior for models generated by ``@Flow.model``.""" + + __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext + __flow_model_return_type__: ClassVar[Type[ResultBase]] = GenericResult + __flow_model_func__: ClassVar[_AnyCallable | None] = None + __flow_model_use_context_args__: ClassVar[bool] = True + __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None + __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_default_param_names__: ClassVar[set[str]] = set() + __flow_model_context_arg_defaults__: ClassVar[Dict[str, Any]] = {} + __flow_model_auto_wrap__: ClassVar[bool] = False + __flow_model_validatable_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_config_validators__: ClassVar[Dict[str, TypeAdapter]] = {} + _context_schema: ClassVar[Dict[str, Type]] = {} + _context_td: ClassVar[Any | None] = None + _cached_context_validator: ClassVar[TypeAdapter | None] = None + + @model_validator(mode="before") + def _resolve_registry_refs(cls, values, info): + if not isinstance(values, dict): + return values + + param_types = getattr(cls, "__flow_model_all_param_types__", {}) + resolved = dict(values) + for field_name, expected_type in param_types.items(): + if field_name not in resolved: + continue + value = resolved[field_name] + if not isinstance(value, str): + continue + if _type_accepts_str(expected_type): + continue + candidate = _resolve_registry_candidate(value) + if candidate is None: + continue + if _registry_candidate_allowed(expected_type, candidate): + resolved[field_name] = candidate + return resolved + + @model_validator(mode="after") + def _validate_field_types(self): + """Validate field values against their declared types. + + This catches type mismatches in the model_validate/deserialization path, + where fields are typed as Any and pydantic won't reject wrong types. + """ + cls = self.__class__ + config_validators = getattr(cls, "__flow_model_config_validators__", {}) + validatable_types = getattr(cls, "__flow_model_validatable_types__", {}) + if not config_validators: + return self + + for field_name, validator in config_validators.items(): + value = getattr(self, field_name, _DEFERRED_INPUT) + if _has_deferred_input(value) or value is None or _is_model_dependency(value): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + return self + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + @property + def flow(self) -> FlowAPI: + return FlowAPI(self) + + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator for this generated model.""" + + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + use_ctx_args = getattr(cls, "__flow_model_use_context_args__", True) + ctx_type = cls.__flow_model_context_type__ + if not use_ctx_args and isinstance(ctx_type, type) and issubclass(ctx_type, ContextBase) and ctx_type is not FlowContext: + # Mode 1 with concrete context type — use TypeAdapter(context_type) + # directly so defaults on the context type are respected. + cls._cached_context_validator = TypeAdapter(ctx_type) + elif cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + runtime_inputs = _runtime_input_names(self) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} + object.__setattr__( + self, + "_instance_context_validator", + _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema, total=False), + ) + return cast(TypeAdapter, getattr(self, "_instance_context_validator")) + + +class Lazy: + """Deferred model execution with runtime context overrides. + + Has two distinct uses: + + 1. **Type annotation** — ``Lazy[T]`` marks a parameter as lazily evaluated. + The framework will NOT pre-evaluate the dependency; instead the function + receives a zero-arg thunk that triggers evaluation on demand:: + + @Flow.model + def smart_training( + data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # NOT eagerly evaluated + threshold: float = 0.9, + ) -> Metrics: + if fast_metrics.r2 > threshold: + return fast_metrics + return slow_metrics() # Evaluated on demand + + 2. **Runtime helper** — ``Lazy(model)(overrides)`` creates a callable that + applies context overrides before calling the model. Used with + ``with_inputs()`` for deferred execution:: + + lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + **Which to use:** + + - Use ``Lazy[T]`` in a ``@Flow.model`` signature when you want conditional/ + on-demand evaluation of an expensive upstream dependency. + - Use ``Lazy(model)(...)`` when you need to rewire context fields before + passing them to an existing model (e.g., shifting a date window). + """ + + def __class_getitem__(cls, item): + """Support Lazy[T] syntax as a type annotation marker. + + Returns Annotated[T, _LazyMarker()] so the framework can detect + lazy parameters during signature analysis. + """ + return Annotated[item, _LazyMarker()] + + def __init__(self, model: "CallableModel"): # noqa: F821 + """Wrap a model for deferred execution. + + Args: + model: The CallableModel to wrap + """ + self._model = model + + def __call__(self, **overrides) -> Callable[[ContextBase], Any]: + """Create a callable that applies overrides to context before execution. + + Args: + **overrides: Context field overrides. Values can be: + - Static values (applied directly) + - Callables (ctx) -> value (called with context at runtime) + + Returns: + A callable (context) -> result that applies overrides and calls the model + """ + model = self._model + + def execute_with_overrides(context: ContextBase) -> Any: + # Build context dict from incoming context + ctx_dict = _context_values(context) + + # Apply overrides + for name, value in overrides.items(): + if callable(value): + ctx_dict[name] = value(context) + else: + ctx_dict[name] = value + + # Call model with modified context + new_ctx = FlowContext(**ctx_dict) + return model(new_ctx) + + return execute_with_overrides + + @property + def model(self) -> "CallableModel": # noqa: F821 + """Access the wrapped model.""" + return self._model + + +def _build_context_schema( + context_args: List[str], func: _AnyCallable, sig: inspect.Signature, resolved_hints: Dict[str, Any] +) -> Tuple[Dict[str, Type], Any]: + """Build context schema from context_args parameter names. + + Instead of creating a dynamic ContextBase subclass, this builds: + - A schema dict mapping field names to types + - A TypedDict for Pydantic TypeAdapter validation + - Optionally, a matched existing ContextBase type for compatibility + + Args: + context_args: List of parameter names that come from context + func: The decorated function + sig: The function signature + + Returns: + Tuple of (schema_dict, TypedDict type) + """ + # Build schema dict from parameter annotations + schema = {} + td_schema = {} + for name in context_args: + if name not in sig.parameters: + raise ValueError(f"context_arg '{name}' not found in function parameters") + param = sig.parameters[name] + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: + raise ValueError(f"context_arg '{name}' must have a type annotation") + schema[name] = annotation + # Use NotRequired in the TypedDict for params that have a default in the + # function signature, so compute() doesn't require them. + if param.default is not inspect.Parameter.empty: + td_schema[name] = NotRequired[annotation] + else: + td_schema[name] = annotation + + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{_callable_name(func)}Inputs", td_schema) + + return schema, context_td + + +def _validate_context_type_override( + context_type: Any, + context_args: List[str], + func_schema: Dict[str, Type], + func_defaults: set[str] = frozenset(), +) -> Type[ContextBase]: + """Validate an explicit ``context_type`` override for ``context_args`` mode.""" + + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + missing = sorted(name for name in context_args if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for context_args: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in context_args and info.is_required() + ) + if required_extra_fields: + raise TypeError(f"context_type {context_type.__name__} has required fields not listed in context_args: {', '.join(required_extra_fields)}") + + # Warn when the function's annotation for a context_arg doesn't match the + # context_type's field annotation. A mismatch means the function declares + # one type but will silently receive whatever Pydantic coerces to. + for name in context_args: + func_ann = func_schema.get(name) + ctx_field = context_fields.get(name) + if func_ann is None or ctx_field is None: + continue + ctx_ann = ctx_field.annotation + if func_ann is ctx_ann: + continue + # Both are concrete types — check subclass relationship + if isinstance(func_ann, type) and isinstance(ctx_ann, type): + if not (issubclass(func_ann, ctx_ann) or issubclass(ctx_ann, func_ann)): + raise TypeError( + f"context_arg '{name}': function annotates {func_ann.__name__} " + f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" + ) + + # Reject if the function has a default for a context_arg but the + # context_type declares that field as required — this is contradictory. + for name in context_args: + if name in func_defaults: + ctx_field = context_fields.get(name) + if ctx_field is not None and ctx_field.is_required(): + raise TypeError(f"context_arg '{name}': function has a default but context_type {context_type.__name__} requires this field") + + return context_type + + +_UNSET = object() + + +def flow_model( + func: Optional[_AnyCallable] = None, + *, + # Context handling + context_args: Optional[List[str]] = None, + context_type: Optional[Type[ContextBase]] = None, + # Flow.call options (passed to generated __call__) + # Default to _UNSET so FlowOptionsOverride can control these globally. + # Only explicitly user-provided values are passed to Flow.call. + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, +) -> _AnyCallable: + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + func: The function to decorate + context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with ``context_args`` mode. + cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) + volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) + log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) + validate_result: Validate return type (default: unset, inherits from FlowOptionsOverride) + verbose: Verbose logging output (default: unset, inherits from FlowOptionsOverride) + evaluator: Custom evaluator (default: unset, inherits from FlowOptionsOverride) + + Two Context Modes: + 1. Explicit context parameter: Function has a 'context' parameter annotated + with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + ... + + 2. Unpacked context_args: Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + ... + + Returns: + A factory function that creates CallableModel instances + """ + + def decorator(fn: _AnyCallable) -> _AnyCallable: + sig = inspect.signature(fn) + params = sig.parameters + + # Resolve string annotations (PEP 563 / from __future__ import annotations) + # into real type objects. include_extras=True preserves Annotated metadata. + try: + _resolved_hints = get_type_hints(fn, include_extras=True) + except Exception: + _resolved_hints = {} + + # Validate return type + return_type = _resolved_hints.get("return", sig.return_annotation) + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") + # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult + return_origin = get_origin(return_type) or return_type + auto_wrap_result = False + if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): + auto_wrap_result = True + internal_return_type = GenericResult # unparameterized for safety + else: + internal_return_type = return_type + + # ── Context mode selection ── + # The decorator supports three mutually exclusive context modes: + # + # Mode 1 (explicit context): Function has a 'context' (or '_') parameter + # annotated with a ContextBase subclass. Behaves like a traditional + # CallableModel.__call__. Other params become model fields. + # + # Mode 2 (context_args): Decorator specifies context_args=[...] listing + # which params come from the context at runtime. Remaining params become + # model fields. Uses FlowContext unless context_type= overrides it. + # + # Mode 3 (dynamic deferred): No 'context' param and no context_args. + # Every param is a potential model field. Params bound at construction + # are config; unbound params become runtime inputs from FlowContext. + # + context_schema_early: Dict[str, Type] = {} + context_td_early = None + if "context" in params or "_" in params: + # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") + context_param_name = "context" if "context" in params else "_" + context_param = params[context_param_name] + context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) + if context_annotation is inspect.Parameter.empty: + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") + resolved_context_type = context_annotation + if not (isinstance(resolved_context_type, type) and issubclass(resolved_context_type, ContextBase)): + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") + model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} + use_context_args = False + explicit_context_args = None + elif context_args is not None: + # Mode 2: Explicit context_args - specified params come from context + context_param_name = "context" + context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) + _func_defaults_set = {name for name in context_args if sig.parameters[name].default is not inspect.Parameter.empty} + explicit_context_type = ( + _validate_context_type_override(context_type, context_args, context_schema_early, _func_defaults_set) + if context_type is not None + else None + ) + resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext + # Exclude context_args from model fields + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + explicit_context_args = context_args + else: + # Mode 3: Dynamic deferred mode - every param can be configured on the model, + # but only params without Python defaults remain runtime inputs when omitted. + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") + context_param_name = "context" + resolved_context_type = FlowContext + model_field_params = {name: param for name, param in params.items() if name != "self"} + use_context_args = True + explicit_context_args = None # Dynamic - determined at construction + + # Analyze parameters to find lazy fields and regular fields. + model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] + default_param_names: set[str] = set() + + # In dynamic deferred mode (no explicit context_args), fields without Python defaults + # are internally represented by a deferred sentinel until runtime context supplies them. + dynamic_deferred_mode = use_context_args and explicit_context_args is None + + for name, param in model_field_params.items(): + # Use resolved hint (handles PEP 563 string annotations) + annotation = _resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation") + + # Check for Lazy[T] annotation first + unwrapped_annotation, is_lazy = _extract_lazy(annotation) + if is_lazy: + lazy_fields.add(name) + + if param.default is not inspect.Parameter.empty: + default_param_names.add(name) + default = param.default + elif dynamic_deferred_mode: + # In dynamic mode, params without defaults remain deferred to runtime context. + default = Field(default_factory=_deferred_input_factory, exclude_if=_has_deferred_input) + else: + # In explicit mode, params without defaults are required + default = ... + + model_fields[name] = (Any, default) + + # Capture variables for closures + ctx_param_name = context_param_name if not use_context_args else "context" + all_param_names = list(model_fields.keys()) # All non-context params (model fields) + all_param_types = {name: _resolved_hints.get(name, param.annotation) for name, param in model_field_params.items()} + # For explicit context_args mode, we also need the list of context arg names + ctx_args_for_closure = context_args if context_args is not None else [] + is_dynamic_mode = use_context_args and explicit_context_args is None + + # Compute context_arg defaults and validators for Mode 2 (context_args) + context_arg_defaults: Dict[str, Any] = {} + _ctx_validatable_types: Dict[str, Type] = {} + _ctx_validators: Dict[str, TypeAdapter] = {} + if context_args is not None: + for name in context_args: + p = sig.parameters[name] + if p.default is not inspect.Parameter.empty: + context_arg_defaults[name] = p.default + _ctx_validatable_types, _ctx_validators = _build_config_validators(context_schema_early) + + # Create the __call__ method + def make_call_impl(): + def __call__(self, context): + def resolve_callable_model(value): + """Resolve a CallableModel field.""" + resolved = value(context) + if isinstance(resolved, GenericResult): + return resolved.value + return resolved + + # Build kwargs for the original function + fn_kwargs = {} + + def _resolve_field(name, value): + """Resolve a single field value, handling lazy wrapping.""" + is_dep = isinstance(value, CallableModel) + if name in lazy_fields: + # Lazy field: wrap in a thunk regardless of type + if is_dep: + return _make_lazy_thunk(value, context) + else: + # Non-dep value: wrap in trivial thunk + return lambda v=value: v + elif is_dep: + return resolve_callable_model(value) + else: + return value + + if not use_context_args: + # Mode 1: Explicit context param - pass context directly + fn_kwargs[ctx_param_name] = context + # Add model fields + for name in all_param_names: + value = getattr(self, name) + fn_kwargs[name] = _resolve_field(name, value) + elif not is_dynamic_mode: + # Mode 2: Explicit context_args - get those from context, rest from self + for name in ctx_args_for_closure: + value = getattr(context, name, _UNSET) + if value is _UNSET: + if name in context_arg_defaults: + fn_kwargs[name] = context_arg_defaults[name] + else: + raise TypeError(f"Missing context field '{name}'") + else: + fn_kwargs[name] = _coerce_context_value(name, value, _ctx_validators, _ctx_validatable_types) + # Add model fields + for name in all_param_names: + value = getattr(self, name) + fn_kwargs[name] = _resolve_field(name, value) + else: + # Mode 3: Dynamic deferred mode - explicit values or Python defaults from self, + # otherwise values come from runtime context. + explicit_fields = _bound_field_names(self) + missing_fields = [] + + for name in all_param_names: + value = getattr(self, name, _DEFERRED_INPUT) + if name in explicit_fields or name in default_param_names: + # Explicitly provided or implicitly bound via Python default. + value = getattr(self, name) + fn_kwargs[name] = _resolve_field(name, value) + continue + + if _has_deferred_input(value): + value = getattr(context, name, _UNSET) + if value is _UNSET: + missing_fields.append(name) + continue + # Validate/coerce context-sourced value, skip CallableModel deps + if not _is_model_dependency(value): + value = _coerce_context_value(name, value, _config_validators, _validatable_types) + fn_kwargs[name] = _resolve_field(name, value) + + if missing_fields: + missing = ", ".join(sorted(missing_fields)) + raise TypeError( + f"Missing runtime input(s) for {_callable_name(fn)}: {missing}. " + "Provide them in the call context or bind them at construction time." + ) + + raw_result = fn(**fn_kwargs) + if auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result + + # Set proper signature for CallableModel validation + cast(Any, __call__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), + ], + return_annotation=internal_return_type, + ) + return __call__ + + call_impl = make_call_impl() + + # Apply Flow.call decorator — only include options the user explicitly set + flow_options = {} + for opt_name, opt_val in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ]: + if opt_val is not _UNSET: + flow_options[opt_name] = opt_val + + decorated_call = Flow.call(**flow_options)(call_impl) + + # Create the __deps__ method + def make_deps_impl(): + def __deps__(self, context) -> GraphDepList: + deps = [] + # Check ALL fields for CallableModel dependencies (auto-detection) + for name in model_fields: + if name in lazy_fields: + continue # Lazy deps are NOT pre-evaluated + value = getattr(self, name) + if isinstance(value, BoundModel): + deps.append((value.model, [value._transform_context(context)])) + elif isinstance(value, CallableModel): + deps.append((value, [context])) + return deps + + # Set proper signature + cast(Any, __deps__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + deps_impl = make_deps_impl() + decorated_deps = Flow.deps(deps_impl) + + # Build pydantic field annotations for the class + annotations = {} + + namespace = { + "__module__": _callable_module(fn), + "__qualname__": f"_{_callable_name(fn)}_Model", + "__call__": decorated_call, + "__deps__": decorated_deps, + } + + for name, (typ, default) in model_fields.items(): + annotations[name] = typ + if default is not ...: + namespace[name] = default + else: + # For required fields, use Field(...) + namespace[name] = Field(...) + + namespace["__annotations__"] = annotations + + _validatable_types, _config_validators = _build_config_validators(all_param_types) + + # Create the class using type() + GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) + + # Set class-level attributes after class creation (to avoid pydantic processing) + GeneratedModel.__flow_model_context_type__ = resolved_context_type + GeneratedModel.__flow_model_return_type__ = internal_return_type + setattr(GeneratedModel, "__flow_model_func__", fn) + GeneratedModel.__flow_model_use_context_args__ = use_context_args + GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args + GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + GeneratedModel.__flow_model_default_param_names__ = default_param_names + GeneratedModel.__flow_model_context_arg_defaults__ = context_arg_defaults + GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result + GeneratedModel.__flow_model_validatable_types__ = _validatable_types + GeneratedModel.__flow_model_config_validators__ = _config_validators + + # Build context_schema + context_schema: Dict[str, Type] = {} + context_td = None + + if explicit_context_args is not None: + # Explicit context_args provided - use early-computed schema + context_schema, context_td = context_schema_early, context_td_early + elif not use_context_args: + # Explicit context mode - schema comes from the context type's fields + if hasattr(resolved_context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in resolved_context_type.model_fields.items()} + # For dynamic mode (is_dynamic_mode), _context_schema remains empty + # and schema is built dynamically from the instance's unresolved runtime inputs. + + # Store context schema for TypedDict-based validation (picklable!) + GeneratedModel._context_schema = context_schema + GeneratedModel._context_td = context_td + # Validator is created lazily to survive pickling + GeneratedModel._cached_context_validator = None + + # Register the MODEL class for serialization (needed for model_dump/_target_). + # Note: We do NOT register dynamic context classes anymore - context handling + # uses FlowContext + TypedDict instead, which don't need registration. + register_ccflow_import_path(GeneratedModel) + + # Rebuild the model to process annotations properly + GeneratedModel.model_rebuild() + + # Create factory function that returns model instances + @wraps(fn) + def factory(**kwargs) -> _GeneratedFlowModelBase: + _validate_config_kwargs(kwargs, _validatable_types, _config_validators) + return GeneratedModel(**kwargs) + + # Preserve useful attributes on factory + cast(Any, factory)._generated_model = GeneratedModel + factory.__doc__ = fn.__doc__ + + return factory + + # Handle both @Flow.model and @Flow.model(...) syntax + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..41acfaf --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,80 @@ +# Flow.model configurations for Hydra integration tests +# This file is separate from conf.yaml to avoid affecting existing tests + +# Basic Flow.model +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: test_source + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +# Pipeline with dependencies (uses registry name references for same instance) +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +# Three-stage pipeline +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +# Diamond dependency pattern +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +# DateRangeContext with transform +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day + source: market_data + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +# context_args models (auto-unpacked context parameters) +ctx_args_loader: + _target_: ccflow.tests.test_flow_model.context_args_loader + source: data_source + +ctx_args_processor: + _target_: ccflow.tests.test_flow_model.context_args_processor + data: ctx_args_loader + prefix: "output" diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..9b51592 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -462,6 +462,7 @@ def test_types(self): error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelMissingContextArg) + # BadModelDoubleContextArg also fails with the same error since extra params aren't allowed error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg) @@ -783,3 +784,373 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +# ============================================================================= +# Tests for Flow.call(auto_context=True) +# ============================================================================= + + +class TestAutoContext(TestCase): + """Tests for @Flow.call(auto_context=True).""" + + def test_basic_usage_with_kwargs(self): + """Test basic auto_context usage with keyword arguments.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + # Call with kwargs + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + # Call with default + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_auto_context_attribute(self): + """Test that __auto_context__ attribute is set.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, a: int, b: str) -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + # The __call__ method should have __auto_context__ + call_method = AutoContextCallable.__call__ + self.assertTrue(hasattr(call_method, "__wrapped__")) + # Access the inner function's __auto_context__ + inner = call_method.__wrapped__ + self.assertTrue(hasattr(inner, "__auto_context__")) + + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("a", auto_ctx.model_fields) + self.assertIn("b", auto_ctx.model_fields) + + def test_auto_context_is_registered(self): + """Test that the auto context is registered for serialization.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = AutoContextCallable.__call__.__wrapped__ + auto_ctx = inner.__auto_context__ + + # Should have __ccflow_import_path__ set + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + """Test calling with a context object instead of kwargs.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + # Get the auto context class + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + # Create a context object + ctx = auto_ctx(x=99, y="context") + result = model(ctx) + self.assertEqual(result.value, "99-context") + + def test_with_parent_context(self): + """Test auto_context with a parent context class.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + # Get auto context + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + # Should inherit from ParentContext + self.assertTrue(issubclass(auto_ctx, ParentContext)) + + # Should have both fields + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) + + # Create context with parent field + ctx = auto_ctx(x=42, base_value="custom") + self.assertEqual(ctx.base_value, "custom") + self.assertEqual(ctx.x, 42) + + def test_parent_fields_must_be_in_signature(self): + """Test that parent context fields must be included in function signature.""" + + class ParentContext(ContextBase): + required_field: str + + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + self.assertIn("required_field", str(cm.exception)) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip for auto_context callable.""" + + class AutoContextCallable(CallableModel): + multiplier: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = AutoContextCallable(multiplier=3) + + # Test roundtrip + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task_execution(self): + """Test auto_context callable in Ray task.""" + + class AutoContextCallable(CallableModel): + factor: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = AutoContextCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) # (10 + 2) * 5 + + def test_context_type_property_works(self): + """Test that type_ property works on the auto context.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=42) + + # type_ should work and be importable + type_path = str(ctx.type_) + self.assertIn("_Local_", type_path) + self.assertEqual(ctx.type_.object, auto_ctx) + + def test_complex_field_types(self): + """Test auto_context with complex field types.""" + from typing import List, Optional + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__( + self, + *, + items: List[int], + name: Optional[str] = None, + count: int = 0, + ) -> GenericResult: + total = sum(items) + count + return GenericResult(value=f"{name}:{total}" if name else str(total)) + + model = AutoContextCallable() + + result = model(items=[1, 2, 3], name="test", count=10) + self.assertEqual(result.value, "test:16") + + result = model(items=[5, 5]) + self.assertEqual(result.value, "10") + + def test_with_flow_options(self): + """Test auto_context with FlowOptions parameters.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True, validate_result=False) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = AutoContextCallable() + result = model(x=42) + self.assertEqual(result.value, 42) + + def test_error_without_auto_context(self): + """Test that using kwargs signature without auto_context raises an error.""" + + class BadCallable(CallableModel): + @Flow.call # Missing auto_context=True! + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + # Error happens at instantiation time when _check_signature validates + with self.assertRaises(ValueError) as cm: + BadCallable() + + # Should fail because __call__ must take a single argument named 'context' + error_msg = str(cm.exception) + self.assertIn("__call__", error_msg) + self.assertIn("context", error_msg) + + def test_invalid_auto_context_value(self): + """Test that invalid auto_context values raise TypeError with helpful message.""" + with self.assertRaises(TypeError) as cm: + + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + error_msg = str(cm.exception) + self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) + + def test_auto_context_rejects_var_args(self): + """auto_context should reject *args early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + self.assertIn("variadic positional", str(cm.exception)) + + def test_auto_context_rejects_var_kwargs(self): + """auto_context should reject **kwargs early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + self.assertIn("variadic keyword", str(cm.exception)) + + def test_auto_context_requires_return_annotation(self): + """auto_context should reject missing return annotations immediately.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + self.assertIn("must have a return type annotation", str(cm.exception)) + + def test_auto_context_rejects_missing_annotation(self): + """auto_context should reject params without type annotations.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + self.assertIn("must have a type annotation", str(cm.exception)) + + +class TestDeclaredTypeMatches(TestCase): + """Tests for _declared_type_matches helper in callable.py.""" + + def test_typevar_always_matches(self): + from ccflow.callable import _declared_type_matches + + T = TypeVar("T") + self.assertTrue(_declared_type_matches(int, T)) + + def test_union_expected_no_type_args(self): + """Union with no concrete type args should return False.""" + from ccflow.callable import _declared_type_matches + + # Union[None] after filtering out NoneType has no concrete args + self.assertFalse(_declared_type_matches(int, Union[None])) + + def test_union_expected_with_actual_type(self): + """Concrete type matching Union expected.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(int, Union[int, str])) + self.assertFalse(_declared_type_matches(float, Union[int, str])) + + def test_union_both_sides(self): + """Both actual and expected are Unions.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(Union[int, str], Union[int, str])) + self.assertTrue(_declared_type_matches(Union[str, int], Union[int, str])) # order independent + self.assertFalse(_declared_type_matches(Union[int, float], Union[int, str])) + + def test_non_type_actual(self): + """Non-type actual should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches("not_a_type", int)) + + def test_non_type_expected(self): + """Non-type expected should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches(int, "not_a_type")) + + +class TestCallableModelGenericValidation(TestCase): + """Tests for CallableModelGeneric type validation paths.""" + + def test_context_type_mismatch_raises(self): + """Generic type validation should reject context type mismatch.""" + + class ContextA(ContextBase): + a: int + + class ContextB(ContextBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: ContextA) -> GenericResult[int]: + return GenericResult(value=context.a) + + with self.assertRaises(ValidationError): + # Expect ContextB but model has ContextA + CallableModelGenericType[ContextB, GenericResult[int]].model_validate(ModelA()) + + def test_result_type_mismatch_raises(self): + """Generic type validation should reject result type mismatch.""" + + class MyContext(ContextBase): + x: int + + class ResultA(ResultBase): + a: int + + class ResultB(ResultBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> ResultA: + return ResultA(a=context.x) + + with self.assertRaises(ValidationError): + CallableModelGenericType[MyContext, ResultB].model_validate(ModelA()) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and auto_context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..970cc08 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,596 @@ +"""Tests for FlowContext, FlowAPI, and TypedDict-based context validation. + +These tests verify the new deferred computation API that uses: +- FlowContext: Universal context carrier with extra="allow" +- TypedDict + TypeAdapter: Schema validation without dynamic class registration +- FlowAPI: The .flow namespace for compute/with_inputs/etc. +""" + +import pickle +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import CallableModel, ContextBase, Flow, FlowAPI, FlowContext, GenericResult +from ccflow.context import DateRangeContext + + +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + +class TestFlowContext: + """Tests for the FlowContext universal carrier.""" + + def test_flow_context_basic(self): + """FlowContext accepts arbitrary fields.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + + def test_flow_context_extra_fields(self): + """FlowContext exposes arbitrary fields through normal model APIs.""" + ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) + assert ctx.x == 1 + assert ctx.y == "hello" + assert ctx.z == [1, 2, 3] + assert dict(ctx) == {"x": 1, "y": "hello", "z": [1, 2, 3]} + + def test_flow_context_frozen(self): + """FlowContext is immutable (frozen).""" + ctx = FlowContext(value=42) + with pytest.raises(Exception): # ValidationError for frozen model + ctx.value = 100 + + def test_flow_context_repr(self): + """FlowContext has a useful repr.""" + ctx = FlowContext(a=1, b=2) + repr_str = repr(ctx) + assert "FlowContext" in repr_str + assert "a=1" in repr_str + assert "b=2" in repr_str + + def test_flow_context_attribute_error(self): + """FlowContext raises AttributeError for missing fields.""" + ctx = FlowContext(x=1) + with pytest.raises(AttributeError, match="no attribute 'missing'"): + _ = ctx.missing + + def test_flow_context_model_dump(self): + """FlowContext can be dumped (includes extra fields).""" + ctx = FlowContext(start_date=date(2024, 1, 1), value=42) + dumped = ctx.model_dump() + assert dumped["start_date"] == date(2024, 1, 1) + assert dumped["value"] == 42 + + def test_flow_context_value_semantics_include_extra_fields(self): + """Equality should reflect the actual extra payload.""" + assert FlowContext(x=1) == FlowContext(x=1) + assert FlowContext(x=1) != FlowContext(x=2) + assert FlowContext(x=1) != FlowContext(y=1) + + def test_flow_context_hash_uses_extra_fields(self): + """Distinct extra payloads should remain distinct in hashed collections.""" + first = FlowContext(values=[1, 2], label="a") + second = FlowContext(values=[1, 3], label="a") + third = FlowContext(values=[1, 2], label="b") + + assert len({first, second, third}) == 3 + + def test_flow_context_hash_raises_for_unhashable_values(self): + """FlowContext with truly unhashable values (no __dict__) should raise TypeError.""" + + class Unhashable: + __hash__ = None # type: ignore[assignment] + + def __init__(self): + pass + + # Deliberately no __dict__ suppression — but __hash__ is None, + # so the fallback path in _freeze_for_hash should use __dict__. + # To trigger the actual TypeError path, we need an object with + # no __dict__ and no __hash__. + + class UnhashableSlots: + __slots__ = () + __hash__ = None # type: ignore[assignment] + + ctx = FlowContext(val=UnhashableSlots()) + with pytest.raises(TypeError, match="unhashable value"): + hash(ctx) + + def test_flow_context_eq_non_flow_context(self): + """FlowContext.__eq__ returns False for non-FlowContext objects.""" + ctx = FlowContext(x=1) + assert ctx != 42 + assert ctx != "hello" + assert ctx != None # noqa: E711 + assert ctx != NumberContext(x=1) + + def test_flow_context_hash_with_set_value(self): + """FlowContext with set values should hash correctly via frozenset.""" + ctx = FlowContext(tags=frozenset({"a", "b"})) + # Should not raise + h = hash(ctx) + assert isinstance(h, int) + + def test_flow_context_hash_with_model_dump_object(self): + """_freeze_for_hash should handle objects with model_dump attribute.""" + from ccflow.context import _freeze_for_hash + + # Directly test _freeze_for_hash with an object that has model_dump + # (FlowContext.__hash__ goes through model_dump first which serializes + # nested models, so we test the helper directly) + inner = NumberContext(x=42) + result = _freeze_for_hash(inner) + assert isinstance(result, tuple) + assert result[0] is NumberContext + + def test_flow_context_hash_unhashable_with_dict_fallback(self): + """Objects with __dict__ but no __hash__ should use __dict__ fallback.""" + + class UnhashableWithDict: + __hash__ = None # type: ignore[assignment] + + def __init__(self, val): + self.val = val + + ctx = FlowContext(obj=UnhashableWithDict(42)) + h = hash(ctx) + assert isinstance(h, int) + + def test_flow_context_pickle(self): + """FlowContext pickles cleanly.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + assert unpickled.start_date == date(2024, 1, 1) + assert unpickled.end_date == date(2024, 1, 31) + + def test_flow_context_cloudpickle(self): + """FlowContext works with cloudpickle (for Ray).""" + ctx = FlowContext(data=[1, 2, 3], name="test") + pickled = cloudpickle.dumps(ctx) + unpickled = cloudpickle.loads(pickled) + assert unpickled.data == [1, 2, 3] + assert unpickled.name == "test" + + +class TestFlowAPI: + """Tests for the FlowAPI (.flow namespace).""" + + def test_flow_compute_basic(self): + """FlowAPI.compute() validates and executes.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date, "source": source}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + assert result.value["source"] == "api" + + def test_flow_compute_type_coercion(self): + """FlowAPI.compute() coerces types via TypeAdapter.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Pass strings - should be coerced to dates + result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_flow_compute_validation_error(self): + """FlowAPI.compute() raises on missing required args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + with pytest.raises(Exception): # ValidationError + model.flow.compute(start_date=date(2024, 1, 1)) # Missing end_date + + def test_flow_unbound_inputs(self): + """FlowAPI.unbound_inputs returns the context schema.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + unbound = model.flow.unbound_inputs + + assert "start_date" in unbound + assert "end_date" in unbound + assert unbound["start_date"] == date + assert unbound["end_date"] == date + # source is not unbound (it has a default/is bound) + assert "source" not in unbound + + def test_flow_bound_inputs(self): + """FlowAPI.bound_inputs returns config values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + bound = model.flow.bound_inputs + + assert "source" in bound + assert bound["source"] == "api" + # Context args are not in bound_inputs + assert "start_date" not in bound + assert "end_date" not in bound + + def test_flow_compute_regular_callable_model(self): + """Regular CallableModels also expose .flow.compute().""" + + model = OffsetModel(offset=10) + result = model.flow.compute(x=5) + + assert result.value == 15 + + def test_flow_unbound_inputs_regular_callable_model(self): + """Regular CallableModels expose their context schema as unbound inputs.""" + + model = OffsetModel(offset=10) + unbound = model.flow.unbound_inputs + + assert unbound == {"x": int} + + def test_flow_bound_inputs_regular_callable_model(self): + """Regular CallableModels expose their configured fields as bound inputs.""" + + model = OffsetModel(offset=10) + bound = model.flow.bound_inputs + + assert bound["offset"] == 10 + + +class TestBoundModel: + """Tests for BoundModel (created via .flow.with_inputs()).""" + + def test_with_inputs_static_value(self): + """with_inputs can bind static values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs(start_date=date(2024, 1, 1)) + + # Call with just end_date (start_date is bound) + ctx = FlowContext(end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_transform_function(self): + """with_inputs can use transform functions.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Lookback: start_date is 7 days before the context's start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_multiple_transforms(self): + """with_inputs can apply multiple transforms.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + timedelta(days=1), + ) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_bound_model_has_flow_property(self): + """BoundModel has a .flow property.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=42) + assert isinstance(bound.flow, FlowAPI) + + def test_bound_model_repr_looks_like_with_inputs_call(self): + """BoundModel repr should mirror the API users wrote.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) + + assert repr(bound) == f"{model!r}.flow.with_inputs(x=)" + + def test_with_inputs_regular_callable_model(self): + """Regular CallableModels support .flow.with_inputs().""" + + model = OffsetModel(offset=1) + shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) + + result = shifted(NumberContext(x=5)) + assert result.value == 11 + + +class TestTypedDictValidation: + """Tests for TypedDict-based context validation.""" + + def test_schema_stored_on_model(self): + """Model stores _context_schema for validation.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + assert hasattr(model, "_context_schema") + assert model._context_schema == {"start_date": date, "end_date": date} + + def test_validator_created_lazily(self): + """TypeAdapter validator is created lazily.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Initially None + assert model.__class__._cached_context_validator is None + + # After getting validator, it's cached + validator = model._get_context_validator() + assert validator is not None + assert model.__class__._cached_context_validator is validator + + def test_explicit_context_type_override(self): + """context_type can opt into an existing ContextBase subclass.""" + + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + assert model.context_type == DateRangeContext + + +class TestPicklingSupport: + """Tests for pickling support (important for Ray). + + Note: Regular pickle cannot pickle locally-defined classes (functions decorated + inside test methods). cloudpickle CAN handle this, which is why Ray uses it. + All tests here use cloudpickle to match Ray's behavior. + """ + + def test_model_cloudpickle_roundtrip(self): + """Model works with cloudpickle (for Ray).""" + + @Flow.model(context_args=["x", "y"]) + def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=(x + y) * multiplier) + + model = compute(multiplier=3) + + # cloudpickle roundtrip (what Ray uses) + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Should work after unpickling + result = unpickled.flow.compute(x=1, y=2) + assert result.value == 9 # (1 + 2) * 3 + + def test_model_cloudpickle_simple(self): + """Simple model cloudpickle test.""" + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + result = unpickled.flow.compute(value=21) + assert result.value == 42 + + def test_validator_recreated_after_cloudpickle(self): + """TypeAdapter validator is recreated after cloudpickling.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Warm up the validator cache + _ = model._get_context_validator() + assert model.__class__._cached_context_validator is not None + + # cloudpickle and unpickle + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Validator should still work (may be lazily recreated) + result = unpickled.flow.compute(x=42) + assert result.value == 42 + + def test_flow_context_pickle_standard(self): + """FlowContext works with standard pickle.""" + ctx = FlowContext(x=1, y=2, z="test") + + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + + assert unpickled.x == 1 + assert unpickled.y == 2 + assert unpickled.z == "test" + + +class TestIntegrationWithExistingContextTypes: + """Tests for integration with existing ContextBase subclasses.""" + + def test_explicit_context_still_works(self): + """Explicit context parameter mode still works.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date, "source": source}) + + model = load_data(source="api") + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["source"] == "api" + + def test_flow_context_coerces_to_date_range(self): + """FlowContext can be used with models expecting DateRangeContext.""" + + @Flow.model + def load_data(context: DateRangeContext) -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data() + # Use FlowContext - should coerce to DateRangeContext + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_flow_api_with_explicit_context(self): + """FlowAPI.compute works with explicit context mode.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + +class TestLazy: + """Tests for Lazy (deferred execution with context overrides).""" + + def test_lazy_basic(self): + """Lazy wraps a model for deferred execution.""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def compute(value: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + model = compute(multiplier=3) + lazy = Lazy(model) + + assert lazy.model is model + + def test_lazy_call_with_static_override(self): + """Lazy.__call__ with static override values.""" + from ccflow import Lazy + + @Flow.model(context_args=["x", "y"]) + def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = add() + lazy_fn = Lazy(model)(y=100) # Override y to 100 + + ctx = FlowContext(x=5, y=10) # Original y=10 + result = lazy_fn(ctx) + assert result.value == 105 # x=5 + y=100 (overridden) + + def test_lazy_call_with_callable_override(self): + """Lazy.__call__ with callable override (computed at runtime).""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + # Override value to be original value + 10 + lazy_fn = Lazy(model)(value=lambda ctx: ctx.value + 10) + + ctx = FlowContext(value=5) + result = lazy_fn(ctx) + assert result.value == 30 # (5 + 10) * 2 = 30 + + def test_lazy_with_date_transforms(self): + """Lazy works with date transforms.""" + from ccflow import Lazy + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + + # Use Lazy to create a transform that shifts dates + lazy_fn = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7), end_date=lambda ctx: ctx.end_date) + + ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) + result = lazy_fn(ctx) + + assert result.value["start"] == date(2024, 1, 8) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_lazy_multiple_overrides(self): + """Lazy supports multiple overrides at once.""" + from ccflow import Lazy + + @Flow.model(context_args=["a", "b", "c"]) + def compute(a: int, b: int, c: int) -> GenericResult[int]: + return GenericResult(value=a + b + c) + + model = compute() + lazy_fn = Lazy(model)( + a=10, # Static + b=lambda ctx: ctx.b * 2, # Transform + # c not overridden, uses context value + ) + + ctx = FlowContext(a=1, b=5, c=100) + result = lazy_fn(ctx) + assert result.value == 10 + 10 + 100 # a=10, b=5*2=10, c=100 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..57a54df --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,2986 @@ +"""Tests for Flow.model decorator.""" + +from datetime import date, timedelta +from unittest import TestCase + +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +from ccflow import ( + BaseModel, + CallableModel, + ContextBase, + DateRangeContext, + Flow, + FlowContext, + FlowOptionsOverride, + GenericResult, + Lazy, + ModelRegistry, + ResultBase, +) +from ccflow.evaluators.common import MemoryCacheEvaluator + + +class SimpleContext(ContextBase): + """Simple context for testing.""" + + value: int + + +class ExtendedContext(ContextBase): + """Extended context with multiple fields.""" + + x: int + y: str = "default" + + +class MyResult(ResultBase): + """Custom result type for testing.""" + + data: str + + +# ============================================================================= +# Basic Flow.model Tests +# ============================================================================= + + +class TestFlowModelBasic(TestCase): + """Basic Flow.model functionality tests.""" + + def test_simple_model_explicit_context(self): + """Test Flow.model with explicit context parameter.""" + + @Flow.model + def simple_loader(context: SimpleContext, multiplier: int) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + # Create model instance + loader = simple_loader(multiplier=3) + + # Should be a CallableModel + self.assertIsInstance(loader, CallableModel) + + # Execute + ctx = SimpleContext(value=10) + result = loader(ctx) + + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 30) + + def test_model_with_default_params(self): + """Test Flow.model with default parameter values.""" + + @Flow.model + def loader_with_defaults(context: SimpleContext, multiplier: int = 2, prefix: str = "result") -> GenericResult[str]: + return GenericResult(value=f"{prefix}:{context.value * multiplier}") + + # Create with defaults + loader = loader_with_defaults() + result = loader(SimpleContext(value=5)) + self.assertEqual(result.value, "result:10") + + # Create with custom values + loader2 = loader_with_defaults(multiplier=3, prefix="custom") + result2 = loader2(SimpleContext(value=5)) + self.assertEqual(result2.value, "custom:15") + + def test_model_context_type_property(self): + """Test that generated model has correct context_type.""" + + @Flow.model + def typed_model(context: ExtendedContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.x * factor) + + model = typed_model(factor=2) + self.assertEqual(model.context_type, ExtendedContext) + + def test_model_result_type_property(self): + """Test that generated model has correct result_type.""" + + @Flow.model + def custom_result_model(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + model = custom_result_model() + self.assertEqual(model.result_type, MyResult) + + def test_model_with_no_extra_params(self): + """Test Flow.model with only context parameter.""" + + @Flow.model + def identity_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + model = identity_model() + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, 42) + + def test_model_with_flow_options(self): + """Test Flow.model with Flow.call options.""" + + @Flow.model(cacheable=True, validate_result=True) + def cached_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + model = cached_model(value=10) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 15) + + def test_model_with_underscore_context(self): + """Test Flow.model with '_' as context parameter (unused context convention).""" + + @Flow.model + def loader(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def consumer(_: SimpleContext, data: int) -> GenericResult[int]: + # Context not used directly, just passed to dependency + return GenericResult(value=data * 2) + + load = loader(base=100) + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 + 100 = 110, consumer: 110 * 2 = 220 + self.assertEqual(result.value, 220) + + # Verify context_type is still correct + self.assertEqual(consume.context_type, SimpleContext) + + +# ============================================================================= +# context_args Mode Tests +# ============================================================================= + + +class TestFlowModelContextArgs(TestCase): + """Tests for Flow.model with context_args (unpacked context).""" + + def test_context_args_basic(self): + """Test basic context_args usage.""" + + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + loader = date_range_loader(source="db") + + # Explicit context_type keeps compatibility with existing contexts. + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "db:2024-01-01 to 2024-01-31") + + def test_context_args_custom_context(self): + """Test context_args with custom context type.""" + + @Flow.model(context_args=["x", "y"]) + def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: + return GenericResult(value=f"{y}:{x * multiplier}") + + model = unpacked_model(multiplier=2) + + # Default context_args mode uses FlowContext unless overridden explicitly. + self.assertEqual(model.context_type, FlowContext) + + # Create context with generated type + ctx_type = model.context_type + ctx = ctx_type(x=5, y="test") + + result = model(ctx) + self.assertEqual(result.value, "test:10") + + def test_context_args_with_defaults(self): + """Test context_args where context fields have defaults.""" + + @Flow.model(context_args=["value"]) + def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult[str]: + return GenericResult(value=f"{extra}:{value}") + + model = model_with_ctx_default() + + # Create context - the generated context should allow default + ctx_type = model.context_type + ctx = ctx_type(value=100) + + result = model(ctx) + self.assertEqual(result.value, "foo:100") + + +# ============================================================================= +# Dependency Tests +# ============================================================================= + + +class TestFlowModelDependencies(TestCase): + """Tests for Flow.model with upstream CallableModel inputs.""" + + def test_simple_dependency(self): + """Test passing an upstream model as a normal parameter.""" + + @Flow.model + def loader(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: int, + multiplier: int = 1, + ) -> GenericResult[int]: + return GenericResult(value=data * multiplier) + + # Create pipeline + load = loader(value=10) + consume = consumer(data=load, multiplier=2) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 + self.assertEqual(result.value, 30) + + def test_dependency_with_direct_value(self): + """Test that dependency-shaped parameters can also take direct values.""" + + @Flow.model + def consumer( + context: SimpleContext, + data: int, + ) -> GenericResult[int]: + return GenericResult(value=data + context.value) + + consume = consumer(data=100) + + result = consume(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_deps_method_generation(self): + """Test that __deps__ method is correctly generated.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: int, + ) -> GenericResult[int]: + return GenericResult(value=data) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=10) + deps = consume.__deps__(ctx) + + # Should have one dependency + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0][0], load) + self.assertEqual(deps[0][1], [ctx]) + + def test_no_deps_when_direct_value(self): + """Test that __deps__ returns empty when direct values used.""" + + @Flow.model + def consumer( + context: SimpleContext, + data: int, + ) -> GenericResult[int]: + return GenericResult(value=data) + + consume = consumer(data=100) + + deps = consume.__deps__(SimpleContext(value=10)) + self.assertEqual(len(deps), 0) + + +# ============================================================================= +# with_inputs Tests +# ============================================================================= + + +class TestFlowModelWithInputs(TestCase): + """Tests for Flow.model with .flow.with_inputs().""" + + def test_transformed_dependency_with_inputs(self): + """Test dependency context transformation via .flow.with_inputs().""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data * 2) + + load = loader().flow.with_inputs(value=lambda ctx: ctx.value + 10) + consume = consumer(data=load) + + result = consume(SimpleContext(value=5)) + self.assertEqual(result.value, 30) + + def test_with_inputs_changes_dependency_context_in_deps(self): + """Test that BoundModel contributes transformed dependency contexts.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data) + + load = loader().flow.with_inputs(value=lambda ctx: ctx.value * 3) + consume = consumer(data=load) + + deps = consume.__deps__(SimpleContext(value=7)) + self.assertEqual(len(deps), 1) + transformed_ctx = deps[0][1][0] + self.assertEqual(transformed_ctx.value, 21) + + def test_date_range_transform_with_inputs(self): + """Test date-range lookback wiring via .flow.with_inputs().""" + + @Flow.model(context_args=["start_date", "end_date"]) + def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date}") + + @Flow.model(context_args=["start_date", "end_date"]) + def range_processor( + start_date: date, + end_date: date, + data: str, + ) -> GenericResult[str]: + return GenericResult(value=f"processed:{data}") + + loader = range_loader(source="db").flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=1)) + processor = range_processor(data=loader) + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + self.assertEqual(result.value, "processed:db:2024-01-09") + + +# ============================================================================= +# Pipeline Tests +# ============================================================================= + + +class TestFlowModelPipeline(TestCase): + """Tests for multi-stage pipelines with Flow.model.""" + + def test_three_stage_pipeline(self): + """Test a three-stage computation pipeline.""" + + @Flow.model + def stage1(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def stage2( + context: SimpleContext, + input_data: int, + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=input_data * multiplier) + + @Flow.model + def stage3( + context: SimpleContext, + input_data: int, + offset: int = 0, + ) -> GenericResult[int]: + return GenericResult(value=input_data + offset) + + # Build pipeline + s1 = stage1(base=100) + s2 = stage2(input_data=s1, multiplier=2) + s3 = stage3(input_data=s2, offset=50) + + ctx = SimpleContext(value=10) + result = s3(ctx) + + # s1: 10 + 100 = 110 + # s2: 110 * 2 = 220 + # s3: 220 + 50 = 270 + self.assertEqual(result.value, 270) + + def test_diamond_dependency_pattern(self): + """Test diamond-shaped dependency pattern.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def branch_a( + context: SimpleContext, + data: int, + ) -> GenericResult[int]: + return GenericResult(value=data * 2) + + @Flow.model + def branch_b( + context: SimpleContext, + data: int, + ) -> GenericResult[int]: + return GenericResult(value=data + 100) + + @Flow.model + def merger( + context: SimpleContext, + a: int, + b: int, + ) -> GenericResult[int]: + return GenericResult(value=a + b) + + src = source() + a = branch_a(data=src) + b = branch_b(data=src) + merge = merger(a=a, b=b) + + ctx = SimpleContext(value=10) + result = merge(ctx) + + # source: 10 + # branch_a: 10 * 2 = 20 + # branch_b: 10 + 100 = 110 + # merger: 20 + 110 = 130 + self.assertEqual(result.value, 130) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestFlowModelIntegration(TestCase): + """Integration tests for Flow.model with ccflow infrastructure.""" + + def test_registry_integration(self): + """Test that Flow.model models work with ModelRegistry.""" + + @Flow.model + def registrable_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=context.value + value) + + model = registrable_model(value=100) + + registry = ModelRegistry.root().clear() + registry.add("test_model", model) + + retrieved = registry["test_model"] + self.assertEqual(retrieved, model) + + result = retrieved(SimpleContext(value=10)) + self.assertEqual(result.value, 110) + + def test_serialization_dump(self): + """Test that generated models can be serialized.""" + + @Flow.model + def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + return GenericResult(value=value) + + model = serializable_model(value=100) + dumped = model.model_dump(mode="python") + + self.assertIn("value", dumped) + self.assertEqual(dumped["value"], 100) + self.assertIn("type_", dumped) + + def test_serialization_roundtrip_preserves_bound_inputs(self): + """Round-tripping should preserve which inputs were bound at construction.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["x"], 10) + self.assertNotIn("y", dumped) + self.assertEqual(restored.flow.bound_inputs, {"x": 10}) + self.assertEqual(restored.flow.unbound_inputs, {"y": int}) + self.assertEqual(restored.flow.compute(y=5).value, 15) + + def test_serialization_roundtrip_preserves_defaults_and_deferred_inputs(self): + """Default-valued params should serialize normally without binding runtime-only inputs.""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + model = load() + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["source"], "warehouse") + self.assertNotIn("start_date", dumped) + self.assertEqual(restored.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(restored.flow.unbound_inputs, {"start_date": str}) + self.assertEqual(restored.flow.compute(start_date="2024-01-01").value, "warehouse:2024-01-01") + + def test_pickle_roundtrip(self): + """Test cloudpickle serialization of generated models.""" + + @Flow.model + def pickleable_model(context: SimpleContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.value * factor) + + model = pickleable_model(factor=3) + + # Cloudpickle roundtrip (standard pickle won't work for local classes) + pickled = rcpdumps(model, protocol=5) + restored = rcploads(pickled) + + result = restored(SimpleContext(value=10)) + self.assertEqual(result.value, 30) + + def test_mix_with_manual_callable_model(self): + """Test mixing Flow.model with manually defined CallableModel.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def generated_consumer( + context: SimpleContext, + data: int, + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=data * multiplier) + + manual = ManualModel(offset=50) + generated = generated_consumer(data=manual, multiplier=2) + + result = generated(SimpleContext(value=10)) + # manual: 10 + 50 = 60 + # generated: 60 * 2 = 120 + self.assertEqual(result.value, 120) + + +# ============================================================================= +# Error Case Tests +# ============================================================================= + + +class TestFlowModelErrors(TestCase): + """Error case tests for Flow.model.""" + + def test_missing_return_type(self): + """Test error when return type annotation is missing.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_return(context: SimpleContext): + return GenericResult(value=1) + + self.assertIn("return type annotation", str(cm.exception)) + + def test_auto_wrap_plain_return_type(self): + """Test that non-ResultBase return types are auto-wrapped in GenericResult.""" + + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value * 2 + + model = plain_return() + result = model(SimpleContext(value=5)) + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 10) + + def test_auto_wrap_unwrap_as_dependency(self): + """Test that auto-wrapped model used as dep delivers unwrapped value downstream. + + Auto-wrapped models have result_type=GenericResult (unparameterized). + When used as an auto-detected dep, the framework resolves + the GenericResult to its inner value for the downstream function. + """ + + @Flow.model + def plain_source(context: SimpleContext) -> int: + return context.value * 3 + + @Flow.model + def consumer( + context: SimpleContext, + data: GenericResult[int], # Auto-detected dep + ) -> GenericResult[int]: + # data is auto-unwrapped to the int value by the framework + return GenericResult(value=data + 1) + + src = plain_source() + model = consumer(data=src) + result = model(SimpleContext(value=10)) + # plain_source: 10 * 3 = 30, auto-wrapped to GenericResult(value=30) + # resolve_callable_model unwraps GenericResult -> 30 + # consumer: 30 + 1 = 31 + self.assertEqual(result.value, 31) + + def test_auto_wrap_result_type_property(self): + """Test that auto-wrapped model has GenericResult as result_type.""" + + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value + + model = plain_return() + self.assertEqual(model.result_type, GenericResult) + + def test_dynamic_deferred_mode(self): + """Test dynamic deferred mode where what you provide at construction = bound.""" + from ccflow import FlowContext + + @Flow.model + def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + # Provide 'multiplier' at construction -> it's bound + # Don't provide 'value' -> comes from context + model = dynamic_model(multiplier=3) + + # Check bound vs unbound + self.assertEqual(model.flow.bound_inputs, {"multiplier": 3}) + self.assertEqual(model.flow.unbound_inputs, {"value": int}) + + # Call with context providing 'value' + ctx = FlowContext(value=10) + result = model(ctx) + self.assertEqual(result.value, 30) # 10 * 3 + + def test_dynamic_deferred_mode_missing_runtime_inputs_is_clear(self): + """Missing deferred inputs should fail at the framework boundary.""" + + @Flow.model + def dynamic_model(value: int, multiplier: int) -> int: + return value * multiplier + + model = dynamic_model() + + with self.assertRaises(TypeError) as cm: + model.flow.compute() + + self.assertIn("Missing runtime input(s) for dynamic_model: multiplier, value", str(cm.exception)) + + def test_all_defaults_is_valid(self): + """All-default functions should treat those defaults as bound config.""" + from ccflow import FlowContext + + @Flow.model + def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{value}-{other}") + + model = all_defaults() + + self.assertEqual(model.flow.bound_inputs, {"value": 1, "other": "x"}) + self.assertEqual(model.flow.unbound_inputs, {}) + + ctx = FlowContext(value=5, other="y") + result = model(ctx) + self.assertEqual(result.value, "1-x") + + def test_invalid_context_arg(self): + """Test error when context_args refers to non-existent parameter.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["nonexistent"]) + def bad_context_args(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("nonexistent", str(cm.exception)) + + def test_context_arg_without_annotation(self): + """Test error when context_arg parameter lacks type annotation.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["x"]) + def untyped_context_arg(x) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("type annotation", str(cm.exception)) + + def test_context_type_requires_context_args_mode(self): + """context_type is only valid alongside context_args.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_type=DateRangeContext) + def dynamic_model(value: int) -> GenericResult[int]: + return GenericResult(value=value) + + self.assertIn("context_args", str(cm.exception)) + + def test_context_type_must_cover_context_args(self): + """context_type must expose all named context_args fields.""" + + class StartOnlyContext(ContextBase): + start_date: date + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["start_date", "end_date"], context_type=StartOnlyContext) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + self.assertIn("end_date", str(cm.exception)) + + +# ============================================================================= +# Validation Tests +# ============================================================================= + + +class TestFlowModelValidation(TestCase): + """Tests for Flow.model validation behavior.""" + + def test_config_validation_rejects_bad_type(self): + """Test that config validator rejects wrong types at construction.""" + + @Flow.model + def typed_config(context: SimpleContext, n_estimators: int = 10) -> GenericResult[int]: + return GenericResult(value=n_estimators) + + with self.assertRaises(TypeError) as cm: + typed_config(n_estimators="banana") + + self.assertIn("n_estimators", str(cm.exception)) + + def test_config_validation_accepts_callable_model(self): + """Test that config validator allows CallableModel values for any field.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data) + + # Passing a CallableModel for an int field should not raise + src = source() + model = consumer(data=src) + self.assertIsNotNone(model) + + def test_config_validation_accepts_correct_types(self): + """Test that config validator accepts correct types.""" + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{name}:{n}") + + # Should not raise + model = typed_config(n=42, name="test") + result = model(SimpleContext(value=1)) + self.assertEqual(result.value, "test:42") + + def test_config_validation_rejects_registry_alias_for_incompatible_type(self): + """Registry aliases should not silently bypass scalar type validation.""" + + class DummyConfig(BaseModel): + x: int = 1 + + registry = ModelRegistry.root() + registry.clear() + try: + registry.add("dummy_config", DummyConfig()) + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + with self.assertRaises(TypeError) as cm: + typed_config(n="dummy_config") + + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + + def test_config_validation_accepts_registry_alias_for_callable_dependency(self): + """Registry aliases still work for CallableModel dependencies.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data + 1) + + registry.add("source_model", source()) + model = consumer(data="source_model") + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 11) + finally: + registry.clear() + + def test_context_type_annotation_mismatch_raises(self): + """context_type validation should reject incompatible field annotations.""" + + class StringIdContext(ContextBase): + item_id: str + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["item_id"], context_type=StringIdContext) + def load(item_id: int) -> int: + return item_id + + self.assertIn("item_id", str(cm.exception)) + self.assertIn("int", str(cm.exception)) + self.assertIn("str", str(cm.exception)) + + def test_model_validate_rejects_bad_scalar_type(self): + """model_validate should reject wrong scalar types, not silently accept them.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + with self.assertRaises(TypeError) as cm: + cls.model_validate({"x": "abc"}) + + self.assertIn("x", str(cm.exception)) + + def test_model_validate_accepts_correct_type(self): + """model_validate should accept correct types.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + restored = cls.model_validate({"x": 42}) + self.assertEqual(restored(SimpleContext(value=0)).value, 42) + + def test_model_validate_rejects_bad_registry_alias(self): + """Typoed registry aliases should not silently pass through model_validate.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def consumer(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + cls = type(consumer(n=1)) + # "not_in_registry" is not a valid int and not a valid registry key + with self.assertRaises(TypeError) as cm: + cls.model_validate({"n": "not_in_registry"}) + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + + def test_context_type_compatible_annotations_accepted(self): + """context_type validation should accept matching or subclass annotations.""" + + # Exact match should work + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_exact(start_date: date, end_date: date) -> str: + return f"{start_date}" + + self.assertIsNotNone(load_exact) + + +# ============================================================================= +# BoundModel Tests +# ============================================================================= + + +class TestBoundModel(TestCase): + """Tests for BoundModel and BoundModel.flow.""" + + def test_bound_model_is_callable_model(self): + """BoundModel should be a proper CallableModel subclass.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + bound = source().flow.with_inputs(x=lambda ctx: ctx.x * 2) + self.assertIsInstance(bound, CallableModel) + + def test_bound_model_flow_compute(self): + """Test that bound.flow.compute() honors transforms.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = my_model(x=10) + + # Create bound model with y transform + bound = model.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 2) + + # flow.compute() should go through BoundModel, applying transform + result = bound.flow.compute(y=5) + # y transform: 5 * 2 = 10, x is bound to 10 + # model: 10 + 10 = 20 + self.assertEqual(result.value, 20) + + def test_bound_model_flow_compute_static_transform(self): + """Test BoundModel.flow.compute() with static value transform.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x * y) + + model = my_model(x=7) + bound = model.flow.with_inputs(y=3) + + result = bound.flow.compute(y=999) # y should be overridden by transform + # y is statically bound to 3, x=7 + # 7 * 3 = 21 + self.assertEqual(result.value, 21) + + def test_bound_model_dump_validate_roundtrip_static(self): + """Static transforms survive model_dump → model_validate roundtrip.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + restored = type(bound).model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(bound(ctx).value, 420) + self.assertEqual(restored(ctx).value, 420) + + def test_bound_model_validate_same_payload_twice(self): + """Validating the same serialized BoundModel payload twice should work both times.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + + r1 = BoundModel.model_validate(dump) + r2 = BoundModel.model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(r1(ctx).value, 420) + self.assertEqual(r2(ctx).value, 420) + + def test_bound_model_failed_validate_does_not_poison_next_construction(self): + """A failed model_validate must not leak static transforms to subsequent constructions.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + base = source() + + # Attempt a model_validate that will fail (invalid model field) + try: + BoundModel.model_validate( + { + "model": "not-a-real-model", + "_static_transforms": {"value": 42}, + "_input_transforms_token": {"value": "42"}, + } + ) + except Exception: + pass # Expected to fail + + # Now construct a fresh BoundModel normally — must NOT inherit stale transforms + clean = BoundModel(model=base, input_transforms={}) + ctx = SimpleContext(value=1) + self.assertEqual(clean(ctx).value, 10) # 1 * 10, no transform applied + + def test_bound_model_cloudpickle_with_lambda_transform(self): + """BoundModel with lambda transforms should survive cloudpickle round-trip.""" + + @Flow.model + def my_model(x: int, y: int) -> int: + return x + y + + bound = my_model(x=10).flow.with_inputs(y=lambda ctx: ctx.y * 2) + restored = rcploads(rcpdumps(bound, protocol=5)) + + self.assertEqual(restored.flow.compute(y=6).value, 22) + + def test_bound_model_as_dependency(self): + """Test that BoundModel can be passed as a dependency to another model.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int]) -> GenericResult[int]: + return GenericResult(value=data + 1) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + + # Pass BoundModel as a dependency + model = consumer(data=bound_src) + result = model.flow.compute(x=5) + # x transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_flow_compute_with_upstream_callable_model_dependency(self): + """flow.compute() should resolve upstream generated-model dependencies.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int], offset: int = 1) -> int: + return data + offset + + model = consumer(data=source(), offset=3) + self.assertEqual(model.flow.compute(x=5).value, 53) + + def test_bound_model_chained_with_inputs(self): + """Test that chaining with_inputs merges transforms correctly.""" + + @Flow.model + def my_model(x: int, y: int, z: int) -> int: + return x + y + z + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 3) + + # Both transforms should be active + result = bound2.flow.compute(x=5, y=10, z=1) + # x transform: 5 * 2 = 10 + # y transform: 10 * 3 = 30 + # z from context: 1 + # 10 + 30 + 1 = 41 + self.assertEqual(result.value, 41) + + def test_bound_model_chained_with_inputs_override(self): + """Test that chaining with_inputs allows overriding transforms.""" + + @Flow.model + def my_model(x: int) -> int: + return x + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 10) + + # Second transform should override the first for 'x' + result = bound2.flow.compute(x=5) + self.assertEqual(result.value, 50) # 5 * 10, not 5 * 2 + + def test_bound_model_with_default_args(self): + """with_inputs works when the model has parameters with default values.""" + + @Flow.model + def load(start_date: str, end_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}-{end_date}" + + # Bind source at construction, leave dates for context + model = load(source="prod_db") + + # with_inputs transforms a context param; default-valued 'source' stays bound + lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") + self.assertEqual(result.value, "prod_db:shifted_2024-01-01-2024-06-30") + + def test_bound_model_with_default_arg_uses_default(self): + """with_inputs should preserve omitted Python defaults as bound config.""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + model = load() + + bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + self.assertEqual(model.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(model.flow.unbound_inputs, {"start_date": str}) + + result = bound.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "warehouse:shifted_2024-01-01") + + def test_bound_model_default_arg_as_dependency(self): + """BoundModel with default args works correctly as a dependency.""" + + @Flow.model + def source(x: int, multiplier: int = 2) -> int: + return x * multiplier + + @Flow.model + def consumer(data: int) -> int: + return data + 1 + + src = source(multiplier=5) + bound_src = src.flow.with_inputs(x=lambda ctx: ctx.x * 10) + model = consumer(data=bound_src) + + result = model.flow.compute(x=3) + # x transform: 3 * 10 = 30 + # source: 30 * 5 (multiplier) = 150 + # consumer: 150 + 1 = 151 + self.assertEqual(result.value, 151) + + def test_bound_model_as_lazy_dependency(self): + """Test that BoundModel works as a Lazy dependency.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 3) + + @Flow.model + def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: + if data > 100: + return GenericResult(value=data) + return GenericResult(value=slow()) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) + 10) + + # Use BoundModel as lazy dependency + model = consumer(data=5, slow=bound_src) + result = model.flow.compute(x=7) + # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 + self.assertEqual(result.value, 51) + + def test_differently_transformed_bound_models_have_distinct_cache_keys(self): + """Two BoundModels with different transforms must not collide under caching.""" + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + b1 = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + b2 = base.flow.with_inputs(value=lambda ctx: ctx.value + 2) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + r1 = b1(ctx) + r2 = b2(ctx) + + # b1 transforms value to 6, source: 6*10=60 + # b2 transforms value to 7, source: 7*10=70 + self.assertEqual(r1.value, 60) + self.assertEqual(r2.value, 70) + # Source called twice (once per distinct transformed context) + self.assertEqual(call_counts["source"], 2) + + def test_bound_and_unbound_models_share_memory_cache(self): + """Shifted and unshifted models should share one evaluator cache. + + They should not share the same cache key when the effective contexts + differ, but repeated evaluations of either model should still hit the + same underlying MemoryCacheEvaluator instance rather than re-executing. + """ + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + shifted = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + + # One execution for the unshifted context and one for the shifted context. + self.assertEqual(call_counts["source"], 2) + # Cache has 3 entries: base(ctx), BoundModel(ctx), and base(shifted_ctx). + # BoundModel is a proper CallableModel now, so it gets its own cache entry. + self.assertEqual(len(evaluator.cache), 3) + + def test_transform_error_propagates(self): + """A buggy transform should raise, not silently fall back to FlowContext.""" + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform has a typo — ctx.sart_date instead of ctx.start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.sart_date) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + with self.assertRaises(AttributeError): + bound(ctx) + + def test_transform_validation_error_propagates(self): + """If transforms produce invalid context data, the error should surface.""" + from pydantic import ValidationError + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform returns a string where a date is expected + bound = model.flow.with_inputs(start_date="not-a-date") + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + # Pydantic validation should raise, not silently fall back to FlowContext + with self.assertRaises(ValidationError): + bound(ctx) + + +class TestFlowModelPipe(TestCase): + """Tests for the ``.pipe(..., param=...)`` convenience API.""" + + def test_pipe_infers_single_required_parameter(self): + """pipe() should infer the only required downstream parameter.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: int, offset: int = 1) -> int: + return data + offset + + pipeline = source().pipe(consumer, offset=3) + self.assertEqual(pipeline.flow.compute(x=5).value, 53) + + def test_pipe_infers_single_defaulted_parameter(self): + """pipe() should fall back to a single defaulted downstream parameter.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int = 0) -> int: + return data + 1 + + pipeline = source().pipe(consumer) + self.assertEqual(pipeline.flow.compute(x=5).value, 51) + + def test_pipe_param_disambiguates_multiple_parameters(self): + """param= should identify the downstream argument to bind.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + pipeline = source().pipe(combine, param="right", left=7) + self.assertEqual(pipeline.flow.compute(x=5).value, 57) + + def test_pipe_rejects_ambiguous_downstream_stage(self): + """pipe() should require param= when multiple targets are available.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + with self.assertRaisesRegex( + TypeError, + r"pipe\(\) could not infer a target parameter for combine; unbound candidates are: left, right", + ): + source().pipe(combine) + + def test_manual_callable_model_can_pipe_into_generated_stage(self): + """Hand-written CallableModels should be usable as pipe sources.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def consumer(data: int, multiplier: int) -> int: + return data * multiplier + + pipeline = ManualModel(offset=5).pipe(consumer, multiplier=2) + self.assertEqual(pipeline.flow.compute(value=10).value, 30) + + def test_bound_model_pipe_preserves_downstream_transforms(self): + """pipe() should keep downstream with_inputs transforms intact.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, scale: int) -> int: + return data + scale + + shifted_source = source().flow.with_inputs(x=lambda ctx: ctx.scale + 1) + scaled_consumer = consumer().flow.with_inputs(scale=lambda ctx: ctx.scale * 3) + + pipeline = shifted_source.pipe(scaled_consumer) + self.assertEqual(pipeline.flow.compute(scale=2).value, 76) + + +# ============================================================================= +# PEP 563 (from __future__ import annotations) Compatibility Tests +# ============================================================================= + +# These functions are defined at module level to simulate realistic usage. +# Note: We can't use `from __future__ import annotations` at module level +# since it would affect ALL annotations in this file. Instead, we test +# that the annotation resolution code handles string annotations. + + +class TestPEP563Annotations(TestCase): + """Test that Flow.model handles string annotations (PEP 563).""" + + def test_string_annotation_lazy_resolved(self): + """Test that Lazy annotations work even when passed through get_type_hints. + + This verifies the fix for from __future__ import annotations by + confirming the annotation resolution pipeline processes Lazy correctly. + """ + # Verify _extract_lazy handles real type objects (resolved by get_type_hints) + from ccflow.flow_model import _extract_lazy + + lazy_int = Lazy[int] + unwrapped, is_lazy = _extract_lazy(lazy_int) + self.assertTrue(is_lazy) + self.assertEqual(unwrapped, int) + + def test_string_annotation_return_type_resolved(self): + """Test that string return type annotations are resolved correctly.""" + + @Flow.model + def model_func(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=42) + + # If annotation resolution works, this should create successfully + model = model_func() + self.assertEqual(model.result_type, GenericResult[int]) + + def test_auto_wrap_with_resolved_annotations(self): + """Test that auto-wrap works with properly resolved type annotations.""" + + @Flow.model + def plain_model(value: int) -> int: + return value * 2 + + model = plain_model() + result = model.flow.compute(value=5) + self.assertEqual(result.value, 10) + self.assertEqual(model.result_type, GenericResult) + + +# ============================================================================= +# Hydra Integration Tests +# ============================================================================= + + +# Define Flow.model functions at module level for Hydra to find them +@Flow.model +def hydra_basic_model(context: SimpleContext, value: int, name: str = "default") -> GenericResult[str]: + """Module-level model for Hydra testing.""" + return GenericResult(value=f"{name}:{context.value + value}") + + +# --- Additional module-level fixtures for Hydra YAML tests --- + + +@Flow.model +def basic_loader(context: SimpleContext, source: str, multiplier: int = 1) -> GenericResult[int]: + """Basic loader that multiplies context value by multiplier.""" + return GenericResult(value=context.value * multiplier) + + +@Flow.model +def string_processor(context: SimpleContext, prefix: str, suffix: str = "") -> GenericResult[str]: + """Process context value into a string with prefix and suffix.""" + return GenericResult(value=f"{prefix}{context.value}{suffix}") + + +@Flow.model +def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: + """Source that provides base data.""" + return GenericResult(value=context.value + base_value) + + +@Flow.model +def data_transformer( + context: SimpleContext, + source: int, + factor: int = 2, +) -> GenericResult[int]: + """Transform data by multiplying with factor.""" + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator( + context: SimpleContext, + input_a: int, + input_b: int, + operation: str = "add", +) -> GenericResult[int]: + """Aggregate two inputs.""" + if operation == "add": + return GenericResult(value=input_a + input_b) + elif operation == "multiply": + return GenericResult(value=input_a * input_b) + else: + return GenericResult(value=input_a - input_b) + + +@Flow.model +def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: + """First stage of pipeline.""" + return GenericResult(value=context.value + initial) + + +@Flow.model +def pipeline_stage2( + context: SimpleContext, + stage1_output: int, + multiplier: int = 2, +) -> GenericResult[int]: + """Second stage of pipeline.""" + return GenericResult(value=stage1_output * multiplier) + + +@Flow.model +def pipeline_stage3( + context: SimpleContext, + stage2_output: int, + offset: int = 0, +) -> GenericResult[int]: + """Third stage of pipeline.""" + return GenericResult(value=stage2_output + offset) + + +@Flow.model +def date_range_loader( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> GenericResult[dict]: + """Load data for a date range.""" + return GenericResult( + value={ + "source": source, + "start_date": str(context.start_date), + "end_date": str(context.end_date), + } + ) + + +@Flow.model +def date_range_loader_previous_day( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> dict: + """Hydra helper that applies a one-day lookback before delegating.""" + shifted = context.model_copy(update={"start_date": context.start_date - timedelta(days=1)}) + return date_range_loader(source=source, include_weekends=include_weekends)(shifted).value + + +@Flow.model +def date_range_processor( + context: DateRangeContext, + raw_data: dict, + normalize: bool = False, +) -> GenericResult[str]: + """Process date range data.""" + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") + + +@Flow.model +def hydra_default_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + """Module-level model with defaults for Hydra testing.""" + return GenericResult(value=context.value + value) + + +@Flow.model +def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: + """Source model for dependency testing.""" + return GenericResult(value=context.value * base) + + +@Flow.model +def hydra_consumer_model( + context: SimpleContext, + source: int, + factor: int = 1, +) -> GenericResult[int]: + """Consumer model for dependency testing.""" + return GenericResult(value=source * factor) + + +# --- context_args fixtures for Hydra testing --- + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: + """Loader using context_args with DateRangeContext.""" + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def context_args_processor( + start_date: date, + end_date: date, + data: dict, + prefix: str = "processed", +) -> GenericResult[str]: + """Processor using context_args with dependency.""" + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") + + +class TestFlowModelHydra(TestCase): + """Tests for Flow.model with Hydra configuration.""" + + def test_hydra_instantiate_basic(self): + """Test that Flow.model factory can be instantiated via Hydra.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create config that references the factory function by module path + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_basic_model", + "value": 100, + "name": "test", + } + ) + + # Instantiate via Hydra + model = instantiate(cfg) + + self.assertIsInstance(model, CallableModel) + result = model(SimpleContext(value=10)) + self.assertEqual(result.value, "test:110") + + def test_hydra_instantiate_with_defaults(self): + """Test Hydra instantiation using default parameter values.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_default_model", + # Not specifying value, should use default + } + ) + + model = instantiate(cfg) + result = model(SimpleContext(value=8)) + self.assertEqual(result.value, 50) + + def test_hydra_instantiate_with_dependency(self): + """Test Hydra instantiation with dependencies.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create nested config + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_consumer_model", + "source": { + "_target_": "ccflow.tests.test_flow_model.hydra_source_model", + "base": 10, + }, + "factor": 2, + } + ) + + model = instantiate(cfg) + + result = model(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 * 2 = 100 + self.assertEqual(result.value, 100) + + +# ============================================================================= +# Lazy[T] Type Annotation Tests +# ============================================================================= + + +class TestLazyTypeAnnotation(TestCase): + """Tests for Lazy[T] type annotation (deferred/conditional evaluation).""" + + def test_lazy_type_annotation_basic(self): + """Lazy[T] param receives a thunk (zero-arg callable). + + The thunk unwraps GenericResult.value, so calling thunk() returns + the inner value (e.g., int), not the GenericResult wrapper. + """ + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # data() returns the unwrapped value (int) + resolved = data() + return GenericResult(value=resolved + 1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_lazy_conditional_evaluation(self): + """Mirror the smart_training example: lazy dep only evaluated if needed. + + Note: Non-lazy CallableModel deps are auto-resolved and their .value is + unwrapped by the framework (auto-detected dep resolution). So 'fast' + receives the unwrapped int, while 'slow' receives a thunk that returns + the unwrapped value (GenericResult.value) when called. + """ + from ccflow import Lazy + + call_counts = {"fast": 0, "slow": 0} + + @Flow.model + def fast_path(context: SimpleContext) -> GenericResult[int]: + call_counts["fast"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def slow_path(context: SimpleContext) -> GenericResult[int]: + call_counts["slow"] += 1 + return GenericResult(value=context.value * 100) + + @Flow.model + def smart_selector( + context: SimpleContext, + fast: GenericResult[int], # Auto-resolved: receives unwrapped int + slow: Lazy[GenericResult[int]], # Lazy: receives thunk returning unwrapped value + threshold: int = 10, + ) -> GenericResult[int]: + # fast is auto-unwrapped to the int value by the framework + if fast > threshold: + return GenericResult(value=fast) + else: + return GenericResult(value=slow()) + + fast = fast_path() + slow = slow_path() + + # Case 1: fast path sufficient (value > threshold) + model = smart_selector(fast=fast, slow=slow, threshold=10) + result = model(SimpleContext(value=20)) + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 0) # Never called! + + # Case 2: fast path insufficient (value <= threshold), slow triggered + call_counts["fast"] = 0 + model2 = smart_selector(fast=fast, slow=slow, threshold=100) + result2 = model2(SimpleContext(value=5)) + self.assertEqual(result2.value, 500) # 5 * 100 + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 1) + + def test_lazy_thunk_caches_result(self): + """Repeated calls to a thunk return the same value without re-evaluation.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Call thunk multiple times — returns the unwrapped int + val1 = data() + val2 = data() + val3 = data() + self.assertEqual(val1, val2) + self.assertEqual(val2, val3) + return GenericResult(value=val1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 50) + self.assertEqual(call_counts["source"], 1) # Called only once despite 3 thunk() calls + + def test_lazy_with_direct_value(self): + """Pre-computed (non-CallableModel) value wrapped in trivial thunk.""" + from ccflow import Lazy + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[int], + ) -> GenericResult[int]: + # data is a thunk even though the underlying value is a plain int + return GenericResult(value=data() * 2) + + model = consumer(data=42) + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, 84) + + def test_lazy_dep_excluded_from_deps(self): + """__deps__ does NOT include lazy dependencies.""" + from ccflow import Lazy + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + return GenericResult(value=eager + lazy_dep()) + + eager = eager_source() + lazy = lazy_source() + model = consumer(eager=eager, lazy_dep=lazy) + + ctx = SimpleContext(value=5) + deps = model.__deps__(ctx) + + # Only eager dep should be in __deps__ + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], eager) + + def test_lazy_eager_dep_still_pre_evaluated(self): + """Non-lazy deps are still eagerly resolved via __deps__.""" + from ccflow import Lazy + + call_counts = {"eager": 0, "lazy": 0} + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + call_counts["eager"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + call_counts["lazy"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + # eager is auto-unwrapped to int, lazy_dep() returns unwrapped value + return GenericResult(value=eager + lazy_dep()) + + model = consumer(eager=eager_source(), lazy_dep=lazy_source()) + result = model(SimpleContext(value=5)) + + self.assertEqual(result.value, 55) # 5 + 50 + self.assertEqual(call_counts["eager"], 1) + self.assertEqual(call_counts["lazy"], 1) + + def test_lazy_in_dynamic_deferred_mode(self): + """Lazy[T] works in dynamic deferred mode (no context_args).""" + from ccflow import FlowContext, Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + value: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + if value > 10: + return GenericResult(value=value) + return GenericResult(value=data()) # data() returns unwrapped int + + # value comes from context, data is bound at construction + model = consumer(data=source()) + result = model(FlowContext(value=20)) # value > 10, lazy not called + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_in_context_args_mode(self): + """Lazy[T] works with explicit context_args.""" + from ccflow import FlowContext, Lazy + + @Flow.model(context_args=["x"]) + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model(context_args=["x"]) + def consumer( + x: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=x + data()) # data() returns unwrapped int + + model = consumer(data=source()) + result = model(FlowContext(x=5)) + self.assertEqual(result.value, 55) # 5 + 50 + + def test_lazy_never_evaluated_if_not_called(self): + """If thunk is never called, the dependency is never evaluated.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Never call data() + return GenericResult(value=42) + + model = consumer(data=source()) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 42) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_with_upstream_model(self): + """Lazy[T] works when bound to an upstream model.""" + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data() + 1) # data() returns unwrapped int + + src = source() + model = consumer(data=src) + + # Lazy dep should NOT be in __deps__ + deps = model.__deps__(SimpleContext(value=5)) + self.assertEqual(len(deps), 0) + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 51) # 50 + 1 + + +# ============================================================================= +# Bug Fix Regression Tests +# ============================================================================= + + +class TestFlowModelBugFixes(TestCase): + """Regression tests for four bugs identified during code review.""" + + # ----- Issue 1: .flow.compute() drops context defaults ----- + + def test_compute_respects_explicit_context_defaults(self): + """Mode 1: compute(x=1) should use ExtendedContext's default y='default'.""" + + @Flow.model + def model_fn(context: ExtendedContext, factor: int = 1) -> str: + return f"{context.x}-{context.y}-{factor}" + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, "1-default-1") + + def test_compute_respects_context_args_defaults(self): + """Mode 2: compute(x=1) should use function default y=42.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, 43) + + def test_unbound_inputs_excludes_context_args_with_defaults(self): + """Mode 2: unbound_inputs should not include context_args that have function defaults.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_unbound_inputs_excludes_context_type_defaults(self): + """Mode 1: unbound_inputs should not include context fields that have defaults.""" + + @Flow.model + def model_fn(context: ExtendedContext) -> str: + return f"{context.x}-{context.y}" + + model = model_fn() + # ExtendedContext has x: int (required) and y: str = "default" + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_context_type_rejects_required_field_with_function_default(self): + """Decoration should fail when function has default but context_type requires the field.""" + + class StrictContext(ContextBase): + x: int # required + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["x"], context_type=StrictContext) + def model_fn(x: int = 5) -> int: + return x + + self.assertIn("x", str(cm.exception)) + self.assertIn("requires", str(cm.exception)) + + def test_context_type_accepts_optional_field_with_function_default(self): + """Both context_type and function have defaults — should work.""" + + class OptionalContext(ContextBase): + x: int = 10 + + @Flow.model(context_args=["x"], context_type=OptionalContext) + def model_fn(x: int = 5) -> int: + return x + + model = model_fn() + result = model(OptionalContext()) + self.assertEqual(result.value, 10) # context default wins + + # ----- Issue 2: Lazy[...] broken in dynamic deferred mode ----- + + def test_lazy_from_runtime_context_in_dynamic_mode(self): + """Lazy[int] provided via FlowContext should be wrapped in a thunk.""" + + @Flow.model + def model_fn(x: int, y: Lazy[int]) -> int: + return x + y() + + model = model_fn(x=10) + result = model(FlowContext(y=32)) + self.assertEqual(result.value, 42) + + def test_callable_model_from_runtime_context_in_dynamic_mode(self): + """CallableModel provided in FlowContext should be resolved.""" + + @Flow.model + def source(value: int) -> int: + return value * 10 + + @Flow.model + def consumer(x: int, data: int) -> int: + return x + data + + model = consumer(x=1) + src = source() + result = model(FlowContext(data=src, value=5)) + # source resolves with value=5 → 50, consumer: 1 + 50 = 51 + self.assertEqual(result.value, 51) + + # ----- Issue 3: FlowContext-backed models skip schema validation ----- + + def test_direct_call_validates_flowcontext_dynamic_mode(self): + """Dynamic mode: FlowContext(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x=1, y="hello")) + + self.assertIn("y", str(cm.exception)) + + def test_direct_call_validates_flowcontext_context_args_mode(self): + """context_args mode: FlowContext(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x="hello")) + + self.assertIn("x", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_dynamic(self): + """Dynamic mode: with_inputs(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn(x=1) + bound = model.flow.with_inputs(y="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("y", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_context_args(self): + """context_args mode: with_inputs(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + bound = model.flow.with_inputs(x="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("x", str(cm.exception)) + + # ----- Issue 4: Registry-name resolution too aggressive for union strings ----- + + def test_registry_resolution_skips_union_str_annotation(self): + """Union[str, int] field with a registry key string should keep the string.""" + from typing import Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[str, int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_optional_str_annotation(self): + """Optional[str] field with a registry key string should keep the string.""" + from typing import Optional + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, label: Optional[str] = None) -> str: + return f"label={label}" + + model = consumer(label="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "label=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_union_annotated_str(self): + """Union[Annotated[str, ...], int] field with a registry key should keep the string.""" + from typing import Annotated, Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + +# ============================================================================= +# Coverage Gap Tests +# ============================================================================= + + +class TestExtractLazyLoopBody(TestCase): + """Group 1: _extract_lazy loop body with non-LazyMarker metadata.""" + + def test_annotated_with_extra_metadata_before_lazy_marker(self): + """Annotated type where _LazyMarker is NOT the first metadata element.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy, _LazyMarker + + # _LazyMarker is the second metadata element — loop must iterate past "other" + ann = Annotated[int, "other_metadata", _LazyMarker()] + base_type, is_lazy = _extract_lazy(ann) + self.assertTrue(is_lazy) + self.assertIs(base_type, int) + + def test_annotated_without_lazy_marker(self): + """Annotated type with no _LazyMarker returns is_lazy=False.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy + + ann = Annotated[int, "just_metadata"] + base_type, is_lazy = _extract_lazy(ann) + self.assertFalse(is_lazy) + + def test_lazy_type_annotation_with_extra_annotated(self): + """End-to-end: Lazy wrapping of an Annotated type.""" + + @Flow.model + def model_with_lazy( + x: int, + dep: Lazy[int], + ) -> int: + return x + dep() + + @Flow.model + def upstream(x: int) -> int: + return x * 10 + + model = model_with_lazy(x=1, dep=upstream()) + result = model.flow.compute(x=1) + self.assertEqual(result.value, 11) + + def test_lazy_dep_returning_custom_result(self): + """Lazy dep returning custom ResultBase (not GenericResult) should return raw result.""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"v={context.value}") + + @Flow.model + def consumer(context: SimpleContext, dep: Lazy[MyResult]) -> GenericResult[str]: + result = dep() + return GenericResult(value=result.data) + + model = consumer(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, "v=42") + + +class TestTransformReprNamedCallable(TestCase): + """Group 2: _transform_repr with a named callable.""" + + def test_named_function_transform_in_repr(self): + """Named functions should appear in BoundModel repr wrapped in angle brackets.""" + from ccflow.flow_model import _transform_repr + + def my_custom_transform(ctx): + return ctx.value + 1 + + result = _transform_repr(my_custom_transform) + self.assertIn("my_custom_transform", result) + self.assertTrue(result.startswith("<")) + self.assertTrue(result.endswith(">")) + + def test_static_value_repr(self): + """Static (non-callable) values should use repr().""" + from ccflow.flow_model import _transform_repr + + self.assertEqual(_transform_repr(42), "42") + self.assertEqual(_transform_repr("hello"), "'hello'") + + +class TestBoundFieldNamesFallback(TestCase): + """Group 3: _bound_field_names fallback for objects without model_fields_set.""" + + def test_fallback_to_bound_fields_attr(self): + from ccflow.flow_model import _bound_field_names + + class FakeModel: + _bound_fields = {"x", "y"} + + result = _bound_field_names(FakeModel()) + self.assertEqual(result, {"x", "y"}) + + def test_fallback_no_attrs(self): + from ccflow.flow_model import _bound_field_names + + class Empty: + pass + + result = _bound_field_names(Empty()) + self.assertEqual(result, set()) + + +class TestRuntimeInputNamesEmpty(TestCase): + """Group 4: _runtime_input_names when all_param_names is empty.""" + + def test_non_flow_model_returns_empty(self): + from ccflow.flow_model import _runtime_input_names + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + self.assertEqual(_runtime_input_names(model), set()) + + +class TestRegistryCandidateAllowed(TestCase): + """Group 5: _registry_candidate_allowed TypeAdapter success path.""" + + def test_non_callable_model_passes_type_check(self): + """Registry value that is not a CallableModel but passes TypeAdapter validation.""" + from ccflow.flow_model import _registry_candidate_allowed + + # int value passes TypeAdapter(int).validate_python + self.assertTrue(_registry_candidate_allowed(int, 42)) + + def test_non_callable_model_fails_type_check(self): + from ccflow.flow_model import _registry_candidate_allowed + + self.assertFalse(_registry_candidate_allowed(int, "not_an_int")) + + +class TestConcreteContextTypeOptional(TestCase): + """Group 6: _concrete_context_type with Optional/Union types.""" + + def test_optional_context_type(self): + """Optional[T] has NoneType that should be skipped to find T.""" + from typing import Optional + + from ccflow.flow_model import _concrete_context_type + + # Optional[SimpleContext] = Union[SimpleContext, None] + # The NoneType arg must be skipped (line 196-197) + result = _concrete_context_type(Optional[SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_with_none_first(self): + """Union[None, T] should skip NoneType and find T.""" + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + # NoneType comes first, must be skipped + result = _concrete_context_type(Union[None, SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_context_type(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[SimpleContext, None]) + self.assertIs(result, SimpleContext) + + def test_union_no_context_base(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[int, str]) + self.assertIsNone(result) + + def test_returns_none_for_non_type(self): + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type("not_a_type") + self.assertIsNone(result) + + +class TestBuildConfigValidatorsException(TestCase): + """Group 7: _build_config_validators when TypeAdapter fails.""" + + def test_unadaptable_type_skipped(self): + """Types that TypeAdapter can't handle should be silently skipped.""" + from ccflow.flow_model import _build_config_validators + + # type(...) (EllipsisType) makes TypeAdapter fail + validatable, validators = _build_config_validators({"x": int, "y": type(...)}) + self.assertIn("x", validatable) + self.assertNotIn("y", validatable) + self.assertIn("x", validators) + self.assertNotIn("y", validators) + + +class TestCoerceContextValueNoValidator(TestCase): + """Group 8: _coerce_context_value early return for fields without validators.""" + + def test_field_without_validator_passes_through(self): + from ccflow.flow_model import _coerce_context_value + + # When name is not in validators, value should pass through unchanged + result = _coerce_context_value("unknown_field", 42, {}, {}) + self.assertEqual(result, 42) + + +class TestGeneratedModelClassFactoryPath(TestCase): + """Group 9: _generated_model_class when stage has no generated model.""" + + def test_returns_none_for_plain_callable(self): + from ccflow.flow_model import _generated_model_class + + def plain_func(): + pass + + self.assertIsNone(_generated_model_class(plain_func)) + + +class TestDescribePipeStagePaths(TestCase): + """Group 10: _describe_pipe_stage for different stage types.""" + + def test_generated_model_instance(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def my_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(my_stage()) + self.assertIn("my_stage", desc) + + def test_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def factory_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(factory_stage) + self.assertIn("factory_stage", desc) + + def test_non_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + desc = _describe_pipe_stage(42) + self.assertEqual(desc, "42") + + +class TestInferPipeParamAmbiguousDefaults(TestCase): + """Cover _infer_pipe_param fallback path with multiple defaulted candidates.""" + + def test_ambiguous_defaulted_candidates(self): + """When all candidates have defaults but multiple are unoccupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(a: int = 1, b: int = 2) -> int: + return a + b + + # Both a and b have defaults, both are unoccupied -> ambiguous + with self.assertRaisesRegex(TypeError, "could not infer a target parameter"): + source().pipe(consumer) + + +class TestPipeErrorPaths(TestCase): + """Group 11: pipe() error paths not covered by existing tests.""" + + def test_pipe_non_callable_model_source(self): + """pipe() should reject non-CallableModel source.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) source must be a CallableModel"): + pipe_model("not_a_model", consumer) + + def test_pipe_non_flow_model_target(self): + """pipe() should reject non-@Flow.model target.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualTarget(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream stages"): + pipe_model(source(), ManualTarget()) + + def test_pipe_invalid_param_name(self): + """pipe() should reject invalid target parameter names.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "is not valid for"): + source().pipe(consumer, param="nonexistent") + + def test_pipe_already_bound_param(self): + """pipe() should reject already-bound parameters.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "is already bound"): + source().pipe(model, param="data") + + def test_pipe_no_available_target_parameter(self): + """pipe() should error when all downstream params are occupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "could not find an available target parameter"): + source().pipe(model) + + def test_pipe_into_generated_instance_rebuilds(self): + """pipe() into an existing generated model instance should rebuild.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, extra: int = 1) -> int: + return data + extra + + instance = consumer(extra=5) + pipeline = source().pipe(instance) + result = pipeline.flow.compute(x=3) + self.assertEqual(result.value, 35) # 3*10 + 5 + + def test_pipe_bound_model_wrapping_non_generated_rejects(self): + """pipe() into BoundModel wrapping a non-generated model should fail.""" + from ccflow.flow_model import BoundModel, pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualModel(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=ManualModel(), input_transforms={"value": 42}) + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream"): + pipe_model(source(), bound) + + +class TestFlowAPIBuildContextFallback(TestCase): + """Group 12: FlowAPI._build_context when _context_schema is None/unset.""" + + def test_unbound_inputs_on_manual_callable_model(self): + """Manual CallableModel with context should show required fields.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + unbound = model.flow.unbound_inputs + self.assertIn("value", unbound) + + +class TestBoundModelRestoreNonDict(TestCase): + """Group 13: BoundModel._restore_serialized_transforms non-dict path.""" + + def test_restore_from_model_instance(self): + """model_validate from an existing BoundModel instance (non-dict).""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + # Pass existing instance through model_validate (non-dict path) + restored = BoundModel.model_validate(bound) + ctx = SimpleContext(value=1) + self.assertEqual(restored(ctx).value, 420) + + +class TestBoundModelInitEmptyTransforms(TestCase): + """Group 14: BoundModel.__init__ with no transforms.""" + + def test_init_without_transforms(self): + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=source()) + self.assertEqual(bound._input_transforms, {}) + result = bound(SimpleContext(value=5)) + self.assertEqual(result.value, 5) + + +class TestBoundModelDeps(TestCase): + """Group 15: BoundModel.__deps__.""" + + def test_deps_returns_wrapped_model(self): + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + deps = bound.__deps__(SimpleContext(value=1)) + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], bound.model) + + +class TestValidateFieldTypesAfterValidator(TestCase): + """Group 16: _validate_field_types in the model_validate path.""" + + def test_model_validate_rejects_wrong_type(self): + """model_validate should reject wrong scalar types.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + cls = type(source(x=5)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "not_an_int"}) + + +class TestGetContextValidatorPaths(TestCase): + """Group 17: _get_context_validator fallback paths.""" + + def test_mode2_context_validator_from_schema(self): + """Mode 2 model should build validator from _context_schema.""" + + @Flow.model(context_args=["start_date"]) + def loader(start_date: str, source: str = "db") -> str: + return f"{source}:{start_date}" + + model = loader() + # Trigger validator creation by calling flow.compute + result = model.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "db:2024-01-01") + + def test_mode1_context_validator_uses_context_type_directly(self): + """Mode 1 should use TypeAdapter(context_type) directly.""" + + @Flow.model + def model_fn(context: SimpleContext, offset: int = 0) -> GenericResult[int]: + return GenericResult(value=context.value + offset) + + model = model_fn() + # compute with SimpleContext fields + result = model.flow.compute(value=5) + self.assertEqual(result.value, 5) + + +class TestValidateContextTypeOverrideErrors(TestCase): + """Group 18: _validate_context_type_override error paths.""" + + def test_non_context_base_raises(self): + with self.assertRaisesRegex(TypeError, "context_type must be a ContextBase subclass"): + + @Flow.model(context_args=["x"], context_type=int) + def bad_model(x: int) -> int: + return x + + def test_context_type_missing_context_args_fields(self): + """context_type missing required context_args fields.""" + + class TinyContext(ContextBase): + a: int + + with self.assertRaisesRegex(TypeError, "must define fields for context_args"): + + @Flow.model(context_args=["a", "b"], context_type=TinyContext) + def bad_model(a: int, b: int) -> int: + return a + b + + def test_context_type_extra_required_fields(self): + """context_type has required fields not listed in context_args.""" + + class BigContext(ContextBase): + a: int + b: int + extra: str + + with self.assertRaisesRegex(TypeError, "has required fields not listed in context_args"): + + @Flow.model(context_args=["a"], context_type=BigContext) + def bad_model(a: int) -> int: + return a + + def test_annotation_type_mismatch(self): + """Function and context_type disagree on annotation type.""" + + class TypedContext(ContextBase): + x: str + + with self.assertRaisesRegex(TypeError, "context_arg 'x'"): + + @Flow.model(context_args=["x"], context_type=TypedContext) + def bad_model(x: int) -> int: + return x + + def test_annotation_skip_when_func_ann_is_none(self): + """Annotation check should skip when function annotation is absent from schema.""" + from ccflow.flow_model import _validate_context_type_override + + class CompatContext(ContextBase): + a: int + + # context_args has 'a', schema has 'a': int. Compatible, no error. + result = _validate_context_type_override(CompatContext, ["a"], {"a": int}) + self.assertIs(result, CompatContext) + + def test_subclass_annotations_allowed(self): + """context_type with subclass-compatible annotations should pass.""" + from ccflow.flow_model import _validate_context_type_override + + class ContextWithBase(ContextBase): + ctx: ContextBase + + # Function declares SimpleContext which is a subclass of ContextBase — should pass + result = _validate_context_type_override(ContextWithBase, ["ctx"], {"ctx": SimpleContext}) + self.assertIs(result, ContextWithBase) + + def test_default_vs_required_field_conflict(self): + """Function has default for context_arg but context_type requires it.""" + + class StrictContext(ContextBase): + x: int + + with self.assertRaisesRegex(TypeError, "function has a default but context_type"): + + @Flow.model(context_args=["x"], context_type=StrictContext) + def bad_model(x: int = 5) -> int: + return x + + +class TestDecoratorErrorPaths(TestCase): + """Group 19: Decorator error paths.""" + + def test_context_type_with_explicit_context_param(self): + """context_type= with explicit context param should raise.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + def test_context_type_without_context_args(self): + """context_type= without context_args should raise in dynamic mode.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(x: int) -> int: + return x + + def test_missing_context_annotation(self): + """Missing type annotation on context param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context) -> int: + return 0 + + def test_missing_param_annotation(self): + """Missing type annotation on a model field param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context: SimpleContext, untyped_param) -> int: + return 0 + + def test_context_param_not_context_base(self): + """context param annotated with non-ContextBase type should raise.""" + with self.assertRaisesRegex(TypeError, "must be annotated with a ContextBase subclass"): + + @Flow.model + def bad_model(context: int) -> int: + return 0 + + def test_pep563_fallback_on_failed_get_type_hints(self): + """When get_type_hints fails, falls back to raw annotations.""" + + # This is hard to trigger directly, but we can test that string annotations work + @Flow.model + def model_with_string_return(x: int) -> "int": + return x * 2 + + result = model_with_string_return().flow.compute(x=5) + self.assertEqual(result.value, 10) + + +class TestMode1CallPath(TestCase): + """Group 20: Mode 1 explicit context pass-through in __call__.""" + + def test_mode1_resolve_callable_model_returns_non_generic_result(self): + """Mode 1 should handle deps that return raw ResultBase (not GenericResult).""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[str]: + # dep is resolved to MyResult since it's not GenericResult + return GenericResult(value=f"got:{dep}") + + model = downstream(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertIn("value=42", result.value) + + +class TestDynamicModeContextLookup(TestCase): + """Group 21: Dynamic mode context lookup for deferred values.""" + + def test_deferred_value_from_context(self): + """Dynamic mode should pull deferred values from context.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + # y is deferred — pulled from context + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_missing_deferred_value_raises(self): + """Dynamic mode should raise for missing deferred values.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + with self.assertRaisesRegex(TypeError, "Missing runtime input"): + model.flow.compute() # y not provided + + def test_context_sourced_value_coercion(self): + """Dynamic mode should coerce context-sourced values through validators.""" + + @Flow.model + def typed_model(x: int, y: int) -> int: + return x + y + + model = typed_model(x=10) + # y provided as a value that can be coerced to int + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_deferred_value_from_context_object(self): + """Dynamic mode should look up deferred values from context attributes.""" + + @Flow.model + def multiply(x: int, y: int) -> int: + return x * y + + model = multiply(x=3) + # Call directly with a FlowContext — y must come from context + result = model(FlowContext(y=7)) + self.assertEqual(result.value, 21) + + +class TestGetContextValidatorFallbacks(TestCase): + """Group 17 additional: _get_context_validator edge cases.""" + + def test_mode2_with_context_type_override(self): + """Mode 2 with explicit context_type should use that type's validator.""" + + @Flow.model(context_args=["value"], context_type=SimpleContext) + def typed_model(value: int) -> int: + return value * 2 + + model = typed_model() + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 10) + + def test_dynamic_mode_instance_validator(self): + """Dynamic mode should create instance-specific validator.""" + + @Flow.model + def add(x: int, y: int, z: int = 0) -> int: + return x + y + z + + m1 = add(x=1) + m2 = add(x=1, y=2) + # Different bound fields => different runtime inputs + self.assertIn("y", m1.flow.unbound_inputs) + self.assertNotIn("y", m2.flow.unbound_inputs) + + +class TestRegistryResolutionInValidateFieldTypes(TestCase): + """Group 16: _resolve_registry_refs and _validate_field_types paths.""" + + def test_registry_string_not_resolving_passes_through(self): + """String value that doesn't resolve from registry should fail type validation.""" + + @Flow.model + def model_fn(x: int) -> int: + return x + + cls = type(model_fn(x=1)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "nonexistent_registry_key"}) + + def test_registry_ref_resolves_to_callable_model(self): + """String value resolving to a CallableModel should be substituted.""" + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def upstream(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[int]: + return GenericResult(value=0) + + registry.add("my_upstream", upstream()) + cls = type(downstream(dep=upstream())) + restored = cls.model_validate({"dep": "my_upstream"}) + self.assertIsNotNone(restored) + finally: + registry.clear() + + +class TestMode2MissingContextField(TestCase): + """Line 1155: Mode 2 missing context field error.""" + + def test_mode2_missing_required_context_field(self): + """Mode 2 model called with context missing a required field should raise.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def loader(start_date: str, end_date: str, source: str = "db") -> str: + return f"{source}:{start_date}-{end_date}" + + model = loader() + # Call with a FlowContext missing end_date + with self.assertRaisesRegex(TypeError, "Missing context field"): + model(FlowContext(start_date="2024-01-01")) + + +class TestDynamicModeContextObjectLookup(TestCase): + """Line 1155/1176: Dynamic mode pulling deferred values from context object.""" + + def test_deferred_value_coercion_through_context(self): + """Dynamic mode should coerce values from FlowContext through validators.""" + + @Flow.model + def typed_add(x: int, y: int) -> int: + return x + y + + model = typed_add(x=10) + # Calling with a FlowContext — y pulled from context and coerced + result = model(FlowContext(y=5)) + self.assertEqual(result.value, 15) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..28f2883 --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,441 @@ +"""Hydra integration tests for Flow.model. + +These tests verify that Flow.model decorated functions work correctly when +loaded from YAML configuration files using ModelRegistry.load_config_from_path(). + +Key feature: Registry name references (e.g., `source: flow_source`) ensure the same +object instance is shared across all consumers. +""" + +from datetime import date +from pathlib import Path +from unittest import TestCase + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, GenericResult, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +class TestFlowModelHydraYAML(TestCase): + """Tests loading Flow.model from YAML config files using ModelRegistry.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_basic_loader_from_yaml(self): + """Test basic model instantiation from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + + self.assertIsInstance(loader, CallableModel) + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 50) # 10 * 5 + + def test_string_processor_from_yaml(self): + """Test string processor model from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_processor"] + + ctx = SimpleContext(value=42) + result = processor(ctx) + self.assertEqual(result.value, "value=42!") + + def test_two_stage_pipeline_from_yaml(self): + """Test two-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + self.assertIsInstance(transformer, CallableModel) + + ctx = SimpleContext(value=5) + result = transformer(ctx) + # flow_source: 5 + 100 = 105 + # flow_transformer: 105 * 3 = 315 + self.assertEqual(result.value, 315) + + def test_three_stage_pipeline_from_yaml(self): + """Test three-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage3 = r["flow_stage3"] + + ctx = SimpleContext(value=10) + result = stage3(ctx) + # stage1: 10 + 10 = 20 + # stage2: 20 * 2 = 40 + # stage3: 40 + 50 = 90 + self.assertEqual(result.value, 90) + + def test_diamond_dependency_from_yaml(self): + """Test diamond dependency pattern from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + aggregator = r["diamond_aggregator"] + + ctx = SimpleContext(value=10) + result = aggregator(ctx) + # source: 10 + 10 = 20 + # branch_a: 20 * 2 = 40 + # branch_b: 20 * 5 = 100 + # aggregator: 40 + 100 = 140 + self.assertEqual(result.value, 140) + + def test_date_range_pipeline_from_yaml(self): + """Test DateRangeContext pipeline with transforms from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # The transform extends start_date back by one day + self.assertIn("2024-01-09", result.value) + self.assertIn("normalized:", result.value) + + def test_context_args_from_yaml(self): + """Test context_args model from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + + self.assertIsInstance(loader, CallableModel) + # context_args models use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual( + result.value, + { + "source": "data_source", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + }, + ) + + def test_context_args_pipeline_from_yaml(self): + """Test context_args pipeline with dependencies from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["ctx_args_processor"] + + ctx = DateRangeContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + result = processor(ctx) + # loader: "data_source:2024-03-01 to 2024-03-31" + # processor: "output:data_source:2024-03-01 to 2024-03-31" + self.assertEqual(result.value, "output:data_source:2024-03-01 to 2024-03-31") + + def test_context_args_shares_instance(self): + """Test that context_args pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + processor = r["ctx_args_processor"] + + self.assertIs(processor.data, loader) + + +class TestFlowModelHydraInstanceSharing(TestCase): + """Tests that registry name references share the same object instance.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_pipeline_shares_instance(self): + """Test that pipeline stages share the same dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + source = r["flow_source"] + + self.assertIs(transformer.source, source) + + def test_three_stage_pipeline_shares_instances(self): + """Test that three-stage pipeline shares instances correctly.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage1 = r["flow_stage1"] + stage2 = r["flow_stage2"] + stage3 = r["flow_stage3"] + + self.assertIs(stage2.stage1_output, stage1) + self.assertIs(stage3.stage2_output, stage2) + + def test_diamond_pattern_shares_source_instance(self): + """Test that diamond pattern branches share the same source instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + source = r["diamond_source"] + branch_a = r["diamond_branch_a"] + branch_b = r["diamond_branch_b"] + aggregator = r["diamond_aggregator"] + + # Both branches should share the SAME source instance + self.assertIs(branch_a.source, source) + self.assertIs(branch_b.source, source) + self.assertIs(branch_a.source, branch_b.source) + + self.assertIs(aggregator.input_a, branch_a) + self.assertIs(aggregator.input_b, branch_b) + + def test_date_range_shares_instance(self): + """Test that date range pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_date_loader"] + processor = r["flow_date_processor"] + + self.assertIs(processor.raw_data, loader) + + +class TestFlowModelHydraOmegaConf(TestCase): + """Tests using OmegaConf.create for dynamic config creation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_instantiate_with_omegaconf(self): + """Test instantiation using OmegaConf.create via ModelRegistry.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=3) + result = loader(ctx) + self.assertEqual(result.value, 21) # 3 * 7 + + def test_nested_deps_with_omegaconf(self): + """Test nested dependencies using OmegaConf with registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 50, + }, + "transformer": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 4, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + transformer = r["transformer"] + + ctx = SimpleContext(value=10) + result = transformer(ctx) + # source: 10 + 50 = 60 + # transformer: 60 * 4 = 240 + self.assertEqual(result.value, 240) + + self.assertIs(transformer.source, r["source"]) + + def test_diamond_with_omegaconf(self): + """Test diamond pattern with OmegaConf using registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 10, + }, + "branch_a": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 2, + }, + "branch_b": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 3, + }, + "aggregator": { + "_target_": "ccflow.tests.test_flow_model.data_aggregator", + "input_a": "branch_a", + "input_b": "branch_b", + "operation": "multiply", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + aggregator = r["aggregator"] + + ctx = SimpleContext(value=5) + result = aggregator(ctx) + # source: 5 + 10 = 15 + # branch_a: 15 * 2 = 30 + # branch_b: 15 * 3 = 45 + # aggregator: 30 * 45 = 1350 + self.assertEqual(result.value, 1350) + + # Verify SAME source instance is shared + self.assertIs(r["branch_a"].source, r["source"]) + self.assertIs(r["branch_b"].source, r["source"]) + + +class TestFlowModelHydraDefaults(TestCase): + """Tests that default parameter values work with Hydra.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_defaults_used_when_not_specified(self): + """Test that default values are used when not in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 10) # 10 * 1 (default) + + def test_defaults_can_be_overridden(self): + """Test that defaults can be overridden in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + "multiplier": 100, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 1000) # 10 * 100 + + +class TestFlowModelHydraModelProperties(TestCase): + """Tests that model properties are correct after Hydra instantiation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_context_type_property(self): + """Test that context_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.context_type, SimpleContext) + + def test_result_type_property(self): + """Test that result_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.result_type, GenericResult[int]) + + def test_deps_method_works(self): + """Test that __deps__ method works after Hydra instantiation.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + ctx = SimpleContext(value=5) + deps = transformer.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIsInstance(deps[0][0], CallableModel) + self.assertEqual(deps[0][1], [ctx]) + self.assertIs(deps[0][0], r["flow_source"]) + + +class TestFlowModelHydraDateRangeTransforms(TestCase): + """Tests transforms with DateRangeContext from Hydra config.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_transform_applied_from_yaml(self): + """Test that transform is applied when loaded from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + deps = processor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + dep_model, dep_contexts = deps[0] + + self.assertIs(dep_model, r["flow_date_loader"]) + self.assertEqual(dep_contexts[0], ctx) + self.assertEqual(dep_model(ctx).value["start_date"], "2024-01-09") + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 586b03f..dc2db55 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1235,7 +1235,7 @@ class TestCreateCcflowModelCloudpickleCrossProcess: id="context_only", ), pytest.param( - # Dynamic context with CallableModel + # Runtime-created context with CallableModel """ from ray.cloudpickle import dump from ccflow import CallableModel, ContextBase, GenericResult, Flow diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md new file mode 100644 index 0000000..7b6ac9f --- /dev/null +++ b/docs/design/flow_model_design.md @@ -0,0 +1,270 @@ +# Flow.model Design + +## Overview + +`@Flow.model` turns a plain Python function into a real `CallableModel`. + +The core goals are: + +- keep the authoring model close to an ordinary function, +- preserve the existing evaluator / registry / serialization machinery, +- make deferred execution explicit with `.flow.compute(...)` and `.flow.with_inputs(...)`, +- allow callers to pass either literal values or upstream models for ordinary parameters. + +`@Flow.model` is syntactic sugar over the existing ccflow framework. The +generated object is still a standard `CallableModel`, so you can execute it the +same way as any other model by calling it with a context object. The +`.flow.compute(...)` helper is an explicit, ergonomic way to mark the deferred +execution boundary when supplying runtime inputs as keyword arguments. + +## Core Patterns + +### Default Deferred Style + +This is the most ergonomic mode. Bind some parameters up front, then provide +the remaining runtime inputs later. + +```python +from ccflow import Flow, FlowContext + + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + + +model = add(x=10) + +# Explicit deferred entry point +assert model.flow.compute(y=5).value == 15 + +# Standard CallableModel call path +assert model(FlowContext(y=5)).value == 15 + +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5).value == 20 +``` + +In this mode: + +- bound parameters are model configuration, +- unbound parameters become runtime inputs for that model instance. + +### Explicit Context Parameter + +```python +from ccflow import DateRangeContext, Flow + + +@Flow.model +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 +``` + +This is the most direct mode. The function receives a normal context object and +returns either a `ResultBase` subclass or a plain value. Plain values are +wrapped into `GenericResult` automatically by the generated model. + +### `context_args` + +```python +from datetime import date + +from ccflow import Flow + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + return 125.0 +``` + +This keeps the function signature focused on the inputs it actually uses while +still producing a `CallableModel` that accepts a context at runtime. + +Use `context_args` when certain parameters are semantically the execution +context and you want that split to be explicit and stable across model +instances. + +By default, `context_args` models use `FlowContext`. If you want compatibility +with an existing context class, pass `context_type=...` explicitly. + +### Upstream Models as Normal Arguments + +Any non-context parameter can be given either: + +- a literal value, or +- another `CallableModel` / `BoundModel`. + +If a model is passed, it is evaluated with the current context and its result is +unwrapped before the function is called. + +```python +from ccflow import DateRangeContext, Flow + + +@Flow.model +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 + + +@Flow.model +def double_revenue(_: DateRangeContext, revenue: float) -> float: + return revenue * 2 + + +revenue = load_revenue(region="us") +model = double_revenue(revenue=revenue) +result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") +``` + +This is the main composition story for the core API. + +### `.flow.with_inputs(...)` + +`with_inputs` is how a caller rewires context locally for one upstream model. + +```python +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } + + +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) + +model = revenue_growth(current=current, previous=previous) +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +direct = model(ctx) +computed = model.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +assert direct == computed +``` + +The transform is local to the bound upstream model. The parent model continues +to receive the original context. + +### `.flow.compute(...)` + +`compute` is the ergonomic entry point for deferred execution: + +```python +from ccflow import Flow + + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + + +model = add(x=10) +assert model.flow.compute(y=5).value == 15 +``` + +It validates the supplied keyword arguments against the generated context +schema, creates a `FlowContext`, and executes the model. + +It returns the same result object you would get from calling `model(context)`. + +It is not the only execution path. Because the generated object is still a +standard `CallableModel`, calling `model(context)` remains fully supported. + +## Lazy Inputs + +`Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an +upstream model, the generated model passes a zero-argument thunk. The thunk +caches its first result. Lazy dependencies are excluded from the `__deps__` +graph, so they are not pre-evaluated by the evaluator infrastructure. + +```python +from ccflow import Flow, Lazy + + +@Flow.model +def source(value: int) -> int: + return value * 10 + + +@Flow.model +def maybe_use_source(value: int, data: Lazy[int]) -> int: + if value > 10: + return value + return data() +``` + +## FlowContext + +`FlowContext` is the universal frozen carrier for generated contexts that do +not map to a dedicated built-in context type. + +The implementation stays intentionally small: + +- context validation is driven by `TypedDict` + `TypeAdapter`, +- runtime execution uses one reusable `FlowContext` type, +- public pydantic iteration (`dict(context)`) is used instead of pydantic + internals. + +## BoundModel + +`.flow.with_inputs(...)` returns a `BoundModel`, which is just a thin wrapper +around: + +- the original model, and +- a mapping of input transforms. + +At call time it: + +1. converts the incoming context into a plain dictionary, +1. applies the configured transforms, +1. rebuilds a `FlowContext`, +1. delegates to the wrapped model. + +That keeps transformed dependency wiring explicit without adding special +annotation machinery to the core API. + +## Flow.call with `auto_context` + +Separately from `@Flow.model`, `Flow.call(auto_context=...)` provides a similar +convenience for class-based `CallableModel`s. Instead of defining a separate +`ContextBase` subclass, the decorator generates one from the function's +keyword-only parameters. + +```python +from ccflow import CallableModel, Flow, GenericResult + + +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") +``` + +Passing a `ContextBase` subclass (e.g., `auto_context=DateContext`) makes the +generated context inherit from that class, so it remains compatible with +infrastructure that expects the parent type. + +The generated class is registered via `create_ccflow_model` for serialization +support. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..5fb27de 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,313 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +### Flow.model Decorator + +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s +using plain Python functions instead of classes. It automatically generates a +standard `CallableModel` class with proper `__call__` and `__deps__` methods, +so it still uses the normal ccflow framework for evaluation, caching, +serialization, and registry loading. + +If a `@Flow.model` function returns a plain value instead of a `ResultBase` +subclass, the generated model automatically wraps it in `GenericResult` at +runtime so it still behaves like a normal `CallableModel`. + +You can execute a generated model in two equivalent ways: + +- call it directly with a context object: `model(ctx)` +- use `.flow.compute(...)` to supply runtime inputs as keyword arguments + +`.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred +execution point. + +#### Context Modes + +There are three ways to define how a `@Flow.model` function receives its +runtime context. + +**Mode 1 — Explicit context parameter:** + +The function takes a `context` parameter (or `_` if unused) annotated with a +`ContextBase` subclass. This is the most direct mode and behaves like a +traditional `CallableModel.__call__`. + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + +loader = load_data(source="my_database") + +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +**Mode 2 — Unpacked context with `context_args`:** + +Instead of receiving a context object, you list which parameters should come +from the context at runtime. The remaining parameters are model configuration. + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +loader = load_data(source="my_database") + +# Opt in explicitly when you want compatibility with an existing context type +assert loader.context_type == DateRangeContext + +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +By default, `context_args` models use `FlowContext`, a universal frozen carrier +for the validated fields. If you want the generated model to advertise and +accept an existing `ContextBase` subclass, pass `context_type=...` explicitly. + +Use `context_args` when some parameters are semantically "the execution +context" and you want that split to stay stable and explicit: + +- the runtime context should be stable across instances +- the split between config and runtime inputs matters semantically +- the model is naturally "run over a context" such as date windows, + partitions, or scenarios +- you want the generated model to accept a specific existing context type + such as `DateRangeContext` + +**Mode 3 — Dynamic deferred style (no explicit context):** + +When there is no `context` parameter and no `context_args`, all parameters are +potential configuration or runtime inputs. Parameters provided at construction +are bound (configuration); everything else comes from the context at runtime. + +```python +from ccflow import Flow + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + +model = add(x=10) + +# `x` is bound when the model is created. +# `y` is supplied later at execution time. +assert model.flow.compute(y=5).value == 15 + +# `.flow.with_inputs(...)` rewrites runtime inputs for this call path. +doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert doubled_y.flow.compute(y=5).value == 20 +``` + +#### Composing Dependencies + +Any non-context parameter can be bound either to a literal value or to another +`CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with +the current context and passes the resolved value into your function. + +```python +from datetime import date, timedelta +from ccflow import DateRangeContext, Flow + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_growth( + start_date: date, + end_date: date, + current: float, + previous: float, +) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } + +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) +growth = revenue_growth(current=current, previous=previous) + +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +# Standard ccflow execution +direct = growth(ctx) + +# Equivalent explicit deferred entry point +computed = growth.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +assert direct == computed +``` + +#### Deferred Execution Helpers + +**`.flow.compute(**kwargs)`** validates the keyword arguments against the +generated context schema, wraps them in a `FlowContext`, and calls the model. +It returns the same result object you would get from `model(context)`. + +**`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies +context transforms before delegating to the underlying model. Each transform +is either a static value or a `(ctx) -> value` callable. Transforms are local +to the wrapped model — upstream models never see them. + +```python +from ccflow import Flow, FlowContext + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + +model = add(x=10) +assert model.flow.compute(y=5).value == 15 + +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5).value == 20 + +# You can also call with a context object directly +ctx = FlowContext(y=5) +assert model(ctx).value == 15 +assert shifted(ctx).value == 20 +``` + +#### Lazy Dependencies with `Lazy[T]` + +Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly +resolving the upstream model, the generated model passes a zero-argument thunk +that evaluates on first call and caches the result. The thunk unwraps +`GenericResult` automatically, so `T` should be the inner value type. + +```python +from ccflow import ContextBase, Flow, GenericResult, Lazy + +class SimpleContext(ContextBase): + value: int + +@Flow.model +def fast_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + +@Flow.model +def slow_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 100) + +@Flow.model +def smart_selector( + context: SimpleContext, + fast: int, # Eagerly resolved and unwrapped + slow: Lazy[int], # Deferred — receives a thunk returning unwrapped int + threshold: int = 10, +) -> GenericResult[int]: + if fast > threshold: + return GenericResult(value=fast) + return GenericResult(value=slow()) # Evaluated only when called + +model = smart_selector( + fast=fast_path(), + slow=slow_path(), + threshold=10, +) +``` + +`Lazy` dependencies are excluded from the model's `__deps__` graph, so they +are not pre-evaluated by the evaluator infrastructure. + +#### Decorator Options + +`@Flow.model(...)` accepts the same options as `Flow.call` to control execution +behavior: + +- `cacheable` — enable caching of results +- `volatile` — mark as volatile (always re-execute) +- `log_level` — logging verbosity +- `validate_result` — validate return type +- `verbose` — verbose logging output +- `evaluator` — custom evaluator + +When not explicitly set, these inherit from any active `FlowOptionsOverride`. + +#### Hydra / YAML Configuration + +`@Flow.model` decorated functions work seamlessly with Hydra configuration and +the `ModelRegistry`: + +```yaml +# config.yaml +data: + _target_: mymodule.load_data + source: my_database + +transformed: + _target_: mymodule.transform_data + raw_data: data # Reference by registry name (same instance is shared) + +aggregated: + _target_: mymodule.aggregate_data + transformed: transformed # Reference by registry name +``` + +```python +from ccflow import ModelRegistry + +registry = ModelRegistry.root() +registry.load_config_from_path("config.yaml") + +# References by name ensure the same object instance is shared +model = registry["aggregated"] +``` + +### Flow.call with `auto_context` + +For class-based `CallableModel`s, `Flow.call(auto_context=...)` provides a +similar convenience. Instead of defining a separate `ContextBase` subclass, the +decorator generates one from the function's keyword-only parameters. + +```python +from ccflow import CallableModel, Flow, GenericResult + +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + +model = MyModel() +result = model(x=42, y="hello") +assert result.value == "42-hello" +``` + +You can also pass a parent context class so the generated context inherits +from it: + +```python +from datetime import date +from ccflow import CallableModel, DateContext, Flow, GenericResult + +class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> GenericResult: + return GenericResult(value=date.day + extra) +``` + +The generated context class is a proper `ContextBase` subclass, so it works +with all existing evaluator and registry infrastructure. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/examples/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..5579a5c --- /dev/null +++ b/examples/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,24 @@ +# Hydra config for examples/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +current_revenue: + _target_: examples.flow_model_hydra_builder_demo.load_revenue + region: us + +week_over_week: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: week_over_week + +month_over_month: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: month_over_month diff --git a/examples/evaluator_demo.py b/examples/evaluator_demo.py new file mode 100644 index 0000000..a85b087 --- /dev/null +++ b/examples/evaluator_demo.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +""" +Evaluator Demo: Caching & Execution Strategies +=============================================== + +Shows how to change execution behavior (caching, graph evaluation, logging) +WITHOUT changing user code. The same @Flow.model functions work with any +evaluator stack — you just configure it at the top level. + +Key insight: "default lazy" is an evaluator concern, not a wiring concern. +Users write plain functions and wire them by passing outputs as inputs. +The evaluator layer controls how they execute. + +Demonstrates: + 1. Default execution (eager, no caching) — diamond dep calls load twice + 2. MemoryCacheEvaluator — deduplicates shared deps in a diamond + 3. GraphEvaluator + Cache — topological evaluation + deduplication + 4. LoggingEvaluator — adds tracing around every model call + 5. Per-model opt-out — @Flow.model(cacheable=False) overrides global + +Run with: python examples/evaluator_demo.py +""" + +from __future__ import annotations + +import logging +import sys + +# Suppress default debug logging from ccflow evaluators for clean demo output +logging.disable(logging.DEBUG) + +from ccflow import Flow, FlowOptionsOverride # noqa: E402 +from ccflow.evaluators.common import ( # noqa: E402 + GraphEvaluator, + LoggingEvaluator, + MemoryCacheEvaluator, + MultiEvaluator, +) + +# ============================================================================= +# Plain @Flow.model functions — no evaluator concerns in the code +# ============================================================================= + +call_counts: dict[str, int] = {} + + +def _track(name: str) -> None: + call_counts[name] = call_counts.get(name, 0) + 1 + + +@Flow.model +def load_data(x: int, source: str = "warehouse") -> list: + """Load raw data. Expensive — we want to avoid calling this twice.""" + _track("load_data") + return [x, x * 2, x * 3] + + +@Flow.model +def compute_sum(data: list) -> int: + """Branch A: sum the data.""" + _track("compute_sum") + return sum(data) + + +@Flow.model +def compute_max(data: list) -> int: + """Branch B: max of the data.""" + _track("compute_max") + return max(data) + + +@Flow.model +def combine(sum_result: int, max_result: int) -> dict: + """Combine results from both branches.""" + _track("combine") + return {"sum": sum_result, "max": max_result, "total": sum_result + max_result} + + +@Flow.model(cacheable=False) +def volatile_timestamp(seed: int) -> str: + """Explicitly non-cacheable — always re-executes even with global caching.""" + _track("volatile_timestamp") + from datetime import datetime + + return datetime.now().isoformat() + + +# ============================================================================= +# Wire the pipeline — diamond dependency on load_data +# +# load_data ──┬── compute_sum ──┐ +# └── compute_max ──┴── combine +# ============================================================================= + +shared = load_data(source="prod") +branch_a = compute_sum(data=shared) +branch_b = compute_max(data=shared) +pipeline = combine(sum_result=branch_a, max_result=branch_b) + + +def run() -> dict: + call_counts.clear() + result = pipeline.flow.compute(x=5) + loads = call_counts.get("load_data", 0) + print(f" Result: {result.value}") + print(f" load_data called: {loads}x | total model calls: {sum(call_counts.values())}") + return result.value + + +# ============================================================================= +# Demo 1: Default — no evaluator +# ============================================================================= + +print("=" * 70) +print("1. Default (eager, no caching)") +print(" load_data is called TWICE — once per branch") +print("=" * 70) +run() + +# ============================================================================= +# Demo 2: MemoryCacheEvaluator — deduplicates shared deps +# ============================================================================= + +print() +print("=" * 70) +print("2. MemoryCacheEvaluator (global override)") +print(" load_data is called ONCE — second branch hits cache") +print("=" * 70) +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + run() + +# ============================================================================= +# Demo 3: Cache + GraphEvaluator — topological order + deduplication +# ============================================================================= + +print() +print("=" * 70) +print("3. GraphEvaluator + MemoryCacheEvaluator") +print(" Evaluates in dependency order: load_data → branches → combine") +print("=" * 70) +evaluator = MultiEvaluator(evaluators=[MemoryCacheEvaluator(), GraphEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# ============================================================================= +# Demo 4: Logging — trace every model call +# ============================================================================= + +print() +print("=" * 70) +print("4. LoggingEvaluator + MemoryCacheEvaluator") +print(" Adds timing/tracing around every evaluation") +print("=" * 70) + +# Re-enable logging for this demo (use stdout so log lines interleave with print correctly) +logging.disable(logging.NOTSET) +logging.basicConfig(level=logging.INFO, format=" LOG: %(message)s", stream=sys.stdout) + +evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(log_level=logging.INFO), MemoryCacheEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# Suppress again for clean output +logging.disable(logging.DEBUG) +logging.getLogger().handlers.clear() + +# ============================================================================= +# Demo 5: Per-model opt-out — cacheable=False overrides global +# ============================================================================= + +print() +print("=" * 70) +print("5. Per-model opt-out: @Flow.model(cacheable=False)") +print(" volatile_timestamp always re-executes despite global cacheable=True") +print("=" * 70) + +ts = volatile_timestamp(seed=0) + +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + call_counts.clear() + r1 = ts.flow.compute(seed=0) + r2 = ts.flow.compute(seed=0) + print(f" Call 1: {r1.value}") + print(f" Call 2: {r2.value}") + print(f" volatile_timestamp called: {call_counts.get('volatile_timestamp', 0)}x") + print(f" (Same result? {r1.value == r2.value} — called twice, timestamps may differ)") diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..27d5d0e --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +"""Canonical Flow.model example. + +This is the main `@Flow.model` story: + +1. define workflow steps as plain Python functions, +2. wire them together by passing upstream models as normal arguments, +3. use a small Python builder for reusable composition, +4. execute either as a normal CallableModel or via `.flow.compute(...)`. + +Run with: + python examples/flow_model_example.py +""" + +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Return synthetic revenue for one reporting window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + label: str, + days_back: int, +) -> dict: + """Compare the current window against a shifted previous window.""" + previous_start = start_date - timedelta(days=days_back) + previous_end = end_date - timedelta(days=days_back) + growth_pct = round((current - previous) / previous * 100, 2) + return { + "comparison": label, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "growth_pct": growth_pct, + } + + +def shifted_window(model, *, days_back: int): + """Reuse one upstream model with a shifted runtime window.""" + return model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=days_back), + end_date=lambda ctx: ctx.end_date - timedelta(days=days_back), + ) + + +def build_week_over_week_pipeline(region: str): + """Build one reusable pipeline from plain Flow.model functions.""" + current = load_revenue(region=region) + previous = shifted_window(current, days_back=7) + return revenue_change( + current=current, + previous=previous, + label="week_over_week", + days_back=7, + ) + + +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_week_over_week_pipeline(region="us") + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + direct = pipeline(ctx) + computed = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ) + + print("\nPipeline wired from plain functions:") + print(" current input:", pipeline.current) + print(" previous input:", pipeline.previous) + + print("\nDirect call and .flow.compute(...) are equivalent:") + print(f" direct == computed: {direct == computed}") + + print("\nResult:") + for key, value in computed.value.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..00c9571 --- /dev/null +++ b/examples/flow_model_hydra_builder_demo.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python examples/flow_model_hydra_builder_demo.py +""" + +from calendar import monthrange +from datetime import date, timedelta +from pathlib import Path +from typing import Literal, Protocol, cast + +from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FlowAPI, GenericResult, ModelRegistry +from typing_extensions import TypedDict + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" +ComparisonName = Literal["week_over_week", "month_over_month"] + + +class RevenueChangeResult(TypedDict): + comparison: ComparisonName + current_window: str + previous_window: str + current: float + previous: float + delta: float + growth_pct: float + + +class RevenueChangeModel(Protocol): + flow: FlowAPI + + def __call__(self, context: DateRangeContext) -> GenericResult[RevenueChangeResult]: ... + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Return synthetic revenue for a date window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0, "apac": 920.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + comparison: ComparisonName, +) -> RevenueChangeResult: + """Compare the current window against a shifted previous window.""" + growth = (current - previous) / previous + previous_start, previous_end = comparison_window(start_date, end_date, comparison) + return { + "comparison": comparison, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "delta": round(current - previous, 2), + "growth_pct": round(growth * 100, 2), + } + + +def comparison_window(start_date: date, end_date: date, comparison: ComparisonName) -> tuple[date, date]: + """Return the previous window for a named comparison policy.""" + if comparison == "week_over_week": + return start_date - timedelta(days=7), end_date - timedelta(days=7) + + if start_date.day != 1: + raise ValueError("month_over_month requires start_date to be the first day of a month") + if start_date.year != end_date.year or start_date.month != end_date.month: + raise ValueError("month_over_month requires the current window to stay within one calendar month") + expected_end = date(end_date.year, end_date.month, monthrange(end_date.year, end_date.month)[1]) + if end_date != expected_end: + raise ValueError("month_over_month requires end_date to be the last day of that month") + + previous_year = start_date.year if start_date.month > 1 else start_date.year - 1 + previous_month = start_date.month - 1 if start_date.month > 1 else 12 + previous_start = date(previous_year, previous_month, 1) + previous_end = date(previous_year, previous_month, monthrange(previous_year, previous_month)[1]) + return previous_start, previous_end + + +def comparison_input(model: CallableModel, comparison: ComparisonName) -> BoundModel: + """Apply a named comparison policy to one dependency.""" + return model.flow.with_inputs( + start_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[0], + end_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[1], + ) + + +def build_comparison(current: CallableModel, *, comparison: ComparisonName) -> RevenueChangeModel: + """Hydra-friendly builder that returns a configured comparison model.""" + previous = comparison_input(current, comparison) + return revenue_change( + current=current, + previous=previous, + comparison=comparison, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + week_over_week = cast(RevenueChangeModel, registry["week_over_week"]) + month_over_month = cast(RevenueChangeModel, registry["month_over_month"]) + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + print(" current_revenue:", registry["current_revenue"]) + print(" week_over_week:", week_over_week) + print(" month_over_month:", month_over_month) + + week_over_week_result = cast( + RevenueChangeResult, + week_over_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value, + ) + month_over_month_result = month_over_month(ctx).value + + print("\nWeek-over-week:") + for key, value in week_over_week_result.items(): + print(f" {key}: {value}") + + print("\nMonth-over-month:") + for key, value in month_over_month_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main() diff --git a/examples/ml_pipeline_demo.py b/examples/ml_pipeline_demo.py new file mode 100644 index 0000000..0c8920f --- /dev/null +++ b/examples/ml_pipeline_demo.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +""" +ML Pipeline Demo: Smart Model Selection +======================================== + +This is the example from the original design conversation — a realistic ML +pipeline that demonstrates how Flow.model lets you write plain functions, +wire them by passing outputs as inputs, and execute with .flow.compute(). + +Features demonstrated: + 1. @Flow.model with auto-wrap (plain return types, no GenericResult needed) + 2. Lazy[T] for conditional evaluation (skip slow model if fast is good enough) + 3. .flow.compute() for execution with automatic context propagation + 4. .flow.with_inputs() for context transforms (lookback windows) + 5. Factored wiring — build_pipeline() shows how to reuse the same graph + structure with different data sources + +The pipeline: + + load_dataset ──> prepare_features ──> train_linear ──> evaluate ──> fast_metrics ──┐ + └──> train_forest ──> evaluate ──> slow_metrics ──┴──> smart_training + +Run with: python examples/ml_pipeline_demo.py +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, timedelta +from math import sin + +from ccflow import Flow, Lazy + + +# ============================================================================= +# Domain types (stand-ins for real ML objects) +# ============================================================================= + + +@dataclass +class PreparedData: + """Container for train/test split data.""" + + X_train: list # list of feature vectors + X_test: list + y_train: list # list of target values + y_test: list + + +@dataclass +class TrainedModel: + """A fitted model (placeholder).""" + + name: str + coefficients: list + intercept: float + augment: bool # Whether to add sin feature during prediction + + +@dataclass +class Metrics: + """Evaluation metrics.""" + + r2: float + mse: float + model_name: str + + +# ============================================================================= +# Data Loading +# ============================================================================= + + +@Flow.model +def load_dataset(start_date: date, end_date: date, source: str = "warehouse") -> list: + """Load raw dataset for a date range. + + Returns a list of dicts (standing in for a DataFrame). + Auto-wrapped: returns plain list, framework wraps in GenericResult. + """ + n_days = (end_date - start_date).days + 1 + print(f" [load_dataset] Loading {n_days} days from '{source}' ({start_date} to {end_date})") + # True relationship: target = 2.0 * x + 10.0 + 15.0 * sin(x * 0.2) + # Linear model captures the trend (R^2 ~0.93), forest also captures the sin wave (~0.99) + return [ + { + "date": str(start_date + timedelta(days=i)), + "x": float(i), + "target": 2.0 * i + 10.0 + 15.0 * sin(i * 0.2), + } + for i in range(n_days) + ] + + +# ============================================================================= +# Feature Engineering +# ============================================================================= + + +@Flow.model +def prepare_features(raw_data: list) -> PreparedData: + """Split data into train/test. + + Returns a PreparedData dataclass — the framework auto-wraps it in GenericResult. + Downstream models can request individual fields via prepared["X_train"] etc. + """ + n = len(raw_data) + split = int(n * 0.8) + print(f" [prepare_features] {n} rows, split at {split}") + + X = [[r["x"]] for r in raw_data] + y = [r["target"] for r in raw_data] + + return PreparedData( + X_train=X[:split], + X_test=X[split:], + y_train=y[:split], + y_test=y[split:], + ) + + +# ============================================================================= +# Model Training +# ============================================================================= + + +def _ols_fit(X, y): + """Simple OLS: compute coefficients and intercept.""" + n = len(X) + n_feat = len(X[0]) + y_mean = sum(y) / n + x_means = [sum(row[j] for row in X) / n for j in range(n_feat)] + + coefficients = [] + for j in range(n_feat): + cov = sum((X[i][j] - x_means[j]) * (y[i] - y_mean) for i in range(n)) / n + var = sum((X[i][j] - x_means[j]) ** 2 for i in range(n)) / n + coefficients.append(cov / var if var > 1e-10 else 0.0) + + intercept = y_mean - sum(c * m for c, m in zip(coefficients, x_means)) + return coefficients, intercept + + +def _augment(X): + """Add sin(x*0.2) feature to capture non-linearity.""" + return [row + [sin(row[0] * 0.2)] for row in X] + + +@Flow.model +def train_linear(prepared: PreparedData) -> TrainedModel: + """Train a fast linear model (linear features only).""" + print(f" [train_linear] Fitting on {len(prepared.X_train)} samples") + coefficients, intercept = _ols_fit(prepared.X_train, prepared.y_train) + return TrainedModel(name="LinearRegression", coefficients=coefficients, intercept=intercept, augment=False) + + +@Flow.model +def train_forest(prepared: PreparedData, n_estimators: int = 100) -> TrainedModel: + """Train a model that also captures non-linear patterns (simulated).""" + print(f" [train_forest] Fitting {n_estimators} trees on {len(prepared.X_train)} samples") + # Augment with sin feature to capture non-linearity + X_aug = _augment(prepared.X_train) + coefficients, intercept = _ols_fit(X_aug, prepared.y_train) + return TrainedModel( + name=f"RandomForest(n={n_estimators})", + coefficients=coefficients, + intercept=intercept, + augment=True, + ) + + +# ============================================================================= +# Model Evaluation +# ============================================================================= + + +@Flow.model +def evaluate_model(model: TrainedModel, prepared: PreparedData) -> Metrics: + """Evaluate a trained model on test data.""" + X_test = prepared.X_test + y_test = prepared.y_test + X_eval = _augment(X_test) if model.augment else X_test + + y_pred = [ + model.intercept + sum(c * x for c, x in zip(model.coefficients, row)) + for row in X_eval + ] + + y_mean = sum(y_test) / len(y_test) if y_test else 0 + ss_tot = sum((y - y_mean) ** 2 for y in y_test) or 1 + ss_res = sum((yt - yp) ** 2 for yt, yp in zip(y_test, y_pred)) + r2 = 1.0 - ss_res / ss_tot + mse = ss_res / len(y_test) if y_test else 0 + + print(f" [evaluate_model] {model.name}: R^2={r2:.4f}, MSE={mse:.2f}") + return Metrics(r2=r2, mse=mse, model_name=model.name) + + +# ============================================================================= +# Smart Pipeline with Conditional Execution +# ============================================================================= + + +@Flow.model +def smart_training( + # data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # Only evaluated if fast isn't good enough + threshold: float = 0.9, +) -> Metrics: + """Use fast model if good enough, else fall back to slow. + + The slow_metrics parameter is Lazy — it receives a zero-arg thunk. + If the fast model exceeds the threshold, the slow model is never + trained or evaluated at all. + """ + print(f" [smart_training] Fast R^2={fast_metrics.r2:.4f}, threshold={threshold}") + if fast_metrics.r2 >= threshold: + print(" [smart_training] Fast model is good enough! Skipping slow model.") + return fast_metrics + else: + print(" [smart_training] Fast model below threshold, evaluating slow model...") + return slow_metrics() + + +# ============================================================================= +# Pipeline Wiring Helper +# ============================================================================= + + +def build_pipeline(raw, *, n_estimators=200, threshold=0.95): + """Wire a complete train/evaluate/select pipeline from a data source. + + This function shows the flexibility of the approach: the same wiring + logic can be applied to different data sources (raw, lookback_raw, etc.) + without duplicating code. Everything here is just wiring — no computation + happens until .flow.compute() is called. + + Args: + raw: A CallableModel or BoundModel that produces raw data (list of dicts) + n_estimators: Number of trees for the forest model + threshold: R^2 threshold for the fast/slow model selection + + Returns: + A smart_training model instance ready for .flow.compute() + """ + # Feature engineering — returns a PreparedData with X_train, X_test, etc. + prepared = prepare_features(raw_data=raw) + + # Train both models — each receives the whole PreparedData and extracts + # the fields it needs internally. + linear = train_linear(prepared=prepared) + forest = train_forest(prepared=prepared, n_estimators=n_estimators) + + # Evaluate both + linear_metrics = evaluate_model(model=linear, prepared=prepared) + forest_metrics = evaluate_model(model=forest, prepared=prepared) + + # Smart selection with Lazy — forest is only evaluated if linear isn't good enough + return smart_training( + fast_metrics=linear_metrics, + slow_metrics=forest_metrics, + threshold=threshold, + ) + + +# ============================================================================= +# Main: Wire and execute the pipeline +# ============================================================================= + + +def main(): + print("=" * 70) + print("ML Pipeline Demo: Smart Model Selection with Flow.model") + print("=" * 70) + + # ------------------------------------------------------------------ + # Step 1: Wire the pipeline (no computation happens here) + # ------------------------------------------------------------------ + print("\n--- Wiring the pipeline (lazy, no computation yet) ---\n") + + raw = load_dataset(source="prod_warehouse") + + # build_pipeline factors out the repeated wiring logic. + # Linear R^2 ≈ 0.93. Threshold is 0.95 → falls through to forest. + pipeline = build_pipeline(raw, n_estimators=200, threshold=0.95) + + print("Pipeline wired. No functions have been called yet.") + + # ------------------------------------------------------------------ + # Step 2: Execute — linear not good enough, falls back to forest + # ------------------------------------------------------------------ + print("\n--- Executing pipeline (Jan-Jun 2024) ---\n") + result = pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + + print(f"\n Best model: {result.value.model_name}") + print(f" R^2: {result.value.r2:.4f}") + print(f" MSE: {result.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 3: Context transforms (lookback) — reuse build_pipeline + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("With Lookback: Same pipeline structure, extra history for loading") + print("=" * 70) + + # flow.with_inputs() creates a BoundModel that transforms the context + # before calling the underlying model. start_date is shifted 30 days earlier. + lookback_raw = raw.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30) + ) + + # Same wiring logic, different data source — no duplication. + lookback_pipeline = build_pipeline(lookback_raw, n_estimators=200, threshold=0.95) + + print("\n--- Executing lookback pipeline ---\n") + result2 = lookback_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + # Notice: load_dataset gets start_date=2023-12-02 (30 days earlier) + + print(f"\n Best model: {result2.value.model_name}") + print(f" R^2: {result2.value.r2:.4f}") + print(f" MSE: {result2.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 4: Lower threshold — linear is good enough, skip forest + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("Lazy Evaluation: Lower threshold so fast model is good enough") + print("=" * 70) + + # With threshold=0.80, the linear model's R^2 (~0.93) passes. + # The forest is NEVER trained or evaluated — Lazy skips it entirely. + fast_pipeline = build_pipeline(raw, n_estimators=200, threshold=0.80) + + print("\n--- Executing (slow model should NOT be trained) ---\n") + result3 = fast_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + print(f"\n Selected: {result3.value.model_name} (R^2={result3.value.r2:.4f})") + print(" (Notice: train_forest and its evaluate_model were never called)") + + +if __name__ == "__main__": + main()