From de87bb26551e8e4feba8c2ad3dbff99f701446e4 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 03:56:58 -0500 Subject: [PATCH 01/17] Add ability to define dynamic context from kwargs Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 165 ++++++++++++-- ccflow/tests/test_callable.py | 381 +++++++++++++++++++++++++++++++++ ccflow/tests/test_evaluator.py | 70 +++++- 3 files changed, 599 insertions(+), 17 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..b6580c9 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -12,10 +12,11 @@ """ import abc +import inspect import logging -from functools import lru_cache, wraps +from functools import lru_cache, partial, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -27,6 +28,7 @@ ResultBase, ResultType, ) +from .local_persistence import create_ccflow_model from .validators import str_to_log_level __all__ = ( @@ -44,6 +46,7 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + "dynamic_context", ) log = logging.getLogger(__name__) @@ -268,14 +271,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( - get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) - ): - raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not ( - get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type)) + + # Check if this is a dynamic_context decorated method + has_dynamic_context = hasattr(fn, "__dynamic_context__") + if has_dynamic_context: + method_context_type = fn.__dynamic_context__ + else: + method_context_type = model.context_type + + # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) + if not has_dynamic_context: + if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( + get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) + ): + raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") + + # Validate result type - use __result_type__ for dynamic contexts if available + if has_dynamic_context and hasattr(fn, "__result_type__"): + method_result_type = fn.__result_type__ + else: + method_result_type = model.result_type + if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not ( + get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type)) ): - raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase") + raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase") + if self._deps and fn.__name__ != "__deps__": raise ValueError("Can only apply Flow.deps decorator to __deps__") if context is Signature.empty: @@ -285,18 +305,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = context = kwargs else: raise TypeError( - f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it" + f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it" ) elif kwargs: # Kwargs passed in as well as context. Not allowed raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'") # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message - if not isinstance(context, model.context_type): - if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type): - model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0] + if not isinstance(context, method_context_type): + if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type): + coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0] else: - model_context_type = model.context_type - context = model_context_type.model_validate(context) + coerce_context_type = method_context_type + context = coerce_context_type.model_validate(context) if fn != getattr(model.__class__, fn.__name__).__wrapped__: # This happens when super().__call__ is used when implementing a CallableModel that derives from another one. @@ -313,6 +333,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_evaluator = self.get_evaluator wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context + + # Preserve dynamic context attributes for introspection + if hasattr(fn, "__dynamic_context__"): + wrap.__dynamic_context__ = fn.__dynamic_context__ + if hasattr(fn, "__result_type__"): + wrap.__result_type__ = fn.__result_type__ + return wrap @@ -417,6 +444,49 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def dynamic_call(*args, **kwargs): + """Decorator for methods that creates a dynamic context from the function signature. + + This combines @Flow.call and @dynamic_context into a single decorator, allowing + you to define the context inline in the function signature instead of creating + a separate context class. + + Example: + class MyModel(CallableModel): + @Flow.dynamic_call + def __call__(self, *, a: int, b: str = "default") -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + model = MyModel() + model(a=42) # Works with kwargs + model(a=42, b="test") # Also works + + Args: + *args: When used without arguments, the decorated function + **kwargs: FlowOptions parameters (log_level, verbose, validate_result, etc.) + plus dynamic_context options: + - parent: Optional parent context class to inherit from + """ + # Import here to avoid circular import at module level + from ccflow.callable import dynamic_context + + # Extract dynamic_context-specific options + parent = kwargs.pop("parent", None) + + if len(args) == 1 and callable(args[0]): + # No arguments to decorator (@Flow.dynamic_call) + fn = args[0] + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(wrapped) + else: + # Arguments to decorator (@Flow.dynamic_call(...)) + def decorator(fn): + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(**kwargs)(wrapped) + + return decorator + # ***************************************************************************** # Define "Evaluators" and associated types @@ -754,3 +824,68 @@ def _validate_callable_model_generic_type(cls, m, handler, info): CallableModelGenericType = CallableModelGeneric + + +# ***************************************************************************** +# Dynamic Context Decorator +# ***************************************************************************** + + +def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: + """Decorator that creates a dynamic context class from function parameters. + + This decorator extracts the parameters from a function signature and creates + a dynamic ContextBase subclass whose fields correspond to those parameters. + The decorated function is then wrapped to accept the context object and + unpack it into keyword arguments. + + Example: + class MyCallable(CallableModel): + @Flow.dynamic_call # or @Flow.call @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = MyCallable() + model(x=42, y="hello") # Works with kwargs + """ + if func is None: + return partial(dynamic_context, parent=parent) + + sig = signature(func) + base_class = parent or ContextBase + + # Validate parent fields are in function signature + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters.keys()) - {"self"} + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + # Build fields from parameters (skip 'self'), pydantic validates types + fields = {} + for name, param in sig.parameters.items(): + if name == "self": + continue + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Create dynamic context class + dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in fields} + return func(self, **fn_kwargs) + + # Must set __signature__ so CallableModel validation sees 'context' parameter + wrapper.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + ], + return_annotation=sig.return_annotation, + ) + wrapper.__dynamic_context__ = dyn_context + wrapper.__result_type__ = sig.return_annotation + return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..444d496 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,6 +20,7 @@ ResultBase, ResultType, WrapperModel, + dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -783,3 +784,383 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +# ============================================================================= +# Tests for dynamic_context decorator +# ============================================================================= + + +class TestDynamicContext(TestCase): + """Tests for the @dynamic_context decorator.""" + + def test_basic_usage_with_kwargs(self): + """Test basic dynamic_context usage with keyword arguments.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Call with kwargs + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + # Call with default + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attribute(self): + """Test that __dynamic_context__ attribute is set.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int, b: str) -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + # The __call__ method should have __dynamic_context__ + call_method = DynamicCallable.__call__ + self.assertTrue(hasattr(call_method, "__wrapped__")) + # Access the inner function's __dynamic_context__ + inner = call_method.__wrapped__ + self.assertTrue(hasattr(inner, "__dynamic_context__")) + + dyn_ctx = inner.__dynamic_context__ + self.assertTrue(issubclass(dyn_ctx, ContextBase)) + self.assertIn("a", dyn_ctx.model_fields) + self.assertIn("b", dyn_ctx.model_fields) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context is registered for serialization.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = DynamicCallable.__call__.__wrapped__ + dyn_ctx = inner.__dynamic_context__ + + # Should have __ccflow_import_path__ set + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + """Test calling with a context object instead of kwargs.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Get the dynamic context class + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Create a context object + ctx = dyn_ctx(x=99, y="context") + result = model(ctx) + self.assertEqual(result.value, "99-context") + + def test_with_parent_context(self): + """Test dynamic_context with parent context class.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + # Get dynamic context + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Should have both fields + self.assertIn("base_value", dyn_ctx.model_fields) + self.assertIn("x", dyn_ctx.model_fields) + + # Create context with parent field + ctx = dyn_ctx(x=42, base_value="custom") + self.assertEqual(ctx.base_value, "custom") + self.assertEqual(ctx.x, 42) + + def test_parent_fields_must_be_in_signature(self): + """Test that parent fields must be included in function signature.""" + + class ParentContext(ContextBase): + required_field: str + + with self.assertRaises(TypeError) as cm: + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + self.assertIn("required_field", str(cm.exception)) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip for dynamic context callable.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + + # Test roundtrip + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task_execution(self): + """Test dynamic context callable in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) # (10 + 2) * 5 + + def test_multiple_dynamic_context_methods(self): + """Test callable with multiple dynamic_context decorated methods.""" + + class MultiMethodCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int) -> GenericResult: + return GenericResult(value=a) + + @dynamic_context + def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: + return GenericResult(value=f"{b}-{c}") + + model = MultiMethodCallable() + + # Test __call__ + result1 = model(a=42) + self.assertEqual(result1.value, 42) + + # Test other_method (without Flow.call, just the dynamic_context wrapper) + # Need to create the context manually + other_ctx = model.other_method.__dynamic_context__ + ctx = other_ctx(b="hello", c=2.5) + result2 = model.other_method(ctx) + self.assertEqual(result2.value, "hello-2.5") + + def test_context_type_property_works(self): + """Test that type_ property works on the dynamic context.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + ctx = dyn_ctx(x=42) + + # type_ should work and be importable + type_path = str(ctx.type_) + self.assertIn("_Local_", type_path) + self.assertEqual(ctx.type_.object, dyn_ctx) + + def test_complex_field_types(self): + """Test dynamic_context with complex field types.""" + from typing import List, Optional + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__( + self, + *, + items: List[int], + name: Optional[str] = None, + count: int = 0, + ) -> GenericResult: + total = sum(items) + count + return GenericResult(value=f"{name}:{total}" if name else str(total)) + + model = DynamicCallable() + + result = model(items=[1, 2, 3], name="test", count=10) + self.assertEqual(result.value, "test:16") + + result = model(items=[5, 5]) + self.assertEqual(result.value, "10") + + +class TestFlowDynamicCall(TestCase): + """Tests for @Flow.dynamic_call decorator.""" + + def test_basic_usage(self): + """Test basic @Flow.dynamic_call usage.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attributes_preserved(self): + """Test that __dynamic_context__ and __result_type__ are directly accessible.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + # Should be directly accessible without traversing __wrapped__ chain + method = DynamicCallable.__call__ + self.assertTrue(hasattr(method, "__dynamic_context__")) + self.assertTrue(hasattr(method, "__result_type__")) + self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) + self.assertEqual(method.__result_type__, GenericResult) + + def test_model_result_type_property(self): + """Test that model.result_type returns correct type for dynamic contexts.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + self.assertEqual(model.result_type, GenericResult) + + def test_with_parent_context(self): + """Test @Flow.dynamic_call with parent context.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + model = DynamicCallable() + + # Get dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Call should work, uses parent default + result = model(x=42, base_value="custom") + self.assertEqual(result.value, "42-custom") + + def test_with_flow_options(self): + """Test @Flow.dynamic_call with FlowOptions parameters.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(validate_result=False) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + result = model(x=42) + self.assertEqual(result.value, 42) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip with @Flow.dynamic_call.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task(self): + """Test @Flow.dynamic_call in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context from @Flow.dynamic_call is registered.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + # Find dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + +def _find_dynamic_context(func): + """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" + visited = set() + current = func + while current is not None and id(current) not in visited: + visited.add(id(current)) + if hasattr(current, "__dynamic_context__"): + return current.__dynamic_context__ + current = getattr(current, "__wrapped__", None) + return None diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..34f3f7e 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyDynamicDateCallable(CallableModel): + """Dynamic context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.dynamic_call(parent=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyDynamicDateCallable], + ids=["standard", "dynamic"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and dynamic context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 From 95119e335d80a3edb4c2e255eef21adb4331be7a Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 23:19:39 -0500 Subject: [PATCH 02/17] Remove dynamic context, add option to Flow.call Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 186 +++++++-------- ccflow/tests/test_callable.py | 318 +++++++------------------ ccflow/tests/test_evaluator.py | 12 +- ccflow/tests/test_local_persistence.py | 2 +- 4 files changed, 180 insertions(+), 338 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 9b971c7..748759c 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,7 +14,7 @@ import abc import inspect import logging -from functools import lru_cache, partial, wraps +from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -46,7 +46,6 @@ "EvaluatorBase", "Evaluator", "WrapperModel", - "dynamic_context", ) log = logging.getLogger(__name__) @@ -272,22 +271,22 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - # Check if this is a dynamic_context decorated method - has_dynamic_context = hasattr(fn, "__dynamic_context__") - if has_dynamic_context: - method_context_type = fn.__dynamic_context__ + # Check if this is an auto_context decorated method + has_auto_context = hasattr(fn, "__auto_context__") + if has_auto_context: + method_context_type = fn.__auto_context__ else: method_context_type = model.context_type - # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) - if not has_dynamic_context: + # Validate context type (skip for auto contexts which are always valid ContextBase subclasses) + if not has_auto_context: if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) ): raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - # Validate result type - use __result_type__ for dynamic contexts if available - if has_dynamic_context and hasattr(fn, "__result_type__"): + # Validate result type - use __result_type__ for auto contexts if available + if has_auto_context and hasattr(fn, "__result_type__"): method_result_type = fn.__result_type__ else: method_result_type = model.result_type @@ -334,9 +333,9 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context - # Preserve dynamic context attributes for introspection - if hasattr(fn, "__dynamic_context__"): - wrap.__dynamic_context__ = fn.__dynamic_context__ + # Preserve auto context attributes for introspection + if hasattr(fn, "__auto_context__"): + wrap.__auto_context__ = fn.__auto_context__ if hasattr(fn, "__result_type__"): wrap.__result_type__ = fn.__result_type__ @@ -418,7 +417,58 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Args: + auto_context: Controls automatic context class generation from the function + signature. Accepts three types of values: + - False (default): No auto-generation, use traditional context parameter + - True: Auto-generate context class with no parent + - ContextBase subclass: Auto-generate context class inheriting from this parent + **kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result, + cacheable, evaluator, volatile). + + Basic Example: + class MyModel(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> MyResult: + return MyResult(value=context.x) + + Auto Context Example: + class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> MyResult: + return MyResult(value=f"{x}-{y}") + + model = MyModel() + model(x=42) # Call with kwargs directly + + With Parent Context: + class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> MyResult: + return MyResult(value=date.day + extra) + + # The generated context inherits from DateContext, so it's compatible + # with infrastructure expecting DateContext instances. + """ + # Extract auto_context option (not part of FlowOptions) + # Can be: False, True, or a ContextBase subclass + auto_context = kwargs.pop("auto_context", False) + + # Determine if auto_context is enabled and extract parent class if provided + if auto_context is False: + auto_context_enabled = False + context_parent = None + elif auto_context is True: + auto_context_enabled = True + context_parent = None + elif isclass(auto_context) and issubclass(auto_context, ContextBase): + auto_context_enabled = True + context_parent = auto_context + else: + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -427,6 +477,14 @@ def call(*args, **kwargs): else: # Arguments to decorator, this is just returning the decorator # Note that the code below is executed only once + if auto_context_enabled: + # Return a decorator that first applies auto_context, then FlowOptions + def auto_context_decorator(fn): + wrapped = _apply_auto_context(fn, parent=context_parent) + # FlowOptions.__call__ already applies wraps, so we just return its result + return FlowOptions(**kwargs)(wrapped) + + return auto_context_decorator return FlowOptions(**kwargs) @staticmethod @@ -444,81 +502,6 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) - @staticmethod - def dynamic_call(*args, **kwargs): - """Decorator that combines @Flow.call with dynamic context creation. - - Instead of defining a separate context class, this decorator creates one - automatically from the function signature. The method can then be called - with keyword arguments directly. - - Basic Example: - class MyModel(CallableModel): - @Flow.dynamic_call - def __call__(self, *, date: date, region: str = "US") -> MyResult: - return MyResult(value=f"{date}-{region}") - - model = MyModel() - model(date=date.today()) # Uses default region="US" - model(date=date.today(), region="EU") # Override default - - With Parent Context: - class MyModel(CallableModel): - @Flow.dynamic_call(parent=DateContext) - def __call__(self, *, date: date, extra: int = 0) -> MyResult: - return MyResult(value=date.day + extra) - - # Parent fields (date) must be included in the function signature. - # This is useful for integrating with existing infrastructure that - # expects specific context types. - - Args: - *args: The decorated function when used without parentheses - **kwargs: Combined options for FlowOptions and dynamic_context: - - Dynamic context options: - parent: Parent context class to inherit from. All parent fields - must appear in the function signature. - - FlowOptions (passed through to @Flow.call): - log_level: Logging level for evaluation (default: DEBUG) - verbose: Use verbose logging (default: True) - validate_result: Validate return against result_type (default: True) - cacheable: Allow result caching (default: False) - evaluator: Custom evaluator instance - - Returns: - A decorated method that accepts keyword arguments matching the signature. - - Notes: - - All parameters (except 'self') must have type annotations - - Use keyword-only parameters (after *) for cleaner signatures - - The generated context class is accessible via method.__dynamic_context__ - - The return type is accessible via method.__result_type__ - - See Also: - dynamic_context: The underlying decorator for context creation - Flow.call: The underlying decorator for flow evaluation - """ - # Import here to avoid circular import at module level - from ccflow.callable import dynamic_context - - # Extract dynamic_context-specific options - parent = kwargs.pop("parent", None) - - if len(args) == 1 and callable(args[0]): - # No arguments to decorator (@Flow.dynamic_call) - fn = args[0] - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(wrapped) - else: - # Arguments to decorator (@Flow.dynamic_call(...)) - def decorator(fn): - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(**kwargs)(wrapped) - - return decorator - # ***************************************************************************** # Define "Evaluators" and associated types @@ -859,30 +842,29 @@ def _validate_callable_model_generic_type(cls, m, handler, info): # ***************************************************************************** -# Dynamic Context Decorator +# Auto Context (internal helper for Flow.call(auto_context=True)) # ***************************************************************************** -def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: - """Decorator that creates a dynamic context class from function parameters. +def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable: + """Internal function that creates an auto context class from function parameters. - This decorator extracts the parameters from a function signature and creates - a dynamic ContextBase subclass whose fields correspond to those parameters. + This function extracts the parameters from a function signature and creates + a ContextBase subclass whose fields correspond to those parameters. The decorated function is then wrapped to accept the context object and unpack it into keyword arguments. + Used internally by Flow.call(auto_context=...). + Example: class MyCallable(CallableModel): - @Flow.dynamic_call # or @Flow.call @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") model = MyCallable() model(x=42, y="hello") # Works with kwargs """ - if func is None: - return partial(dynamic_context, parent=parent) - sig = signature(func) base_class = parent or ContextBase @@ -902,8 +884,8 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: default = ... if param.default is inspect.Parameter.empty else param.default fields[name] = (param.annotation, default) - # Create dynamic context class - dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + # Create auto context class + auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields) @wraps(func) def wrapper(self, context): @@ -914,10 +896,10 @@ def wrapper(self, context): wrapper.__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), ], return_annotation=sig.return_annotation, ) - wrapper.__dynamic_context__ = dyn_context + wrapper.__auto_context__ = auto_context_class wrapper.__result_type__ = sig.return_annotation return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 444d496..a748765 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,7 +20,6 @@ ResultBase, ResultType, WrapperModel, - dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -787,23 +786,22 @@ def foo(self, context): # ============================================================================= -# Tests for dynamic_context decorator +# Tests for Flow.call(auto_context=True) # ============================================================================= -class TestDynamicContext(TestCase): - """Tests for the @dynamic_context decorator.""" +class TestAutoContext(TestCase): + """Tests for @Flow.call(auto_context=True).""" def test_basic_usage_with_kwargs(self): - """Test basic dynamic_context usage with keyword arguments.""" + """Test basic auto_context usage with keyword arguments.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() # Call with kwargs result = model(x=42, y="hello") @@ -813,117 +811,111 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: result = model(x=10) self.assertEqual(result.value, "10-default") - def test_dynamic_context_attribute(self): - """Test that __dynamic_context__ attribute is set.""" + def test_auto_context_attribute(self): + """Test that __auto_context__ attribute is set.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, a: int, b: str) -> GenericResult: return GenericResult(value=f"{a}-{b}") - # The __call__ method should have __dynamic_context__ - call_method = DynamicCallable.__call__ + # The __call__ method should have __auto_context__ + call_method = AutoContextCallable.__call__ self.assertTrue(hasattr(call_method, "__wrapped__")) - # Access the inner function's __dynamic_context__ + # Access the inner function's __auto_context__ inner = call_method.__wrapped__ - self.assertTrue(hasattr(inner, "__dynamic_context__")) + self.assertTrue(hasattr(inner, "__auto_context__")) - dyn_ctx = inner.__dynamic_context__ - self.assertTrue(issubclass(dyn_ctx, ContextBase)) - self.assertIn("a", dyn_ctx.model_fields) - self.assertIn("b", dyn_ctx.model_fields) + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("a", auto_ctx.model_fields) + self.assertIn("b", auto_ctx.model_fields) - def test_dynamic_context_is_registered(self): - """Test that the dynamic context is registered for serialization.""" + def test_auto_context_is_registered(self): + """Test that the auto context is registered for serialization.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, value: int) -> GenericResult: return GenericResult(value=value) - inner = DynamicCallable.__call__.__wrapped__ - dyn_ctx = inner.__dynamic_context__ + inner = AutoContextCallable.__call__.__wrapped__ + auto_ctx = inner.__auto_context__ # Should have __ccflow_import_path__ set - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) def test_call_with_context_object(self): """Test calling with a context object instead of kwargs.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() - # Get the dynamic context class - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get the auto context class + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Create a context object - ctx = dyn_ctx(x=99, y="context") + ctx = auto_ctx(x=99, y="context") result = model(ctx) self.assertEqual(result.value, "99-context") def test_with_parent_context(self): - """Test dynamic_context with parent context class.""" + """Test auto_context with a parent context class.""" class ParentContext(ContextBase): base_value: str = "base" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int, base_value: str) -> GenericResult: return GenericResult(value=f"{x}-{base_value}") - # Get dynamic context - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get auto context + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) + self.assertTrue(issubclass(auto_ctx, ParentContext)) # Should have both fields - self.assertIn("base_value", dyn_ctx.model_fields) - self.assertIn("x", dyn_ctx.model_fields) + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) # Create context with parent field - ctx = dyn_ctx(x=42, base_value="custom") + ctx = auto_ctx(x=42, base_value="custom") self.assertEqual(ctx.base_value, "custom") self.assertEqual(ctx.x, 42) def test_parent_fields_must_be_in_signature(self): - """Test that parent fields must be included in function signature.""" + """Test that parent context fields must be included in function signature.""" class ParentContext(ContextBase): required_field: str with self.assertRaises(TypeError) as cm: - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) self.assertIn("required_field", str(cm.exception)) def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip for dynamic context callable.""" + """Test cloudpickle roundtrip for auto_context callable.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): multiplier: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x * self.multiplier) - model = DynamicCallable(multiplier=3) + model = AutoContextCallable(multiplier=3) # Test roundtrip restored = rcploads(rcpdumps(model)) @@ -932,13 +924,12 @@ def __call__(self, *, x: int) -> GenericResult: self.assertEqual(result.value, 30) def test_ray_task_execution(self): - """Test dynamic context callable in Ray task.""" + """Test auto_context callable in Ray task.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): factor: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: int = 1) -> GenericResult: return GenericResult(value=(x + y) * self.factor) @@ -946,63 +937,35 @@ def __call__(self, *, x: int, y: int = 1) -> GenericResult: def run_callable(model, **kwargs): return model(**kwargs).value - model = DynamicCallable(factor=5) + model = AutoContextCallable(factor=5) with ray.init(num_cpus=1): result = ray.get(run_callable.remote(model, x=10, y=2)) self.assertEqual(result, 60) # (10 + 2) * 5 - def test_multiple_dynamic_context_methods(self): - """Test callable with multiple dynamic_context decorated methods.""" - - class MultiMethodCallable(CallableModel): - @Flow.call - @dynamic_context - def __call__(self, *, a: int) -> GenericResult: - return GenericResult(value=a) - - @dynamic_context - def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: - return GenericResult(value=f"{b}-{c}") - - model = MultiMethodCallable() - - # Test __call__ - result1 = model(a=42) - self.assertEqual(result1.value, 42) - - # Test other_method (without Flow.call, just the dynamic_context wrapper) - # Need to create the context manually - other_ctx = model.other_method.__dynamic_context__ - ctx = other_ctx(b="hello", c=2.5) - result2 = model.other_method(ctx) - self.assertEqual(result2.value, "hello-2.5") - def test_context_type_property_works(self): - """Test that type_ property works on the dynamic context.""" + """Test that type_ property works on the auto context.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ - ctx = dyn_ctx(x=42) + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=42) # type_ should work and be importable type_path = str(ctx.type_) self.assertIn("_Local_", type_path) - self.assertEqual(ctx.type_.object, dyn_ctx) + self.assertEqual(ctx.type_.object, auto_ctx) def test_complex_field_types(self): - """Test dynamic_context with complex field types.""" + """Test auto_context with complex field types.""" from typing import List, Optional - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__( self, *, @@ -1013,7 +976,7 @@ def __call__( total = sum(items) + count return GenericResult(value=f"{name}:{total}" if name else str(total)) - model = DynamicCallable() + model = AutoContextCallable() result = model(items=[1, 2, 3], name="test", count=10) self.assertEqual(result.value, "test:16") @@ -1021,146 +984,43 @@ def __call__( result = model(items=[5, 5]) self.assertEqual(result.value, "10") - -class TestFlowDynamicCall(TestCase): - """Tests for @Flow.dynamic_call decorator.""" - - def test_basic_usage(self): - """Test basic @Flow.dynamic_call usage.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int, y: str = "default") -> GenericResult: - return GenericResult(value=f"{x}-{y}") - - model = DynamicCallable() - - result = model(x=42, y="hello") - self.assertEqual(result.value, "42-hello") - - result = model(x=10) - self.assertEqual(result.value, "10-default") - - def test_dynamic_context_attributes_preserved(self): - """Test that __dynamic_context__ and __result_type__ are directly accessible.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - # Should be directly accessible without traversing __wrapped__ chain - method = DynamicCallable.__call__ - self.assertTrue(hasattr(method, "__dynamic_context__")) - self.assertTrue(hasattr(method, "__result_type__")) - self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) - self.assertEqual(method.__result_type__, GenericResult) - - def test_model_result_type_property(self): - """Test that model.result_type returns correct type for dynamic contexts.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - model = DynamicCallable() - self.assertEqual(model.result_type, GenericResult) - - def test_with_parent_context(self): - """Test @Flow.dynamic_call with parent context.""" - - class ParentContext(ContextBase): - base_value: str = "base" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call(parent=ParentContext) - def __call__(self, *, x: int, base_value: str) -> GenericResult: - return GenericResult(value=f"{x}-{base_value}") - - model = DynamicCallable() - - # Get dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) - - # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) - - # Call should work, uses parent default - result = model(x=42, base_value="custom") - self.assertEqual(result.value, "42-custom") - def test_with_flow_options(self): - """Test @Flow.dynamic_call with FlowOptions parameters.""" + """Test auto_context with FlowOptions parameters.""" - class DynamicCallable(CallableModel): - @Flow.dynamic_call(validate_result=False) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True, validate_result=False) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - model = DynamicCallable() + model = AutoContextCallable() result = model(x=42) self.assertEqual(result.value, 42) - def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip with @Flow.dynamic_call.""" - - class DynamicCallable(CallableModel): - multiplier: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x * self.multiplier) - - model = DynamicCallable(multiplier=3) - restored = rcploads(rcpdumps(model)) - - result = restored(x=10) - self.assertEqual(result.value, 30) - - def test_ray_task(self): - """Test @Flow.dynamic_call in Ray task.""" - - class DynamicCallable(CallableModel): - factor: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int, y: int = 1) -> GenericResult: - return GenericResult(value=(x + y) * self.factor) - - @ray.remote - def run_callable(model, **kwargs): - return model(**kwargs).value - - model = DynamicCallable(factor=5) + def test_error_without_auto_context(self): + """Test that using kwargs signature without auto_context raises an error.""" - with ray.init(num_cpus=1): - result = ray.get(run_callable.remote(model, x=10, y=2)) - - self.assertEqual(result, 60) - - def test_dynamic_context_is_registered(self): - """Test that the dynamic context from @Flow.dynamic_call is registered.""" + class BadCallable(CallableModel): + @Flow.call # Missing auto_context=True! + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, value: int) -> GenericResult: - return GenericResult(value=value) + # Error happens at instantiation time when _check_signature validates + with self.assertRaises(ValueError) as cm: + BadCallable() - # Find dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + # Should fail because __call__ must take a single argument named 'context' + error_msg = str(cm.exception) + self.assertIn("__call__", error_msg) + self.assertIn("context", error_msg) - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + def test_invalid_auto_context_value(self): + """Test that invalid auto_context values raise TypeError with helpful message.""" + with self.assertRaises(TypeError) as cm: + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) -def _find_dynamic_context(func): - """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" - visited = set() - current = func - while current is not None and id(current) not in visited: - visited.add(id(current)) - if hasattr(current, "__dynamic_context__"): - return current.__dynamic_context__ - current = getattr(current, "__wrapped__", None) - return None + error_msg = str(cm.exception) + self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) + self.assertIn("invalid", error_msg) diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index 34f3f7e..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -8,12 +8,12 @@ from .evaluators.util import MyDateCallable, MyResult -class MyDynamicDateCallable(CallableModel): - """Dynamic context version of MyDateCallable for testing evaluators.""" +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" offset: int - @Flow.dynamic_call(parent=DateContext) + @Flow.call(auto_context=DateContext) def __call__(self, *, date: date) -> MyResult: return MyResult(x=date.day + self.offset) @@ -48,11 +48,11 @@ def test_evaluator_deps(self): @pytest.mark.parametrize( "callable_class", - [MyDateCallable, MyDynamicDateCallable], - ids=["standard", "dynamic"], + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], ) class TestEvaluatorParametrized: - """Test evaluators work with both standard and dynamic context callables.""" + """Test evaluators work with both standard and auto_context callables.""" def test_evaluator_with_context_object(self, callable_class): """Test evaluator with a context object.""" diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 586b03f..dc2db55 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1235,7 +1235,7 @@ class TestCreateCcflowModelCloudpickleCrossProcess: id="context_only", ), pytest.param( - # Dynamic context with CallableModel + # Runtime-created context with CallableModel """ from ray.cloudpickle import dump from ccflow import CallableModel, ContextBase, GenericResult, Flow From 989e278eccc1a1dec446eb26fb53febad1d3449e Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 19:22:25 -0500 Subject: [PATCH 03/17] Add @Flow.model decorator, new annotation that pulls from deps Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 + ccflow/callable.py | 216 +++- ccflow/dep.py | 278 +++++ ccflow/flow_model.py | 341 ++++++ ccflow/tests/config/conf_flow.yaml | 80 ++ ccflow/tests/test_callable.py | 1 + ccflow/tests/test_flow_model.py | 1477 +++++++++++++++++++++++++ ccflow/tests/test_flow_model_hydra.py | 437 ++++++++ docs/wiki/Key-Features.md | 115 ++ examples/flow_model_example.py | 219 ++++ 10 files changed, 3163 insertions(+), 2 deletions(-) create mode 100644 ccflow/dep.py create mode 100644 ccflow/flow_model.py create mode 100644 ccflow/tests/config/conf_flow.yaml create mode 100644 ccflow/tests/test_flow_model.py create mode 100644 ccflow/tests/test_flow_model_hydra.py create mode 100644 examples/flow_model_example.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 163f275..9916168 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -10,6 +10,7 @@ from .compose import * from .callable import * from .context import * +from .dep import * from .enums import Enum from .global_state import * from .local_persistence import * diff --git a/ccflow/callable.py b/ccflow/callable.py index 748759c..5296bfe 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -28,6 +28,7 @@ ResultBase, ResultType, ) +from .dep import Dep, extract_dep from .local_persistence import create_ccflow_model from .validators import str_to_log_level @@ -128,7 +129,7 @@ def _check_result_type(cls, result_type): @model_validator(mode="after") def _check_signature(self): sig_call = _cached_signature(self.__class__.__call__) - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: raise ValueError("__call__ method must take a single argument, named 'context'") sig_deps = _cached_signature(self.__class__.__deps__) @@ -195,6 +196,114 @@ def _get_logging_evaluator(log_level): return LoggingEvaluator(log_level=log_level) +def _get_dep_fields(model_class) -> Dict[str, Dep]: + """Analyze class fields to find Dep-annotated fields. + + Returns a dict mapping field name to Dep instance for fields that need resolution. + """ + dep_fields = {} + + # Get type hints from the class + hints = {} + for cls in model_class.__mro__: + if hasattr(cls, "__annotations__"): + for name, annotation in cls.__annotations__.items(): + if name not in hints: # Don't override child class annotations + hints[name] = annotation + + for name, annotation in hints.items(): + base_type, dep = extract_dep(annotation) + if dep is not None: + dep_fields[name] = dep + + return dep_fields + + +def _wrap_with_dep_resolution(fn): + """Wrap a function to auto-resolve DepOf fields before calling. + + For each Dep-annotated field on the model that contains a CallableModel, + resolves it using __deps__ and temporarily sets the resolved value on self. + + Note: This wrapper is only applied at runtime when the function is called, + not during decoration. This avoids issues with functools.wraps flattening + the __wrapped__ chain. + + Args: + fn: The original function + + Returns: + The original function unchanged - dep resolution happens at the call site + """ + # Don't modify the function - dep resolution is handled in ModelEvaluationContext + return fn + + +def _resolve_deps_and_call(model, context, fn): + """Resolve DepOf fields and call the function. + + This is called from ModelEvaluationContext.__call__ to handle dep resolution. + + Args: + model: The CallableModel instance + context: The context to pass to the function + fn: The function to call + + Returns: + The result of calling fn(model, context) + """ + # Don't resolve deps for __deps__ method + if fn.__name__ == "__deps__": + return fn(model, context) + + # Get Dep-annotated fields for this model class + dep_fields = _get_dep_fields(model.__class__) + + if not dep_fields: + return fn(model, context) + + # Get dependencies from __deps__ + deps_result = model.__deps__(context) + # Build a map from model instance id to (model, contexts) for lookup + dep_map = {} + for dep_model, contexts in deps_result: + dep_map[id(dep_model)] = (dep_model, contexts) + + # Store original values and resolve + originals = {} + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + originals[field_name] = field_value + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) + + # Temporarily set resolved value on model + object.__setattr__(model, field_name, resolved) + + try: + # Call original function + return fn(model, context) + finally: + # Restore original CallableModel values + for field_name, original_value in originals.items(): + object.__setattr__(model, field_name, original_value) + + class FlowOptions(BaseModel): """Options for Flow evaluation. @@ -246,6 +355,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": return self._get_evaluator_from_options(options) def __call__(self, fn): + # Wrap function with dependency resolution for DepOf fields + fn = _wrap_with_dep_resolution(fn) + # Used for building a graph of model evaluation contexts without evaluating def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): # Create the evaluation context. @@ -451,6 +563,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # The generated context inherits from DateContext, so it's compatible # with infrastructure expecting DateContext instances. + + Auto-Resolve Dependencies Example: + When __call__ has parameters beyond 'self' and 'context' that match field + names annotated with DepOf/Dep, those dependencies are automatically resolved + using __deps__ (if defined) or auto-generated from Dep annotations. + + class MyModel(CallableModel): + data: Annotated[GenericResult[dict], Dep(transform=my_transform)] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + # data is automatically resolved - no manual calling needed + return GenericResult(value=process(data.value)) + + For transforms that need access to instance fields, define __deps__ manually: + + class MyModel(CallableModel): + data: DepOf[..., GenericResult[dict]] + window: int = 7 + + def __deps__(self, context): + # Can access self.window here + return [(self.data, [context.with_lookback(self.window)])] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + return GenericResult(value=process(data.value)) """ # Extract auto_context option (not part of FlowOptions) # Can be: False, True, or a ContextBase subclass @@ -502,6 +641,78 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Two Context Modes: + + Mode 1 - Explicit context parameter: + Function has a 'context' parameter annotated with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + + Mode 2 - Unpacked context_args: + Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + Dependencies: + Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: + + from ccflow import Dep, DepOf + from typing import Annotated + + @Flow.model + def compute_returns( + context: DateRangeContext, + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + )] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=prices.value.pct_change()) + + # Or use DepOf shorthand for no transform: + @Flow.model + def compute_stats( + context: DateRangeContext, + data: DepOf[..., GenericResult[pl.DataFrame]] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=data.value.describe()) + + Usage: + # Create model instances + loader = load_prices(source="prod_db") + returns = compute_returns(prices=loader) + + # Execute + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = returns(ctx) + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -555,7 +766,8 @@ def _context_validator(cls, values, handler, info): def __call__(self) -> ResultType: fn = getattr(self.model, self.fn) if hasattr(fn, "__wrapped__"): - result = fn.__wrapped__(self.model, self.context) + # Call through _resolve_deps_and_call to handle DepOf field resolution + result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) # If it's a callable model, then we can validate the result if self.options.get("validate_result", True): if fn.__name__ == "__deps__": diff --git a/ccflow/dep.py b/ccflow/dep.py new file mode 100644 index 0000000..a7e0121 --- /dev/null +++ b/ccflow/dep.py @@ -0,0 +1,278 @@ +"""Dependency annotation markers for Flow.model. + +This module provides: +- Dep: Annotation marker for dependency parameters that can accept CallableModel +- DepOf: Shorthand for Annotated[Union[T, CallableModel], Dep()] +""" + +from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, TypeVar, Union, get_args, get_origin + +from .base import ContextBase + +if TYPE_CHECKING: + from .callable import CallableModel + +__all__ = ("Dep", "DepOf") + +T = TypeVar("T") + +# Lazy reference to CallableModel to avoid circular import +_CallableModel = None + + +def _get_callable_model(): + """Lazily import CallableModel to avoid circular imports.""" + global _CallableModel + if _CallableModel is None: + from .callable import CallableModel + + _CallableModel = CallableModel + return _CallableModel + + +class _DepOfMeta(type): + """Metaclass that makes DepOf[ContextType, ResultType] work.""" + + def __getitem__(cls, item): + if not isinstance(item, tuple) or len(item) != 2: + raise TypeError( + "DepOf requires 2 type arguments: DepOf[ContextType, ResultType]. " + "Use ... for ContextType to inherit from parent: DepOf[..., ResultType]" + ) + context_type, result_type = item + CallableModel = _get_callable_model() + + if context_type is ...: + # DepOf[..., ResultType] - inherit context from parent + return Annotated[Union[result_type, CallableModel], Dep()] + else: + # DepOf[ContextType, ResultType] - explicit context type + return Annotated[Union[result_type, CallableModel], Dep(context_type=context_type)] + + +class DepOf(metaclass=_DepOfMeta): + """ + Shorthand for Annotated[Union[ResultType, CallableModel], Dep(context_type=...)]. + + Follows Callable convention: DepOf[InputContext, OutputResult] + + For class fields, accepts either: + - The result type directly (pre-computed value) + - A CallableModel that produces the result type (resolved at call time) + + Usage: + # Inherit context type from parent model (most common) + data: DepOf[..., GenericResult[dict]] + + # Explicit context type validation + data: DepOf[DateRangeContext, GenericResult[dict]] + + At call time, if the field contains a CallableModel, it will be automatically + resolved using __deps__ and the resolved value will be accessible via self.field_name. + + For dependencies with transforms, define them in __deps__: + def __deps__(self, context): + transformed_ctx = context.model_copy(update={...}) + return [(self.data, [transformed_ctx])] + """ + + pass + + +def _is_compatible_type(actual: Type, expected: Type) -> bool: + """Check if actual type is compatible with expected type. + + Handles generic types like GenericResult[pl.DataFrame] where issubclass + would raise TypeError. + + Args: + actual: The actual type to check + expected: The expected type to match against + + Returns: + True if actual is compatible with expected + """ + # Handle None/empty types + if actual is None or expected is None: + return actual is expected + + # Get origins for generic types + actual_origin = get_origin(actual) or actual + expected_origin = get_origin(expected) or expected + + # Check if origins are compatible + try: + if not (isinstance(actual_origin, type) and isinstance(expected_origin, type)): + return False + if not issubclass(actual_origin, expected_origin): + return False + except TypeError: + # issubclass can fail for certain types + return False + + # Check generic args if present + actual_args = get_args(actual) + expected_args = get_args(expected) + + if expected_args and actual_args: + if len(actual_args) != len(expected_args): + return False + return all(_is_compatible_type(a, e) for a, e in zip(actual_args, expected_args)) + + return True + + +class Dep: + """ + Annotation marker for dependency parameters. + + Marks a parameter as accepting either the declared type or a CallableModel + that produces that type. Supports optional context transform and + construction-time type validation. + + Usage: + # No transform, no explicit validation (uses parent's context_type) + prices: Annotated[GenericResult[pl.DataFrame], Dep()] + + # With transform + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # With explicit context_type validation + prices: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=DateRangeContext, + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # Cross-context dependency (transform changes context type) + sim_data: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=SimulationContext, + transform=date_to_simulation_context + )] + """ + + def __init__( + self, + transform: Optional[Callable[[ContextBase], ContextBase]] = None, + context_type: Optional[Type[ContextBase]] = None, + ): + """ + Args: + transform: Optional function to transform context before calling dependency. + Signature: (context) -> transformed_context + context_type: Expected context_type of the dependency CallableModel. + If None, defaults to the parent model's context_type. + Validated at construction time when a CallableModel is passed. + """ + self.transform = transform + self.context_type = context_type + + def apply(self, context: ContextBase) -> ContextBase: + """Apply the transform to a context, or return unchanged if no transform.""" + if self.transform is not None: + return self.transform(context) + return context + + def validate_dependency( + self, + value: "CallableModel", # noqa: F821 + expected_result_type: Type, + parent_context_type: Type[ContextBase], + param_name: str, + ) -> None: + """ + Validate a CallableModel dependency at construction time. + + Args: + value: The CallableModel being passed as a dependency + expected_result_type: The result type from the Annotated type hint + parent_context_type: The context_type of the parent model + param_name: Name of the parameter (for error messages) + + Raises: + TypeError: If context_type or result_type don't match + """ + # Import here to avoid circular import + from .callable import CallableModel + + if not isinstance(value, CallableModel): + return # Not a CallableModel, skip validation + + # Determine expected context type + expected_ctx = self.context_type if self.context_type is not None else parent_context_type + + # Validate context_type - the dependency's context_type should be compatible + # with what we'll pass to it (expected_ctx) + dep_context_type = value.context_type + try: + if not issubclass(expected_ctx, dep_context_type): + raise TypeError( + f"Dependency '{param_name}': expected context_type compatible with " + f"{dep_context_type.__name__}, but will pass {expected_ctx.__name__}" + ) + except TypeError: + # issubclass can fail for certain types, try alternate check + if expected_ctx != dep_context_type: + raise TypeError(f"Dependency '{param_name}': context_type mismatch - expected {dep_context_type}, got {expected_ctx}") + + # Validate result_type using the generic-safe comparison + # If expected_result_type is Union[T, CallableModel], extract T for validation + dep_result_type = value.result_type + actual_expected_type = expected_result_type + + # Handle Union[T, CallableModel] from DepOf expansion + if get_origin(expected_result_type) is Union: + union_args = get_args(expected_result_type) + # Filter out CallableModel from the union + non_callable_types = [t for t in union_args if t is not CallableModel] + if non_callable_types: + actual_expected_type = non_callable_types[0] + + if not _is_compatible_type(dep_result_type, actual_expected_type): + raise TypeError( + f"Dependency '{param_name}': expected result_type compatible with " + f"{actual_expected_type}, but got CallableModel with result_type {dep_result_type}" + ) + + def __repr__(self): + parts = [] + if self.transform is not None: + parts.append(f"transform={self.transform}") + if self.context_type is not None: + parts.append(f"context_type={self.context_type.__name__}") + return f"Dep({', '.join(parts)})" if parts else "Dep()" + + def __eq__(self, other): + if not isinstance(other, Dep): + return False + return self.transform == other.transform and self.context_type == other.context_type + + def __hash__(self): + # Make Dep hashable for use in sets/dicts + return hash((id(self.transform), self.context_type)) + + +def extract_dep(annotation) -> tuple: + """Extract Dep from Annotated[T, Dep(...)] or DepOf[ContextType, T]. + + When multiple Dep annotations exist (e.g., from nested Annotated that flattens), + returns the LAST one, which represents the outermost user annotation. + + Args: + annotation: A type annotation, possibly Annotated with Dep + + Returns: + Tuple of (base_type, Dep instance or None) + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + base_type = args[0] + # Find the LAST Dep - nested Annotated flattens, so outer annotation comes last + last_dep = None + for metadata in args[1:]: + if isinstance(metadata, Dep): + last_dep = metadata + if last_dep is not None: + return base_type, last_dep + return annotation, None diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..3a96886 --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,341 @@ +"""Flow.model decorator implementation. + +This module provides the Flow.model decorator that generates CallableModel classes +from plain Python functions, reducing boilerplate while maintaining full compatibility +with existing ccflow infrastructure. +""" + +import inspect +import logging +from functools import wraps +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin + +from pydantic import Field + +from .base import ContextBase, ResultBase +from .dep import Dep, extract_dep +from .local_persistence import register_ccflow_import_path + +__all__ = ("flow_model",) + +log = logging.getLogger(__name__) + + +def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: inspect.Signature) -> Type[ContextBase]: + """Infer or create a context type from context_args parameter names. + + This attempts to match existing context types or creates a new one. + + Args: + context_args: List of parameter names that come from context + func: The decorated function + sig: The function signature + + Returns: + A ContextBase subclass + """ + from .local_persistence import create_ccflow_model + + # Build field definitions for the context from parameter annotations + fields = {} + for name in context_args: + if name not in sig.parameters: + raise ValueError(f"context_arg '{name}' not found in function parameters") + param = sig.parameters[name] + if param.annotation is inspect.Parameter.empty: + raise ValueError(f"context_arg '{name}' must have a type annotation") + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Try to match common context types + from .context import DateRangeContext + + # Check for DateRangeContext pattern + if set(context_args) == {"start_date", "end_date"}: + from datetime import date + + if all( + sig.parameters[name].annotation in (date, "date") + or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) + for name in context_args + ): + return DateRangeContext + + # Create a new context type dynamically + context_class = create_ccflow_model( + f"_{func.__name__}_Context", + __base__=ContextBase, + **fields, + ) + return context_class + + +def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: + """Extract dependency info from an annotation. + + Returns: + Tuple of (base_type, Dep instance or None) + """ + return extract_dep(annotation) + + +def flow_model( + func: Callable = None, + *, + # Context handling + context_args: Optional[List[str]] = None, + # Flow.call options (passed to generated __call__) + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: Optional[Any] = None, +) -> Callable: + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + func: The function to decorate + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results + volatile: Mark as volatile (always re-execute) + log_level: Logging verbosity + validate_result: Validate return type + verbose: Verbose logging output + evaluator: Custom evaluator + + Two Context Modes: + 1. Explicit context parameter: Function has a 'context' parameter annotated + with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + ... + + 2. Unpacked context_args: Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + ... + + Returns: + A factory function that creates CallableModel instances + """ + + def decorator(fn: Callable) -> Callable: + # Import here to avoid circular imports + from .callable import CallableModel, Flow, GraphDepList + + sig = inspect.signature(fn) + params = sig.parameters + + # Validate return type + return_type = sig.return_annotation + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {fn.__name__} must have a return type annotation") + # Check that return type is a ResultBase subclass + return_origin = get_origin(return_type) or return_type + if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): + raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") + + # Determine context mode and extract info + if context_args is not None: + # Mode 2: Unpacked context args + context_type = _infer_context_type_from_args(context_args, fn, sig) + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + elif "context" in params or "_" in params: + # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + context_param_name = "context" if "context" in params else "_" + context_param = params[context_param_name] + if context_param.annotation is inspect.Parameter.empty: + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") + context_type = context_param.annotation + if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") + model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} + use_context_args = False + else: + raise TypeError(f"Function {fn.__name__} must either have a 'context' (or '_') parameter or specify context_args in the decorator") + + # Analyze parameters to find dependencies and regular fields + dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) + model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + + for name, param in model_field_params.items(): + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation") + + base_type, dep = _get_dep_info(param.annotation) + default = ... if param.default is inspect.Parameter.empty else param.default + + if dep is not None: + # This is a dependency parameter + dep_fields[name] = (base_type, dep) + # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep + # This consolidates resolution logic into one place + model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) + else: + # Regular model field + model_fields[name] = (param.annotation, default) + + # Capture context_args in local variable for closures + ctx_args_list = context_args or [] + # Capture context parameter name for closures (only used in mode 1) + ctx_param_name = context_param_name if not use_context_args else "context" + + # Create the __call__ method + def make_call_impl(): + def __call__(self, context): + # Build kwargs for the original function + if use_context_args: + # Unpack context into args + fn_kwargs = {name: getattr(context, name) for name in ctx_args_list} + else: + # Pass context directly (using actual parameter name: 'context' or '_') + fn_kwargs = {ctx_param_name: context} + + # Add model fields (deps are resolved by _resolve_deps_and_call in callable.py) + for name in model_fields: + fn_kwargs[name] = getattr(self, name) + + return fn(**fn_kwargs) + + # Set proper signature for CallableModel validation + __call__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=return_type, + ) + return __call__ + + call_impl = make_call_impl() + + # Apply Flow.call decorator + flow_options = { + "cacheable": cacheable, + "volatile": volatile, + "log_level": log_level, + "validate_result": validate_result, + "verbose": verbose, + } + if evaluator is not None: + flow_options["evaluator"] = evaluator + + decorated_call = Flow.call(**flow_options)(call_impl) + + # Create the __deps__ method + def make_deps_impl(): + def __deps__(self, context) -> GraphDepList: + deps = [] + for dep_name, (base_type, dep_obj) in dep_fields.items(): + value = getattr(self, dep_name) + if isinstance(value, CallableModel): + transformed_ctx = dep_obj.apply(context) + deps.append((value, [transformed_ctx])) + return deps + + # Set proper signature + __deps__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + deps_impl = make_deps_impl() + decorated_deps = Flow.deps(deps_impl) + + # Build pydantic field annotations for the class + annotations = {} + + namespace = { + "__module__": fn.__module__, + "__qualname__": f"_{fn.__name__}_Model", + "__call__": decorated_call, + "__deps__": decorated_deps, + } + + for name, (typ, default) in model_fields.items(): + annotations[name] = typ + if default is not ...: + namespace[name] = default + else: + # For required fields, use Field(...) + namespace[name] = Field(...) + + namespace["__annotations__"] = annotations + + # Add model validator for dependency validation if we have dep fields + if dep_fields: + from pydantic import model_validator + + # Create validator function that captures dep_fields and context_type + def make_dep_validator(d_fields, ctx_type): + @model_validator(mode="after") + def __validate_deps__(self): + from .callable import CallableModel + + for dep_name, (base_type, dep_obj) in d_fields.items(): + value = getattr(self, dep_name) + if isinstance(value, CallableModel): + dep_obj.validate_dependency(value, base_type, ctx_type, dep_name) + return self + + return __validate_deps__ + + namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) + + # Create the class using type() + GeneratedModel = type(f"_{fn.__name__}_Model", (CallableModel,), namespace) + + # Set class-level attributes after class creation (to avoid pydantic processing) + GeneratedModel.__flow_model_context_type__ = context_type + GeneratedModel.__flow_model_return_type__ = return_type + GeneratedModel.__flow_model_func__ = fn + GeneratedModel.__flow_model_dep_fields__ = dep_fields + GeneratedModel.__flow_model_use_context_args__ = use_context_args + GeneratedModel.__flow_model_context_args__ = ctx_args_list + + # Override context_type property after class creation + @property + def context_type_getter(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + # Override result_type property after class creation + @property + def result_type_getter(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + GeneratedModel.context_type = context_type_getter + GeneratedModel.result_type = result_type_getter + + # Register for serialization (local classes need this) + register_ccflow_import_path(GeneratedModel) + + # Rebuild the model to process annotations properly + GeneratedModel.model_rebuild() + + # Create factory function that returns model instances + @wraps(fn) + def factory(**kwargs) -> GeneratedModel: + return GeneratedModel(**kwargs) + + # Preserve useful attributes on factory + factory._generated_model = GeneratedModel + factory.__doc__ = fn.__doc__ + + return factory + + # Handle both @Flow.model and @Flow.model(...) syntax + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..781bd24 --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,80 @@ +# Flow.model configurations for Hydra integration tests +# This file is separate from conf.yaml to avoid affecting existing tests + +# Basic Flow.model +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: test_source + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +# Pipeline with dependencies (uses registry name references for same instance) +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +# Three-stage pipeline +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +# Diamond dependency pattern +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +# DateRangeContext with transform +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader + source: market_data + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +# context_args models (auto-unpacked context parameters) +ctx_args_loader: + _target_: ccflow.tests.test_flow_model.context_args_loader + source: data_source + +ctx_args_processor: + _target_: ccflow.tests.test_flow_model.context_args_processor + data: ctx_args_loader + prefix: "output" diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index a748765..29f4524 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -462,6 +462,7 @@ def test_types(self): error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelMissingContextArg) + # BadModelDoubleContextArg also fails with the same error since extra params aren't allowed error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..75e0899 --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,1477 @@ +"""Tests for Flow.model decorator.""" + +from datetime import date, timedelta +from typing import Annotated +from unittest import TestCase + +from pydantic import ValidationError +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +from ccflow import ( + CallableModel, + ContextBase, + DateRangeContext, + Dep, + DepOf, + Flow, + GenericResult, + ModelRegistry, + ResultBase, +) + + +class SimpleContext(ContextBase): + """Simple context for testing.""" + + value: int + + +class ExtendedContext(ContextBase): + """Extended context with multiple fields.""" + + x: int + y: str = "default" + + +class MyResult(ResultBase): + """Custom result type for testing.""" + + data: str + + +# ============================================================================= +# Basic Flow.model Tests +# ============================================================================= + + +class TestFlowModelBasic(TestCase): + """Basic Flow.model functionality tests.""" + + def test_simple_model_explicit_context(self): + """Test Flow.model with explicit context parameter.""" + + @Flow.model + def simple_loader(context: SimpleContext, multiplier: int) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + # Create model instance + loader = simple_loader(multiplier=3) + + # Should be a CallableModel + self.assertIsInstance(loader, CallableModel) + + # Execute + ctx = SimpleContext(value=10) + result = loader(ctx) + + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 30) + + def test_model_with_default_params(self): + """Test Flow.model with default parameter values.""" + + @Flow.model + def loader_with_defaults(context: SimpleContext, multiplier: int = 2, prefix: str = "result") -> GenericResult[str]: + return GenericResult(value=f"{prefix}:{context.value * multiplier}") + + # Create with defaults + loader = loader_with_defaults() + result = loader(SimpleContext(value=5)) + self.assertEqual(result.value, "result:10") + + # Create with custom values + loader2 = loader_with_defaults(multiplier=3, prefix="custom") + result2 = loader2(SimpleContext(value=5)) + self.assertEqual(result2.value, "custom:15") + + def test_model_context_type_property(self): + """Test that generated model has correct context_type.""" + + @Flow.model + def typed_model(context: ExtendedContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.x * factor) + + model = typed_model(factor=2) + self.assertEqual(model.context_type, ExtendedContext) + + def test_model_result_type_property(self): + """Test that generated model has correct result_type.""" + + @Flow.model + def custom_result_model(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + model = custom_result_model() + self.assertEqual(model.result_type, MyResult) + + def test_model_with_no_extra_params(self): + """Test Flow.model with only context parameter.""" + + @Flow.model + def identity_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + model = identity_model() + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, 42) + + def test_model_with_flow_options(self): + """Test Flow.model with Flow.call options.""" + + @Flow.model(cacheable=True, validate_result=True) + def cached_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + model = cached_model(value=10) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 15) + + def test_model_with_underscore_context(self): + """Test Flow.model with '_' as context parameter (unused context convention).""" + + @Flow.model + def loader(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def consumer(_: SimpleContext, data: DepOf[..., GenericResult[int]]) -> GenericResult[int]: + # Context not used directly, just passed to dependency + return GenericResult(value=data.value * 2) + + load = loader(base=100) + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 + 100 = 110, consumer: 110 * 2 = 220 + self.assertEqual(result.value, 220) + + # Verify context_type is still correct + self.assertEqual(consume.context_type, SimpleContext) + + +# ============================================================================= +# context_args Mode Tests +# ============================================================================= + + +class TestFlowModelContextArgs(TestCase): + """Tests for Flow.model with context_args (unpacked context).""" + + def test_context_args_basic(self): + """Test basic context_args usage.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + loader = date_range_loader(source="db") + + # Should use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "db:2024-01-01 to 2024-01-31") + + def test_context_args_custom_context(self): + """Test context_args with custom context type.""" + + @Flow.model(context_args=["x", "y"]) + def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: + return GenericResult(value=f"{y}:{x * multiplier}") + + model = unpacked_model(multiplier=2) + + # Create context with generated type + ctx_type = model.context_type + ctx = ctx_type(x=5, y="test") + + result = model(ctx) + self.assertEqual(result.value, "test:10") + + def test_context_args_with_defaults(self): + """Test context_args where context fields have defaults.""" + + @Flow.model(context_args=["value"]) + def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult[str]: + return GenericResult(value=f"{extra}:{value}") + + model = model_with_ctx_default() + + # Create context - the generated context should allow default + ctx_type = model.context_type + ctx = ctx_type(value=100) + + result = model(ctx) + self.assertEqual(result.value, "foo:100") + + +# ============================================================================= +# Dependency Tests +# ============================================================================= + + +class TestFlowModelDependencies(TestCase): + """Tests for Flow.model with dependencies.""" + + def test_simple_dependency_with_depof(self): + """Test simple dependency using DepOf shorthand.""" + + @Flow.model + def loader(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int = 1, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + # Create pipeline + load = loader(value=10) + consume = consumer(data=load, multiplier=2) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 + self.assertEqual(result.value, 30) + + def test_dependency_with_explicit_dep(self): + """Test dependency using explicit Dep() annotation.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep()], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + load = loader() + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 * 2 = 20, consumer: 20 + 100 = 120 + self.assertEqual(result.value, 120) + + def test_dependency_with_direct_value(self): + """Test that Dep fields can accept direct values (not CallableModel).""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + context.value) + + # Pass direct value instead of CallableModel + consume = consumer(data=GenericResult(value=100)) + + result = consume(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_deps_method_generation(self): + """Test that __deps__ method is correctly generated.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=10) + deps = consume.__deps__(ctx) + + # Should have one dependency + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0][0], load) + self.assertEqual(deps[0][1], [ctx]) + + def test_no_deps_when_direct_value(self): + """Test that __deps__ returns empty when direct values used.""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + consume = consumer(data=GenericResult(value=100)) + + deps = consume.__deps__(SimpleContext(value=10)) + self.assertEqual(len(deps), 0) + + +# ============================================================================= +# Transform Tests +# ============================================================================= + + +class TestFlowModelTransforms(TestCase): + """Tests for Flow.model with context transforms.""" + + def test_transform_in_dep(self): + """Test dependency with context transform.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[ + GenericResult[int], + Dep(transform=lambda ctx: ctx.model_copy(update={"value": ctx.value + 10})), + ], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # Transform adds 10 to context.value: 5 + 10 = 15 + # Loader returns that: 15 + # Consumer multiplies by 2: 30 + self.assertEqual(result.value, 30) + + def test_transform_in_deps_method(self): + """Test that transform is applied in __deps__ method.""" + + def transform_fn(ctx): + return ctx.model_copy(update={"value": ctx.value * 3}) + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(transform=transform_fn)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=7) + deps = consume.__deps__(ctx) + + # Transform should be applied + self.assertEqual(len(deps), 1) + transformed_ctx = deps[0][1][0] + self.assertEqual(transformed_ctx.value, 21) # 7 * 3 + + def test_date_range_transform(self): + """Test transform pattern with date ranges using context_args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date}") + + def lookback_transform(ctx: DateRangeContext) -> DateRangeContext: + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + @Flow.model(context_args=["start_date", "end_date"]) + def range_processor( + start_date: date, + end_date: date, + data: Annotated[GenericResult[str], Dep(transform=lookback_transform)], + ) -> GenericResult[str]: + return GenericResult(value=f"processed:{data.value}") + + loader = range_loader(source="db") + processor = range_processor(data=loader) + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # Transform should shift start_date back by 1 day + self.assertEqual(result.value, "processed:db:2024-01-09") + + +# ============================================================================= +# Pipeline Tests +# ============================================================================= + + +class TestFlowModelPipeline(TestCase): + """Tests for multi-stage pipelines with Flow.model.""" + + def test_three_stage_pipeline(self): + """Test a three-stage computation pipeline.""" + + @Flow.model + def stage1(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def stage2( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value * multiplier) + + @Flow.model + def stage3( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + offset: int = 0, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value + offset) + + # Build pipeline + s1 = stage1(base=100) + s2 = stage2(input_data=s1, multiplier=2) + s3 = stage3(input_data=s2, offset=50) + + ctx = SimpleContext(value=10) + result = s3(ctx) + + # s1: 10 + 100 = 110 + # s2: 110 * 2 = 220 + # s3: 220 + 50 = 270 + self.assertEqual(result.value, 270) + + def test_diamond_dependency_pattern(self): + """Test diamond-shaped dependency pattern.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def branch_a( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + @Flow.model + def branch_b( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + @Flow.model + def merger( + context: SimpleContext, + a: DepOf[..., GenericResult[int]], + b: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=a.value + b.value) + + src = source() + a = branch_a(data=src) + b = branch_b(data=src) + merge = merger(a=a, b=b) + + ctx = SimpleContext(value=10) + result = merge(ctx) + + # source: 10 + # branch_a: 10 * 2 = 20 + # branch_b: 10 + 100 = 110 + # merger: 20 + 110 = 130 + self.assertEqual(result.value, 130) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestFlowModelIntegration(TestCase): + """Integration tests for Flow.model with ccflow infrastructure.""" + + def test_registry_integration(self): + """Test that Flow.model models work with ModelRegistry.""" + + @Flow.model + def registrable_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=context.value + value) + + model = registrable_model(value=100) + + registry = ModelRegistry.root().clear() + registry.add("test_model", model) + + retrieved = registry["test_model"] + self.assertEqual(retrieved, model) + + result = retrieved(SimpleContext(value=10)) + self.assertEqual(result.value, 110) + + def test_serialization_dump(self): + """Test that generated models can be serialized.""" + + @Flow.model + def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + return GenericResult(value=value) + + model = serializable_model(value=100) + dumped = model.model_dump(mode="python") + + self.assertIn("value", dumped) + self.assertEqual(dumped["value"], 100) + self.assertIn("type_", dumped) + + def test_pickle_roundtrip(self): + """Test cloudpickle serialization of generated models.""" + + @Flow.model + def pickleable_model(context: SimpleContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.value * factor) + + model = pickleable_model(factor=3) + + # Cloudpickle roundtrip (standard pickle won't work for local classes) + pickled = rcpdumps(model, protocol=5) + restored = rcploads(pickled) + + result = restored(SimpleContext(value=10)) + self.assertEqual(result.value, 30) + + def test_mix_with_manual_callable_model(self): + """Test mixing Flow.model with manually defined CallableModel.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def generated_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + manual = ManualModel(offset=50) + generated = generated_consumer(data=manual, multiplier=2) + + result = generated(SimpleContext(value=10)) + # manual: 10 + 50 = 60 + # generated: 60 * 2 = 120 + self.assertEqual(result.value, 120) + + +# ============================================================================= +# Error Case Tests +# ============================================================================= + + +class TestFlowModelErrors(TestCase): + """Error case tests for Flow.model.""" + + def test_missing_return_type(self): + """Test error when return type annotation is missing.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_return(context: SimpleContext): + return GenericResult(value=1) + + self.assertIn("return type annotation", str(cm.exception)) + + def test_non_result_return_type(self): + """Test error when return type is not ResultBase subclass.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def bad_return(context: SimpleContext) -> int: + return 42 + + self.assertIn("ResultBase", str(cm.exception)) + + def test_missing_context_and_context_args(self): + """Test error when neither context param nor context_args provided.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_context(value: int) -> GenericResult[int]: + return GenericResult(value=value) + + self.assertIn("context", str(cm.exception)) + + def test_invalid_context_arg(self): + """Test error when context_args refers to non-existent parameter.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["nonexistent"]) + def bad_context_args(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("nonexistent", str(cm.exception)) + + def test_context_arg_without_annotation(self): + """Test error when context_arg parameter lacks type annotation.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["x"]) + def untyped_context_arg(x) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("type annotation", str(cm.exception)) + + +# ============================================================================= +# Dep and DepOf Tests +# ============================================================================= + + +class TestDepAndDepOf(TestCase): + """Tests for Dep and DepOf classes.""" + + def test_depof_creates_annotated(self): + """Test that DepOf[..., T] creates Annotated[Union[T, CallableModel], Dep()].""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep() + self.assertIsInstance(args[1], Dep) + self.assertIsNone(args[1].context_type) # ... means inherit from parent + + def test_depof_with_generic_type(self): + """Test DepOf with nested generic types.""" + from typing import List as TypingList, Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[TypingList[str]]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[TypingList[str]], union_args) + self.assertIn(CallableModel, union_args) + + def test_depof_with_context_type(self): + """Test DepOf[ContextType, ResultType] syntax.""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[SimpleContext, GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep with context_type + self.assertIsInstance(args[1], Dep) + self.assertEqual(args[1].context_type, SimpleContext) + + def test_extract_dep_with_annotated(self): + """Test extract_dep with Annotated type.""" + from ccflow.dep import extract_dep + + dep = Dep(context_type=SimpleContext) + annotation = Annotated[GenericResult[int], dep] + + base_type, extracted_dep = extract_dep(annotation) + self.assertEqual(base_type, GenericResult[int]) + self.assertEqual(extracted_dep, dep) + + def test_extract_dep_with_depof(self): + """Test extract_dep with DepOf type.""" + from typing import Union as TypingUnion, get_args, get_origin + + from ccflow.dep import extract_dep + + annotation = DepOf[..., GenericResult[str]] + base_type, extracted_dep = extract_dep(annotation) + + # base_type is Union[ResultType, CallableModel] + self.assertEqual(get_origin(base_type), TypingUnion) + union_args = get_args(base_type) + self.assertIn(GenericResult[str], union_args) + self.assertIn(CallableModel, union_args) + self.assertIsInstance(extracted_dep, Dep) + + def test_extract_dep_without_dep(self): + """Test extract_dep with regular type (no Dep).""" + from ccflow.dep import extract_dep + + base_type, extracted_dep = extract_dep(int) + self.assertEqual(base_type, int) + self.assertIsNone(extracted_dep) + + def test_extract_dep_annotated_without_dep(self): + """Test extract_dep with Annotated but no Dep marker.""" + from ccflow.dep import extract_dep + + annotation = Annotated[int, "some metadata"] + base_type, extracted_dep = extract_dep(annotation) + + # When no Dep marker is found, returns original annotation unchanged + self.assertEqual(base_type, annotation) + self.assertIsNone(extracted_dep) + + def test_is_compatible_type_simple(self): + """Test _is_compatible_type with simple types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(int, int)) + self.assertFalse(_is_compatible_type(int, str)) + self.assertTrue(_is_compatible_type(bool, int)) # bool is subclass of int + + def test_is_compatible_type_generic(self): + """Test _is_compatible_type with generic types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(GenericResult[int], GenericResult[int])) + self.assertFalse(_is_compatible_type(GenericResult[int], GenericResult[str])) + self.assertTrue(_is_compatible_type(GenericResult, GenericResult)) + + def test_is_compatible_type_none(self): + """Test _is_compatible_type with None.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(None, None)) + self.assertFalse(_is_compatible_type(None, int)) + self.assertFalse(_is_compatible_type(int, None)) + + def test_is_compatible_type_subclass(self): + """Test _is_compatible_type with subclasses.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(MyResult, ResultBase)) + self.assertFalse(_is_compatible_type(ResultBase, MyResult)) + + def test_dep_validate_dependency_success(self): + """Test Dep.validate_dependency with valid dependency.""" + + @Flow.model + def valid_dep(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + dep = Dep() + model = valid_dep() + + # Should not raise + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + def test_dep_validate_dependency_context_mismatch(self): + """Test Dep.validate_dependency with context type mismatch.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def other_dep(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + dep = Dep(context_type=SimpleContext) + model = other_dep() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("context_type", str(cm.exception)) + + def test_dep_validate_dependency_result_mismatch(self): + """Test Dep.validate_dependency with result type mismatch.""" + + @Flow.model + def wrong_result(context: SimpleContext) -> MyResult: + return MyResult(data="test") + + dep = Dep() + model = wrong_result() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("result_type", str(cm.exception)) + + def test_dep_validate_dependency_non_callable(self): + """Test Dep.validate_dependency with non-CallableModel value.""" + dep = Dep() + # Should not raise for non-CallableModel values + dep.validate_dependency(GenericResult(value=42), GenericResult[int], SimpleContext, "data") + dep.validate_dependency("string", GenericResult[int], SimpleContext, "data") + dep.validate_dependency(123, GenericResult[int], SimpleContext, "data") + + def test_dep_hash(self): + """Test Dep is hashable for use in sets/dicts.""" + dep1 = Dep() + dep2 = Dep(context_type=SimpleContext) + + # Should be hashable + dep_set = {dep1, dep2} + self.assertEqual(len(dep_set), 2) + + dep_dict = {dep1: "value1", dep2: "value2"} + self.assertEqual(dep_dict[dep1], "value1") + self.assertEqual(dep_dict[dep2], "value2") + + def test_dep_apply_with_transform(self): + """Test Dep.apply with transform function.""" + + def transform(ctx): + return ctx.model_copy(update={"value": ctx.value * 2}) + + dep = Dep(transform=transform) + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertEqual(result.value, 20) + + def test_dep_apply_without_transform(self): + """Test Dep.apply without transform (identity).""" + dep = Dep() + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertIs(result, ctx) + + def test_dep_repr(self): + """Test Dep string representation.""" + dep1 = Dep() + self.assertEqual(repr(dep1), "Dep()") + + dep2 = Dep(context_type=SimpleContext) + self.assertIn("SimpleContext", repr(dep2)) + + dep3 = Dep(transform=lambda x: x) + self.assertIn("transform=", repr(dep3)) + + def test_dep_equality(self): + """Test Dep equality comparison.""" + dep1 = Dep() + dep2 = Dep() + dep3 = Dep(context_type=SimpleContext) + + # Note: Two Dep() instances with no arguments are equal + self.assertEqual(dep1, dep2) + self.assertNotEqual(dep1, dep3) + + +# ============================================================================= +# Validation Tests +# ============================================================================= + + +class TestFlowModelValidation(TestCase): + """Tests for dependency validation in Flow.model.""" + + def test_context_type_validation(self): + """Test that context_type mismatch is detected.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def simple_loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def other_loader(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(context_type=SimpleContext)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + # Should work with matching context + load1 = simple_loader() + consume1 = consumer(data=load1) + self.assertIsNotNone(consume1) + + # Should fail with mismatched context + load2 = other_loader() + with self.assertRaises((TypeError, ValidationError)): + consumer(data=load2) + + +# ============================================================================= +# Hydra Integration Tests +# ============================================================================= + + +# Define Flow.model functions at module level for Hydra to find them +@Flow.model +def hydra_basic_model(context: SimpleContext, value: int, name: str = "default") -> GenericResult[str]: + """Module-level model for Hydra testing.""" + return GenericResult(value=f"{name}:{context.value + value}") + + +# --- Additional module-level fixtures for Hydra YAML tests --- + + +@Flow.model +def basic_loader(context: SimpleContext, source: str, multiplier: int = 1) -> GenericResult[int]: + """Basic loader that multiplies context value by multiplier.""" + return GenericResult(value=context.value * multiplier) + + +@Flow.model +def string_processor(context: SimpleContext, prefix: str, suffix: str = "") -> GenericResult[str]: + """Process context value into a string with prefix and suffix.""" + return GenericResult(value=f"{prefix}{context.value}{suffix}") + + +@Flow.model +def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: + """Source that provides base data.""" + return GenericResult(value=context.value + base_value) + + +@Flow.model +def data_transformer( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 2, +) -> GenericResult[int]: + """Transform data by multiplying with factor.""" + return GenericResult(value=source.value * factor) + + +@Flow.model +def data_aggregator( + context: SimpleContext, + input_a: DepOf[..., GenericResult[int]], + input_b: DepOf[..., GenericResult[int]], + operation: str = "add", +) -> GenericResult[int]: + """Aggregate two inputs.""" + if operation == "add": + return GenericResult(value=input_a.value + input_b.value) + elif operation == "multiply": + return GenericResult(value=input_a.value * input_b.value) + else: + return GenericResult(value=input_a.value - input_b.value) + + +@Flow.model +def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: + """First stage of pipeline.""" + return GenericResult(value=context.value + initial) + + +@Flow.model +def pipeline_stage2( + context: SimpleContext, + stage1_output: DepOf[..., GenericResult[int]], + multiplier: int = 2, +) -> GenericResult[int]: + """Second stage of pipeline.""" + return GenericResult(value=stage1_output.value * multiplier) + + +@Flow.model +def pipeline_stage3( + context: SimpleContext, + stage2_output: DepOf[..., GenericResult[int]], + offset: int = 0, +) -> GenericResult[int]: + """Third stage of pipeline.""" + return GenericResult(value=stage2_output.value + offset) + + +def lookback_one_day(ctx: DateRangeContext) -> DateRangeContext: + """Transform that extends start_date back by one day.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + +@Flow.model +def date_range_loader( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> GenericResult[str]: + """Load data for a date range.""" + return GenericResult(value=f"{source}:{context.start_date} to {context.end_date}") + + +@Flow.model +def date_range_processor( + context: DateRangeContext, + raw_data: Annotated[GenericResult[str], Dep(transform=lookback_one_day)], + normalize: bool = False, +) -> GenericResult[str]: + """Process date range data with lookback.""" + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data.value}") + + +@Flow.model +def hydra_default_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + """Module-level model with defaults for Hydra testing.""" + return GenericResult(value=context.value + value) + + +@Flow.model +def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: + """Source model for dependency testing.""" + return GenericResult(value=context.value * base) + + +@Flow.model +def hydra_consumer_model( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 1, +) -> GenericResult[int]: + """Consumer model for dependency testing.""" + return GenericResult(value=source.value * factor) + + +# --- context_args fixtures for Hydra testing --- + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + """Loader using context_args with DateRangeContext.""" + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_processor( + start_date: date, + end_date: date, + data: DepOf[..., GenericResult[str]], + prefix: str = "processed", +) -> GenericResult[str]: + """Processor using context_args with dependency.""" + return GenericResult(value=f"{prefix}:{data.value}") + + +class TestFlowModelHydra(TestCase): + """Tests for Flow.model with Hydra configuration.""" + + def test_hydra_instantiate_basic(self): + """Test that Flow.model factory can be instantiated via Hydra.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create config that references the factory function by module path + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_basic_model", + "value": 100, + "name": "test", + } + ) + + # Instantiate via Hydra + model = instantiate(cfg) + + self.assertIsInstance(model, CallableModel) + result = model(SimpleContext(value=10)) + self.assertEqual(result.value, "test:110") + + def test_hydra_instantiate_with_defaults(self): + """Test Hydra instantiation using default parameter values.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_default_model", + # Not specifying value, should use default + } + ) + + model = instantiate(cfg) + result = model(SimpleContext(value=8)) + self.assertEqual(result.value, 50) + + def test_hydra_instantiate_with_dependency(self): + """Test Hydra instantiation with dependencies.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create nested config + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_consumer_model", + "source": { + "_target_": "ccflow.tests.test_flow_model.hydra_source_model", + "base": 10, + }, + "factor": 2, + } + ) + + model = instantiate(cfg) + + result = model(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 * 2 = 100 + self.assertEqual(result.value, 100) + + +# ============================================================================= +# Class-based CallableModel with Auto-Resolution Tests +# ============================================================================= + + +class TestClassBasedDepResolution(TestCase): + """Tests for auto-resolution of DepOf fields in class-based CallableModels. + + Key pattern: Fields use DepOf annotation, __call__ only takes context, + and resolved values are accessed via self.field_name during __call__. + """ + + def test_class_based_auto_resolve_basic(self): + """Test that DepOf fields are auto-resolved and accessible via self.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + # DepOf expands to Annotated[Union[ResultType, CallableModel], Dep()] + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # Access resolved value via self.source + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_class_based_with_custom_transform(self): + """Test that custom __deps__ transform is used.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + offset: int = 100 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + self.offset) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Apply custom transform + transformed_ctx = SimpleContext(value=context.value + 5) + return [(self.source, [transformed_ctx])] + + src = data_source() + consumer = Consumer(source=src, offset=1) + + result = consumer(SimpleContext(value=5)) + # transformed context: 5 + 5 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_with_annotated_transform(self): + """Test that Dep transform is used when field not in __deps__.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + def double_value(ctx: SimpleContext) -> SimpleContext: + return SimpleContext(value=ctx.value * 2) + + class Consumer(CallableModel): + source: Annotated[DepOf[..., GenericResult[int]], Dep(transform=double_value)] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] # Empty - uses Dep annotation transform from field + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_multiple_deps(self): + """Test auto-resolution with multiple dependencies.""" + + @Flow.model + def source_a(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def source_b(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + class Aggregator(CallableModel): + a: DepOf[..., GenericResult[int]] + b: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.a.value + self.b.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.a, [context]), (self.b, [context])] + + agg = Aggregator(a=source_a(), b=source_b()) + + result = agg(SimpleContext(value=10)) + # a: 10, b: 20, aggregator: 30 + self.assertEqual(result.value, 30) + + def test_class_based_deps_with_instance_field_access(self): + """Test that __deps__ can access instance fields for configurable transforms. + + This is the key advantage of class-based models over @Flow.model: + transforms can use instance fields like window size. + """ + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + lookback: int = 5 # Configurable instance field + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Access self.lookback in transform - this is why we use class-based! + transformed = SimpleContext(value=context.value + self.lookback) + return [(self.source, [transformed])] + + src = data_source() + consumer = Consumer(source=src, lookback=10) + + result = consumer(SimpleContext(value=5)) + # transformed: 5 + 10 = 15 + # source: 15 + # consumer: 15 * 2 = 30 + self.assertEqual(result.value, 30) + + def test_class_based_with_direct_value(self): + """Test that DepOf fields can accept pre-resolved values.""" + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # No deps when source is already resolved + return [] + + # Pass direct value instead of CallableModel + consumer = Consumer(source=GenericResult(value=100)) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_class_based_no_double_call(self): + """Test that dependencies are not called twice during DepOf resolution. + + This verifies that the auto-resolution mechanism doesn't accidentally + evaluate the same dependency multiple times. + """ + call_counts = {"source": 0} + + @Flow.model + def counting_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.data.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = counting_source() + consumer = Consumer(data=src) + + # Call consumer - source should only be called once + result = consumer(SimpleContext(value=5)) + + self.assertEqual(result.value, 51) # 5 * 10 + 1 + self.assertEqual(call_counts["source"], 1, "Source should only be called once") + + def test_class_based_nested_depof_no_double_call(self): + """Test nested DepOf chain (A -> B -> C) has no double-calls at any layer. + + This tests a 3-layer dependency chain where: + - layer_c is the leaf (no dependencies) + - layer_b depends on layer_c + - layer_a depends on layer_b + + Each layer should be called exactly once. + """ + call_counts = {"layer_a": 0, "layer_b": 0, "layer_c": 0} + + # Layer C: leaf node (no dependencies) + @Flow.model + def layer_c(context: SimpleContext) -> GenericResult[int]: + call_counts["layer_c"] += 1 + return GenericResult(value=context.value) + + # Layer B: depends on layer_c + class LayerB(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_b"] += 1 + return GenericResult(value=self.source.value * 10) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Layer A: depends on layer_b + class LayerA(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_a"] += 1 + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Build the chain: A -> B -> C + c = layer_c() + b = LayerB(source=c) + a = LayerA(source=b) + + # Call layer_a - each layer should be called exactly once + result = a(SimpleContext(value=5)) + + # Verify result: C returns 5, B returns 5*10=50, A returns 50+1=51 + self.assertEqual(result.value, 51) + + # Verify each layer called exactly once + self.assertEqual(call_counts["layer_c"], 1, "layer_c should be called exactly once") + self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") + self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") + + def test_flow_model_uses_unified_resolution_path(self): + """Test that @Flow.model uses the same resolution path as class-based CallableModel. + + This verifies the consolidation of resolution logic - both @Flow.model and + class-based models should use _resolve_deps_and_call in callable.py. + """ + call_counts = {"source": 0, "decorator_model": 0, "class_model": 0} + + @Flow.model + def shared_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 2) + + # @Flow.model consumer + @Flow.model + def decorator_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + call_counts["decorator_model"] += 1 + return GenericResult(value=data.value + 100) + + # Class-based consumer (same logic) + class ClassConsumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["class_model"] += 1 + return GenericResult(value=self.data.value + 100) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + # Test both consumers with the same source + src = shared_source() + dec_consumer = decorator_consumer(data=src) + cls_consumer = ClassConsumer(data=src) + + ctx = SimpleContext(value=10) + + # Both should produce the same result + dec_result = dec_consumer(ctx) + cls_result = cls_consumer(ctx) + + self.assertEqual(dec_result.value, cls_result.value) + self.assertEqual(dec_result.value, 120) # 10 * 2 + 100 + + # Source should be called exactly twice (once per consumer) + self.assertEqual(call_counts["source"], 2) + self.assertEqual(call_counts["decorator_model"], 1) + self.assertEqual(call_counts["class_model"], 1) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..661ac4f --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,437 @@ +"""Hydra integration tests for Flow.model. + +These tests verify that Flow.model decorated functions work correctly when +loaded from YAML configuration files using ModelRegistry.load_config_from_path(). + +Key feature: Registry name references (e.g., `source: flow_source`) ensure the same +object instance is shared across all consumers. +""" + +from datetime import date +from pathlib import Path +from unittest import TestCase + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, GenericResult, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +class TestFlowModelHydraYAML(TestCase): + """Tests loading Flow.model from YAML config files using ModelRegistry.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_basic_loader_from_yaml(self): + """Test basic model instantiation from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + + self.assertIsInstance(loader, CallableModel) + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 50) # 10 * 5 + + def test_string_processor_from_yaml(self): + """Test string processor model from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_processor"] + + ctx = SimpleContext(value=42) + result = processor(ctx) + self.assertEqual(result.value, "value=42!") + + def test_two_stage_pipeline_from_yaml(self): + """Test two-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + self.assertIsInstance(transformer, CallableModel) + + ctx = SimpleContext(value=5) + result = transformer(ctx) + # flow_source: 5 + 100 = 105 + # flow_transformer: 105 * 3 = 315 + self.assertEqual(result.value, 315) + + def test_three_stage_pipeline_from_yaml(self): + """Test three-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage3 = r["flow_stage3"] + + ctx = SimpleContext(value=10) + result = stage3(ctx) + # stage1: 10 + 10 = 20 + # stage2: 20 * 2 = 40 + # stage3: 40 + 50 = 90 + self.assertEqual(result.value, 90) + + def test_diamond_dependency_from_yaml(self): + """Test diamond dependency pattern from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + aggregator = r["diamond_aggregator"] + + ctx = SimpleContext(value=10) + result = aggregator(ctx) + # source: 10 + 10 = 20 + # branch_a: 20 * 2 = 40 + # branch_b: 20 * 5 = 100 + # aggregator: 40 + 100 = 140 + self.assertEqual(result.value, 140) + + def test_date_range_pipeline_from_yaml(self): + """Test DateRangeContext pipeline with transforms from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # The transform extends start_date back by one day + self.assertIn("2024-01-09", result.value) + self.assertIn("normalized:", result.value) + + def test_context_args_from_yaml(self): + """Test context_args model from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + + self.assertIsInstance(loader, CallableModel) + # context_args models use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "data_source:2024-01-01 to 2024-01-31") + + def test_context_args_pipeline_from_yaml(self): + """Test context_args pipeline with dependencies from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["ctx_args_processor"] + + ctx = DateRangeContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + result = processor(ctx) + # loader: "data_source:2024-03-01 to 2024-03-31" + # processor: "output:data_source:2024-03-01 to 2024-03-31" + self.assertEqual(result.value, "output:data_source:2024-03-01 to 2024-03-31") + + def test_context_args_shares_instance(self): + """Test that context_args pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + processor = r["ctx_args_processor"] + + self.assertIs(processor.data, loader) + + +class TestFlowModelHydraInstanceSharing(TestCase): + """Tests that registry name references share the same object instance.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_pipeline_shares_instance(self): + """Test that pipeline stages share the same dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + source = r["flow_source"] + + self.assertIs(transformer.source, source) + + def test_three_stage_pipeline_shares_instances(self): + """Test that three-stage pipeline shares instances correctly.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage1 = r["flow_stage1"] + stage2 = r["flow_stage2"] + stage3 = r["flow_stage3"] + + self.assertIs(stage2.stage1_output, stage1) + self.assertIs(stage3.stage2_output, stage2) + + def test_diamond_pattern_shares_source_instance(self): + """Test that diamond pattern branches share the same source instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + source = r["diamond_source"] + branch_a = r["diamond_branch_a"] + branch_b = r["diamond_branch_b"] + aggregator = r["diamond_aggregator"] + + # Both branches should share the SAME source instance + self.assertIs(branch_a.source, source) + self.assertIs(branch_b.source, source) + self.assertIs(branch_a.source, branch_b.source) + + self.assertIs(aggregator.input_a, branch_a) + self.assertIs(aggregator.input_b, branch_b) + + def test_date_range_shares_instance(self): + """Test that date range pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_date_loader"] + processor = r["flow_date_processor"] + + self.assertIs(processor.raw_data, loader) + + +class TestFlowModelHydraOmegaConf(TestCase): + """Tests using OmegaConf.create for dynamic config creation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_instantiate_with_omegaconf(self): + """Test instantiation using OmegaConf.create via ModelRegistry.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=3) + result = loader(ctx) + self.assertEqual(result.value, 21) # 3 * 7 + + def test_nested_deps_with_omegaconf(self): + """Test nested dependencies using OmegaConf with registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 50, + }, + "transformer": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 4, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + transformer = r["transformer"] + + ctx = SimpleContext(value=10) + result = transformer(ctx) + # source: 10 + 50 = 60 + # transformer: 60 * 4 = 240 + self.assertEqual(result.value, 240) + + self.assertIs(transformer.source, r["source"]) + + def test_diamond_with_omegaconf(self): + """Test diamond pattern with OmegaConf using registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 10, + }, + "branch_a": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 2, + }, + "branch_b": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 3, + }, + "aggregator": { + "_target_": "ccflow.tests.test_flow_model.data_aggregator", + "input_a": "branch_a", + "input_b": "branch_b", + "operation": "multiply", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + aggregator = r["aggregator"] + + ctx = SimpleContext(value=5) + result = aggregator(ctx) + # source: 5 + 10 = 15 + # branch_a: 15 * 2 = 30 + # branch_b: 15 * 3 = 45 + # aggregator: 30 * 45 = 1350 + self.assertEqual(result.value, 1350) + + # Verify SAME source instance is shared + self.assertIs(r["branch_a"].source, r["source"]) + self.assertIs(r["branch_b"].source, r["source"]) + + +class TestFlowModelHydraDefaults(TestCase): + """Tests that default parameter values work with Hydra.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_defaults_used_when_not_specified(self): + """Test that default values are used when not in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 10) # 10 * 1 (default) + + def test_defaults_can_be_overridden(self): + """Test that defaults can be overridden in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + "multiplier": 100, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 1000) # 10 * 100 + + +class TestFlowModelHydraModelProperties(TestCase): + """Tests that model properties are correct after Hydra instantiation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_context_type_property(self): + """Test that context_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.context_type, SimpleContext) + + def test_result_type_property(self): + """Test that result_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.result_type, GenericResult[int]) + + def test_deps_method_works(self): + """Test that __deps__ method works after Hydra instantiation.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + ctx = SimpleContext(value=5) + deps = transformer.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIsInstance(deps[0][0], CallableModel) + self.assertEqual(deps[0][1], [ctx]) + self.assertIs(deps[0][0], r["flow_source"]) + + +class TestFlowModelHydraDateRangeTransforms(TestCase): + """Tests transforms with DateRangeContext from Hydra config.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_transform_applied_from_yaml(self): + """Test that transform is applied when loaded from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + deps = processor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + dep_model, dep_contexts = deps[0] + + # The transform should extend start_date back by one day + transformed_ctx = dep_contexts[0] + self.assertEqual(transformed_ctx.start_date, date(2024, 1, 9)) + self.assertEqual(transformed_ctx.end_date, date(2024, 1, 31)) + + self.assertIs(dep_model, r["flow_date_loader"]) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..a89d8f8 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,121 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +### Flow.model Decorator + +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s using plain Python functions instead of classes. It automatically generates a `CallableModel` class with proper `__call__` and `__deps__` methods. + +**Basic Example:** + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + # Your data loading logic here + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + +# Create model instance +loader = load_data(source="my_database") + +# Execute with context +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +**Composing Dependencies with `Dep` and `DepOf`:** + +Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. + +> **Tip:** If your function doesn't use the context directly (only passes it to dependencies), use `_` as the parameter name to signal this: `def my_func(_: DateRangeContext, data: DepOf[..., ResultType])`. This is a Python convention for intentionally unused parameters. + +```python +from datetime import date, timedelta +from typing import Annotated +from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"records": [1, 2, 3]}) + +@Flow.model +def transform_data( + _: DateRangeContext, # Context passed to dependency, not used directly + raw_data: Annotated[GenericResult[dict], Dep( + # Transform context to fetch one extra day for lookback + transform=lambda ctx: ctx.model_copy(update={ + "start_date": ctx.start_date - timedelta(days=1) + }) + )] +) -> GenericResult[dict]: + # raw_data.value contains the resolved result from load_data + return GenericResult(value={"transformed": raw_data.value["records"]}) + +# Or use DepOf shorthand (no transform needed): +@Flow.model +def aggregate_data( + _: DateRangeContext, # Context passed to dependency, not used directly + transformed: DepOf[..., GenericResult[dict]] # Shorthand for Annotated[T, Dep()] +) -> GenericResult[dict]: + return GenericResult(value={"count": len(transformed.value["transformed"])}) + +# Build the pipeline +data = load_data(source="my_database") +transformed = transform_data(raw_data=data) +aggregated = aggregate_data(transformed=transformed) + +# Execute - dependencies are automatically resolved +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = aggregated(ctx) +``` + +**Hydra/YAML Configuration:** + +`Flow.model` decorated functions work seamlessly with Hydra configuration and the `ModelRegistry`: + +```yaml +# config.yaml +data: + _target_: mymodule.load_data + source: my_database + +transformed: + _target_: mymodule.transform_data + raw_data: data # Reference by registry name (same instance is shared) + +aggregated: + _target_: mymodule.aggregate_data + transformed: transformed # Reference by registry name +``` + +When loaded via `ModelRegistry.load_config()`, references by name ensure the same object instance is shared across all consumers. + +**Auto-Unpacked Context with `context_args`:** + +Instead of taking an explicit `context` parameter, you can use `context_args` to automatically unpack context fields as function parameters. This is useful when you want cleaner function signatures: + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +# Instead of: def load_data(context: DateRangeContext, source: str) +# Use context_args to unpack the context fields directly: +@Flow.model(context_args=["start_date", "end_date"]) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +# The decorator infers DateRangeContext from the parameter types +loader = load_data(source="my_database") +assert loader.context_type == DateRangeContext + +# Execute with context as usual +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" +``` + +The `context_args` parameter specifies which function parameters should be extracted from the context. The framework automatically determines the context type based on the parameter type annotations. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..c3d12d1 --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +"""Example demonstrating Flow.model decorator and class-based CallableModel. + +This example shows: +- Flow.model for simple functions with minimal boilerplate +- Context transforms with Dep annotations +- Class-based CallableModel for complex cases needing instance field access +""" + +from datetime import date, timedelta +from typing import Annotated + +from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult + + +# ============================================================================= +# Example 1: Basic Flow.model - No more boilerplate classes! +# ============================================================================= + +@Flow.model +def load_records(context: DateRangeContext, source: str, limit: int = 100) -> GenericResult[list]: + """Load records from a data source for the given date range.""" + print(f" Loading from '{source}' for {context.start_date} to {context.end_date} (limit={limit})") + return GenericResult(value=[ + {"id": i, "date": str(context.start_date), "value": i * 10} + for i in range(min(limit, 5)) + ]) + + +# ============================================================================= +# Example 2: Dependencies with DepOf - Automatic dependency resolution +# ============================================================================= + +@Flow.model +def compute_totals( + _: DateRangeContext, # Context passed to dependency, not used directly here + records: DepOf[..., GenericResult[list]], +) -> GenericResult[dict]: + """Compute totals from loaded records.""" + total = sum(r["value"] for r in records.value) + count = len(records.value) + print(f" Computing totals: {count} records, total={total}") + return GenericResult(value={"total": total, "count": count}) + + +# ============================================================================= +# Example 3: Simple Transform with Flow.model +# When the transform is a fixed function, Flow.model works great +# ============================================================================= + +def lookback_7_days(ctx: DateRangeContext) -> DateRangeContext: + """Fixed transform that extends the date range back by 7 days.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=7)}) + + +@Flow.model +def compute_weekly_average( + _: DateRangeContext, + records: Annotated[GenericResult[list], Dep(transform=lookback_7_days)], +) -> GenericResult[float]: + """Compute average using fixed 7-day lookback.""" + values = [r["value"] for r in records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing weekly average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + +# ============================================================================= +# Example 4: Class-based CallableModel with Configurable Transform +# When the transform needs access to instance fields (like window size), +# use a class-based approach with auto-resolution +# ============================================================================= + +class ComputeMovingAverage(CallableModel): + """Compute moving average with configurable lookback window. + + This demonstrates: + - Field uses DepOf annotation: accepts either result or CallableModel + - Instance field (window) accessible in __deps__ for custom transforms + - Auto-resolution: self.records returns resolved value during __call__ + """ + + records: DepOf[..., GenericResult[list]] + window: int = 7 # Configurable lookback window + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + """Compute the moving average - self.records is already resolved.""" + values = [r["value"] for r in self.records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + """Define dependencies with transform that uses self.window.""" + # This is where we can access instance fields! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + + +# ============================================================================= +# Example 5: Multi-stage pipeline - Composing models together +# ============================================================================= + +@Flow.model +def generate_report( + context: DateRangeContext, + totals: DepOf[..., GenericResult[dict]], + moving_avg: DepOf[..., GenericResult[float]], + report_name: str = "Daily Report", +) -> GenericResult[str]: + """Generate a report combining multiple data sources.""" + report = f""" +{report_name} +{'=' * len(report_name)} +Date Range: {context.start_date} to {context.end_date} +Total Value: {totals.value['total']} +Record Count: {totals.value['count']} +Moving Avg: {moving_avg.value:.2f} +""" + return GenericResult(value=report.strip()) + + +# ============================================================================= +# Example 6: Using context_args for cleaner signatures +# ============================================================================= + +@Flow.model(context_args=["start_date", "end_date"]) +def fetch_metadata(start_date: date, end_date: date, category: str) -> GenericResult[dict]: + """Fetch metadata - note how start_date/end_date are direct parameters.""" + print(f" Fetching metadata for '{category}' from {start_date} to {end_date}") + return GenericResult(value={ + "category": category, + "days": (end_date - start_date).days, + "generated_at": str(date.today()), + }) + + +# ============================================================================= +# Main: Build and execute the pipeline +# ============================================================================= + +def main(): + print("=" * 60) + print("Flow.model Example - Simplified CallableModel Creation") + print("=" * 60) + + ctx = DateRangeContext( + start_date=date(2024, 1, 15), + end_date=date(2024, 1, 31) + ) + + # --- Example 1: Basic model --- + print("\n[1] Basic Flow.model:") + loader = load_records(source="main_db", limit=5) + result = loader(ctx) + print(f" Result: {result.value}") + + # --- Example 2: Simple dependency chain --- + print("\n[2] Dependency chain (loader -> totals):") + loader = load_records(source="main_db") + totals = compute_totals(records=loader) + result = totals(ctx) + print(f" Result: {result.value}") + + # --- Example 3: Fixed transform with Flow.model --- + print("\n[3] Fixed transform (7-day lookback with Flow.model):") + loader = load_records(source="main_db") + weekly_avg = compute_weekly_average(records=loader) + result = weekly_avg(ctx) + print(f" Result: {result.value}") + + # --- Example 4: Configurable transform with class-based model --- + print("\n[4] Configurable transform (class-based with auto-resolution):") + loader = load_records(source="main_db") + + # 14-day window + moving_avg_14 = ComputeMovingAverage(records=loader, window=14) + result = moving_avg_14(ctx) + print(f" 14-day result: {result.value}") + + # 30-day window - same loader, different window + moving_avg_30 = ComputeMovingAverage(records=loader, window=30) + result = moving_avg_30(ctx) + print(f" 30-day result: {result.value}") + + # --- Example 5: Full pipeline --- + print("\n[5] Full pipeline (mixing Flow.model and class-based):") + loader = load_records(source="analytics_db") + totals = compute_totals(records=loader) + moving_avg = ComputeMovingAverage(records=loader, window=7) + report = generate_report( + totals=totals, + moving_avg=moving_avg, + report_name="Analytics Summary" + ) + result = report(ctx) + print(result.value) + + # --- Example 6: context_args --- + print("\n[6] Using context_args (auto-unpacked context):") + metadata = fetch_metadata(category="sales") + result = metadata(ctx) + print(f" Result: {result.value}") + + # --- Bonus: Inspecting models --- + print("\n[Bonus] Inspecting models:") + print(f" load_records.context_type = {loader.context_type.__name__}") + print(f" ComputeMovingAverage uses __deps__ for custom transforms") + deps = moving_avg.__deps__(ctx) + for dep_model, dep_contexts in deps: + print(f" - Dependency context start: {dep_contexts[0].start_date} (lookback applied)") + + +if __name__ == "__main__": + main() From 970b4ab24f0f1d2c86ad1a0d3aa652bece4c9187 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 21:09:54 -0500 Subject: [PATCH 04/17] Inside CallableModel, force calling resolve on DepOf to not do hacky switching out attributes at runtime Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 72 +++++++++++++++++++++++++----- ccflow/dep.py | 2 +- ccflow/flow_model.py | 12 ++++- ccflow/tests/test_flow_model.py | 78 ++++++++++++++++++++++++++++----- examples/flow_model_example.py | 8 ++-- 5 files changed, 144 insertions(+), 28 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 5296bfe..8f6adfe 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,6 +14,7 @@ import abc import inspect import logging +from contextvars import ContextVar from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -47,6 +48,8 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + # Note: resolve() is intentionally not in __all__ to avoid namespace pollution. + # Users who need it can import explicitly: from ccflow.callable import resolve ) log = logging.getLogger(__name__) @@ -239,10 +242,60 @@ def _wrap_with_dep_resolution(fn): return fn +# Context variable for storing resolved dependency values during __call__ +# Maps id(callable_model) -> resolved_value +_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={}) + +# TypeVar for resolve() function to enable proper type inference +_T = TypeVar("_T") + + +def resolve(dep: Union[_T, "_CallableModel"]) -> _T: + """Access the resolved value of a DepOf dependency during __call__. + + This function is used inside a CallableModel's __call__ method to get + the resolved value of a dependency field. It provides proper type inference - + if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`. + + Args: + dep: The dependency field value (either a CallableModel or already-resolved value) + + Returns: + The resolved value. If dep is already a resolved value (not a CallableModel), + returns it unchanged. + + Raises: + RuntimeError: If called outside of __call__ or if the dependency wasn't resolved. + + Example: + class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() provides proper type inference + data = resolve(self.data) # type: GenericResult[int] + return GenericResult(value=data.value + 1) + """ + # If it's not a CallableModel, it's already a resolved value - pass through + if not isinstance(dep, _CallableModel): + return dep # type: ignore[return-value] + + # Look up in context var + store = _resolved_deps.get() + dep_id = id(dep) + if dep_id not in store: + raise RuntimeError( + "resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel." + ) + return store[dep_id] + + def _resolve_deps_and_call(model, context, fn): """Resolve DepOf fields and call the function. This is called from ModelEvaluationContext.__call__ to handle dep resolution. + Resolved values are stored in a context variable and accessed via resolve(). Args: model: The CallableModel instance @@ -269,8 +322,8 @@ def _resolve_deps_and_call(model, context, fn): for dep_model, contexts in deps_result: dep_map[id(dep_model)] = (dep_model, contexts) - # Store original values and resolve - originals = {} + # Resolve dependencies and store in context var + resolved_values = {} for field_name, dep in dep_fields.items(): field_value = getattr(model, field_name, None) if field_value is None: @@ -280,8 +333,6 @@ def _resolve_deps_and_call(model, context, fn): if not isinstance(field_value, _CallableModel): continue # Already a resolved value, skip - originals[field_name] = field_value - # Check if this field is in __deps__ (for custom transforms) if id(field_value) in dep_map: dep_model, contexts = dep_map[id(field_value)] @@ -292,16 +343,17 @@ def _resolve_deps_and_call(model, context, fn): transformed_ctx = dep.apply(context) resolved = field_value(transformed_ctx) - # Temporarily set resolved value on model - object.__setattr__(model, field_name, resolved) + # Store resolved value keyed by the CallableModel's id + resolved_values[id(field_value)] = resolved + # Store in context var and call function + current_store = _resolved_deps.get() + new_store = {**current_store, **resolved_values} + token = _resolved_deps.set(new_store) try: - # Call original function return fn(model, context) finally: - # Restore original CallableModel values - for field_name, original_value in originals.items(): - object.__setattr__(model, field_name, original_value) + _resolved_deps.reset(token) class FlowOptions(BaseModel): diff --git a/ccflow/dep.py b/ccflow/dep.py index a7e0121..b57261e 100644 --- a/ccflow/dep.py +++ b/ccflow/dep.py @@ -154,7 +154,7 @@ class Dep: def __init__( self, - transform: Optional[Callable[[ContextBase], ContextBase]] = None, + transform: Optional[Callable[..., ContextBase]] = None, context_type: Optional[Type[ContextBase]] = None, ): """ diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 3a96886..25db69c 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -190,6 +190,9 @@ def decorator(fn: Callable) -> Callable: # Create the __call__ method def make_call_impl(): + # Import resolve here to avoid circular import at module level + from .callable import resolve + def __call__(self, context): # Build kwargs for the original function if use_context_args: @@ -199,9 +202,14 @@ def __call__(self, context): # Pass context directly (using actual parameter name: 'context' or '_') fn_kwargs = {ctx_param_name: context} - # Add model fields (deps are resolved by _resolve_deps_and_call in callable.py) + # Add model fields - use resolve() for dep fields to get resolved values for name in model_fields: - fn_kwargs[name] = getattr(self, name) + value = getattr(self, name) + if name in dep_fields: + # Use resolve() to get the resolved value from context var + fn_kwargs[name] = resolve(value) + else: + fn_kwargs[name] = value return fn(**fn_kwargs) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 75e0899..fe7e32e 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -18,6 +18,7 @@ ModelRegistry, ResultBase, ) +from ccflow.callable import resolve class SimpleContext(ContextBase): @@ -1153,7 +1154,7 @@ class TestClassBasedDepResolution(TestCase): """ def test_class_based_auto_resolve_basic(self): - """Test that DepOf fields are auto-resolved and accessible via self.""" + """Test that DepOf fields are auto-resolved and accessible via resolve().""" @Flow.model def data_source(context: SimpleContext) -> GenericResult[int]: @@ -1165,8 +1166,8 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - # Access resolved value via self.source - return GenericResult(value=self.source.value + 1) + # Access resolved value via resolve() + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1192,7 +1193,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + self.offset) + return GenericResult(value=resolve(self.source).value + self.offset) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1224,7 +1225,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + 1) + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1256,7 +1257,7 @@ class Aggregator(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.a.value + self.b.value) + return GenericResult(value=resolve(self.a).value + resolve(self.b).value) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1285,7 +1286,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value * 2) + return GenericResult(value=resolve(self.source).value * 2) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1310,7 +1311,8 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + context.value) + # resolve() passes through non-CallableModel values unchanged + return GenericResult(value=resolve(self.source).value + context.value) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1341,7 +1343,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.data.value + 1) + return GenericResult(value=resolve(self.data).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1381,7 +1383,7 @@ class LayerB(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["layer_b"] += 1 - return GenericResult(value=self.source.value * 10) + return GenericResult(value=resolve(self.source).value * 10) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1394,7 +1396,7 @@ class LayerA(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["layer_a"] += 1 - return GenericResult(value=self.source.value + 1) + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1416,6 +1418,58 @@ def __deps__(self, context: SimpleContext): self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") + def test_resolve_direct_value_passthrough(self): + """Test that resolve() passes through non-CallableModel values unchanged.""" + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # resolve() should return the GenericResult directly (pass-through) + resolved = resolve(self.data) + # Verify it's the actual GenericResult, not a CallableModel + assert isinstance(resolved, GenericResult) + return GenericResult(value=resolved.value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] + + # Pass a direct value, not a CallableModel + direct_result = GenericResult(value=42) + consumer = Consumer(data=direct_result) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 84) # 42 * 2 + + def test_resolve_outside_call_raises_error(self): + """Test that resolve() raises RuntimeError when called outside __call__.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = source() + consumer = Consumer(data=src) + + # Calling resolve() outside of __call__ should raise RuntimeError + with self.assertRaises(RuntimeError) as cm: + resolve(consumer.data) + + self.assertIn("resolve() can only be used inside __call__", str(cm.exception)) + def test_flow_model_uses_unified_resolution_path(self): """Test that @Flow.model uses the same resolution path as class-based CallableModel. @@ -1445,7 +1499,7 @@ class ClassConsumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["class_model"] += 1 - return GenericResult(value=self.data.value + 100) + return GenericResult(value=resolve(self.data).value + 100) @Flow.deps def __deps__(self, context: SimpleContext): diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index c3d12d1..e93d452 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -11,6 +11,7 @@ from typing import Annotated from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult +from ccflow.callable import resolve # ============================================================================= @@ -77,7 +78,7 @@ class ComputeMovingAverage(CallableModel): This demonstrates: - Field uses DepOf annotation: accepts either result or CallableModel - Instance field (window) accessible in __deps__ for custom transforms - - Auto-resolution: self.records returns resolved value during __call__ + - resolve() to access resolved dependency values during __call__ """ records: DepOf[..., GenericResult[list]] @@ -85,8 +86,9 @@ class ComputeMovingAverage(CallableModel): @Flow.call def __call__(self, context: DateRangeContext) -> GenericResult[float]: - """Compute the moving average - self.records is already resolved.""" - values = [r["value"] for r in self.records.value] + """Compute the moving average - use resolve() to get resolved value.""" + records = resolve(self.records) # Get the resolved GenericResult + values = [r["value"] for r in records.value] avg = sum(values) / len(values) if values else 0 print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") return GenericResult(value=avg) From 2696fcea9993be36470b026e6d70935dc1b8f550 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 22:00:04 -0500 Subject: [PATCH 05/17] High level design doc Signed-off-by: Nijat Khanbabayev --- docs/design/flow_model_design.md | 440 +++++++++++++++++++++++++++++++ 1 file changed, 440 insertions(+) create mode 100644 docs/design/flow_model_design.md diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md new file mode 100644 index 0000000..76d0eb7 --- /dev/null +++ b/docs/design/flow_model_design.md @@ -0,0 +1,440 @@ +# Flow.model and DepOf: Dependency Injection for CallableModel + +## Overview + +This document describes the `@Flow.model` decorator and `DepOf` annotation system for reducing boilerplate when creating `CallableModel` pipelines with dependencies. + +**Key features:** +- `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions +- `DepOf[ContextType, ResultType]` - Type annotation for dependency fields +- `resolve()` - Function to access resolved dependency values in class-based models + +## Quick Start + +### Pattern 1: `@Flow.model` (Recommended for Simple Cases) + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import Flow, DateRangeContext, GenericResult, DepOf + +@Flow.model +def load_records(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"count": 100, "date": str(context.start_date)}) + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: DepOf[..., GenericResult[dict]], # Dependency field +) -> GenericResult[float]: + # records is already resolved - just use it directly + return GenericResult(value=records.value["count"] * 0.05) + +# Build pipeline +loader = load_records(source="main_db") +stats = compute_stats(records=loader) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = stats(ctx) +``` + +### Pattern 2: Class-Based (For Complex Cases) + +Use class-based when you need **configurable transforms** that depend on instance fields: + +```python +from datetime import timedelta + +from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf +from ccflow.callable import resolve # Import resolve for class-based models + +class AggregateWithWindow(CallableModel): + """Aggregate records with configurable lookback window.""" + + records: DepOf[..., GenericResult[dict]] + window: int = 7 # Configurable instance field + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + # Use resolve() to get the resolved value + records = resolve(self.records) + return GenericResult(value=records.value["count"] / self.window) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + # Transform uses self.window - this is why we need class-based! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + +# Usage - different window sizes, same source +loader = load_records(source="main_db") +agg_7 = AggregateWithWindow(records=loader, window=7) +agg_30 = AggregateWithWindow(records=loader, window=30) +``` + +## When to Use Which Pattern + +| Use `@Flow.model` when... | Use Class-Based when... | +|--------------------------------|--------------------------------------| +| Simple transformations | Transforms depend on instance fields | +| Fixed context transforms | Need `self.field` in `__deps__` | +| Less boilerplate is priority | Full control over resolution | +| No custom `__deps__` logic | Complex dependency patterns | + +## Core Concepts + +### `DepOf[ContextType, ResultType]` + +Shorthand for declaring dependency fields that can accept either: +- A pre-computed value of `ResultType` +- A `CallableModel` that produces `ResultType` + +```python +# Inherit context type from parent model +data: DepOf[..., GenericResult[dict]] + +# Explicit context type +data: DepOf[DateRangeContext, GenericResult[dict]] + +# Equivalent to: +data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] +``` + +### `Dep(transform=..., context_type=...)` + +For transforms, use the full `Annotated` form: + +```python +from ccflow import Dep + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: Annotated[GenericResult[dict], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + return GenericResult(value=records.value["count"] * 0.05) +``` + +### `resolve()` Function + +**Only needed for class-based models.** Accesses the resolved value of a `DepOf` field during `__call__`. + +```python +from ccflow.callable import resolve + +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() returns the GenericResult, not the CallableModel + result = resolve(self.data) + return GenericResult(value=result.value + 1) +``` + +**Behavior:** +- Inside `__call__`: Returns the resolved value +- With direct values (not CallableModel): Returns unchanged (no-op) +- Outside `__call__`: Raises `RuntimeError` +- In `@Flow.model`: Not needed - values are passed as function arguments + +**Type inference:** +```python +data: DepOf[..., GenericResult[int]] +resolved = resolve(self.data) # Type: GenericResult[int] +``` + +## How Resolution Works + +### `@Flow.model` Resolution Flow + +1. User calls `model(context)` +2. Generated `__call__` invokes `_resolve_deps_and_call()` +3. For each `DepOf` field containing a `CallableModel`: + - Apply transform (if any) + - Call the dependency + - Store resolved value in context variable +4. Generated `__call__` retrieves resolved values via `resolve()` +5. Original function receives resolved values as arguments + +### Class-Based Resolution Flow + +1. User calls `model(context)` +2. `_resolve_deps_and_call()` runs +3. For each `DepOf` field containing a `CallableModel`: + - Check `__deps__` for custom transforms + - Call the dependency + - Store resolved value in context variable +4. User's `__call__` accesses values via `resolve(self.field)` + +**Important:** Resolution uses a context variable (`contextvars.ContextVar`), making it thread-safe and async-safe. + +## Design Decisions + +### Decision 1: `resolve()` Instead of Temporary Mutation + +**What we chose:** Explicit `resolve()` function with context variables. + +**Alternative considered:** Temporarily mutate `self.field` during `__call__` to hold the resolved value, then restore after. + +**Why we chose this:** +- No mutation of model state +- Thread/async-safe via contextvars +- Explicit about what's happening +- Easier to debug - `self.field` always shows the original value + +**Trade-off:** Slightly more verbose (`resolve(self.data).value` vs `self.data.value`). + +### Decision 2: Unified Resolution Path + +**What we chose:** Both `@Flow.model` and class-based use the same `_resolve_deps_and_call()` function. + +**Why:** +- Single source of truth for resolution logic +- Easier to maintain +- Consistent behavior across patterns + +### Decision 3: `resolve()` Not in Top-Level `__all__` + +**What we chose:** `resolve` must be imported explicitly: `from ccflow.callable import resolve` + +**Why:** +- Only needed for class-based models with `DepOf` +- Keeps top-level namespace clean +- Users who need it can find it easily + +### Decision 4: No Auto-Wrapping Return Values + +**What we chose:** Functions must explicitly return `ResultBase` subclass. + +**Why:** +- Type annotations remain honest +- Consistent with existing `CallableModel` contract +- `GenericResult(value=x)` is minimal overhead + +### Decision 5: Generated Classes Are Real CallableModels + +**What we chose:** Generate actual `CallableModel` subclasses using `type()`. + +**Why:** +- Full compatibility with existing infrastructure +- Caching, registry, serialization work unchanged +- Can mix with hand-written classes + +## Pitfalls and Limitations + +### Pitfall 1: Forgetting `resolve()` in Class-Based Models + +```python +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + # WRONG - self.data is still the CallableModel! + return GenericResult(value=self.data.value + 1) + + # CORRECT + return GenericResult(value=resolve(self.data).value + 1) +``` + +**Error you'll see:** `AttributeError: '_SomeModel' object has no attribute 'value'` + +### Pitfall 2: Calling `resolve()` Outside `__call__` + +```python +model = MyModel(data=some_source()) +resolve(model.data) # RuntimeError! +``` + +`resolve()` only works during `__call__` execution. + +### Pitfall 3: Lambda Transforms Don't Serialize + +```python +# Won't serialize - lambdas can't be pickled +Dep(transform=lambda ctx: ctx.model_copy(...)) + +# Will serialize - use named functions +def shift_start(ctx): + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + +Dep(transform=shift_start) +``` + +### Pitfall 4: GraphEvaluator Requires Caching + +When using `GraphEvaluator` with `DepOf`, dependencies may be called twice (once by GraphEvaluator, once by resolution) unless caching is enabled. + +```python +# Use with caching +from ccflow.evaluators import GraphEvaluator, CachingEvaluator, MultiEvaluator + +evaluator = MultiEvaluator(evaluators=[ + CachingEvaluator(), + GraphEvaluator(), +]) +``` + +### Pitfall 5: Two Mental Models + +Users need to remember: +- `@Flow.model`: Use dependency values directly as function arguments +- Class-based: Use `resolve(self.field)` to access values + +### Limitation: `__deps__` Still Required for Class-Based + +Even without transforms, class-based models need `__deps__`: + +```python +class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context): + return [(self.data, [context])] # Boilerplate, but required +``` + +## Complete Example: Multi-Stage Pipeline + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import ( + CallableModel, DateRangeContext, Dep, DepOf, + Flow, GenericResult +) +from ccflow.callable import resolve + + +# Stage 1: Data loader (simple, use @Flow.model) +@Flow.model +def load_events(context: DateRangeContext, source: str) -> GenericResult[list]: + print(f"Loading from {source} for {context.start_date} to {context.end_date}") + return GenericResult(value=[ + {"date": str(context.start_date), "count": 100 + i} + for i in range(5) + ]) + + +# Stage 2: Transform with fixed lookback (use @Flow.model with Dep transform) +@Flow.model +def compute_daily_totals( + context: DateRangeContext, + events: Annotated[GenericResult[list], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + values = [e["count"] for e in events.value] + total = sum(values) / len(values) if values else 0 + return GenericResult(value=total) + + +# Stage 3: Configurable window (use class-based) +class ComputeRollingSummary(CallableModel): + """Summary with configurable lookback window.""" + + totals: DepOf[..., GenericResult[float]] + window: int = 20 + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + totals = resolve(self.totals) + # Scale by window size + summary = totals.value * (self.window ** 0.5) + return GenericResult(value=summary) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + lookback = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.totals, [lookback])] + + +# Build pipeline +events = load_events(source="main_db") +totals = compute_daily_totals(events=events) +summary_20 = ComputeRollingSummary(totals=totals, window=20) +summary_60 = ComputeRollingSummary(totals=totals, window=60) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) +print(f"20-day summary: {summary_20(ctx).value}") +print(f"60-day summary: {summary_60(ctx).value}") +``` + +## API Reference + +### `@Flow.model` + +```python +@Flow.model( + context_args: list[str] = None, # Unpack context fields as function args + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: EvaluatorBase = None, +) +def my_function(context: ContextType, ...) -> ResultType: + ... +``` + +### `DepOf[ContextType, ResultType]` + +```python +# Inherit context from parent +field: DepOf[..., GenericResult[int]] + +# Explicit context type +field: DepOf[DateRangeContext, GenericResult[int]] +``` + +### `Dep(transform=..., context_type=...)` + +```python +field: Annotated[GenericResult[int], Dep( + transform=my_transform_func, # Optional: (context) -> transformed_context + context_type=DateRangeContext, # Optional: Expected context type +)] +``` + +### `resolve(dep)` + +```python +from ccflow.callable import resolve + +# Inside __call__ of class-based CallableModel: +resolved_value = resolve(self.dep_field) + +# Type signature: +def resolve(dep: Union[T, CallableModel]) -> T: ... +``` + +## File Structure + +``` +ccflow/ +├── callable.py # CallableModel, Flow, resolve(), _resolve_deps_and_call() +├── dep.py # Dep, DepOf, extract_dep() +├── flow_model.py # @Flow.model implementation +└── tests/ + └── test_flow_model.py # Comprehensive tests +``` From d180f3566d53f7d5989f4bf005ea4ec073c73210 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 13 Jan 2026 11:03:22 -0500 Subject: [PATCH 06/17] Add extra stuff, need clean-up Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 + ccflow/callable.py | 51 +-- ccflow/context.py | 41 ++- ccflow/flow_model.py | 522 ++++++++++++++++++++++++++---- ccflow/tests/test_context.py | 9 +- ccflow/tests/test_flow_context.py | 467 ++++++++++++++++++++++++++ ccflow/tests/test_flow_model.py | 44 ++- 7 files changed, 1042 insertions(+), 93 deletions(-) create mode 100644 ccflow/tests/test_flow_context.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 9916168..c8d2259 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -12,6 +12,7 @@ from .context import * from .dep import * from .enums import Enum +from .flow_model import FlowAPI, BoundModel, Lazy from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index 8f6adfe..1aa7189 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -312,7 +312,10 @@ def _resolve_deps_and_call(model, context, fn): # Get Dep-annotated fields for this model class dep_fields = _get_dep_fields(model.__class__) - if not dep_fields: + # Check if model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False) + + if not dep_fields and not has_custom_deps: return fn(model, context) # Get dependencies from __deps__ @@ -324,27 +327,37 @@ def _resolve_deps_and_call(model, context, fn): # Resolve dependencies and store in context var resolved_values = {} - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context + # If custom deps, resolve ALL CallableModel fields from dep_map + if has_custom_deps: + for dep_model, contexts in deps_result: resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) + # Unwrap GenericResult if present (consistent with auto-detected deps) + if hasattr(resolved, 'value'): + resolved = resolved.value + resolved_values[id(dep_model)] = resolved + else: + # Standard path: iterate over Dep-annotated fields + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) - # Store resolved value keyed by the CallableModel's id - resolved_values[id(field_value)] = resolved + resolved_values[id(field_value)] = resolved # Store in context var and call function current_store = _resolved_deps.get() diff --git a/ccflow/context.py b/ccflow/context.py index cf17d24..62ce0f7 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -2,10 +2,10 @@ import warnings from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency @@ -15,6 +15,7 @@ __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -93,6 +94,42 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + + Fields are stored in __pydantic_extra__ and accessed via __getattr__. + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def __getattr__(self, name: str) -> Any: + """Access fields stored in __pydantic_extra__.""" + # Use object.__getattribute__ to avoid infinite recursion + try: + extra = object.__getattribute__(self, "__pydantic_extra__") + if extra is not None and name in extra: + return extra[name] + except AttributeError: + pass + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __repr__(self) -> str: + """Show all fields including extra fields.""" + extra = object.__getattribute__(self, "__pydantic_extra__") or {} + fields = ", ".join(f"{k}={v!r}" for k, v in extra.items()) + return f"FlowContext({fields})" + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 25db69c..2d3ab3a 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -3,6 +3,10 @@ This module provides the Flow.model decorator that generates CallableModel classes from plain Python functions, reducing boilerplate while maintaining full compatibility with existing ccflow infrastructure. + +Key design: Uses TypedDict + TypeAdapter for context schema validation instead of +generating dynamic ContextBase subclasses. This avoids class registration overhead +and enables clean pickling for distributed computing (e.g., Ray). """ import inspect @@ -10,21 +14,219 @@ from functools import wraps from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin -from pydantic import Field +from pydantic import Field, TypeAdapter +from typing_extensions import TypedDict from .base import ContextBase, ResultBase +from .context import FlowContext from .dep import Dep, extract_dep -from .local_persistence import register_ccflow_import_path -__all__ = ("flow_model",) +__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy") log = logging.getLogger(__name__) -def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: inspect.Signature) -> Type[ContextBase]: - """Infer or create a context type from context_args parameter names. +class FlowAPI: + """API namespace for deferred computation operations. + + Provides methods for executing models and transforming contexts. + Accessed via model.flow property. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + self._model = model + + def compute(self, **kwargs) -> Any: + """Execute the model with the provided context arguments. + + Validates kwargs against the model's context schema using TypeAdapter, + then wraps in FlowContext and calls the model. + + Args: + **kwargs: Context arguments (e.g., start_date, end_date) + + Returns: + The model's result, unwrapped from GenericResult if applicable. + """ + # Get validator from model (lazily created if needed after unpickling) + validator = self._model._get_context_validator() + + # Validate and coerce kwargs via TypeAdapter + validated = validator.validate_python(kwargs) + + # Wrap in FlowContext (single class, always) + ctx = FlowContext(**validated) + + # Call the model + result = self._model(ctx) + + # Unwrap GenericResult if present + if hasattr(result, "value"): + return result.value + return result + + @property + def unbound_inputs(self) -> Dict[str, Type]: + """Return the context schema (field name -> type). + + In deferred mode, this is everything NOT provided at construction. + """ + all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self._model, "_bound_fields", set()) + + # If explicit context_args was provided, use _context_schema + explicit_args = getattr(self._model.__class__, "__flow_model_explicit_context_args__", None) + if explicit_args is not None: + return self._model._context_schema.copy() + + # Otherwise, unbound = all params - bound + return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + + @property + def bound_inputs(self) -> Dict[str, Any]: + """Return the config values bound at construction time.""" + bound_fields = getattr(self._model, "_bound_fields", set()) + result = {} + for name in bound_fields: + if hasattr(self._model, name): + result[name] = getattr(self._model, name) + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Create a version of this model with transformed context inputs. + + Args: + **transforms: Mapping of field name to either: + - A callable (ctx) -> value for dynamic transforms + - A static value to bind + + Returns: + A BoundModel that applies the transforms before calling. + """ + return BoundModel(model=self._model, input_transforms=transforms) + + +class BoundModel: + """A model with context transforms applied. + + Created by model.flow.with_inputs(). Applies transforms to context + before delegating to the underlying model. + """ + + def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 + self._model = model + self._input_transforms = input_transforms + + def __call__(self, context: ContextBase) -> Any: + """Call the model with transformed context.""" + # Build new context dict with transforms applied + ctx_dict = {} + + # Get fields from context + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply transforms + for name, transform in self._input_transforms.items(): + if callable(transform): + ctx_dict[name] = transform(context) + else: + ctx_dict[name] = transform + + # Create new context and call model + new_ctx = FlowContext(**ctx_dict) + return self._model(new_ctx) + + @property + def flow(self) -> FlowAPI: + """Access the flow API.""" + return FlowAPI(self._model) + + +class Lazy: + """Deferred model execution with runtime context overrides. + + Wraps a CallableModel to allow context fields to be determined at + runtime rather than at construction time. Use in with_inputs() when + you need values that aren't available until execution. + + Example: + # Create a model that needs runtime-determined context + market_data = load_market_data(symbols=["AAPL"]) + + # Use Lazy to defer the start_date calculation + lookback_data = market_data.flow.with_inputs( + start_date=Lazy(market_data)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + ) + + # More commonly, use Lazy for self-referential transforms: + adjusted_model = model.flow.with_inputs( + value=Lazy(other_model)(multiplier=2) # Call other_model with multiplier=2 + ) + + The __call__ method returns a callable that, when invoked with a context, + calls the wrapped model with the specified overrides applied. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + """Wrap a model for deferred execution. + + Args: + model: The CallableModel to wrap + """ + self._model = model + + def __call__(self, **overrides) -> Callable[[ContextBase], Any]: + """Create a callable that applies overrides to context before execution. + + Args: + **overrides: Context field overrides. Values can be: + - Static values (applied directly) + - Callables (ctx) -> value (called with context at runtime) + + Returns: + A callable (context) -> result that applies overrides and calls the model + """ + model = self._model + + def execute_with_overrides(context: ContextBase) -> Any: + # Build context dict from incoming context + ctx_dict = {} + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply overrides + for name, value in overrides.items(): + if callable(value): + ctx_dict[name] = value(context) + else: + ctx_dict[name] = value + + # Call model with modified context + new_ctx = FlowContext(**ctx_dict) + return model(new_ctx) + + return execute_with_overrides + + @property + def model(self) -> "CallableModel": # noqa: F821 + """Access the wrapped model.""" + return self._model + + +def _build_context_schema( + context_args: List[str], func: Callable, sig: inspect.Signature +) -> Tuple[Dict[str, Type], Type, Optional[Type[ContextBase]]]: + """Build context schema from context_args parameter names. - This attempts to match existing context types or creates a new one. + Instead of creating a dynamic ContextBase subclass, this builds: + - A schema dict mapping field names to types + - A TypedDict for Pydantic TypeAdapter validation + - Optionally, a matched existing ContextBase type for compatibility Args: context_args: List of parameter names that come from context @@ -32,25 +234,22 @@ def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: sig: The function signature Returns: - A ContextBase subclass + Tuple of (schema_dict, TypedDict type, optional matched ContextBase type) """ - from .local_persistence import create_ccflow_model - - # Build field definitions for the context from parameter annotations - fields = {} + # Build schema dict from parameter annotations + schema = {} for name in context_args: if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") param = sig.parameters[name] if param.annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") - default = ... if param.default is inspect.Parameter.empty else param.default - fields[name] = (param.annotation, default) + schema[name] = param.annotation - # Try to match common context types + # Try to match common context types for compatibility + matched_context_type = None from .context import DateRangeContext - # Check for DateRangeContext pattern if set(context_args) == {"start_date", "end_date"}: from datetime import date @@ -59,15 +258,12 @@ def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) for name in context_args ): - return DateRangeContext + matched_context_type = DateRangeContext + + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{func.__name__}Inputs", schema) - # Create a new context type dynamically - context_class = create_ccflow_model( - f"_{func.__name__}_Context", - __base__=ContextBase, - **fields, - ) - return context_class + return schema, context_td, matched_context_type def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: @@ -142,13 +338,8 @@ def decorator(fn: Callable) -> Callable: if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") - # Determine context mode and extract info - if context_args is not None: - # Mode 2: Unpacked context args - context_type = _infer_context_type_from_args(context_args, fn, sig) - model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} - use_context_args = True - elif "context" in params or "_" in params: + # Determine context mode + if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] @@ -159,57 +350,139 @@ def decorator(fn: Callable) -> Callable: raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False + explicit_context_args = None + elif context_args is not None: + # Mode 2: Explicit context_args - specified params come from context + context_param_name = "context" + # Build context schema early to determine matched_context_type + context_schema_early, _, matched_type = _build_context_schema(context_args, fn, sig) + # Use matched type if available (e.g., DateRangeContext), else FlowContext + context_type = matched_type if matched_type is not None else FlowContext + # Exclude context_args from model fields + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + explicit_context_args = context_args else: - raise TypeError(f"Function {fn.__name__} must either have a 'context' (or '_') parameter or specify context_args in the decorator") + # Mode 3: Dynamic deferred mode - ALL params are potential context or config + # What's provided at construction = config/deps + # What's NOT provided = comes from context at runtime + context_param_name = "context" + context_type = FlowContext + model_field_params = {name: param for name, param in params.items() if name != "self"} + use_context_args = True + explicit_context_args = None # Dynamic - determined at construction # Analyze parameters to find dependencies and regular fields dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + # In dynamic deferred mode (no explicit context_args), all fields are optional + # because values not provided at construction come from context at runtime + dynamic_deferred_mode = use_context_args and explicit_context_args is None + for name, param in model_field_params.items(): if param.annotation is inspect.Parameter.empty: raise TypeError(f"Parameter '{name}' must have a type annotation") base_type, dep = _get_dep_info(param.annotation) - default = ... if param.default is inspect.Parameter.empty else param.default + if param.default is not inspect.Parameter.empty: + default = param.default + elif dynamic_deferred_mode: + # In dynamic mode, params without defaults are optional (come from context) + default = None + else: + # In explicit mode, params without defaults are required + default = ... if dep is not None: - # This is a dependency parameter + # This is an explicit dependency parameter (DepOf annotation) dep_fields[name] = (base_type, dep) # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep - # This consolidates resolution logic into one place model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) else: - # Regular model field - model_fields[name] = (param.annotation, default) + # Regular model field - use Any for auto-detection of CallableModels. + # We can't use Union[T, CallableModel] because Pydantic tries to generate + # schema for T, which fails for arbitrary types like pl.DataFrame. + # Using Any allows any value; we do runtime isinstance checks in __call__. + model_fields[name] = (Any, default) - # Capture context_args in local variable for closures - ctx_args_list = context_args or [] - # Capture context parameter name for closures (only used in mode 1) + # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" + all_param_names = list(model_fields.keys()) # All non-context params (model fields) + all_param_types = {name: param.annotation for name, param in model_field_params.items()} + # For explicit context_args mode, we also need the list of context arg names + ctx_args_for_closure = context_args if context_args is not None else [] + is_dynamic_mode = use_context_args and explicit_context_args is None # Create the __call__ method def make_call_impl(): - # Import resolve here to avoid circular import at module level - from .callable import resolve - def __call__(self, context): + # Import here (inside function) to avoid pickling issues with ContextVar + from .callable import _resolved_deps + + # Check if this model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(self.__class__, "__has_custom_deps__", False) + + def resolve_callable_model(name, value, store): + """Resolve a CallableModel field. + + When has_custom_deps is True and the value is NOT in the store, + it means the custom deps function chose not to include this dep. + In that case, we return None (the field's default) instead of + calling the CallableModel directly. + """ + if id(value) in store: + return store[id(value)] + elif has_custom_deps: + # Custom deps excluded this field - use None + return None + else: + # Auto-detection fallback: call directly + resolved = value(context) + if hasattr(resolved, 'value'): + return resolved.value + return resolved + # Build kwargs for the original function - if use_context_args: - # Unpack context into args - fn_kwargs = {name: getattr(context, name) for name in ctx_args_list} + fn_kwargs = {} + store = _resolved_deps.get() + + if not use_context_args: + # Mode 1: Explicit context param - pass context directly + fn_kwargs[ctx_param_name] = context + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + elif not is_dynamic_mode: + # Mode 2: Explicit context_args - get those from context, rest from self + for name in ctx_args_for_closure: + fn_kwargs[name] = getattr(context, name) + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value else: - # Pass context directly (using actual parameter name: 'context' or '_') - fn_kwargs = {ctx_param_name: context} - - # Add model fields - use resolve() for dep fields to get resolved values - for name in model_fields: - value = getattr(self, name) - if name in dep_fields: - # Use resolve() to get the resolved value from context var - fn_kwargs[name] = resolve(value) - else: - fn_kwargs[name] = value + # Mode 3: Dynamic deferred mode - unbound from context, bound from self + bound_fields = getattr(self, "_bound_fields", set()) + + for name in all_param_names: + if name in bound_fields: + # Bound at construction - get from self + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + else: + # Unbound - get from context + fn_kwargs[name] = getattr(context, name) return fn(**fn_kwargs) @@ -242,11 +515,18 @@ def __call__(self, context): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - for dep_name, (base_type, dep_obj) in dep_fields.items(): - value = getattr(self, dep_name) + # Check ALL fields for CallableModels (auto-detection) + for name in model_fields: + value = getattr(self, name) if isinstance(value, CallableModel): - transformed_ctx = dep_obj.apply(context) - deps.append((value, [transformed_ctx])) + if name in dep_fields: + # Explicit DepOf with transform (backwards compat) + _, dep_obj = dep_fields[name] + transformed_ctx = dep_obj.apply(context) + deps.append((value, [transformed_ctx])) + else: + # Auto-detected dependency - use context as-is + deps.append((value, [context])) return deps # Set proper signature @@ -311,7 +591,69 @@ def __validate_deps__(self): GeneratedModel.__flow_model_func__ = fn GeneratedModel.__flow_model_dep_fields__ = dep_fields GeneratedModel.__flow_model_use_context_args__ = use_context_args - GeneratedModel.__flow_model_context_args__ = ctx_args_list + GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args + GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + + # Build context_schema and matched_context_type + context_schema: Dict[str, Type] = {} + context_td = None + matched_context_type: Optional[Type[ContextBase]] = None + + if explicit_context_args is not None: + # Explicit context_args provided - use early-computed schema + # (matched_context_type was already used to set context_type above) + context_schema, context_td, matched_context_type = _build_context_schema(explicit_context_args, fn, sig) + elif not use_context_args: + # Explicit context mode - schema comes from the context type's fields + if hasattr(context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in context_type.model_fields.items()} + # For dynamic mode (is_dynamic_mode), _context_schema remains empty + # and schema is built dynamically from _bound_fields at runtime + + # Store context schema for TypedDict-based validation (picklable!) + GeneratedModel._context_schema = context_schema + GeneratedModel._context_td = context_td + GeneratedModel._matched_context_type = matched_context_type + # Validator is created lazily to survive pickling + GeneratedModel._cached_context_validator = None + + # Method to get/create context validator (lazy for pickling support) + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator. + + For dynamic deferred mode, builds schema from unbound fields. + For explicit context_args or explicit context mode, uses cached schema. + """ + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + # For explicit context_args or explicit context mode, use cached validator + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + if cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + td = TypedDict(f"{cls.__name__}Inputs", cls._context_schema) + cls._cached_context_validator = TypeAdapter(td) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + # Dynamic mode: build schema from unbound fields (instance-specific) + # Cache on instance since bound_fields varies per instance + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self, "_bound_fields", set()) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + if unbound_schema: + td = TypedDict(f"{cls.__name__}Inputs", unbound_schema) + object.__setattr__(self, "_instance_context_validator", TypeAdapter(td)) + else: + # No unbound fields - empty validator + object.__setattr__(self, "_instance_context_validator", TypeAdapter(dict)) + return self._instance_context_validator + + GeneratedModel._get_context_validator = _get_context_validator # Override context_type property after class creation @property @@ -323,10 +665,20 @@ def context_type_getter(self) -> Type[ContextBase]: def result_type_getter(self) -> Type[ResultBase]: return self.__class__.__flow_model_return_type__ + # Add .flow property for the new API + @property + def flow_getter(self) -> FlowAPI: + return FlowAPI(self) + GeneratedModel.context_type = context_type_getter GeneratedModel.result_type = result_type_getter + GeneratedModel.flow = flow_getter + + # Register the MODEL class for serialization (needed for model_dump/_target_). + # Note: We do NOT register dynamic context classes anymore - context handling + # uses FlowContext + TypedDict instead, which don't need registration. + from .local_persistence import register_ccflow_import_path - # Register for serialization (local classes need this) register_ccflow_import_path(GeneratedModel) # Rebuild the model to process annotations properly @@ -335,12 +687,56 @@ def result_type_getter(self) -> Type[ResultBase]: # Create factory function that returns model instances @wraps(fn) def factory(**kwargs) -> GeneratedModel: - return GeneratedModel(**kwargs) + instance = GeneratedModel(**kwargs) + # Track which fields were explicitly provided at construction + # These are "bound" - everything else comes from context at runtime + object.__setattr__(instance, "_bound_fields", set(kwargs.keys())) + return instance # Preserve useful attributes on factory factory._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ + # Add .deps decorator for customizing __deps__ + def deps_decorator(deps_fn): + """Decorator to customize the __deps__ method. + + Usage: + @Flow.model + def my_func(start_date: date, prices: dict) -> GenericResult[...]: + ... + + @my_func.deps + def _(self, context): + # Custom context transform + lookback_ctx = FlowContext( + start_date=context.start_date - timedelta(days=30), + end_date=context.end_date, + ) + return [(self.prices, [lookback_ctx])] + """ + from .callable import GraphDepList + + # Rename the function to __deps__ so Flow.deps accepts it + deps_fn.__name__ = "__deps__" + deps_fn.__qualname__ = f"{GeneratedModel.__qualname__}.__deps__" + # Set proper signature to match __call__'s context type + deps_fn.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + # Wrap with Flow.deps and replace on the class + decorated = Flow.deps(deps_fn) + GeneratedModel.__deps__ = decorated + # Mark that this model has custom deps (so _resolve_deps_and_call will call it) + GeneratedModel.__has_custom_deps__ = True + return factory # Return factory for chaining + + factory.deps = deps_decorator + return factory # Handle both @Flow.model and @Flow.model(...) syntax diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..70af8b2 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,467 @@ +"""Tests for FlowContext, FlowAPI, and TypedDict-based context validation. + +These tests verify the new deferred computation API that uses: +- FlowContext: Universal context carrier with extra="allow" +- TypedDict + TypeAdapter: Schema validation without dynamic class registration +- FlowAPI: The .flow namespace for compute/with_inputs/etc. +""" + +import pickle +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import Flow, FlowAPI, FlowContext, GenericResult +from ccflow.context import DateRangeContext + + +class TestFlowContext: + """Tests for the FlowContext universal carrier.""" + + def test_flow_context_basic(self): + """FlowContext accepts arbitrary fields.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + + def test_flow_context_extra_fields(self): + """FlowContext stores fields in __pydantic_extra__.""" + ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) + assert ctx.x == 1 + assert ctx.y == "hello" + assert ctx.z == [1, 2, 3] + assert ctx.__pydantic_extra__ == {"x": 1, "y": "hello", "z": [1, 2, 3]} + + def test_flow_context_frozen(self): + """FlowContext is immutable (frozen).""" + ctx = FlowContext(value=42) + with pytest.raises(Exception): # ValidationError for frozen model + ctx.value = 100 + + def test_flow_context_repr(self): + """FlowContext has a useful repr.""" + ctx = FlowContext(a=1, b=2) + repr_str = repr(ctx) + assert "FlowContext" in repr_str + assert "a=1" in repr_str + assert "b=2" in repr_str + + def test_flow_context_attribute_error(self): + """FlowContext raises AttributeError for missing fields.""" + ctx = FlowContext(x=1) + with pytest.raises(AttributeError, match="no attribute 'missing'"): + _ = ctx.missing + + def test_flow_context_model_dump(self): + """FlowContext can be dumped (includes extra fields).""" + ctx = FlowContext(start_date=date(2024, 1, 1), value=42) + dumped = ctx.model_dump() + assert dumped["start_date"] == date(2024, 1, 1) + assert dumped["value"] == 42 + + def test_flow_context_pickle(self): + """FlowContext pickles cleanly.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + assert unpickled.start_date == date(2024, 1, 1) + assert unpickled.end_date == date(2024, 1, 31) + + def test_flow_context_cloudpickle(self): + """FlowContext works with cloudpickle (for Ray).""" + ctx = FlowContext(data=[1, 2, 3], name="test") + pickled = cloudpickle.dumps(ctx) + unpickled = cloudpickle.loads(pickled) + assert unpickled.data == [1, 2, 3] + assert unpickled.name == "test" + + +class TestFlowAPI: + """Tests for the FlowAPI (.flow namespace).""" + + def test_flow_compute_basic(self): + """FlowAPI.compute() validates and executes.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date, "source": source}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + assert result["source"] == "api" + + def test_flow_compute_type_coercion(self): + """FlowAPI.compute() coerces types via TypeAdapter.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Pass strings - should be coerced to dates + result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + def test_flow_compute_validation_error(self): + """FlowAPI.compute() raises on missing required args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + with pytest.raises(Exception): # ValidationError + model.flow.compute(start_date=date(2024, 1, 1)) # Missing end_date + + def test_flow_unbound_inputs(self): + """FlowAPI.unbound_inputs returns the context schema.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + unbound = model.flow.unbound_inputs + + assert "start_date" in unbound + assert "end_date" in unbound + assert unbound["start_date"] == date + assert unbound["end_date"] == date + # source is not unbound (it has a default/is bound) + assert "source" not in unbound + + def test_flow_bound_inputs(self): + """FlowAPI.bound_inputs returns config values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + bound = model.flow.bound_inputs + + assert "source" in bound + assert bound["source"] == "api" + # Context args are not in bound_inputs + assert "start_date" not in bound + assert "end_date" not in bound + + +class TestBoundModel: + """Tests for BoundModel (created via .flow.with_inputs()).""" + + def test_with_inputs_static_value(self): + """with_inputs can bind static values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs(start_date=date(2024, 1, 1)) + + # Call with just end_date (start_date is bound) + ctx = FlowContext(end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_transform_function(self): + """with_inputs can use transform functions.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Lookback: start_date is 7 days before the context's start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_multiple_transforms(self): + """with_inputs can apply multiple transforms.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + timedelta(days=1), + ) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_bound_model_has_flow_property(self): + """BoundModel has a .flow property.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=42) + assert isinstance(bound.flow, FlowAPI) + + +class TestTypedDictValidation: + """Tests for TypedDict-based context validation.""" + + def test_schema_stored_on_model(self): + """Model stores _context_schema for validation.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + assert hasattr(model, "_context_schema") + assert model._context_schema == {"start_date": date, "end_date": date} + + def test_validator_created_lazily(self): + """TypeAdapter validator is created lazily.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Initially None + assert model.__class__._cached_context_validator is None + + # After getting validator, it's cached + validator = model._get_context_validator() + assert validator is not None + assert model.__class__._cached_context_validator is validator + + def test_matched_context_type(self): + """DateRangeContext pattern is matched for compatibility.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + # Should match DateRangeContext + assert model.context_type == DateRangeContext + + +class TestPicklingSupport: + """Tests for pickling support (important for Ray). + + Note: Regular pickle cannot pickle locally-defined classes (functions decorated + inside test methods). cloudpickle CAN handle this, which is why Ray uses it. + All tests here use cloudpickle to match Ray's behavior. + """ + + def test_model_cloudpickle_roundtrip(self): + """Model works with cloudpickle (for Ray).""" + + @Flow.model(context_args=["x", "y"]) + def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=(x + y) * multiplier) + + model = compute(multiplier=3) + + # cloudpickle roundtrip (what Ray uses) + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Should work after unpickling + result = unpickled.flow.compute(x=1, y=2) + assert result == 9 # (1 + 2) * 3 + + def test_model_cloudpickle_simple(self): + """Simple model cloudpickle test.""" + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + result = unpickled.flow.compute(value=21) + assert result == 42 + + def test_validator_recreated_after_cloudpickle(self): + """TypeAdapter validator is recreated after cloudpickling.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Warm up the validator cache + _ = model._get_context_validator() + assert model.__class__._cached_context_validator is not None + + # cloudpickle and unpickle + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Validator should still work (may be lazily recreated) + result = unpickled.flow.compute(x=42) + assert result == 42 + + def test_flow_context_pickle_standard(self): + """FlowContext works with standard pickle.""" + ctx = FlowContext(x=1, y=2, z="test") + + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + + assert unpickled.x == 1 + assert unpickled.y == 2 + assert unpickled.z == "test" + + +class TestIntegrationWithExistingContextTypes: + """Tests for integration with existing ContextBase subclasses.""" + + def test_explicit_context_still_works(self): + """Explicit context parameter mode still works.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date, "source": source}) + + model = load_data(source="api") + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["source"] == "api" + + def test_flow_context_coerces_to_date_range(self): + """FlowContext can be used with models expecting DateRangeContext.""" + + @Flow.model + def load_data(context: DateRangeContext) -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data() + # Use FlowContext - should coerce to DateRangeContext + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_flow_api_with_explicit_context(self): + """FlowAPI.compute works with explicit context mode.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + +class TestLazy: + """Tests for Lazy (deferred execution with context overrides).""" + + def test_lazy_basic(self): + """Lazy wraps a model for deferred execution.""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def compute(value: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + model = compute(multiplier=3) + lazy = Lazy(model) + + assert lazy.model is model + + def test_lazy_call_with_static_override(self): + """Lazy.__call__ with static override values.""" + from ccflow import Lazy + + @Flow.model(context_args=["x", "y"]) + def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = add() + lazy_fn = Lazy(model)(y=100) # Override y to 100 + + ctx = FlowContext(x=5, y=10) # Original y=10 + result = lazy_fn(ctx) + assert result.value == 105 # x=5 + y=100 (overridden) + + def test_lazy_call_with_callable_override(self): + """Lazy.__call__ with callable override (computed at runtime).""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + # Override value to be original value + 10 + lazy_fn = Lazy(model)(value=lambda ctx: ctx.value + 10) + + ctx = FlowContext(value=5) + result = lazy_fn(ctx) + assert result.value == 30 # (5 + 10) * 2 = 30 + + def test_lazy_with_date_transforms(self): + """Lazy works with date transforms.""" + from ccflow import Lazy + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + + # Use Lazy to create a transform that shifts dates + lazy_fn = Lazy(model)( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + ) + + ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) + result = lazy_fn(ctx) + + assert result.value["start"] == date(2024, 1, 8) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_lazy_multiple_overrides(self): + """Lazy supports multiple overrides at once.""" + from ccflow import Lazy + + @Flow.model(context_args=["a", "b", "c"]) + def compute(a: int, b: int, c: int) -> GenericResult[int]: + return GenericResult(value=a + b + c) + + model = compute() + lazy_fn = Lazy(model)( + a=10, # Static + b=lambda ctx: ctx.b * 2, # Transform + # c not overridden, uses context value + ) + + ctx = FlowContext(a=1, b=5, c=100) + result = lazy_fn(ctx) + assert result.value == 10 + 10 + 100 # a=10, b=5*2=10, c=100 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index fe7e32e..b283a2b 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -609,15 +609,45 @@ def bad_return(context: SimpleContext) -> int: self.assertIn("ResultBase", str(cm.exception)) - def test_missing_context_and_context_args(self): - """Test error when neither context param nor context_args provided.""" - with self.assertRaises(TypeError) as cm: + def test_dynamic_deferred_mode(self): + """Test dynamic deferred mode where what you provide at construction = bound.""" + from ccflow import FlowContext - @Flow.model - def no_context(value: int) -> GenericResult[int]: - return GenericResult(value=value) + @Flow.model + def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + # Provide 'multiplier' at construction -> it's bound + # Don't provide 'value' -> comes from context + model = dynamic_model(multiplier=3) + + # Check bound vs unbound + self.assertEqual(model.flow.bound_inputs, {"multiplier": 3}) + self.assertEqual(model.flow.unbound_inputs, {"value": int}) + + # Call with context providing 'value' + ctx = FlowContext(value=10) + result = model(ctx) + self.assertEqual(result.value, 30) # 10 * 3 - self.assertIn("context", str(cm.exception)) + def test_all_defaults_is_valid(self): + """Test that all-defaults function is valid (everything can be pre-bound).""" + from ccflow import FlowContext + + @Flow.model + def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{value}-{other}") + + # No args provided -> everything comes from defaults or context + model = all_defaults() + + # All params are unbound (not provided at construction) + self.assertEqual(model.flow.unbound_inputs, {"value": int, "other": str}) + + # Call with context - context values override defaults + ctx = FlowContext(value=5, other="y") + result = model(ctx) + self.assertEqual(result.value, "5-y") def test_invalid_context_arg(self): """Test error when context_args refers to non-existent parameter.""" From 94959158d17d4c3791d66ae0ad8e6fec7b00ac10 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Mon, 16 Mar 2026 18:30:24 -0400 Subject: [PATCH 07/17] Lint fixes Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 2 +- ccflow/flow_model.py | 2 +- ccflow/tests/test_flow_context.py | 5 +---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 1aa7189..aa92ae1 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -333,7 +333,7 @@ def _resolve_deps_and_call(model, context, fn): for dep_model, contexts in deps_result: resolved = dep_model(contexts[0]) if contexts else dep_model(context) # Unwrap GenericResult if present (consistent with auto-detected deps) - if hasattr(resolved, 'value'): + if hasattr(resolved, "value"): resolved = resolved.value resolved_values[id(dep_model)] = resolved else: diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 2d3ab3a..414202e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -439,7 +439,7 @@ def resolve_callable_model(name, value, store): else: # Auto-detection fallback: call directly resolved = value(context) - if hasattr(resolved, 'value'): + if hasattr(resolved, "value"): return resolved.value return resolved diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 70af8b2..3f613ab 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -436,10 +436,7 @@ def load_data(start_date: date, end_date: date) -> GenericResult[dict]: model = load_data() # Use Lazy to create a transform that shifts dates - lazy_fn = Lazy(model)( - start_date=lambda ctx: ctx.start_date - timedelta(days=7), - end_date=lambda ctx: ctx.end_date - ) + lazy_fn = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7), end_date=lambda ctx: ctx.end_date) ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) result = lazy_fn(ctx) From 765299d9612960bb05c3ba49c5c11e55ab3c66bc Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 05:36:22 -0400 Subject: [PATCH 08/17] Flow.model cleanup Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 2 +- ccflow/callable.py | 52 +- ccflow/context.py | 21 +- ccflow/flow_model.py | 582 +++++++++++++-------- ccflow/tests/test_flow_context.py | 4 +- ccflow/tests/test_flow_model.py | 810 +++++++++++++++++++++++++++++- 6 files changed, 1211 insertions(+), 260 deletions(-) diff --git a/ccflow/__init__.py b/ccflow/__init__.py index c703a1c..4dbe143 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -12,7 +12,7 @@ from .context import * from .dep import * from .enums import Enum -from .flow_model import FlowAPI, BoundModel, Lazy +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index aa92ae1..d3b22e4 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -312,10 +312,7 @@ def _resolve_deps_and_call(model, context, fn): # Get Dep-annotated fields for this model class dep_fields = _get_dep_fields(model.__class__) - # Check if model has custom deps (from @func.deps decorator) - has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False) - - if not dep_fields and not has_custom_deps: + if not dep_fields: return fn(model, context) # Get dependencies from __deps__ @@ -328,36 +325,27 @@ def _resolve_deps_and_call(model, context, fn): # Resolve dependencies and store in context var resolved_values = {} - # If custom deps, resolve ALL CallableModel fields from dep_map - if has_custom_deps: - for dep_model, contexts in deps_result: + # Standard path: iterate over Dep-annotated fields + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context resolved = dep_model(contexts[0]) if contexts else dep_model(context) - # Unwrap GenericResult if present (consistent with auto-detected deps) - if hasattr(resolved, "value"): - resolved = resolved.value - resolved_values[id(dep_model)] = resolved - else: - # Standard path: iterate over Dep-annotated fields - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context - resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) - resolved_values[id(field_value)] = resolved + resolved_values[id(field_value)] = resolved # Store in context var and call function current_store = _resolved_deps.get() diff --git a/ccflow/context.py b/ccflow/context.py index 50ff6dc..0d00d2e 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,7 +1,7 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" from datetime import date, datetime -from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated from pydantic import ConfigDict, field_validator, model_validator @@ -102,29 +102,10 @@ class FlowContext(ContextBase): - Proliferation of dynamic _funcname_Context classes - Class registration overhead for serialization - Pickling issues with Ray/distributed computing - - Fields are stored in __pydantic_extra__ and accessed via __getattr__. """ model_config = ConfigDict(extra="allow", frozen=True) - def __getattr__(self, name: str) -> Any: - """Access fields stored in __pydantic_extra__.""" - # Use object.__getattribute__ to avoid infinite recursion - try: - extra = object.__getattribute__(self, "__pydantic_extra__") - if extra is not None and name in extra: - return extra[name] - except AttributeError: - pass - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - def __repr__(self) -> str: - """Show all fields including extra fields.""" - extra = object.__getattribute__(self, "__pydantic_extra__") or {} - fields = ", ".join(f"{k}={v!r}" for k, v in extra.items()) - return f"FlowContext({fields})" - C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 414202e..e9f2704 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,20 +12,121 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin from pydantic import Field, TypeAdapter from typing_extensions import TypedDict from .base import ContextBase, ResultBase +from .callable import CallableModel, Flow, GraphDepList, _CallableModel from .context import FlowContext from .dep import Dep, extract_dep +from .local_persistence import register_ccflow_import_path +from .result import GenericResult + +__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") + + +class _LazyMarker: + """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" + + pass + + +def _extract_lazy(annotation) -> Tuple[Any, bool]: + """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). + + Handles nested Annotated types — e.g. Lazy[Annotated[T, Dep(...)]] produces + Annotated[Annotated[T, Dep(...)], _LazyMarker()], so we need to check the + outermost Annotated layer for _LazyMarker. + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + return args[0], True + return annotation, False + + +def _make_lazy_thunk(model, context): + """Create a zero-arg callable that evaluates model(context) on demand. + + The thunk caches its result so repeated calls don't re-evaluate. + """ + _cache = {} + + def thunk(): + if "result" not in _cache: + result = model(context) + if isinstance(result, GenericResult): + result = result.value + _cache["result"] = result + return _cache["result"] + + return thunk -__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy") log = logging.getLogger(__name__) +def _context_values(context: ContextBase) -> Dict[str, Any]: + """Return a plain mapping of all context values. + + `dict(context)` uses pydantic's public iteration behavior, which includes + both declared fields and any allowed extra fields. + """ + + return dict(context) + + +def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: + """Build a TypeAdapter for a runtime TypedDict schema.""" + + if not schema: + return TypeAdapter(dict) + return TypeAdapter(TypedDict(name, schema)) + + +def _build_config_validators( + all_param_types: Dict[str, Type], dep_fields: Dict[str, Tuple[Type, Dep]] +) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: + """Precompute validators for non-dependency config fields.""" + + validatable_types: Dict[str, Type] = {} + for name, typ in all_param_types.items(): + if name in dep_fields: + continue + try: + TypeAdapter(typ) + validatable_types[name] = typ + except Exception: + pass + + validators = {name: TypeAdapter(typ) for name, typ in validatable_types.items()} + return validatable_types, validators + + +def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: + """Validate plain config inputs while still allowing dependency objects.""" + + if not validators: + return + + from .callable import CallableModel as _CM + + for field_name, validator in validators.items(): + if field_name not in kwargs: + continue + value = kwargs[field_name] + if value is None or isinstance(value, (_CM, BoundModel)): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + + class FlowAPI: """API namespace for deferred computation operations. @@ -61,7 +162,7 @@ def compute(self, **kwargs) -> Any: result = self._model(ctx) # Unwrap GenericResult if present - if hasattr(result, "value"): + if isinstance(result, GenericResult): return result.value return result @@ -111,6 +212,20 @@ class BoundModel: Created by model.flow.with_inputs(). Applies transforms to context before delegating to the underlying model. + + Context propagation across dependencies: + Each BoundModel transforms the context locally — only for the model it + wraps. When used as a dependency inside another model, the FlowContext + flows through the chain unchanged until it reaches this BoundModel, + which intercepts it, applies its transforms, and passes the modified + context to the wrapped model. Upstream models never see the transform. + + Chaining with_inputs: + Calling ``bound.flow.with_inputs(...)`` merges the new transforms with + the existing ones (new overrides old for the same key). All transforms + are applied to the incoming context in one pass — they don't compose + sequentially (each transform sees the original context, not the output + of a previous transform). """ def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 @@ -120,13 +235,7 @@ def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" # Build new context dict with transforms applied - ctx_dict = {} - - # Get fields from context - if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: - ctx_dict.update(context.__pydantic_extra__) - for field in context.__class__.model_fields: - ctx_dict[field] = getattr(context, field) + ctx_dict = _context_values(context) # Apply transforms for name, transform in self._input_transforms.items(): @@ -140,35 +249,125 @@ def __call__(self, context: ContextBase) -> Any: return self._model(new_ctx) @property - def flow(self) -> FlowAPI: + def flow(self) -> "FlowAPI": """Access the flow API.""" - return FlowAPI(self._model) + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + """FlowAPI that delegates to a BoundModel, honoring transforms.""" + + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model._model) + + def compute(self, **kwargs) -> Any: + validator = self._model._get_context_validator() + validated = validator.validate_python(kwargs) + ctx = FlowContext(**validated) + result = self._bound(ctx) # Call through BoundModel, not _model + if isinstance(result, GenericResult): + return result.value + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Chain transforms: merge new transforms with existing ones. + + New transforms override existing ones for the same key. + """ + merged = {**self._bound._input_transforms, **transforms} + return BoundModel(model=self._bound._model, input_transforms=merged) + + +class _FieldExtractorMixin: + """Turn unknown public attributes into FieldExtractors. + + Real model attributes are still resolved by the normal pydantic/base-model + attribute path via ``super().__getattr__``. + """ + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + if name.startswith("_"): + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None + return FieldExtractor(source=self, field_name=name) + + +class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): + """Shared behavior for models generated by ``@Flow.model``.""" + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + @property + def flow(self) -> FlowAPI: + return FlowAPI(self) + + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator for this generated model.""" + + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + if cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self, "_bound_fields", set()) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + object.__setattr__(self, "_instance_context_validator", _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema)) + return self._instance_context_validator class Lazy: """Deferred model execution with runtime context overrides. - Wraps a CallableModel to allow context fields to be determined at - runtime rather than at construction time. Use in with_inputs() when - you need values that aren't available until execution. + Has two distinct uses: - Example: - # Create a model that needs runtime-determined context - market_data = load_market_data(symbols=["AAPL"]) + 1. **Type annotation** — ``Lazy[T]`` marks a parameter as lazily evaluated. + The framework will NOT pre-evaluate the dependency; instead the function + receives a zero-arg thunk that triggers evaluation on demand:: - # Use Lazy to defer the start_date calculation - lookback_data = market_data.flow.with_inputs( - start_date=Lazy(market_data)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) - ) + @Flow.model + def smart_training( + data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # NOT eagerly evaluated + threshold: float = 0.9, + ) -> Metrics: + if fast_metrics.r2 > threshold: + return fast_metrics + return slow_metrics() # Evaluated on demand + + 2. **Runtime helper** — ``Lazy(model)(overrides)`` creates a callable that + applies context overrides before calling the model. Used with + ``with_inputs()`` for deferred execution:: + + lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + """ - # More commonly, use Lazy for self-referential transforms: - adjusted_model = model.flow.with_inputs( - value=Lazy(other_model)(multiplier=2) # Call other_model with multiplier=2 - ) + def __class_getitem__(cls, item): + """Support Lazy[T] syntax as a type annotation marker. - The __call__ method returns a callable that, when invoked with a context, - calls the wrapped model with the specified overrides applied. - """ + Returns Annotated[T, _LazyMarker()] so the framework can detect + lazy parameters during signature analysis. + """ + return Annotated[item, _LazyMarker()] def __init__(self, model: "CallableModel"): # noqa: F821 """Wrap a model for deferred execution. @@ -193,11 +392,7 @@ def __call__(self, **overrides) -> Callable[[ContextBase], Any]: def execute_with_overrides(context: ContextBase) -> Any: # Build context dict from incoming context - ctx_dict = {} - if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: - ctx_dict.update(context.__pydantic_extra__) - for field in context.__class__.model_fields: - ctx_dict[field] = getattr(context, field) + ctx_dict = _context_values(context) # Apply overrides for name, value in overrides.items(): @@ -275,18 +470,23 @@ def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: return extract_dep(annotation) +_UNSET = object() + + def flow_model( func: Callable = None, *, # Context handling context_args: Optional[List[str]] = None, # Flow.call options (passed to generated __call__) - cacheable: bool = False, - volatile: bool = False, - log_level: int = logging.DEBUG, - validate_result: bool = True, - verbose: bool = True, - evaluator: Optional[Any] = None, + # Default to _UNSET so FlowOptionsOverride can control these globally. + # Only explicitly user-provided values are passed to Flow.call. + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, ) -> Callable: """Decorator that generates a CallableModel class from a plain Python function. @@ -297,12 +497,12 @@ def flow_model( Args: func: The function to decorate context_args: List of parameter names that come from context (for unpacked mode) - cacheable: Enable caching of results - volatile: Mark as volatile (always re-execute) - log_level: Logging verbosity - validate_result: Validate return type - verbose: Verbose logging output - evaluator: Custom evaluator + cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) + volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) + log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) + validate_result: Validate return type (default: unset, inherits from FlowOptionsOverride) + verbose: Verbose logging output (default: unset, inherits from FlowOptionsOverride) + evaluator: Custom evaluator (default: unset, inherits from FlowOptionsOverride) Two Context Modes: 1. Explicit context parameter: Function has a 'context' parameter annotated @@ -323,29 +523,40 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ """ def decorator(fn: Callable) -> Callable: - # Import here to avoid circular imports - from .callable import CallableModel, Flow, GraphDepList + import typing as _typing sig = inspect.signature(fn) params = sig.parameters + # Resolve string annotations (PEP 563 / from __future__ import annotations) + # into real type objects. include_extras=True preserves Annotated metadata. + try: + _resolved_hints = _typing.get_type_hints(fn, include_extras=True) + except Exception: + _resolved_hints = {} + # Validate return type - return_type = sig.return_annotation + return_type = _resolved_hints.get("return", sig.return_annotation) if return_type is inspect.Signature.empty: raise TypeError(f"Function {fn.__name__} must have a return type annotation") - # Check that return type is a ResultBase subclass + # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult return_origin = get_origin(return_type) or return_type + auto_wrap_result = False if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): - raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") + auto_wrap_result = True + internal_return_type = GenericResult # unparameterized for safety + else: + internal_return_type = return_type # Determine context mode if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] - if context_param.annotation is inspect.Parameter.empty: + context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) + if context_annotation is inspect.Parameter.empty: raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") - context_type = context_param.annotation + context_type = context_annotation if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} @@ -372,19 +583,28 @@ def decorator(fn: Callable) -> Callable: use_context_args = True explicit_context_args = None # Dynamic - determined at construction - # Analyze parameters to find dependencies and regular fields + # Analyze parameters to find dependencies, lazy fields, and regular fields dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + lazy_fields: set = set() # Names of parameters marked with Lazy[T] # In dynamic deferred mode (no explicit context_args), all fields are optional # because values not provided at construction come from context at runtime dynamic_deferred_mode = use_context_args and explicit_context_args is None for name, param in model_field_params.items(): - if param.annotation is inspect.Parameter.empty: + # Use resolved hint (handles PEP 563 string annotations) + annotation = _resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: raise TypeError(f"Parameter '{name}' must have a type annotation") - base_type, dep = _get_dep_info(param.annotation) + # Check for Lazy[T] annotation first + unwrapped_annotation, is_lazy = _extract_lazy(annotation) + if is_lazy: + lazy_fields.add(name) + + # Extract Dep info from the (possibly unwrapped) annotation + base_type, dep = _get_dep_info(unwrapped_annotation) if param.default is not inspect.Parameter.empty: default = param.default elif dynamic_deferred_mode: @@ -409,7 +629,7 @@ def decorator(fn: Callable) -> Callable: # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" all_param_names = list(model_fields.keys()) # All non-context params (model fields) - all_param_types = {name: param.annotation for name, param in model_field_params.items()} + all_param_types = {name: _resolved_hints.get(name, param.annotation) for name, param in model_field_params.items()} # For explicit context_args mode, we also need the list of context arg names ctx_args_for_closure = context_args if context_args is not None else [] is_dynamic_mode = use_context_args and explicit_context_args is None @@ -420,26 +640,14 @@ def __call__(self, context): # Import here (inside function) to avoid pickling issues with ContextVar from .callable import _resolved_deps - # Check if this model has custom deps (from @func.deps decorator) - has_custom_deps = getattr(self.__class__, "__has_custom_deps__", False) - def resolve_callable_model(name, value, store): - """Resolve a CallableModel field. - - When has_custom_deps is True and the value is NOT in the store, - it means the custom deps function chose not to include this dep. - In that case, we return None (the field's default) instead of - calling the CallableModel directly. - """ + """Resolve a CallableModel field.""" if id(value) in store: return store[id(value)] - elif has_custom_deps: - # Custom deps excluded this field - use None - return None else: # Auto-detection fallback: call directly resolved = value(context) - if hasattr(resolved, "value"): + if isinstance(resolved, GenericResult): return resolved.value return resolved @@ -447,16 +655,28 @@ def resolve_callable_model(name, value, store): fn_kwargs = {} store = _resolved_deps.get() + def _resolve_field(name, value): + """Resolve a single field value, handling lazy wrapping.""" + is_dep = isinstance(value, (CallableModel, BoundModel)) + if name in lazy_fields: + # Lazy field: wrap in a thunk regardless of type + if is_dep: + return _make_lazy_thunk(value, context) + else: + # Non-dep value: wrap in trivial thunk + return lambda v=value: v + elif is_dep: + return resolve_callable_model(name, value, store) + else: + return value + if not use_context_args: # Mode 1: Explicit context param - pass context directly fn_kwargs[ctx_param_name] = context # Add model fields for name in all_param_names: value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) elif not is_dynamic_mode: # Mode 2: Explicit context_args - get those from context, rest from self for name in ctx_args_for_closure: @@ -464,10 +684,7 @@ def resolve_callable_model(name, value, store): # Add model fields for name in all_param_names: value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) else: # Mode 3: Dynamic deferred mode - unbound from context, bound from self bound_fields = getattr(self, "_bound_fields", set()) @@ -476,15 +693,15 @@ def resolve_callable_model(name, value, store): if name in bound_fields: # Bound at construction - get from self value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) else: # Unbound - get from context fn_kwargs[name] = getattr(context, name) - return fn(**fn_kwargs) + raw_result = fn(**fn_kwargs) + if auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result # Set proper signature for CallableModel validation __call__.__signature__ = inspect.Signature( @@ -492,22 +709,24 @@ def resolve_callable_model(name, value, store): inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), ], - return_annotation=return_type, + return_annotation=internal_return_type, ) return __call__ call_impl = make_call_impl() - # Apply Flow.call decorator - flow_options = { - "cacheable": cacheable, - "volatile": volatile, - "log_level": log_level, - "validate_result": validate_result, - "verbose": verbose, - } - if evaluator is not None: - flow_options["evaluator"] = evaluator + # Apply Flow.call decorator — only include options the user explicitly set + flow_options = {} + for opt_name, opt_val in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ]: + if opt_val is not _UNSET: + flow_options[opt_name] = opt_val decorated_call = Flow.call(**flow_options)(call_impl) @@ -515,10 +734,12 @@ def resolve_callable_model(name, value, store): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - # Check ALL fields for CallableModels (auto-detection) + # Check ALL fields for CallableModels/BoundModels (auto-detection) for name in model_fields: + if name in lazy_fields: + continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) - if isinstance(value, CallableModel): + if isinstance(value, (CallableModel, BoundModel)): if name in dep_fields: # Explicit DepOf with transform (backwards compat) _, dep_obj = dep_fields[name] @@ -582,17 +803,20 @@ def __validate_deps__(self): namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) + _validatable_types, _config_validators = _build_config_validators(all_param_types, dep_fields) + # Create the class using type() - GeneratedModel = type(f"_{fn.__name__}_Model", (CallableModel,), namespace) + GeneratedModel = type(f"_{fn.__name__}_Model", (_GeneratedFlowModelBase,), namespace) # Set class-level attributes after class creation (to avoid pydantic processing) GeneratedModel.__flow_model_context_type__ = context_type - GeneratedModel.__flow_model_return_type__ = return_type + GeneratedModel.__flow_model_return_type__ = internal_return_type GeneratedModel.__flow_model_func__ = fn GeneratedModel.__flow_model_dep_fields__ = dep_fields GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result # Build context_schema and matched_context_type context_schema: Dict[str, Type] = {} @@ -617,68 +841,9 @@ def __validate_deps__(self): # Validator is created lazily to survive pickling GeneratedModel._cached_context_validator = None - # Method to get/create context validator (lazy for pickling support) - def _get_context_validator(self) -> TypeAdapter: - """Get or create the context validator. - - For dynamic deferred mode, builds schema from unbound fields. - For explicit context_args or explicit context mode, uses cached schema. - """ - cls = self.__class__ - explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) - - # For explicit context_args or explicit context mode, use cached validator - if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): - if cls._cached_context_validator is None: - if cls._context_td is not None: - cls._cached_context_validator = TypeAdapter(cls._context_td) - elif cls._context_schema: - td = TypedDict(f"{cls.__name__}Inputs", cls._context_schema) - cls._cached_context_validator = TypeAdapter(td) - else: - cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) - return cls._cached_context_validator - - # Dynamic mode: build schema from unbound fields (instance-specific) - # Cache on instance since bound_fields varies per instance - if not hasattr(self, "_instance_context_validator"): - all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self, "_bound_fields", set()) - unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} - if unbound_schema: - td = TypedDict(f"{cls.__name__}Inputs", unbound_schema) - object.__setattr__(self, "_instance_context_validator", TypeAdapter(td)) - else: - # No unbound fields - empty validator - object.__setattr__(self, "_instance_context_validator", TypeAdapter(dict)) - return self._instance_context_validator - - GeneratedModel._get_context_validator = _get_context_validator - - # Override context_type property after class creation - @property - def context_type_getter(self) -> Type[ContextBase]: - return self.__class__.__flow_model_context_type__ - - # Override result_type property after class creation - @property - def result_type_getter(self) -> Type[ResultBase]: - return self.__class__.__flow_model_return_type__ - - # Add .flow property for the new API - @property - def flow_getter(self) -> FlowAPI: - return FlowAPI(self) - - GeneratedModel.context_type = context_type_getter - GeneratedModel.result_type = result_type_getter - GeneratedModel.flow = flow_getter - # Register the MODEL class for serialization (needed for model_dump/_target_). # Note: We do NOT register dynamic context classes anymore - context handling # uses FlowContext + TypedDict instead, which don't need registration. - from .local_persistence import register_ccflow_import_path - register_ccflow_import_path(GeneratedModel) # Rebuild the model to process annotations properly @@ -687,6 +852,8 @@ def flow_getter(self) -> FlowAPI: # Create factory function that returns model instances @wraps(fn) def factory(**kwargs) -> GeneratedModel: + _validate_config_kwargs(kwargs, _validatable_types, _config_validators) + instance = GeneratedModel(**kwargs) # Track which fields were explicitly provided at construction # These are "bound" - everything else comes from context at runtime @@ -697,49 +864,68 @@ def factory(**kwargs) -> GeneratedModel: factory._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ - # Add .deps decorator for customizing __deps__ - def deps_decorator(deps_fn): - """Decorator to customize the __deps__ method. - - Usage: - @Flow.model - def my_func(start_date: date, prices: dict) -> GenericResult[...]: - ... - - @my_func.deps - def _(self, context): - # Custom context transform - lookback_ctx = FlowContext( - start_date=context.start_date - timedelta(days=30), - end_date=context.end_date, - ) - return [(self.prices, [lookback_ctx])] - """ - from .callable import GraphDepList - - # Rename the function to __deps__ so Flow.deps accepts it - deps_fn.__name__ = "__deps__" - deps_fn.__qualname__ = f"{GeneratedModel.__qualname__}.__deps__" - # Set proper signature to match __call__'s context type - deps_fn.__signature__ = inspect.Signature( - parameters=[ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), - ], - return_annotation=GraphDepList, - ) - # Wrap with Flow.deps and replace on the class - decorated = Flow.deps(deps_fn) - GeneratedModel.__deps__ = decorated - # Mark that this model has custom deps (so _resolve_deps_and_call will call it) - GeneratedModel.__has_custom_deps__ = True - return factory # Return factory for chaining - - factory.deps = deps_decorator - return factory # Handle both @Flow.model and @Flow.model(...) syntax if func is not None: return decorator(func) return decorator + + +# ============================================================================= +# FieldExtractor — structured output field access +# ============================================================================= + + +class FieldExtractor(_FieldExtractorMixin, CallableModel): + """Extracts a named field from a source model's result. + + Created automatically by accessing an unknown attribute on a @Flow.model + instance (e.g., ``prepared.X_train``). The extractor is itself a + CallableModel, so it can be wired as a dependency to downstream models. + + When evaluated, it runs the source model and returns + ``GenericResult(value=getattr(source_result, field_name))``. + + Multiple extractors from the same source share the source model instance. + If caching is enabled on the evaluator, the source is evaluated only once. + """ + + source: Any # The source CallableModel + field_name: str # The attribute to extract + + @property + def context_type(self): + if isinstance(self.source, _CallableModel): + return self.source.context_type + return ContextBase + + @property + def result_type(self): + return GenericResult + + @Flow.call + def __call__(self, context: ContextBase) -> GenericResult: + # Lazy import: _resolved_deps is a ContextVar that can't be pickled + from .callable import _resolved_deps + + store = _resolved_deps.get() + if id(self.source) in store: + result = store[id(self.source)] + else: + result = self.source(context) + if isinstance(result, GenericResult): + result = result.value + # Support both attribute access and dict key access + if isinstance(result, dict): + return GenericResult(value=result[self.field_name]) + return GenericResult(value=getattr(result, self.field_name)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + if isinstance(self.source, _CallableModel): + return [(self.source, [context])] + return [] + + +register_ccflow_import_path(FieldExtractor) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 3f613ab..61869f9 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -26,12 +26,12 @@ def test_flow_context_basic(self): assert ctx.end_date == date(2024, 1, 31) def test_flow_context_extra_fields(self): - """FlowContext stores fields in __pydantic_extra__.""" + """FlowContext exposes arbitrary fields through normal model APIs.""" ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) assert ctx.x == 1 assert ctx.y == "hello" assert ctx.z == [1, 2, 3] - assert ctx.__pydantic_extra__ == {"x": 1, "y": "hello", "z": [1, 2, 3]} + assert dict(ctx) == {"x": 1, "y": "hello", "z": [1, 2, 3]} def test_flow_context_frozen(self): """FlowContext is immutable (frozen).""" diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index b283a2b..b547aee 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -15,6 +15,7 @@ DepOf, Flow, GenericResult, + Lazy, ModelRegistry, ResultBase, ) @@ -599,15 +600,55 @@ def no_return(context: SimpleContext): self.assertIn("return type annotation", str(cm.exception)) - def test_non_result_return_type(self): - """Test error when return type is not ResultBase subclass.""" - with self.assertRaises(TypeError) as cm: + def test_auto_wrap_plain_return_type(self): + """Test that non-ResultBase return types are auto-wrapped in GenericResult.""" - @Flow.model - def bad_return(context: SimpleContext) -> int: - return 42 + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value * 2 + + model = plain_return() + result = model(SimpleContext(value=5)) + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 10) + + def test_auto_wrap_unwrap_as_dependency(self): + """Test that auto-wrapped model used as dep delivers unwrapped value downstream. + + Auto-wrapped models have result_type=GenericResult (unparameterized). + When used as an auto-detected dep (no DepOf), the framework resolves + the GenericResult and unwraps .value for the downstream function. + """ + + @Flow.model + def plain_source(context: SimpleContext) -> int: + return context.value * 3 - self.assertIn("ResultBase", str(cm.exception)) + @Flow.model + def consumer( + context: SimpleContext, + data: GenericResult[int], # Auto-detected dep, not DepOf + ) -> GenericResult[int]: + # data is auto-unwrapped to the int value by the framework + return GenericResult(value=data + 1) + + src = plain_source() + model = consumer(data=src) + result = model(SimpleContext(value=10)) + # plain_source: 10 * 3 = 30, auto-wrapped to GenericResult(value=30) + # resolve_callable_model unwraps GenericResult -> 30 + # consumer: 30 + 1 = 31 + self.assertEqual(result.value, 31) + + def test_auto_wrap_result_type_property(self): + """Test that auto-wrapped model has GenericResult as result_type.""" + + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value + + model = plain_return() + self.assertEqual(model.result_type, GenericResult) def test_dynamic_deferred_mode(self): """Test dynamic deferred mode where what you provide at construction = bound.""" @@ -953,6 +994,271 @@ def consumer( with self.assertRaises((TypeError, ValidationError)): consumer(data=load2) + def test_config_validation_rejects_bad_type(self): + """Test that config validator rejects wrong types at construction.""" + + @Flow.model + def typed_config(context: SimpleContext, n_estimators: int = 10) -> GenericResult[int]: + return GenericResult(value=n_estimators) + + with self.assertRaises(TypeError) as cm: + typed_config(n_estimators="banana") + + self.assertIn("n_estimators", str(cm.exception)) + + def test_config_validation_accepts_callable_model(self): + """Test that config validator allows CallableModel values for any field.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data) + + # Passing a CallableModel for an int field should not raise + src = source() + model = consumer(data=src) + self.assertIsNotNone(model) + + def test_config_validation_accepts_correct_types(self): + """Test that config validator accepts correct types.""" + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{name}:{n}") + + # Should not raise + model = typed_config(n=42, name="test") + result = model(SimpleContext(value=1)) + self.assertEqual(result.value, "test:42") + + +# ============================================================================= +# BoundModel Tests +# ============================================================================= + + +class TestBoundModel(TestCase): + """Tests for BoundModel and BoundModel.flow.""" + + def test_bound_model_flow_compute(self): + """Test that bound.flow.compute() honors transforms.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = my_model(x=10) + + # Create bound model with y transform + bound = model.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 2) + + # flow.compute() should go through BoundModel, applying transform + result = bound.flow.compute(y=5) + # y transform: 5 * 2 = 10, x is bound to 10 + # model: 10 + 10 = 20 + self.assertEqual(result, 20) + + def test_bound_model_flow_compute_static_transform(self): + """Test BoundModel.flow.compute() with static value transform.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x * y) + + model = my_model(x=7) + bound = model.flow.with_inputs(y=3) + + result = bound.flow.compute(y=999) # y should be overridden by transform + # y is statically bound to 3, x=7 + # 7 * 3 = 21 + self.assertEqual(result, 21) + + def test_bound_model_as_dependency(self): + """Test that BoundModel can be passed as a dependency to another model.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int]) -> GenericResult[int]: + return GenericResult(value=data + 1) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + + # Pass BoundModel as a dependency + model = consumer(data=bound_src) + result = model.flow.compute(x=5) + # x transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result, 101) + + def test_bound_model_chained_with_inputs(self): + """Test that chaining with_inputs merges transforms correctly.""" + + @Flow.model + def my_model(x: int, y: int, z: int) -> int: + return x + y + z + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 3) + + # Both transforms should be active + result = bound2.flow.compute(x=5, y=10, z=1) + # x transform: 5 * 2 = 10 + # y transform: 10 * 3 = 30 + # z from context: 1 + # 10 + 30 + 1 = 41 + self.assertEqual(result, 41) + + def test_bound_model_chained_with_inputs_override(self): + """Test that chaining with_inputs allows overriding transforms.""" + + @Flow.model + def my_model(x: int) -> int: + return x + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 10) + + # Second transform should override the first for 'x' + result = bound2.flow.compute(x=5) + self.assertEqual(result, 50) # 5 * 10, not 5 * 2 + + def test_bound_model_with_default_args(self): + """with_inputs works when the model has parameters with default values.""" + + @Flow.model + def load(start_date: str, end_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}-{end_date}" + + # Bind source at construction, leave dates for context + model = load(source="prod_db") + + # with_inputs transforms a context param; default-valued 'source' stays bound + lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") + self.assertEqual(result, "prod_db:shifted_2024-01-01-2024-06-30") + + def test_bound_model_with_default_arg_unbound(self): + """with_inputs works when defaulted parameter is left unbound (comes from context).""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + # Don't bind 'source' — it keeps its default in the model, + # but in dynamic deferred mode, unbound params come from context + model = load() + + # Transform start_date; source comes from context (overriding the default) + bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + result = bound.flow.compute(start_date="2024-01-01", source="s3_bucket") + self.assertEqual(result, "s3_bucket:shifted_2024-01-01") + + def test_bound_model_default_arg_as_dependency(self): + """BoundModel with default args works correctly as a dependency.""" + + @Flow.model + def source(x: int, multiplier: int = 2) -> int: + return x * multiplier + + @Flow.model + def consumer(data: int) -> int: + return data + 1 + + src = source(multiplier=5) + bound_src = src.flow.with_inputs(x=lambda ctx: ctx.x * 10) + model = consumer(data=bound_src) + + result = model.flow.compute(x=3) + # x transform: 3 * 10 = 30 + # source: 30 * 5 (multiplier) = 150 + # consumer: 150 + 1 = 151 + self.assertEqual(result, 151) + + def test_bound_model_as_lazy_dependency(self): + """Test that BoundModel works as a Lazy dependency.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 3) + + @Flow.model + def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: + if data > 100: + return GenericResult(value=data) + return GenericResult(value=slow()) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) + 10) + + # Use BoundModel as lazy dependency + model = consumer(data=5, slow=bound_src) + result = model.flow.compute(x=7) + # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 + self.assertEqual(result, 51) + + +# ============================================================================= +# PEP 563 (from __future__ import annotations) Compatibility Tests +# ============================================================================= + +# These functions are defined at module level to simulate realistic usage. +# Note: We can't use `from __future__ import annotations` at module level +# since it would affect ALL annotations in this file. Instead, we test +# that the annotation resolution code handles string annotations. + + +class TestPEP563Annotations(TestCase): + """Test that Flow.model handles string annotations (PEP 563).""" + + def test_string_annotation_lazy_resolved(self): + """Test that Lazy annotations work even when passed through get_type_hints. + + This verifies the fix for from __future__ import annotations by + confirming the annotation resolution pipeline processes Lazy correctly. + """ + # Verify _extract_lazy handles real type objects (resolved by get_type_hints) + from ccflow.flow_model import _extract_lazy + + lazy_int = Lazy[int] + unwrapped, is_lazy = _extract_lazy(lazy_int) + self.assertTrue(is_lazy) + self.assertEqual(unwrapped, int) + + def test_string_annotation_return_type_resolved(self): + """Test that string return type annotations are resolved correctly.""" + + @Flow.model + def model_func(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=42) + + # If annotation resolution works, this should create successfully + model = model_func() + self.assertEqual(model.result_type, GenericResult[int]) + + def test_auto_wrap_with_resolved_annotations(self): + """Test that auto-wrap works with properly resolved type annotations.""" + + @Flow.model + def plain_model(value: int) -> int: + return value * 2 + + model = plain_model() + result = model.flow.compute(value=5) + self.assertEqual(result, 10) + self.assertEqual(model.result_type, GenericResult) + # ============================================================================= # Hydra Integration Tests @@ -1555,6 +1861,496 @@ def __deps__(self, context: SimpleContext): self.assertEqual(call_counts["class_model"], 1) +# ============================================================================= +# Lazy[T] Type Annotation Tests +# ============================================================================= + + +class TestLazyTypeAnnotation(TestCase): + """Tests for Lazy[T] type annotation (deferred/conditional evaluation).""" + + def test_lazy_type_annotation_basic(self): + """Lazy[T] param receives a thunk (zero-arg callable). + + The thunk unwraps GenericResult.value, so calling thunk() returns + the inner value (e.g., int), not the GenericResult wrapper. + """ + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # data() returns the unwrapped value (int) + resolved = data() + return GenericResult(value=resolved + 1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_lazy_conditional_evaluation(self): + """Mirror the smart_training example: lazy dep only evaluated if needed. + + Note: Non-lazy CallableModel deps are auto-resolved and their .value is + unwrapped by the framework (auto-detected dep resolution). So 'fast' + receives the unwrapped int, while 'slow' receives a thunk that returns + the unwrapped value (GenericResult.value) when called. + """ + from ccflow import Lazy + + call_counts = {"fast": 0, "slow": 0} + + @Flow.model + def fast_path(context: SimpleContext) -> GenericResult[int]: + call_counts["fast"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def slow_path(context: SimpleContext) -> GenericResult[int]: + call_counts["slow"] += 1 + return GenericResult(value=context.value * 100) + + @Flow.model + def smart_selector( + context: SimpleContext, + fast: GenericResult[int], # Auto-resolved: receives unwrapped int + slow: Lazy[GenericResult[int]], # Lazy: receives thunk returning unwrapped value + threshold: int = 10, + ) -> GenericResult[int]: + # fast is auto-unwrapped to the int value by the framework + if fast > threshold: + return GenericResult(value=fast) + else: + return GenericResult(value=slow()) + + fast = fast_path() + slow = slow_path() + + # Case 1: fast path sufficient (value > threshold) + model = smart_selector(fast=fast, slow=slow, threshold=10) + result = model(SimpleContext(value=20)) + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 0) # Never called! + + # Case 2: fast path insufficient (value <= threshold), slow triggered + call_counts["fast"] = 0 + model2 = smart_selector(fast=fast, slow=slow, threshold=100) + result2 = model2(SimpleContext(value=5)) + self.assertEqual(result2.value, 500) # 5 * 100 + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 1) + + def test_lazy_thunk_caches_result(self): + """Repeated calls to a thunk return the same value without re-evaluation.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Call thunk multiple times — returns the unwrapped int + val1 = data() + val2 = data() + val3 = data() + self.assertEqual(val1, val2) + self.assertEqual(val2, val3) + return GenericResult(value=val1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 50) + self.assertEqual(call_counts["source"], 1) # Called only once despite 3 thunk() calls + + def test_lazy_with_direct_value(self): + """Pre-computed (non-CallableModel) value wrapped in trivial thunk.""" + from ccflow import Lazy + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[int], + ) -> GenericResult[int]: + # data is a thunk even though the underlying value is a plain int + return GenericResult(value=data() * 2) + + model = consumer(data=42) + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, 84) + + def test_lazy_dep_excluded_from_deps(self): + """__deps__ does NOT include lazy dependencies.""" + from ccflow import Lazy + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + return GenericResult(value=eager + lazy_dep()) + + eager = eager_source() + lazy = lazy_source() + model = consumer(eager=eager, lazy_dep=lazy) + + ctx = SimpleContext(value=5) + deps = model.__deps__(ctx) + + # Only eager dep should be in __deps__ + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], eager) + + def test_lazy_eager_dep_still_pre_evaluated(self): + """Non-lazy deps are still eagerly resolved via __deps__.""" + from ccflow import Lazy + + call_counts = {"eager": 0, "lazy": 0} + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + call_counts["eager"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + call_counts["lazy"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + # eager is auto-unwrapped to int, lazy_dep() returns unwrapped value + return GenericResult(value=eager + lazy_dep()) + + model = consumer(eager=eager_source(), lazy_dep=lazy_source()) + result = model(SimpleContext(value=5)) + + self.assertEqual(result.value, 55) # 5 + 50 + self.assertEqual(call_counts["eager"], 1) + self.assertEqual(call_counts["lazy"], 1) + + def test_lazy_in_dynamic_deferred_mode(self): + """Lazy[T] works in dynamic deferred mode (no context_args).""" + from ccflow import FlowContext, Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + value: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + if value > 10: + return GenericResult(value=value) + return GenericResult(value=data()) # data() returns unwrapped int + + # value comes from context, data is bound at construction + model = consumer(data=source()) + result = model(FlowContext(value=20)) # value > 10, lazy not called + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_in_context_args_mode(self): + """Lazy[T] works with explicit context_args.""" + from ccflow import FlowContext, Lazy + + @Flow.model(context_args=["x"]) + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model(context_args=["x"]) + def consumer( + x: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=x + data()) # data() returns unwrapped int + + model = consumer(data=source()) + result = model(FlowContext(x=5)) + self.assertEqual(result.value, 55) # 5 + 50 + + def test_lazy_never_evaluated_if_not_called(self): + """If thunk is never called, the dependency is never evaluated.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Never call data() + return GenericResult(value=42) + + model = consumer(data=source()) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 42) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_with_depof(self): + """Lazy[DepOf[...]] works: lazy dep with explicit DepOf annotation.""" + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[DepOf[..., GenericResult[int]]], + ) -> GenericResult[int]: + return GenericResult(value=data() + 1) # data() returns unwrapped int + + src = source() + model = consumer(data=src) + + # Lazy dep should NOT be in __deps__ + deps = model.__deps__(SimpleContext(value=5)) + self.assertEqual(len(deps), 0) + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 51) # 50 + 1 + + +# ============================================================================= +# FieldExtractor Tests (Structured Output Field Access) +# ============================================================================= + + +class TestFieldExtractor(TestCase): + """Tests for structured output field access (prepared.X_train pattern).""" + + def test_field_extraction_basic(self): + """Accessing unknown attr on @Flow.model instance returns FieldExtractor.""" + from ccflow.flow_model import FieldExtractor + + @Flow.model + def prepare(context: SimpleContext, factor: int = 2) -> GenericResult[dict]: + return GenericResult(value={"X_train": context.value * factor, "X_test": context.value}) + + model = prepare(factor=3) + extractor = model.X_train + + self.assertIsInstance(extractor, FieldExtractor) + self.assertIs(extractor.source, model) + self.assertEqual(extractor.field_name, "X_train") + + def test_field_extraction_evaluates_correctly(self): + """FieldExtractor runs source and extracts the named field.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"X_train": [1, 2, 3], "y_train": [4, 5, 6]}) + + model = prepare() + x_train = model.X_train + + result = x_train(SimpleContext(value=0)) + self.assertEqual(result.value, [1, 2, 3]) + + def test_field_extraction_as_dependency(self): + """FieldExtractor wired as a dep to a downstream model. + + Note: FieldExtractors are CallableModels, so they're auto-detected as deps + and auto-unwrapped (GenericResult.value). The downstream function receives + the raw extracted value, not a GenericResult wrapper. + """ + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + v = context.value + return GenericResult(value={"X_train": [v, v * 2], "y_train": [v * 10]}) + + @Flow.model + def train(context: SimpleContext, X: list, y: list) -> GenericResult[int]: + # X and y are auto-unwrapped to the raw list values + return GenericResult(value=sum(X) + sum(y)) + + prepared = prepare() + model = train(X=prepared.X_train, y=prepared.y_train) + + result = model(SimpleContext(value=5)) + # X_train = [5, 10], y_train = [50] + # sum(X) + sum(y) = 15 + 50 = 65 + self.assertEqual(result.value, 65) + + def test_field_extraction_multiple_from_same_source(self): + """Multiple extractors from same source share the source instance.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"a": 1, "b": 2, "c": 3}) + + model = prepare() + ext_a = model.a + ext_b = model.b + ext_c = model.c + + # All should reference the same source + self.assertIs(ext_a.source, model) + self.assertIs(ext_b.source, model) + self.assertIs(ext_c.source, model) + + # All should evaluate correctly + ctx = SimpleContext(value=0) + self.assertEqual(ext_a(ctx).value, 1) + self.assertEqual(ext_b(ctx).value, 2) + self.assertEqual(ext_c(ctx).value, 3) + + def test_field_extraction_nested(self): + """Chained extraction (result.a.b) creates nested FieldExtractors.""" + from ccflow.flow_model import FieldExtractor + + class Nested: + def __init__(self): + self.inner_val = 42 + + @Flow.model + def produce(context: SimpleContext) -> GenericResult: + return GenericResult(value={"nested": Nested()}) + + model = produce() + nested_extractor = model.nested + inner_extractor = nested_extractor.inner_val + + self.assertIsInstance(nested_extractor, FieldExtractor) + self.assertIsInstance(inner_extractor, FieldExtractor) + + result = inner_extractor(SimpleContext(value=0)) + self.assertEqual(result.value, 42) + + def test_field_extraction_context_type_inherited(self): + """FieldExtractor inherits context_type from source.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + self.assertEqual(extractor.context_type, SimpleContext) + + def test_field_extraction_nonexistent_field_runtime_error(self): + """Non-existent field raises error at evaluation time, not construction. + + For dict results, raises KeyError. For object results, raises AttributeError. + """ + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.nonexistent # No error at construction + + # Error at evaluation time (KeyError for dicts, AttributeError for objects) + with self.assertRaises((KeyError, AttributeError)): + extractor(SimpleContext(value=0)) + + def test_field_extraction_pydantic_fields_not_intercepted(self): + """Accessing real pydantic fields returns the field value, NOT an extractor.""" + from ccflow.flow_model import FieldExtractor + + @Flow.model + def model_with_fields(context: SimpleContext, multiplier: int = 5) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + model = model_with_fields(multiplier=10) + + # 'multiplier' is a real pydantic field — should return the value, not a FieldExtractor + self.assertEqual(model.multiplier, 10) + self.assertNotIsInstance(model.multiplier, FieldExtractor) + + # 'meta' is inherited from CallableModel — should also not be intercepted + self.assertNotIsInstance(model.meta, FieldExtractor) + + def test_field_extraction_with_context_args(self): + """FieldExtractor works with context_args mode models.""" + from ccflow import FlowContext + + @Flow.model(context_args=["x"]) + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2, "tripled": x * 3}) + + model = prepare() + doubled = model.doubled + + result = doubled(FlowContext(x=5)) + self.assertEqual(result.value, 10) + + def test_field_extraction_has_flow_property(self): + """FieldExtractor has .flow property (inherits from CallableModel).""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + self.assertTrue(hasattr(extractor, "flow")) + + def test_field_extraction_deps(self): + """FieldExtractor.__deps__ returns the source as a dependency.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + ctx = SimpleContext(value=0) + deps = extractor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], model) + self.assertEqual(deps[0][1], [ctx]) + + if __name__ == "__main__": import unittest From 097ae6220cb7b2c70d897c806ab5caf885015199 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 05:52:31 -0400 Subject: [PATCH 09/17] Update docs for @Flow.model Signed-off-by: Nijat Khanbabayev --- docs/design/flow_model_design.md | 141 ++++++++++++++++++++++--------- docs/wiki/Key-Features.md | 110 ++++++++++++++++++------ 2 files changed, 183 insertions(+), 68 deletions(-) diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 76d0eb7..909b597 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -6,38 +6,55 @@ This document describes the `@Flow.model` decorator and `DepOf` annotation syste **Key features:** - `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions +- `FlowContext` - Universal context carrier for unpacked/deferred execution +- `model.flow.compute(...)` / `model.flow.with_inputs(...)` - Deferred execution helpers - `DepOf[ContextType, ResultType]` - Type annotation for dependency fields +- `Lazy[T]` - Mark a dependency for lazy, on-demand evaluation +- `FieldExtractor` - Access structured outputs via attribute access on generated models - `resolve()` - Function to access resolved dependency values in class-based models ## Quick Start -### Pattern 1: `@Flow.model` (Recommended for Simple Cases) +### Pattern 1: `@Flow.model` (Recommended for Declarative Cases) ```python from datetime import date, timedelta from typing import Annotated -from ccflow import Flow, DateRangeContext, GenericResult, DepOf +from ccflow import Flow, DateRangeContext, GenericResult, Dep, DepOf + + +def previous_window(ctx: DateRangeContext) -> DateRangeContext: + window = ctx.end_date - ctx.start_date + return ctx.model_copy( + update={ + "start_date": ctx.start_date - window - timedelta(days=1), + "end_date": ctx.start_date - timedelta(days=1), + } + ) @Flow.model -def load_records(context: DateRangeContext, source: str) -> GenericResult[dict]: - return GenericResult(value={"count": 100, "date": str(context.start_date)}) +def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: + return GenericResult(value=125.0) @Flow.model -def compute_stats( +def revenue_growth( context: DateRangeContext, - records: DepOf[..., GenericResult[dict]], # Dependency field -) -> GenericResult[float]: - # records is already resolved - just use it directly - return GenericResult(value=records.value["count"] * 0.05) - -# Build pipeline -loader = load_records(source="main_db") -stats = compute_stats(records=loader) + current: DepOf[..., GenericResult[float]], + previous: Annotated[GenericResult[float], Dep(transform=previous_window)], +) -> GenericResult[dict]: + growth = (current.value - previous.value) / previous.value + return GenericResult(value={"as_of": context.end_date, "growth": growth}) + +# Build pipeline. The same upstream model is reused twice: +# - once with the original context +# - once with a fixed lookback transform +revenue = load_revenue(region="us") +growth = revenue_growth(current=revenue, previous=revenue) # Execute ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = stats(ctx) +result = growth(ctx) ``` ### Pattern 2: Class-Based (For Complex Cases) @@ -50,17 +67,17 @@ from datetime import timedelta from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf from ccflow.callable import resolve # Import resolve for class-based models -class AggregateWithWindow(CallableModel): - """Aggregate records with configurable lookback window.""" +class RevenueAverageWithWindow(CallableModel): + """Aggregate revenue with a configurable lookback window.""" - records: DepOf[..., GenericResult[dict]] + revenue: DepOf[..., GenericResult[float]] window: int = 7 # Configurable instance field @Flow.call def __call__(self, context: DateRangeContext) -> GenericResult[float]: # Use resolve() to get the resolved value - records = resolve(self.records) - return GenericResult(value=records.value["count"] / self.window) + revenue = resolve(self.revenue) + return GenericResult(value=revenue.value / self.window) @Flow.deps def __deps__(self, context: DateRangeContext): @@ -68,22 +85,22 @@ class AggregateWithWindow(CallableModel): lookback_ctx = context.model_copy( update={"start_date": context.start_date - timedelta(days=self.window)} ) - return [(self.records, [lookback_ctx])] + return [(self.revenue, [lookback_ctx])] # Usage - different window sizes, same source -loader = load_records(source="main_db") -agg_7 = AggregateWithWindow(records=loader, window=7) -agg_30 = AggregateWithWindow(records=loader, window=30) +loader = load_revenue(region="us") +avg_7 = RevenueAverageWithWindow(revenue=loader, window=7) +avg_30 = RevenueAverageWithWindow(revenue=loader, window=30) ``` ## When to Use Which Pattern -| Use `@Flow.model` when... | Use Class-Based when... | -|--------------------------------|--------------------------------------| -| Simple transformations | Transforms depend on instance fields | -| Fixed context transforms | Need `self.field` in `__deps__` | -| Less boilerplate is priority | Full control over resolution | -| No custom `__deps__` logic | Complex dependency patterns | +| Use `@Flow.model` when... | Use Class-Based when... | +|--------------------------------|---------------------------------------| +| The node still reads like a normal function | The main value is custom graph logic | +| Transforms are fixed/declarative | Transforms depend on instance fields | +| Less boilerplate is priority | You need full control over `__deps__` | +| Dependency wiring fits in the signature | Dependency behavior deserves its own class | ## Core Concepts @@ -104,6 +121,17 @@ data: DepOf[DateRangeContext, GenericResult[dict]] data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] ``` +For `@Flow.model`, plain non-`DepOf` parameters can also be populated with a +`CallableModel` instance. That lets callers either inject a concrete value or +splice in an upstream computation for the same parameter. Use `Dep`/`DepOf` +when you need explicit dependency metadata such as context transforms or +context-type validation. + +That means `DepOf` inside `@Flow.model` is most compelling when the function is +still doing real work and the dependency relationship is simple. If the node is +mostly a vessel for custom dependency graph wiring, a hand-written +`CallableModel` is usually clearer. + ### `Dep(transform=..., context_type=...)` For transforms, use the full `Annotated` form: @@ -158,12 +186,12 @@ resolved = resolve(self.data) # Type: GenericResult[int] 1. User calls `model(context)` 2. Generated `__call__` invokes `_resolve_deps_and_call()` -3. For each `DepOf` field containing a `CallableModel`: +3. For each dependency-bearing field containing a `CallableModel`: - Apply transform (if any) - Call the dependency - Store resolved value in context variable -4. Generated `__call__` retrieves resolved values via `resolve()` -5. Original function receives resolved values as arguments +4. Generated `__call__` reads the resolved values from the dependency store +5. Original function receives resolved values directly as normal function arguments ### Class-Based Resolution Flow @@ -171,6 +199,7 @@ resolved = resolve(self.data) # Type: GenericResult[int] 2. `_resolve_deps_and_call()` runs 3. For each `DepOf` field containing a `CallableModel`: - Check `__deps__` for custom transforms + - If not listed in `__deps__`, fall back to the field's `Dep(...)` transform (or the original context) - Call the dependency - Store resolved value in context variable 4. User's `__call__` accesses values via `resolve(self.field)` @@ -211,14 +240,18 @@ resolved = resolve(self.data) # Type: GenericResult[int] - Keeps top-level namespace clean - Users who need it can find it easily -### Decision 4: No Auto-Wrapping Return Values +### Decision 4: Auto-Wrap Plain Return Values -**What we chose:** Functions must explicitly return `ResultBase` subclass. +**What we chose:** If the function's declared return type is not a `ResultBase` +subclass, the generated model wraps the returned value in `GenericResult`. **Why:** -- Type annotations remain honest -- Consistent with existing `CallableModel` contract -- `GenericResult(value=x)` is minimal overhead +- Reduces boilerplate for simple scalar / container-returning functions +- Preserves the `CallableModel` contract that runtime results are `ResultBase` +- Still allows explicit `ResultBase` subclasses when you want a precise result type + +**Trade-off:** The original Python function may be annotated with a plain value +type while the generated model's runtime `result_type` is `GenericResult`. ### Decision 5: Generated Classes Are Real CallableModels @@ -290,23 +323,47 @@ Users need to remember: - `@Flow.model`: Use dependency values directly as function arguments - Class-based: Use `resolve(self.field)` to access values -### Limitation: `__deps__` Still Required for Class-Based +### Limitation: Custom `__deps__` Is Only Needed for Custom Graph Logic -Even without transforms, class-based models need `__deps__`: +Class-based models do not need a custom `__deps__` override when the default +field-level `Dep(...)` behavior is sufficient. Override `__deps__` only when +you need instance-dependent transforms or a custom dependency graph: ```python class Consumer(CallableModel): data: DepOf[..., GenericResult[int]] + @Flow.call + def __call__(self, context): + return GenericResult(value=resolve(self.data).value) +``` + +If you do need to use instance fields in the transform, then `__deps__` is the +right place to do it: + +```python +class WindowedConsumer(CallableModel): + data: DepOf[..., GenericResult[int]] + window: int = 7 + @Flow.call def __call__(self, context): return GenericResult(value=resolve(self.data).value) @Flow.deps def __deps__(self, context): - return [(self.data, [context])] # Boilerplate, but required + shifted = context.model_copy(update={"value": context.value + self.window}) + return [(self.data, [shifted])] ``` +### Limitation: `context_args` Type Matching Is Best-Effort + +When you use `context_args=[...]`, the framework validates those fields via a +runtime `TypedDict` schema. It only maps to a concrete built-in context type in +special cases such as `DateRangeContext`. Otherwise the generated model's +`context_type` is `FlowContext`, a universal frozen carrier for the validated +context values. + ## Complete Example: Multi-Stage Pipeline ```python @@ -397,6 +454,10 @@ def my_function(context: ContextType, ...) -> ResultType: ... ``` +If the function is annotated with a plain value type instead of a `ResultBase` +subclass, the generated model will wrap the returned value in `GenericResult` +at runtime. + ### `DepOf[ContextType, ResultType]` ```python diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index a89d8f8..f73ac6b 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -49,46 +49,75 @@ result = loader(ctx) Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. -> **Tip:** If your function doesn't use the context directly (only passes it to dependencies), use `_` as the parameter name to signal this: `def my_func(_: DateRangeContext, data: DepOf[..., ResultType])`. This is a Python convention for intentionally unused parameters. +For `@Flow.model`, regular parameters can also accept a `CallableModel` value at +construction time. This lets you either inject a literal value or splice in an +upstream computation for the same parameter. Use `Dep`/`DepOf` when you need +context transforms or explicit dependency metadata. + +> **Rule of thumb:** `@Flow.model` works best when the dependency wiring is declarative and local to the signature. If the main point of the node is custom graph logic or transforms that depend on instance fields, use a class-based `CallableModel` instead. ```python from datetime import date, timedelta from typing import Annotated from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf -@Flow.model -def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: - return GenericResult(value={"records": [1, 2, 3]}) +def previous_window(ctx: DateRangeContext) -> DateRangeContext: + window = ctx.end_date - ctx.start_date + return ctx.model_copy( + update={ + "start_date": ctx.start_date - window - timedelta(days=1), + "end_date": ctx.start_date - timedelta(days=1), + } + ) @Flow.model -def transform_data( - _: DateRangeContext, # Context passed to dependency, not used directly - raw_data: Annotated[GenericResult[dict], Dep( - # Transform context to fetch one extra day for lookback - transform=lambda ctx: ctx.model_copy(update={ - "start_date": ctx.start_date - timedelta(days=1) - }) - )] -) -> GenericResult[dict]: - # raw_data.value contains the resolved result from load_data - return GenericResult(value={"transformed": raw_data.value["records"]}) +def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: + # Pretend this queries a warehouse + return GenericResult(value=125.0) -# Or use DepOf shorthand (no transform needed): @Flow.model -def aggregate_data( - _: DateRangeContext, # Context passed to dependency, not used directly - transformed: DepOf[..., GenericResult[dict]] # Shorthand for Annotated[T, Dep()] +def revenue_growth( + context: DateRangeContext, + current: DepOf[..., GenericResult[float]], + previous: Annotated[GenericResult[float], Dep(transform=previous_window)], ) -> GenericResult[dict]: - return GenericResult(value={"count": len(transformed.value["transformed"])}) + growth = (current.value - previous.value) / previous.value + return GenericResult(value={"as_of": context.end_date, "growth": growth}) -# Build the pipeline -data = load_data(source="my_database") -transformed = transform_data(raw_data=data) -aggregated = aggregate_data(transformed=transformed) +# Build the pipeline. The same loader is reused with two contexts: +# - current window: original context +# - previous window: transformed via Dep(transform=...) +revenue = load_revenue(region="us") +growth = revenue_growth(current=revenue, previous=revenue) -# Execute - dependencies are automatically resolved ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = aggregated(ctx) +result = growth(ctx) +``` + +`DepOf` is also useful when you want the same parameter to accept either an +upstream model or a precomputed value: + +```python +from ccflow import DateRangeContext, DepOf, Flow, GenericResult + +@Flow.model +def load_signal(context: DateRangeContext, source: str) -> GenericResult[float]: + return GenericResult(value=0.87) + +@Flow.model +def publish_signal( + context: DateRangeContext, + signal: DepOf[..., GenericResult[float]], + threshold: float = 0.8, +) -> GenericResult[dict]: + return GenericResult(value={ + "as_of": context.end_date, + "signal": signal.value, + "go_live": signal.value >= threshold, + }) + +live = publish_signal(signal=load_signal(source="prod")) +override = publish_signal(signal=GenericResult(value=0.95), threshold=0.9) ``` **Hydra/YAML Configuration:** @@ -126,7 +155,7 @@ from ccflow import Flow, GenericResult, DateRangeContext def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") -# The decorator infers DateRangeContext from the parameter types +# The decorator matches common built-in context types when possible loader = load_data(source="my_database") assert loader.context_type == DateRangeContext @@ -135,7 +164,32 @@ ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" ``` -The `context_args` parameter specifies which function parameters should be extracted from the context. The framework automatically determines the context type based on the parameter type annotations. +The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. + +**Deferred Execution Helpers:** + +Generated models also expose a `.flow` helper namespace: + +```python +from ccflow import Flow, GenericResult + +@Flow.model +def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + +model = add(x=10) + +# Validate and execute by passing context fields as kwargs +assert model.flow.compute(y=5) == 15 + +# Derive a new model by transforming context inputs +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 +``` + +If a `@Flow.model` function returns a plain value instead of a `ResultBase` +subclass, the generated model automatically wraps that value in `GenericResult` +at runtime so it still behaves like a normal `CallableModel`. ## Model Registry From 3d26896c245f2568b3cbd1aa8f3b7da74699218c Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 15:36:19 -0400 Subject: [PATCH 10/17] Clean up, ty check @Flow.model, add test Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 - ccflow/callable.py | 219 +----- ccflow/dep.py | 278 -------- ccflow/flow_model.py | 222 +++--- ccflow/tests/config/conf_flow.yaml | 2 +- ccflow/tests/test_flow_model.py | 927 ++++---------------------- ccflow/tests/test_flow_model_hydra.py | 16 +- docs/design/flow_model_design.md | 552 ++++----------- docs/wiki/Key-Features.md | 200 +++--- examples/flow_model_example.py | 246 ++----- 10 files changed, 565 insertions(+), 2098 deletions(-) delete mode 100644 ccflow/dep.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 4dbe143..1bb69fe 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -10,7 +10,6 @@ from .compose import * from .callable import * from .context import * -from .dep import * from .enums import Enum from .flow_model import * from .global_state import * diff --git a/ccflow/callable.py b/ccflow/callable.py index d3b22e4..fd849c5 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,7 +14,6 @@ import abc import inspect import logging -from contextvars import ContextVar from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -29,7 +28,6 @@ ResultBase, ResultType, ) -from .dep import Dep, extract_dep from .local_persistence import create_ccflow_model from .validators import str_to_log_level @@ -48,8 +46,6 @@ "EvaluatorBase", "Evaluator", "WrapperModel", - # Note: resolve() is intentionally not in __all__ to avoid namespace pollution. - # Users who need it can import explicitly: from ccflow.callable import resolve ) log = logging.getLogger(__name__) @@ -199,164 +195,6 @@ def _get_logging_evaluator(log_level): return LoggingEvaluator(log_level=log_level) -def _get_dep_fields(model_class) -> Dict[str, Dep]: - """Analyze class fields to find Dep-annotated fields. - - Returns a dict mapping field name to Dep instance for fields that need resolution. - """ - dep_fields = {} - - # Get type hints from the class - hints = {} - for cls in model_class.__mro__: - if hasattr(cls, "__annotations__"): - for name, annotation in cls.__annotations__.items(): - if name not in hints: # Don't override child class annotations - hints[name] = annotation - - for name, annotation in hints.items(): - base_type, dep = extract_dep(annotation) - if dep is not None: - dep_fields[name] = dep - - return dep_fields - - -def _wrap_with_dep_resolution(fn): - """Wrap a function to auto-resolve DepOf fields before calling. - - For each Dep-annotated field on the model that contains a CallableModel, - resolves it using __deps__ and temporarily sets the resolved value on self. - - Note: This wrapper is only applied at runtime when the function is called, - not during decoration. This avoids issues with functools.wraps flattening - the __wrapped__ chain. - - Args: - fn: The original function - - Returns: - The original function unchanged - dep resolution happens at the call site - """ - # Don't modify the function - dep resolution is handled in ModelEvaluationContext - return fn - - -# Context variable for storing resolved dependency values during __call__ -# Maps id(callable_model) -> resolved_value -_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={}) - -# TypeVar for resolve() function to enable proper type inference -_T = TypeVar("_T") - - -def resolve(dep: Union[_T, "_CallableModel"]) -> _T: - """Access the resolved value of a DepOf dependency during __call__. - - This function is used inside a CallableModel's __call__ method to get - the resolved value of a dependency field. It provides proper type inference - - if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`. - - Args: - dep: The dependency field value (either a CallableModel or already-resolved value) - - Returns: - The resolved value. If dep is already a resolved value (not a CallableModel), - returns it unchanged. - - Raises: - RuntimeError: If called outside of __call__ or if the dependency wasn't resolved. - - Example: - class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: MyContext) -> GenericResult[int]: - # resolve() provides proper type inference - data = resolve(self.data) # type: GenericResult[int] - return GenericResult(value=data.value + 1) - """ - # If it's not a CallableModel, it's already a resolved value - pass through - if not isinstance(dep, _CallableModel): - return dep # type: ignore[return-value] - - # Look up in context var - store = _resolved_deps.get() - dep_id = id(dep) - if dep_id not in store: - raise RuntimeError( - "resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel." - ) - return store[dep_id] - - -def _resolve_deps_and_call(model, context, fn): - """Resolve DepOf fields and call the function. - - This is called from ModelEvaluationContext.__call__ to handle dep resolution. - Resolved values are stored in a context variable and accessed via resolve(). - - Args: - model: The CallableModel instance - context: The context to pass to the function - fn: The function to call - - Returns: - The result of calling fn(model, context) - """ - # Don't resolve deps for __deps__ method - if fn.__name__ == "__deps__": - return fn(model, context) - - # Get Dep-annotated fields for this model class - dep_fields = _get_dep_fields(model.__class__) - - if not dep_fields: - return fn(model, context) - - # Get dependencies from __deps__ - deps_result = model.__deps__(context) - # Build a map from model instance id to (model, contexts) for lookup - dep_map = {} - for dep_model, contexts in deps_result: - dep_map[id(dep_model)] = (dep_model, contexts) - - # Resolve dependencies and store in context var - resolved_values = {} - - # Standard path: iterate over Dep-annotated fields - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context - resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) - - resolved_values[id(field_value)] = resolved - - # Store in context var and call function - current_store = _resolved_deps.get() - new_store = {**current_store, **resolved_values} - token = _resolved_deps.set(new_store) - try: - return fn(model, context) - finally: - _resolved_deps.reset(token) - - class FlowOptions(BaseModel): """Options for Flow evaluation. @@ -408,9 +246,6 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": return self._get_evaluator_from_options(options) def __call__(self, fn): - # Wrap function with dependency resolution for DepOf fields - fn = _wrap_with_dep_resolution(fn) - # Used for building a graph of model evaluation contexts without evaluating def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): # Create the evaluation context. @@ -617,32 +452,6 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # The generated context inherits from DateContext, so it's compatible # with infrastructure expecting DateContext instances. - Auto-Resolve Dependencies Example: - When __call__ has parameters beyond 'self' and 'context' that match field - names annotated with DepOf/Dep, those dependencies are automatically resolved - using __deps__ (if defined) or auto-generated from Dep annotations. - - class MyModel(CallableModel): - data: Annotated[GenericResult[dict], Dep(transform=my_transform)] - - @Flow.call - def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: - # data is automatically resolved - no manual calling needed - return GenericResult(value=process(data.value)) - - For transforms that need access to instance fields, define __deps__ manually: - - class MyModel(CallableModel): - data: DepOf[..., GenericResult[dict]] - window: int = 7 - - def __deps__(self, context): - # Can access self.window here - return [(self.data, [context.with_lookback(self.window)])] - - @Flow.call - def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: - return GenericResult(value=process(data.value)) """ # Extract auto_context option (not part of FlowOptions) # Can be: False, True, or a ContextBase subclass @@ -728,27 +537,10 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ return GenericResult(value=query_db(source, start_date, end_date)) Dependencies: - Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: - - from ccflow import Dep, DepOf - from typing import Annotated - - @Flow.model - def compute_returns( - context: DateRangeContext, - prices: Annotated[GenericResult[pl.DataFrame], Dep( - transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) - )] - ) -> GenericResult[pl.DataFrame]: - return GenericResult(value=prices.value.pct_change()) - - # Or use DepOf shorthand for no transform: - @Flow.model - def compute_stats( - context: DateRangeContext, - data: DepOf[..., GenericResult[pl.DataFrame]] - ) -> GenericResult[pl.DataFrame]: - return GenericResult(value=data.value.describe()) + Any non-context parameter can be bound either to a literal value or + to another CallableModel. When a CallableModel is supplied, the + generated model treats it as an upstream dependency and resolves it + with the current context before calling the underlying function. Usage: # Create model instances @@ -819,8 +611,7 @@ def _context_validator(cls, values, handler, info): def __call__(self) -> ResultType: fn = getattr(self.model, self.fn) if hasattr(fn, "__wrapped__"): - # Call through _resolve_deps_and_call to handle DepOf field resolution - result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) + result = fn.__wrapped__(self.model, self.context) # If it's a callable model, then we can validate the result if self.options.get("validate_result", True): if fn.__name__ == "__deps__": diff --git a/ccflow/dep.py b/ccflow/dep.py deleted file mode 100644 index b57261e..0000000 --- a/ccflow/dep.py +++ /dev/null @@ -1,278 +0,0 @@ -"""Dependency annotation markers for Flow.model. - -This module provides: -- Dep: Annotation marker for dependency parameters that can accept CallableModel -- DepOf: Shorthand for Annotated[Union[T, CallableModel], Dep()] -""" - -from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, TypeVar, Union, get_args, get_origin - -from .base import ContextBase - -if TYPE_CHECKING: - from .callable import CallableModel - -__all__ = ("Dep", "DepOf") - -T = TypeVar("T") - -# Lazy reference to CallableModel to avoid circular import -_CallableModel = None - - -def _get_callable_model(): - """Lazily import CallableModel to avoid circular imports.""" - global _CallableModel - if _CallableModel is None: - from .callable import CallableModel - - _CallableModel = CallableModel - return _CallableModel - - -class _DepOfMeta(type): - """Metaclass that makes DepOf[ContextType, ResultType] work.""" - - def __getitem__(cls, item): - if not isinstance(item, tuple) or len(item) != 2: - raise TypeError( - "DepOf requires 2 type arguments: DepOf[ContextType, ResultType]. " - "Use ... for ContextType to inherit from parent: DepOf[..., ResultType]" - ) - context_type, result_type = item - CallableModel = _get_callable_model() - - if context_type is ...: - # DepOf[..., ResultType] - inherit context from parent - return Annotated[Union[result_type, CallableModel], Dep()] - else: - # DepOf[ContextType, ResultType] - explicit context type - return Annotated[Union[result_type, CallableModel], Dep(context_type=context_type)] - - -class DepOf(metaclass=_DepOfMeta): - """ - Shorthand for Annotated[Union[ResultType, CallableModel], Dep(context_type=...)]. - - Follows Callable convention: DepOf[InputContext, OutputResult] - - For class fields, accepts either: - - The result type directly (pre-computed value) - - A CallableModel that produces the result type (resolved at call time) - - Usage: - # Inherit context type from parent model (most common) - data: DepOf[..., GenericResult[dict]] - - # Explicit context type validation - data: DepOf[DateRangeContext, GenericResult[dict]] - - At call time, if the field contains a CallableModel, it will be automatically - resolved using __deps__ and the resolved value will be accessible via self.field_name. - - For dependencies with transforms, define them in __deps__: - def __deps__(self, context): - transformed_ctx = context.model_copy(update={...}) - return [(self.data, [transformed_ctx])] - """ - - pass - - -def _is_compatible_type(actual: Type, expected: Type) -> bool: - """Check if actual type is compatible with expected type. - - Handles generic types like GenericResult[pl.DataFrame] where issubclass - would raise TypeError. - - Args: - actual: The actual type to check - expected: The expected type to match against - - Returns: - True if actual is compatible with expected - """ - # Handle None/empty types - if actual is None or expected is None: - return actual is expected - - # Get origins for generic types - actual_origin = get_origin(actual) or actual - expected_origin = get_origin(expected) or expected - - # Check if origins are compatible - try: - if not (isinstance(actual_origin, type) and isinstance(expected_origin, type)): - return False - if not issubclass(actual_origin, expected_origin): - return False - except TypeError: - # issubclass can fail for certain types - return False - - # Check generic args if present - actual_args = get_args(actual) - expected_args = get_args(expected) - - if expected_args and actual_args: - if len(actual_args) != len(expected_args): - return False - return all(_is_compatible_type(a, e) for a, e in zip(actual_args, expected_args)) - - return True - - -class Dep: - """ - Annotation marker for dependency parameters. - - Marks a parameter as accepting either the declared type or a CallableModel - that produces that type. Supports optional context transform and - construction-time type validation. - - Usage: - # No transform, no explicit validation (uses parent's context_type) - prices: Annotated[GenericResult[pl.DataFrame], Dep()] - - # With transform - prices: Annotated[GenericResult[pl.DataFrame], Dep( - transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) - )] - - # With explicit context_type validation - prices: Annotated[GenericResult[pl.DataFrame], Dep( - context_type=DateRangeContext, - transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) - )] - - # Cross-context dependency (transform changes context type) - sim_data: Annotated[GenericResult[pl.DataFrame], Dep( - context_type=SimulationContext, - transform=date_to_simulation_context - )] - """ - - def __init__( - self, - transform: Optional[Callable[..., ContextBase]] = None, - context_type: Optional[Type[ContextBase]] = None, - ): - """ - Args: - transform: Optional function to transform context before calling dependency. - Signature: (context) -> transformed_context - context_type: Expected context_type of the dependency CallableModel. - If None, defaults to the parent model's context_type. - Validated at construction time when a CallableModel is passed. - """ - self.transform = transform - self.context_type = context_type - - def apply(self, context: ContextBase) -> ContextBase: - """Apply the transform to a context, or return unchanged if no transform.""" - if self.transform is not None: - return self.transform(context) - return context - - def validate_dependency( - self, - value: "CallableModel", # noqa: F821 - expected_result_type: Type, - parent_context_type: Type[ContextBase], - param_name: str, - ) -> None: - """ - Validate a CallableModel dependency at construction time. - - Args: - value: The CallableModel being passed as a dependency - expected_result_type: The result type from the Annotated type hint - parent_context_type: The context_type of the parent model - param_name: Name of the parameter (for error messages) - - Raises: - TypeError: If context_type or result_type don't match - """ - # Import here to avoid circular import - from .callable import CallableModel - - if not isinstance(value, CallableModel): - return # Not a CallableModel, skip validation - - # Determine expected context type - expected_ctx = self.context_type if self.context_type is not None else parent_context_type - - # Validate context_type - the dependency's context_type should be compatible - # with what we'll pass to it (expected_ctx) - dep_context_type = value.context_type - try: - if not issubclass(expected_ctx, dep_context_type): - raise TypeError( - f"Dependency '{param_name}': expected context_type compatible with " - f"{dep_context_type.__name__}, but will pass {expected_ctx.__name__}" - ) - except TypeError: - # issubclass can fail for certain types, try alternate check - if expected_ctx != dep_context_type: - raise TypeError(f"Dependency '{param_name}': context_type mismatch - expected {dep_context_type}, got {expected_ctx}") - - # Validate result_type using the generic-safe comparison - # If expected_result_type is Union[T, CallableModel], extract T for validation - dep_result_type = value.result_type - actual_expected_type = expected_result_type - - # Handle Union[T, CallableModel] from DepOf expansion - if get_origin(expected_result_type) is Union: - union_args = get_args(expected_result_type) - # Filter out CallableModel from the union - non_callable_types = [t for t in union_args if t is not CallableModel] - if non_callable_types: - actual_expected_type = non_callable_types[0] - - if not _is_compatible_type(dep_result_type, actual_expected_type): - raise TypeError( - f"Dependency '{param_name}': expected result_type compatible with " - f"{actual_expected_type}, but got CallableModel with result_type {dep_result_type}" - ) - - def __repr__(self): - parts = [] - if self.transform is not None: - parts.append(f"transform={self.transform}") - if self.context_type is not None: - parts.append(f"context_type={self.context_type.__name__}") - return f"Dep({', '.join(parts)})" if parts else "Dep()" - - def __eq__(self, other): - if not isinstance(other, Dep): - return False - return self.transform == other.transform and self.context_type == other.context_type - - def __hash__(self): - # Make Dep hashable for use in sets/dicts - return hash((id(self.transform), self.context_type)) - - -def extract_dep(annotation) -> tuple: - """Extract Dep from Annotated[T, Dep(...)] or DepOf[ContextType, T]. - - When multiple Dep annotations exist (e.g., from nested Annotated that flattens), - returns the LAST one, which represents the outermost user annotation. - - Args: - annotation: A type annotation, possibly Annotated with Dep - - Returns: - Tuple of (base_type, Dep instance or None) - """ - if get_origin(annotation) is Annotated: - args = get_args(annotation) - base_type = args[0] - # Find the LAST Dep - nested Annotated flattens, so outer annotation comes last - last_dep = None - for metadata in args[1:]: - if isinstance(metadata, Dep): - last_dep = metadata - if last_dep is not None: - return base_type, last_dep - return annotation, None diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index e9f2704..44d9cfa 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,20 +12,29 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, cast, get_args, get_origin -from pydantic import Field, TypeAdapter +from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict from .base import ContextBase, ResultBase from .callable import CallableModel, Flow, GraphDepList, _CallableModel from .context import FlowContext -from .dep import Dep, extract_dep from .local_persistence import register_ccflow_import_path from .result import GenericResult __all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") +_AnyCallable = Callable[..., Any] + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_module(func: _AnyCallable) -> str: + return getattr(func, "__module__", __name__) + class _LazyMarker: """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" @@ -36,9 +45,8 @@ class _LazyMarker: def _extract_lazy(annotation) -> Tuple[Any, bool]: """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). - Handles nested Annotated types — e.g. Lazy[Annotated[T, Dep(...)]] produces - Annotated[Annotated[T, Dep(...)], _LazyMarker()], so we need to check the - outermost Annotated layer for _LazyMarker. + Handles nested Annotated types, so we need to check the outermost + Annotated layer for _LazyMarker. """ if get_origin(annotation) is Annotated: args = get_args(annotation) @@ -87,15 +95,11 @@ def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter return TypeAdapter(TypedDict(name, schema)) -def _build_config_validators( - all_param_types: Dict[str, Type], dep_fields: Dict[str, Tuple[Type, Dep]] -) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: - """Precompute validators for non-dependency config fields.""" +def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: + """Precompute validators for constructor fields.""" validatable_types: Dict[str, Type] = {} for name, typ in all_param_types.items(): - if name in dep_fields: - continue try: TypeAdapter(typ) validatable_types[name] = typ @@ -112,6 +116,7 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, if not validators: return + from .base import ModelRegistry as _MR from .callable import CallableModel as _CM for field_name, validator in validators.items(): @@ -120,6 +125,8 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, value = kwargs[field_name] if value is None or isinstance(value, (_CM, BoundModel)): continue + if isinstance(value, str) and value in _MR.root(): + continue try: validator.validate_python(value) except Exception: @@ -134,7 +141,7 @@ class FlowAPI: Accessed via model.flow property. """ - def __init__(self, model: "CallableModel"): # noqa: F821 + def __init__(self, model: "_GeneratedFlowModelBase"): self._model = model def compute(self, **kwargs) -> Any: @@ -228,25 +235,23 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 + def __init__(self, model: "_GeneratedFlowModelBase", input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms - def __call__(self, context: ContextBase) -> Any: - """Call the model with transformed context.""" - # Build new context dict with transforms applied + def _transform_context(self, context: ContextBase) -> FlowContext: + """Return a FlowContext with this model's input transforms applied.""" ctx_dict = _context_values(context) - - # Apply transforms for name, transform in self._input_transforms.items(): if callable(transform): ctx_dict[name] = transform(context) else: ctx_dict[name] = transform + return FlowContext(**ctx_dict) - # Create new context and call model - new_ctx = FlowContext(**ctx_dict) - return self._model(new_ctx) + def __call__(self, context: ContextBase) -> Any: + """Call the model with transformed context.""" + return self._model(self._transform_context(context)) @property def flow(self) -> "FlowAPI": @@ -288,7 +293,10 @@ class _FieldExtractorMixin: def __getattr__(self, name): try: - return super().__getattr__(name) + super_getattr = getattr(super(), "__getattr__", None) + if super_getattr is None: + raise AttributeError(name) + return super_getattr(name) except AttributeError: if name.startswith("_"): raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None @@ -298,6 +306,43 @@ def __getattr__(self, name): class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): """Shared behavior for models generated by ``@Flow.model``.""" + __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext + __flow_model_return_type__: ClassVar[Type[ResultBase]] = GenericResult + __flow_model_func__: ClassVar[_AnyCallable | None] = None + __flow_model_use_context_args__: ClassVar[bool] = True + __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None + __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_auto_wrap__: ClassVar[bool] = False + _context_schema: ClassVar[Dict[str, Type]] = {} + _context_td: ClassVar[Any | None] = None + _matched_context_type: ClassVar[Optional[Type[ContextBase]]] = None + _cached_context_validator: ClassVar[TypeAdapter | None] = None + + @model_validator(mode="before") + def _resolve_registry_refs(cls, values, info): + if not isinstance(values, dict): + return values + + from .base import BaseModel as _BM + + param_types = getattr(cls, "__flow_model_all_param_types__", {}) + resolved = dict(values) + for field_name, expected_type in param_types.items(): + if field_name not in resolved: + continue + value = resolved[field_name] + if not isinstance(value, str): + continue + if expected_type is str: + continue + try: + candidate = _BM.model_validate(value) + except Exception: + continue + if isinstance(candidate, _BM): + resolved[field_name] = candidate + return resolved + @property def context_type(self) -> Type[ContextBase]: return self.__class__.__flow_model_context_type__ @@ -414,8 +459,8 @@ def model(self) -> "CallableModel": # noqa: F821 def _build_context_schema( - context_args: List[str], func: Callable, sig: inspect.Signature -) -> Tuple[Dict[str, Type], Type, Optional[Type[ContextBase]]]: + context_args: List[str], func: _AnyCallable, sig: inspect.Signature +) -> Tuple[Dict[str, Type], Any, Optional[Type[ContextBase]]]: """Build context schema from context_args parameter names. Instead of creating a dynamic ContextBase subclass, this builds: @@ -456,25 +501,16 @@ def _build_context_schema( matched_context_type = DateRangeContext # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{func.__name__}Inputs", schema) + context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) return schema, context_td, matched_context_type -def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: - """Extract dependency info from an annotation. - - Returns: - Tuple of (base_type, Dep instance or None) - """ - return extract_dep(annotation) - - _UNSET = object() def flow_model( - func: Callable = None, + func: Optional[_AnyCallable] = None, *, # Context handling context_args: Optional[List[str]] = None, @@ -487,7 +523,7 @@ def flow_model( validate_result: Any = _UNSET, verbose: Any = _UNSET, evaluator: Any = _UNSET, -) -> Callable: +) -> _AnyCallable: """Decorator that generates a CallableModel class from a plain Python function. This is syntactic sugar over CallableModel. The decorator generates a real @@ -522,7 +558,7 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ A factory function that creates CallableModel instances """ - def decorator(fn: Callable) -> Callable: + def decorator(fn: _AnyCallable) -> _AnyCallable: import typing as _typing sig = inspect.signature(fn) @@ -538,7 +574,7 @@ def decorator(fn: Callable) -> Callable: # Validate return type return_type = _resolved_hints.get("return", sig.return_annotation) if return_type is inspect.Signature.empty: - raise TypeError(f"Function {fn.__name__} must have a return type annotation") + raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult return_origin = get_origin(return_type) or return_type auto_wrap_result = False @@ -555,10 +591,10 @@ def decorator(fn: Callable) -> Callable: context_param = params[context_param_name] context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) if context_annotation is inspect.Parameter.empty: - raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") context_type = context_annotation if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): - raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False explicit_context_args = None @@ -583,10 +619,9 @@ def decorator(fn: Callable) -> Callable: use_context_args = True explicit_context_args = None # Dynamic - determined at construction - # Analyze parameters to find dependencies, lazy fields, and regular fields - dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) + # Analyze parameters to find lazy fields and regular fields. model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) - lazy_fields: set = set() # Names of parameters marked with Lazy[T] + lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] # In dynamic deferred mode (no explicit context_args), all fields are optional # because values not provided at construction come from context at runtime @@ -603,8 +638,6 @@ def decorator(fn: Callable) -> Callable: if is_lazy: lazy_fields.add(name) - # Extract Dep info from the (possibly unwrapped) annotation - base_type, dep = _get_dep_info(unwrapped_annotation) if param.default is not inspect.Parameter.empty: default = param.default elif dynamic_deferred_mode: @@ -614,17 +647,7 @@ def decorator(fn: Callable) -> Callable: # In explicit mode, params without defaults are required default = ... - if dep is not None: - # This is an explicit dependency parameter (DepOf annotation) - dep_fields[name] = (base_type, dep) - # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep - model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) - else: - # Regular model field - use Any for auto-detection of CallableModels. - # We can't use Union[T, CallableModel] because Pydantic tries to generate - # schema for T, which fails for arbitrary types like pl.DataFrame. - # Using Any allows any value; we do runtime isinstance checks in __call__. - model_fields[name] = (Any, default) + model_fields[name] = (Any, default) # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" @@ -637,23 +660,15 @@ def decorator(fn: Callable) -> Callable: # Create the __call__ method def make_call_impl(): def __call__(self, context): - # Import here (inside function) to avoid pickling issues with ContextVar - from .callable import _resolved_deps - - def resolve_callable_model(name, value, store): + def resolve_callable_model(value): """Resolve a CallableModel field.""" - if id(value) in store: - return store[id(value)] - else: - # Auto-detection fallback: call directly - resolved = value(context) - if isinstance(resolved, GenericResult): - return resolved.value - return resolved + resolved = value(context) + if isinstance(resolved, GenericResult): + return resolved.value + return resolved # Build kwargs for the original function fn_kwargs = {} - store = _resolved_deps.get() def _resolve_field(name, value): """Resolve a single field value, handling lazy wrapping.""" @@ -666,7 +681,7 @@ def _resolve_field(name, value): # Non-dep value: wrap in trivial thunk return lambda v=value: v elif is_dep: - return resolve_callable_model(name, value, store) + return resolve_callable_model(value) else: return value @@ -704,7 +719,7 @@ def _resolve_field(name, value): return raw_result # Set proper signature for CallableModel validation - __call__.__signature__ = inspect.Signature( + cast(Any, __call__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), @@ -739,19 +754,14 @@ def __deps__(self, context) -> GraphDepList: if name in lazy_fields: continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) - if isinstance(value, (CallableModel, BoundModel)): - if name in dep_fields: - # Explicit DepOf with transform (backwards compat) - _, dep_obj = dep_fields[name] - transformed_ctx = dep_obj.apply(context) - deps.append((value, [transformed_ctx])) - else: - # Auto-detected dependency - use context as-is - deps.append((value, [context])) + if isinstance(value, BoundModel): + deps.append((value._model, [value._transform_context(context)])) + elif isinstance(value, CallableModel): + deps.append((value, [context])) return deps # Set proper signature - __deps__.__signature__ = inspect.Signature( + cast(Any, __deps__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), @@ -767,8 +777,8 @@ def __deps__(self, context) -> GraphDepList: annotations = {} namespace = { - "__module__": fn.__module__, - "__qualname__": f"_{fn.__name__}_Model", + "__module__": _callable_module(fn), + "__qualname__": f"_{_callable_name(fn)}_Model", "__call__": decorated_call, "__deps__": decorated_deps, } @@ -783,36 +793,15 @@ def __deps__(self, context) -> GraphDepList: namespace["__annotations__"] = annotations - # Add model validator for dependency validation if we have dep fields - if dep_fields: - from pydantic import model_validator - - # Create validator function that captures dep_fields and context_type - def make_dep_validator(d_fields, ctx_type): - @model_validator(mode="after") - def __validate_deps__(self): - from .callable import CallableModel - - for dep_name, (base_type, dep_obj) in d_fields.items(): - value = getattr(self, dep_name) - if isinstance(value, CallableModel): - dep_obj.validate_dependency(value, base_type, ctx_type, dep_name) - return self - - return __validate_deps__ - - namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) - - _validatable_types, _config_validators = _build_config_validators(all_param_types, dep_fields) + _validatable_types, _config_validators = _build_config_validators(all_param_types) # Create the class using type() - GeneratedModel = type(f"_{fn.__name__}_Model", (_GeneratedFlowModelBase,), namespace) + GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) # Set class-level attributes after class creation (to avoid pydantic processing) GeneratedModel.__flow_model_context_type__ = context_type GeneratedModel.__flow_model_return_type__ = internal_return_type - GeneratedModel.__flow_model_func__ = fn - GeneratedModel.__flow_model_dep_fields__ = dep_fields + setattr(GeneratedModel, "__flow_model_func__", fn) GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type @@ -851,7 +840,7 @@ def __validate_deps__(self): # Create factory function that returns model instances @wraps(fn) - def factory(**kwargs) -> GeneratedModel: + def factory(**kwargs) -> _GeneratedFlowModelBase: _validate_config_kwargs(kwargs, _validatable_types, _config_validators) instance = GeneratedModel(**kwargs) @@ -861,7 +850,7 @@ def factory(**kwargs) -> GeneratedModel: return instance # Preserve useful attributes on factory - factory._generated_model = GeneratedModel + cast(Any, factory)._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ return factory @@ -906,16 +895,9 @@ def result_type(self): @Flow.call def __call__(self, context: ContextBase) -> GenericResult: - # Lazy import: _resolved_deps is a ContextVar that can't be pickled - from .callable import _resolved_deps - - store = _resolved_deps.get() - if id(self.source) in store: - result = store[id(self.source)] - else: - result = self.source(context) - if isinstance(result, GenericResult): - result = result.value + result = self.source(context) + if isinstance(result, GenericResult): + result = result.value # Support both attribute access and dict key access if isinstance(result, dict): return GenericResult(value=result[self.field_name]) diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml index 781bd24..41acfaf 100644 --- a/ccflow/tests/config/conf_flow.yaml +++ b/ccflow/tests/config/conf_flow.yaml @@ -60,7 +60,7 @@ diamond_aggregator: # DateRangeContext with transform flow_date_loader: - _target_: ccflow.tests.test_flow_model.date_range_loader + _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day source: market_data include_weekends: false diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index b547aee..458569d 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -1,25 +1,22 @@ """Tests for Flow.model decorator.""" from datetime import date, timedelta -from typing import Annotated from unittest import TestCase -from pydantic import ValidationError from ray.cloudpickle import dumps as rcpdumps, loads as rcploads from ccflow import ( CallableModel, ContextBase, DateRangeContext, - Dep, - DepOf, Flow, + FlowOptionsOverride, GenericResult, Lazy, ModelRegistry, ResultBase, ) -from ccflow.callable import resolve +from ccflow.evaluators.common import MemoryCacheEvaluator class SimpleContext(ContextBase): @@ -136,9 +133,9 @@ def loader(context: SimpleContext, base: int) -> GenericResult[int]: return GenericResult(value=context.value + base) @Flow.model - def consumer(_: SimpleContext, data: DepOf[..., GenericResult[int]]) -> GenericResult[int]: + def consumer(_: SimpleContext, data: int) -> GenericResult[int]: # Context not used directly, just passed to dependency - return GenericResult(value=data.value * 2) + return GenericResult(value=data * 2) load = loader(base=100) consume = consumer(data=load) @@ -214,10 +211,10 @@ def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult class TestFlowModelDependencies(TestCase): - """Tests for Flow.model with dependencies.""" + """Tests for Flow.model with upstream CallableModel inputs.""" - def test_simple_dependency_with_depof(self): - """Test simple dependency using DepOf shorthand.""" + def test_simple_dependency(self): + """Test passing an upstream model as a normal parameter.""" @Flow.model def loader(context: SimpleContext, value: int) -> GenericResult[int]: @@ -226,10 +223,10 @@ def loader(context: SimpleContext, value: int) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, multiplier: int = 1, ) -> GenericResult[int]: - return GenericResult(value=data.value * multiplier) + return GenericResult(value=data * multiplier) # Create pipeline load = loader(value=10) @@ -241,39 +238,17 @@ def consumer( # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 self.assertEqual(result.value, 30) - def test_dependency_with_explicit_dep(self): - """Test dependency using explicit Dep() annotation.""" - - @Flow.model - def loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 2) - - @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep()], - ) -> GenericResult[int]: - return GenericResult(value=data.value + 100) - - load = loader() - consume = consumer(data=load) - - result = consume(SimpleContext(value=10)) - # loader: 10 * 2 = 20, consumer: 20 + 100 = 120 - self.assertEqual(result.value, 120) - def test_dependency_with_direct_value(self): - """Test that Dep fields can accept direct values (not CallableModel).""" + """Test that dependency-shaped parameters can also take direct values.""" @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value + context.value) + return GenericResult(value=data + context.value) - # Pass direct value instead of CallableModel - consume = consumer(data=GenericResult(value=100)) + consume = consumer(data=100) result = consume(SimpleContext(value=5)) self.assertEqual(result.value, 105) @@ -288,9 +263,9 @@ def loader(context: SimpleContext) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value) + return GenericResult(value=data) load = loader() consume = consumer(data=load) @@ -309,105 +284,80 @@ def test_no_deps_when_direct_value(self): @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value) + return GenericResult(value=data) - consume = consumer(data=GenericResult(value=100)) + consume = consumer(data=100) deps = consume.__deps__(SimpleContext(value=10)) self.assertEqual(len(deps), 0) # ============================================================================= -# Transform Tests +# with_inputs Tests # ============================================================================= -class TestFlowModelTransforms(TestCase): - """Tests for Flow.model with context transforms.""" +class TestFlowModelWithInputs(TestCase): + """Tests for Flow.model with .flow.with_inputs().""" - def test_transform_in_dep(self): - """Test dependency with context transform.""" + def test_transformed_dependency_with_inputs(self): + """Test dependency context transformation via .flow.with_inputs().""" @Flow.model def loader(context: SimpleContext) -> GenericResult[int]: return GenericResult(value=context.value) @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[ - GenericResult[int], - Dep(transform=lambda ctx: ctx.model_copy(update={"value": ctx.value + 10})), - ], - ) -> GenericResult[int]: - return GenericResult(value=data.value * 2) + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data * 2) - load = loader() + load = loader().flow.with_inputs(value=lambda ctx: ctx.value + 10) consume = consumer(data=load) - ctx = SimpleContext(value=5) - result = consume(ctx) - - # Transform adds 10 to context.value: 5 + 10 = 15 - # Loader returns that: 15 - # Consumer multiplies by 2: 30 + result = consume(SimpleContext(value=5)) self.assertEqual(result.value, 30) - def test_transform_in_deps_method(self): - """Test that transform is applied in __deps__ method.""" - - def transform_fn(ctx): - return ctx.model_copy(update={"value": ctx.value * 3}) + def test_with_inputs_changes_dependency_context_in_deps(self): + """Test that BoundModel contributes transformed dependency contexts.""" @Flow.model def loader(context: SimpleContext) -> GenericResult[int]: return GenericResult(value=context.value) @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep(transform=transform_fn)], - ) -> GenericResult[int]: - return GenericResult(value=data.value) + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data) - load = loader() + load = loader().flow.with_inputs(value=lambda ctx: ctx.value * 3) consume = consumer(data=load) - ctx = SimpleContext(value=7) - deps = consume.__deps__(ctx) - - # Transform should be applied + deps = consume.__deps__(SimpleContext(value=7)) self.assertEqual(len(deps), 1) transformed_ctx = deps[0][1][0] - self.assertEqual(transformed_ctx.value, 21) # 7 * 3 + self.assertEqual(transformed_ctx.value, 21) - def test_date_range_transform(self): - """Test transform pattern with date ranges using context_args.""" + def test_date_range_transform_with_inputs(self): + """Test date-range lookback wiring via .flow.with_inputs().""" @Flow.model(context_args=["start_date", "end_date"]) def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date}") - def lookback_transform(ctx: DateRangeContext) -> DateRangeContext: - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) - @Flow.model(context_args=["start_date", "end_date"]) def range_processor( start_date: date, end_date: date, - data: Annotated[GenericResult[str], Dep(transform=lookback_transform)], + data: str, ) -> GenericResult[str]: - return GenericResult(value=f"processed:{data.value}") + return GenericResult(value=f"processed:{data}") - loader = range_loader(source="db") + loader = range_loader(source="db").flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=1)) processor = range_processor(data=loader) ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) result = processor(ctx) - - # Transform should shift start_date back by 1 day self.assertEqual(result.value, "processed:db:2024-01-09") @@ -429,18 +379,18 @@ def stage1(context: SimpleContext, base: int) -> GenericResult[int]: @Flow.model def stage2( context: SimpleContext, - input_data: DepOf[..., GenericResult[int]], + input_data: int, multiplier: int, ) -> GenericResult[int]: - return GenericResult(value=input_data.value * multiplier) + return GenericResult(value=input_data * multiplier) @Flow.model def stage3( context: SimpleContext, - input_data: DepOf[..., GenericResult[int]], + input_data: int, offset: int = 0, ) -> GenericResult[int]: - return GenericResult(value=input_data.value + offset) + return GenericResult(value=input_data + offset) # Build pipeline s1 = stage1(base=100) @@ -465,24 +415,24 @@ def source(context: SimpleContext) -> GenericResult[int]: @Flow.model def branch_a( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value * 2) + return GenericResult(value=data * 2) @Flow.model def branch_b( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value + 100) + return GenericResult(value=data + 100) @Flow.model def merger( context: SimpleContext, - a: DepOf[..., GenericResult[int]], - b: DepOf[..., GenericResult[int]], + a: int, + b: int, ) -> GenericResult[int]: - return GenericResult(value=a.value + b.value) + return GenericResult(value=a + b) src = source() a = branch_a(data=src) @@ -568,10 +518,10 @@ def __call__(self, context: SimpleContext) -> GenericResult[int]: @Flow.model def generated_consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, multiplier: int, ) -> GenericResult[int]: - return GenericResult(value=data.value * multiplier) + return GenericResult(value=data * multiplier) manual = ManualModel(offset=50) generated = generated_consumer(data=manual, multiplier=2) @@ -616,7 +566,7 @@ def test_auto_wrap_unwrap_as_dependency(self): """Test that auto-wrapped model used as dep delivers unwrapped value downstream. Auto-wrapped models have result_type=GenericResult (unparameterized). - When used as an auto-detected dep (no DepOf), the framework resolves + When used as an auto-detected dep, the framework resolves the GenericResult and unwraps .value for the downstream function. """ @@ -627,7 +577,7 @@ def plain_source(context: SimpleContext) -> int: @Flow.model def consumer( context: SimpleContext, - data: GenericResult[int], # Auto-detected dep, not DepOf + data: GenericResult[int], # Auto-detected dep ) -> GenericResult[int]: # data is auto-unwrapped to the int value by the framework return GenericResult(value=data + 1) @@ -711,288 +661,13 @@ def untyped_context_arg(x) -> GenericResult[int]: self.assertIn("type annotation", str(cm.exception)) -# ============================================================================= -# Dep and DepOf Tests -# ============================================================================= - - -class TestDepAndDepOf(TestCase): - """Tests for Dep and DepOf classes.""" - - def test_depof_creates_annotated(self): - """Test that DepOf[..., T] creates Annotated[Union[T, CallableModel], Dep()].""" - from typing import Union as TypingUnion, get_args, get_origin - - annotation = DepOf[..., GenericResult[int]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[int], union_args) - self.assertIn(CallableModel, union_args) - # Second arg is Dep() - self.assertIsInstance(args[1], Dep) - self.assertIsNone(args[1].context_type) # ... means inherit from parent - - def test_depof_with_generic_type(self): - """Test DepOf with nested generic types.""" - from typing import List as TypingList, Union as TypingUnion, get_args, get_origin - - annotation = DepOf[..., GenericResult[TypingList[str]]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[TypingList[str]], union_args) - self.assertIn(CallableModel, union_args) - - def test_depof_with_context_type(self): - """Test DepOf[ContextType, ResultType] syntax.""" - from typing import Union as TypingUnion, get_args, get_origin - - annotation = DepOf[SimpleContext, GenericResult[int]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[int], union_args) - self.assertIn(CallableModel, union_args) - # Second arg is Dep with context_type - self.assertIsInstance(args[1], Dep) - self.assertEqual(args[1].context_type, SimpleContext) - - def test_extract_dep_with_annotated(self): - """Test extract_dep with Annotated type.""" - from ccflow.dep import extract_dep - - dep = Dep(context_type=SimpleContext) - annotation = Annotated[GenericResult[int], dep] - - base_type, extracted_dep = extract_dep(annotation) - self.assertEqual(base_type, GenericResult[int]) - self.assertEqual(extracted_dep, dep) - - def test_extract_dep_with_depof(self): - """Test extract_dep with DepOf type.""" - from typing import Union as TypingUnion, get_args, get_origin - - from ccflow.dep import extract_dep - - annotation = DepOf[..., GenericResult[str]] - base_type, extracted_dep = extract_dep(annotation) - - # base_type is Union[ResultType, CallableModel] - self.assertEqual(get_origin(base_type), TypingUnion) - union_args = get_args(base_type) - self.assertIn(GenericResult[str], union_args) - self.assertIn(CallableModel, union_args) - self.assertIsInstance(extracted_dep, Dep) - - def test_extract_dep_without_dep(self): - """Test extract_dep with regular type (no Dep).""" - from ccflow.dep import extract_dep - - base_type, extracted_dep = extract_dep(int) - self.assertEqual(base_type, int) - self.assertIsNone(extracted_dep) - - def test_extract_dep_annotated_without_dep(self): - """Test extract_dep with Annotated but no Dep marker.""" - from ccflow.dep import extract_dep - - annotation = Annotated[int, "some metadata"] - base_type, extracted_dep = extract_dep(annotation) - - # When no Dep marker is found, returns original annotation unchanged - self.assertEqual(base_type, annotation) - self.assertIsNone(extracted_dep) - - def test_is_compatible_type_simple(self): - """Test _is_compatible_type with simple types.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(int, int)) - self.assertFalse(_is_compatible_type(int, str)) - self.assertTrue(_is_compatible_type(bool, int)) # bool is subclass of int - - def test_is_compatible_type_generic(self): - """Test _is_compatible_type with generic types.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(GenericResult[int], GenericResult[int])) - self.assertFalse(_is_compatible_type(GenericResult[int], GenericResult[str])) - self.assertTrue(_is_compatible_type(GenericResult, GenericResult)) - - def test_is_compatible_type_none(self): - """Test _is_compatible_type with None.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(None, None)) - self.assertFalse(_is_compatible_type(None, int)) - self.assertFalse(_is_compatible_type(int, None)) - - def test_is_compatible_type_subclass(self): - """Test _is_compatible_type with subclasses.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(MyResult, ResultBase)) - self.assertFalse(_is_compatible_type(ResultBase, MyResult)) - - def test_dep_validate_dependency_success(self): - """Test Dep.validate_dependency with valid dependency.""" - - @Flow.model - def valid_dep(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - dep = Dep() - model = valid_dep() - - # Should not raise - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - def test_dep_validate_dependency_context_mismatch(self): - """Test Dep.validate_dependency with context type mismatch.""" - - class OtherContext(ContextBase): - other: str - - @Flow.model - def other_dep(context: OtherContext) -> GenericResult[int]: - return GenericResult(value=42) - - dep = Dep(context_type=SimpleContext) - model = other_dep() - - with self.assertRaises(TypeError) as cm: - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - self.assertIn("context_type", str(cm.exception)) - - def test_dep_validate_dependency_result_mismatch(self): - """Test Dep.validate_dependency with result type mismatch.""" - - @Flow.model - def wrong_result(context: SimpleContext) -> MyResult: - return MyResult(data="test") - - dep = Dep() - model = wrong_result() - - with self.assertRaises(TypeError) as cm: - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - self.assertIn("result_type", str(cm.exception)) - - def test_dep_validate_dependency_non_callable(self): - """Test Dep.validate_dependency with non-CallableModel value.""" - dep = Dep() - # Should not raise for non-CallableModel values - dep.validate_dependency(GenericResult(value=42), GenericResult[int], SimpleContext, "data") - dep.validate_dependency("string", GenericResult[int], SimpleContext, "data") - dep.validate_dependency(123, GenericResult[int], SimpleContext, "data") - - def test_dep_hash(self): - """Test Dep is hashable for use in sets/dicts.""" - dep1 = Dep() - dep2 = Dep(context_type=SimpleContext) - - # Should be hashable - dep_set = {dep1, dep2} - self.assertEqual(len(dep_set), 2) - - dep_dict = {dep1: "value1", dep2: "value2"} - self.assertEqual(dep_dict[dep1], "value1") - self.assertEqual(dep_dict[dep2], "value2") - - def test_dep_apply_with_transform(self): - """Test Dep.apply with transform function.""" - - def transform(ctx): - return ctx.model_copy(update={"value": ctx.value * 2}) - - dep = Dep(transform=transform) - - ctx = SimpleContext(value=10) - result = dep.apply(ctx) - - self.assertEqual(result.value, 20) - - def test_dep_apply_without_transform(self): - """Test Dep.apply without transform (identity).""" - dep = Dep() - - ctx = SimpleContext(value=10) - result = dep.apply(ctx) - - self.assertIs(result, ctx) - - def test_dep_repr(self): - """Test Dep string representation.""" - dep1 = Dep() - self.assertEqual(repr(dep1), "Dep()") - - dep2 = Dep(context_type=SimpleContext) - self.assertIn("SimpleContext", repr(dep2)) - - dep3 = Dep(transform=lambda x: x) - self.assertIn("transform=", repr(dep3)) - - def test_dep_equality(self): - """Test Dep equality comparison.""" - dep1 = Dep() - dep2 = Dep() - dep3 = Dep(context_type=SimpleContext) - - # Note: Two Dep() instances with no arguments are equal - self.assertEqual(dep1, dep2) - self.assertNotEqual(dep1, dep3) - - # ============================================================================= # Validation Tests # ============================================================================= class TestFlowModelValidation(TestCase): - """Tests for dependency validation in Flow.model.""" - - def test_context_type_validation(self): - """Test that context_type mismatch is detected.""" - - class OtherContext(ContextBase): - other: str - - @Flow.model - def simple_loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def other_loader(context: OtherContext) -> GenericResult[int]: - return GenericResult(value=42) - - @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep(context_type=SimpleContext)], - ) -> GenericResult[int]: - return GenericResult(value=data.value) - - # Should work with matching context - load1 = simple_loader() - consume1 = consumer(data=load1) - self.assertIsNotNone(consume1) - - # Should fail with mismatched context - load2 = other_loader() - with self.assertRaises((TypeError, ValidationError)): - consumer(data=load2) + """Tests for Flow.model validation behavior.""" def test_config_validation_rejects_bad_type(self): """Test that config validator rejects wrong types at construction.""" @@ -1208,6 +883,36 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 self.assertEqual(result, 51) + def test_bound_and_unbound_models_share_memory_cache(self): + """Shifted and unshifted models should share one evaluator cache. + + They should not share the same cache key when the effective contexts + differ, but repeated evaluations of either model should still hit the + same underlying MemoryCacheEvaluator instance rather than re-executing. + """ + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + shifted = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + + # One execution for the unshifted context and one for the shifted context. + self.assertEqual(call_counts["source"], 2) + self.assertEqual(len(evaluator.cache), 2) + # ============================================================================= # PEP 563 (from __future__ import annotations) Compatibility Tests @@ -1296,27 +1001,27 @@ def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: @Flow.model def data_transformer( context: SimpleContext, - source: DepOf[..., GenericResult[int]], + source: int, factor: int = 2, ) -> GenericResult[int]: """Transform data by multiplying with factor.""" - return GenericResult(value=source.value * factor) + return GenericResult(value=source * factor) @Flow.model def data_aggregator( context: SimpleContext, - input_a: DepOf[..., GenericResult[int]], - input_b: DepOf[..., GenericResult[int]], + input_a: int, + input_b: int, operation: str = "add", ) -> GenericResult[int]: """Aggregate two inputs.""" if operation == "add": - return GenericResult(value=input_a.value + input_b.value) + return GenericResult(value=input_a + input_b) elif operation == "multiply": - return GenericResult(value=input_a.value * input_b.value) + return GenericResult(value=input_a * input_b) else: - return GenericResult(value=input_a.value - input_b.value) + return GenericResult(value=input_a - input_b) @Flow.model @@ -1328,26 +1033,21 @@ def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: @Flow.model def pipeline_stage2( context: SimpleContext, - stage1_output: DepOf[..., GenericResult[int]], + stage1_output: int, multiplier: int = 2, ) -> GenericResult[int]: """Second stage of pipeline.""" - return GenericResult(value=stage1_output.value * multiplier) + return GenericResult(value=stage1_output * multiplier) @Flow.model def pipeline_stage3( context: SimpleContext, - stage2_output: DepOf[..., GenericResult[int]], + stage2_output: int, offset: int = 0, ) -> GenericResult[int]: """Third stage of pipeline.""" - return GenericResult(value=stage2_output.value + offset) - - -def lookback_one_day(ctx: DateRangeContext) -> DateRangeContext: - """Transform that extends start_date back by one day.""" - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + return GenericResult(value=stage2_output + offset) @Flow.model @@ -1355,20 +1055,37 @@ def date_range_loader( context: DateRangeContext, source: str, include_weekends: bool = True, -) -> GenericResult[str]: +) -> GenericResult[dict]: """Load data for a date range.""" - return GenericResult(value=f"{source}:{context.start_date} to {context.end_date}") + return GenericResult( + value={ + "source": source, + "start_date": str(context.start_date), + "end_date": str(context.end_date), + } + ) + + +@Flow.model +def date_range_loader_previous_day( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> dict: + """Hydra helper that applies a one-day lookback before delegating.""" + shifted = context.model_copy(update={"start_date": context.start_date - timedelta(days=1)}) + return date_range_loader(source=source, include_weekends=include_weekends)(shifted).value @Flow.model def date_range_processor( context: DateRangeContext, - raw_data: Annotated[GenericResult[str], Dep(transform=lookback_one_day)], + raw_data: dict, normalize: bool = False, ) -> GenericResult[str]: - """Process date range data with lookback.""" + """Process date range data.""" prefix = "normalized:" if normalize else "raw:" - return GenericResult(value=f"{prefix}{raw_data.value}") + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") @Flow.model @@ -1386,31 +1103,37 @@ def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: @Flow.model def hydra_consumer_model( context: SimpleContext, - source: DepOf[..., GenericResult[int]], + source: int, factor: int = 1, ) -> GenericResult[int]: """Consumer model for dependency testing.""" - return GenericResult(value=source.value * factor) + return GenericResult(value=source * factor) # --- context_args fixtures for Hydra testing --- @Flow.model(context_args=["start_date", "end_date"]) -def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: """Loader using context_args with DateRangeContext.""" - return GenericResult(value=f"{source}:{start_date} to {end_date}") + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) @Flow.model(context_args=["start_date", "end_date"]) def context_args_processor( start_date: date, end_date: date, - data: DepOf[..., GenericResult[str]], + data: dict, prefix: str = "processed", ) -> GenericResult[str]: """Processor using context_args with dependency.""" - return GenericResult(value=f"{prefix}:{data.value}") + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") class TestFlowModelHydra(TestCase): @@ -1477,390 +1200,6 @@ def test_hydra_instantiate_with_dependency(self): self.assertEqual(result.value, 100) -# ============================================================================= -# Class-based CallableModel with Auto-Resolution Tests -# ============================================================================= - - -class TestClassBasedDepResolution(TestCase): - """Tests for auto-resolution of DepOf fields in class-based CallableModels. - - Key pattern: Fields use DepOf annotation, __call__ only takes context, - and resolved values are accessed via self.field_name during __call__. - """ - - def test_class_based_auto_resolve_basic(self): - """Test that DepOf fields are auto-resolved and accessible via resolve().""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - # DepOf expands to Annotated[Union[ResultType, CallableModel], Dep()] - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # Access resolved value via resolve() - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - src = data_source() - consumer = Consumer(source=src) - - result = consumer(SimpleContext(value=5)) - # source: 5 * 10 = 50, consumer: 50 + 1 = 51 - self.assertEqual(result.value, 51) - - def test_class_based_with_custom_transform(self): - """Test that custom __deps__ transform is used.""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - offset: int = 100 - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value + self.offset) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # Apply custom transform - transformed_ctx = SimpleContext(value=context.value + 5) - return [(self.source, [transformed_ctx])] - - src = data_source() - consumer = Consumer(source=src, offset=1) - - result = consumer(SimpleContext(value=5)) - # transformed context: 5 + 5 = 10 - # source: 10 * 10 = 100 - # consumer: 100 + 1 = 101 - self.assertEqual(result.value, 101) - - def test_class_based_with_annotated_transform(self): - """Test that Dep transform is used when field not in __deps__.""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - def double_value(ctx: SimpleContext) -> SimpleContext: - return SimpleContext(value=ctx.value * 2) - - class Consumer(CallableModel): - source: Annotated[DepOf[..., GenericResult[int]], Dep(transform=double_value)] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [] # Empty - uses Dep annotation transform from field - - src = data_source() - consumer = Consumer(source=src) - - result = consumer(SimpleContext(value=5)) - # transform: 5 * 2 = 10 - # source: 10 * 10 = 100 - # consumer: 100 + 1 = 101 - self.assertEqual(result.value, 101) - - def test_class_based_multiple_deps(self): - """Test auto-resolution with multiple dependencies.""" - - @Flow.model - def source_a(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def source_b(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 2) - - class Aggregator(CallableModel): - a: DepOf[..., GenericResult[int]] - b: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.a).value + resolve(self.b).value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.a, [context]), (self.b, [context])] - - agg = Aggregator(a=source_a(), b=source_b()) - - result = agg(SimpleContext(value=10)) - # a: 10, b: 20, aggregator: 30 - self.assertEqual(result.value, 30) - - def test_class_based_deps_with_instance_field_access(self): - """Test that __deps__ can access instance fields for configurable transforms. - - This is the key advantage of class-based models over @Flow.model: - transforms can use instance fields like window size. - """ - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - lookback: int = 5 # Configurable instance field - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value * 2) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # Access self.lookback in transform - this is why we use class-based! - transformed = SimpleContext(value=context.value + self.lookback) - return [(self.source, [transformed])] - - src = data_source() - consumer = Consumer(source=src, lookback=10) - - result = consumer(SimpleContext(value=5)) - # transformed: 5 + 10 = 15 - # source: 15 - # consumer: 15 * 2 = 30 - self.assertEqual(result.value, 30) - - def test_class_based_with_direct_value(self): - """Test that DepOf fields can accept pre-resolved values.""" - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # resolve() passes through non-CallableModel values unchanged - return GenericResult(value=resolve(self.source).value + context.value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # No deps when source is already resolved - return [] - - # Pass direct value instead of CallableModel - consumer = Consumer(source=GenericResult(value=100)) - - result = consumer(SimpleContext(value=5)) - self.assertEqual(result.value, 105) - - def test_class_based_no_double_call(self): - """Test that dependencies are not called twice during DepOf resolution. - - This verifies that the auto-resolution mechanism doesn't accidentally - evaluate the same dependency multiple times. - """ - call_counts = {"source": 0} - - @Flow.model - def counting_source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.data).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - src = counting_source() - consumer = Consumer(data=src) - - # Call consumer - source should only be called once - result = consumer(SimpleContext(value=5)) - - self.assertEqual(result.value, 51) # 5 * 10 + 1 - self.assertEqual(call_counts["source"], 1, "Source should only be called once") - - def test_class_based_nested_depof_no_double_call(self): - """Test nested DepOf chain (A -> B -> C) has no double-calls at any layer. - - This tests a 3-layer dependency chain where: - - layer_c is the leaf (no dependencies) - - layer_b depends on layer_c - - layer_a depends on layer_b - - Each layer should be called exactly once. - """ - call_counts = {"layer_a": 0, "layer_b": 0, "layer_c": 0} - - # Layer C: leaf node (no dependencies) - @Flow.model - def layer_c(context: SimpleContext) -> GenericResult[int]: - call_counts["layer_c"] += 1 - return GenericResult(value=context.value) - - # Layer B: depends on layer_c - class LayerB(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["layer_b"] += 1 - return GenericResult(value=resolve(self.source).value * 10) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - # Layer A: depends on layer_b - class LayerA(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["layer_a"] += 1 - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - # Build the chain: A -> B -> C - c = layer_c() - b = LayerB(source=c) - a = LayerA(source=b) - - # Call layer_a - each layer should be called exactly once - result = a(SimpleContext(value=5)) - - # Verify result: C returns 5, B returns 5*10=50, A returns 50+1=51 - self.assertEqual(result.value, 51) - - # Verify each layer called exactly once - self.assertEqual(call_counts["layer_c"], 1, "layer_c should be called exactly once") - self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") - self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") - - def test_resolve_direct_value_passthrough(self): - """Test that resolve() passes through non-CallableModel values unchanged.""" - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # resolve() should return the GenericResult directly (pass-through) - resolved = resolve(self.data) - # Verify it's the actual GenericResult, not a CallableModel - assert isinstance(resolved, GenericResult) - return GenericResult(value=resolved.value * 2) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [] - - # Pass a direct value, not a CallableModel - direct_result = GenericResult(value=42) - consumer = Consumer(data=direct_result) - - result = consumer(SimpleContext(value=5)) - self.assertEqual(result.value, 84) # 42 * 2 - - def test_resolve_outside_call_raises_error(self): - """Test that resolve() raises RuntimeError when called outside __call__.""" - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.data).value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - src = source() - consumer = Consumer(data=src) - - # Calling resolve() outside of __call__ should raise RuntimeError - with self.assertRaises(RuntimeError) as cm: - resolve(consumer.data) - - self.assertIn("resolve() can only be used inside __call__", str(cm.exception)) - - def test_flow_model_uses_unified_resolution_path(self): - """Test that @Flow.model uses the same resolution path as class-based CallableModel. - - This verifies the consolidation of resolution logic - both @Flow.model and - class-based models should use _resolve_deps_and_call in callable.py. - """ - call_counts = {"source": 0, "decorator_model": 0, "class_model": 0} - - @Flow.model - def shared_source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 2) - - # @Flow.model consumer - @Flow.model - def decorator_consumer( - context: SimpleContext, - data: DepOf[..., GenericResult[int]], - ) -> GenericResult[int]: - call_counts["decorator_model"] += 1 - return GenericResult(value=data.value + 100) - - # Class-based consumer (same logic) - class ClassConsumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["class_model"] += 1 - return GenericResult(value=resolve(self.data).value + 100) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - # Test both consumers with the same source - src = shared_source() - dec_consumer = decorator_consumer(data=src) - cls_consumer = ClassConsumer(data=src) - - ctx = SimpleContext(value=10) - - # Both should produce the same result - dec_result = dec_consumer(ctx) - cls_result = cls_consumer(ctx) - - self.assertEqual(dec_result.value, cls_result.value) - self.assertEqual(dec_result.value, 120) # 10 * 2 + 100 - - # Source should be called exactly twice (once per consumer) - self.assertEqual(call_counts["source"], 2) - self.assertEqual(call_counts["decorator_model"], 1) - self.assertEqual(call_counts["class_model"], 1) - - # ============================================================================= # Lazy[T] Type Annotation Tests # ============================================================================= @@ -2128,8 +1467,8 @@ def consumer( self.assertEqual(result.value, 42) self.assertEqual(call_counts["source"], 0) - def test_lazy_with_depof(self): - """Lazy[DepOf[...]] works: lazy dep with explicit DepOf annotation.""" + def test_lazy_with_upstream_model(self): + """Lazy[T] works when bound to an upstream model.""" from ccflow import Lazy @Flow.model @@ -2139,7 +1478,7 @@ def source(context: SimpleContext) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: Lazy[DepOf[..., GenericResult[int]]], + data: Lazy[GenericResult[int]], ) -> GenericResult[int]: return GenericResult(value=data() + 1) # data() returns unwrapped int diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py index 661ac4f..28f2883 100644 --- a/ccflow/tests/test_flow_model_hydra.py +++ b/ccflow/tests/test_flow_model_hydra.py @@ -124,7 +124,14 @@ def test_context_args_from_yaml(self): ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) - self.assertEqual(result.value, "data_source:2024-01-01 to 2024-01-31") + self.assertEqual( + result.value, + { + "source": "data_source", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + }, + ) def test_context_args_pipeline_from_yaml(self): """Test context_args pipeline with dependencies from YAML.""" @@ -423,12 +430,9 @@ def test_transform_applied_from_yaml(self): self.assertEqual(len(deps), 1) dep_model, dep_contexts = deps[0] - # The transform should extend start_date back by one day - transformed_ctx = dep_contexts[0] - self.assertEqual(transformed_ctx.start_date, date(2024, 1, 9)) - self.assertEqual(transformed_ctx.end_date, date(2024, 1, 31)) - self.assertIs(dep_model, r["flow_date_loader"]) + self.assertEqual(dep_contexts[0], ctx) + self.assertEqual(dep_model(ctx).value["start_date"], "2024-01-09") if __name__ == "__main__": diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 909b597..dca3fbe 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -1,501 +1,245 @@ -# Flow.model and DepOf: Dependency Injection for CallableModel +# Flow.model Design ## Overview -This document describes the `@Flow.model` decorator and `DepOf` annotation system for reducing boilerplate when creating `CallableModel` pipelines with dependencies. +`@Flow.model` turns a plain Python function into a real `CallableModel`. -**Key features:** -- `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions -- `FlowContext` - Universal context carrier for unpacked/deferred execution -- `model.flow.compute(...)` / `model.flow.with_inputs(...)` - Deferred execution helpers -- `DepOf[ContextType, ResultType]` - Type annotation for dependency fields -- `Lazy[T]` - Mark a dependency for lazy, on-demand evaluation -- `FieldExtractor` - Access structured outputs via attribute access on generated models -- `resolve()` - Function to access resolved dependency values in class-based models +The core goals are: -## Quick Start +- keep the authoring model close to an ordinary function, +- preserve the existing evaluator / registry / serialization machinery, +- make deferred execution explicit with `.flow.compute(...)` and `.flow.with_inputs(...)`, +- allow callers to pass either literal values or upstream models for ordinary parameters. -### Pattern 1: `@Flow.model` (Recommended for Declarative Cases) +`@Flow.model` is syntactic sugar over the existing ccflow framework. The +generated object is still a standard `CallableModel`, so you can execute it the +same way as any other model by calling it with a context object. The +`.flow.compute(...)` helper is an explicit, ergonomic way to mark the deferred +execution boundary when supplying runtime inputs as keyword arguments. -```python -from datetime import date, timedelta -from typing import Annotated - -from ccflow import Flow, DateRangeContext, GenericResult, Dep, DepOf - - -def previous_window(ctx: DateRangeContext) -> DateRangeContext: - window = ctx.end_date - ctx.start_date - return ctx.model_copy( - update={ - "start_date": ctx.start_date - window - timedelta(days=1), - "end_date": ctx.start_date - timedelta(days=1), - } - ) +## Core Patterns -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: - return GenericResult(value=125.0) +### Default Deferred Style -@Flow.model -def revenue_growth( - context: DateRangeContext, - current: DepOf[..., GenericResult[float]], - previous: Annotated[GenericResult[float], Dep(transform=previous_window)], -) -> GenericResult[dict]: - growth = (current.value - previous.value) / previous.value - return GenericResult(value={"as_of": context.end_date, "growth": growth}) - -# Build pipeline. The same upstream model is reused twice: -# - once with the original context -# - once with a fixed lookback transform -revenue = load_revenue(region="us") -growth = revenue_growth(current=revenue, previous=revenue) - -# Execute -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = growth(ctx) -``` - -### Pattern 2: Class-Based (For Complex Cases) - -Use class-based when you need **configurable transforms** that depend on instance fields: +This is the most ergonomic mode. Bind some parameters up front, then provide +the remaining runtime inputs later. ```python -from datetime import timedelta - -from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf -from ccflow.callable import resolve # Import resolve for class-based models - -class RevenueAverageWithWindow(CallableModel): - """Aggregate revenue with a configurable lookback window.""" - - revenue: DepOf[..., GenericResult[float]] - window: int = 7 # Configurable instance field - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - # Use resolve() to get the resolved value - revenue = resolve(self.revenue) - return GenericResult(value=revenue.value / self.window) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - # Transform uses self.window - this is why we need class-based! - lookback_ctx = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.revenue, [lookback_ctx])] - -# Usage - different window sizes, same source -loader = load_revenue(region="us") -avg_7 = RevenueAverageWithWindow(revenue=loader, window=7) -avg_30 = RevenueAverageWithWindow(revenue=loader, window=30) -``` +from ccflow import Flow, FlowContext -## When to Use Which Pattern -| Use `@Flow.model` when... | Use Class-Based when... | -|--------------------------------|---------------------------------------| -| The node still reads like a normal function | The main value is custom graph logic | -| Transforms are fixed/declarative | Transforms depend on instance fields | -| Less boilerplate is priority | You need full control over `__deps__` | -| Dependency wiring fits in the signature | Dependency behavior deserves its own class | +@Flow.model +def add(x: int, y: int) -> int: + return x + y -## Core Concepts -### `DepOf[ContextType, ResultType]` +model = add(x=10) -Shorthand for declaring dependency fields that can accept either: -- A pre-computed value of `ResultType` -- A `CallableModel` that produces `ResultType` +# Explicit deferred entry point +assert model.flow.compute(y=5) == 15 -```python -# Inherit context type from parent model -data: DepOf[..., GenericResult[dict]] +# Standard CallableModel call path +assert model(FlowContext(y=5)).value == 15 -# Explicit context type -data: DepOf[DateRangeContext, GenericResult[dict]] - -# Equivalent to: -data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 ``` -For `@Flow.model`, plain non-`DepOf` parameters can also be populated with a -`CallableModel` instance. That lets callers either inject a concrete value or -splice in an upstream computation for the same parameter. Use `Dep`/`DepOf` -when you need explicit dependency metadata such as context transforms or -context-type validation. - -That means `DepOf` inside `@Flow.model` is most compelling when the function is -still doing real work and the dependency relationship is simple. If the node is -mostly a vessel for custom dependency graph wiring, a hand-written -`CallableModel` is usually clearer. +In this mode: -### `Dep(transform=..., context_type=...)` +- bound parameters are model configuration, +- unbound parameters become runtime inputs for that model instance. -For transforms, use the full `Annotated` form: +### Explicit Context Parameter ```python -from ccflow import Dep +from ccflow import DateRangeContext, Flow + @Flow.model -def compute_stats( - context: DateRangeContext, - records: Annotated[GenericResult[dict], Dep( - transform=lambda ctx: ctx.model_copy( - update={"start_date": ctx.start_date - timedelta(days=1)} - ) - )], -) -> GenericResult[float]: - return GenericResult(value=records.value["count"] * 0.05) +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 ``` -### `resolve()` Function +This is the most direct mode. The function receives a normal context object and +returns either a `ResultBase` subclass or a plain value. Plain values are +wrapped into `GenericResult` automatically by the generated model. -**Only needed for class-based models.** Accesses the resolved value of a `DepOf` field during `__call__`. +### `context_args` ```python -from ccflow.callable import resolve +from datetime import date -class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] +from ccflow import Flow - @Flow.call - def __call__(self, context: MyContext) -> GenericResult[int]: - # resolve() returns the GenericResult, not the CallableModel - result = resolve(self.data) - return GenericResult(value=result.value + 1) -``` -**Behavior:** -- Inside `__call__`: Returns the resolved value -- With direct values (not CallableModel): Returns unchanged (no-op) -- Outside `__call__`: Raises `RuntimeError` -- In `@Flow.model`: Not needed - values are passed as function arguments - -**Type inference:** -```python -data: DepOf[..., GenericResult[int]] -resolved = resolve(self.data) # Type: GenericResult[int] +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + return 125.0 ``` -## How Resolution Works - -### `@Flow.model` Resolution Flow - -1. User calls `model(context)` -2. Generated `__call__` invokes `_resolve_deps_and_call()` -3. For each dependency-bearing field containing a `CallableModel`: - - Apply transform (if any) - - Call the dependency - - Store resolved value in context variable -4. Generated `__call__` reads the resolved values from the dependency store -5. Original function receives resolved values directly as normal function arguments - -### Class-Based Resolution Flow - -1. User calls `model(context)` -2. `_resolve_deps_and_call()` runs -3. For each `DepOf` field containing a `CallableModel`: - - Check `__deps__` for custom transforms - - If not listed in `__deps__`, fall back to the field's `Dep(...)` transform (or the original context) - - Call the dependency - - Store resolved value in context variable -4. User's `__call__` accesses values via `resolve(self.field)` - -**Important:** Resolution uses a context variable (`contextvars.ContextVar`), making it thread-safe and async-safe. - -## Design Decisions - -### Decision 1: `resolve()` Instead of Temporary Mutation +This keeps the function signature focused on the inputs it actually uses while +still producing a `CallableModel` that accepts a context at runtime. -**What we chose:** Explicit `resolve()` function with context variables. +Use `context_args` when certain parameters are semantically the execution +context and you want that split to be explicit and stable across model +instances. -**Alternative considered:** Temporarily mutate `self.field` during `__call__` to hold the resolved value, then restore after. +When the requested shape matches a built-in context like +`DateRangeContext(start_date, end_date)`, the generated model uses that type. +Otherwise it falls back to `FlowContext`. -**Why we chose this:** -- No mutation of model state -- Thread/async-safe via contextvars -- Explicit about what's happening -- Easier to debug - `self.field` always shows the original value +### Upstream Models as Normal Arguments -**Trade-off:** Slightly more verbose (`resolve(self.data).value` vs `self.data.value`). +Any non-context parameter can be given either: -### Decision 2: Unified Resolution Path +- a literal value, or +- another `CallableModel` / `BoundModel`. -**What we chose:** Both `@Flow.model` and class-based use the same `_resolve_deps_and_call()` function. +If a model is passed, it is evaluated with the current context and its result is +unwrapped before the function is called. -**Why:** -- Single source of truth for resolution logic -- Easier to maintain -- Consistent behavior across patterns - -### Decision 3: `resolve()` Not in Top-Level `__all__` - -**What we chose:** `resolve` must be imported explicitly: `from ccflow.callable import resolve` - -**Why:** -- Only needed for class-based models with `DepOf` -- Keeps top-level namespace clean -- Users who need it can find it easily +```python +from ccflow import DateRangeContext, Flow -### Decision 4: Auto-Wrap Plain Return Values -**What we chose:** If the function's declared return type is not a `ResultBase` -subclass, the generated model wraps the returned value in `GenericResult`. +@Flow.model +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 -**Why:** -- Reduces boilerplate for simple scalar / container-returning functions -- Preserves the `CallableModel` contract that runtime results are `ResultBase` -- Still allows explicit `ResultBase` subclasses when you want a precise result type -**Trade-off:** The original Python function may be annotated with a plain value -type while the generated model's runtime `result_type` is `GenericResult`. +@Flow.model +def double_revenue(_: DateRangeContext, revenue: float) -> float: + return revenue * 2 -### Decision 5: Generated Classes Are Real CallableModels -**What we chose:** Generate actual `CallableModel` subclasses using `type()`. +revenue = load_revenue(region="us") +model = double_revenue(revenue=revenue) +result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") +``` -**Why:** -- Full compatibility with existing infrastructure -- Caching, registry, serialization work unchanged -- Can mix with hand-written classes +This is the main composition story for the core API. -## Pitfalls and Limitations +### `.flow.with_inputs(...)` -### Pitfall 1: Forgetting `resolve()` in Class-Based Models +`with_inputs` is how a caller rewires context locally for one upstream model. ```python -class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] +from datetime import date, timedelta - @Flow.call - def __call__(self, context): - # WRONG - self.data is still the CallableModel! - return GenericResult(value=self.data.value + 1) +from ccflow import DateRangeContext, Flow - # CORRECT - return GenericResult(value=resolve(self.data).value + 1) -``` -**Error you'll see:** `AttributeError: '_SomeModel' object has no attribute 'value'` +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 -### Pitfall 2: Calling `resolve()` Outside `__call__` -```python -model = MyModel(data=some_source()) -resolve(model.data) # RuntimeError! -``` +@Flow.model(context_args=["start_date", "end_date"]) +def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } -`resolve()` only works during `__call__` execution. -### Pitfall 3: Lambda Transforms Don't Serialize +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) -```python -# Won't serialize - lambdas can't be pickled -Dep(transform=lambda ctx: ctx.model_copy(...)) +model = revenue_growth(current=current, previous=previous) +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) -# Will serialize - use named functions -def shift_start(ctx): - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) +direct = model(ctx).value +computed = model.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) -Dep(transform=shift_start) +assert direct == computed ``` -### Pitfall 4: GraphEvaluator Requires Caching - -When using `GraphEvaluator` with `DepOf`, dependencies may be called twice (once by GraphEvaluator, once by resolution) unless caching is enabled. - -```python -# Use with caching -from ccflow.evaluators import GraphEvaluator, CachingEvaluator, MultiEvaluator +The transform is local to the bound upstream model. The parent model continues +to receive the original context. -evaluator = MultiEvaluator(evaluators=[ - CachingEvaluator(), - GraphEvaluator(), -]) -``` +### `.flow.compute(...)` -### Pitfall 5: Two Mental Models +`compute` is the ergonomic entry point for deferred execution: -Users need to remember: -- `@Flow.model`: Use dependency values directly as function arguments -- Class-based: Use `resolve(self.field)` to access values +```python +from ccflow import Flow -### Limitation: Custom `__deps__` Is Only Needed for Custom Graph Logic -Class-based models do not need a custom `__deps__` override when the default -field-level `Dep(...)` behavior is sufficient. Override `__deps__` only when -you need instance-dependent transforms or a custom dependency graph: +@Flow.model +def add(x: int, y: int) -> int: + return x + y -```python -class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - @Flow.call - def __call__(self, context): - return GenericResult(value=resolve(self.data).value) +model = add(x=10) +assert model.flow.compute(y=5) == 15 ``` -If you do need to use instance fields in the transform, then `__deps__` is the -right place to do it: +It validates the supplied keyword arguments against the generated context +schema, creates a `FlowContext`, executes the model, and unwraps +`GenericResult.value` if needed. -```python -class WindowedConsumer(CallableModel): - data: DepOf[..., GenericResult[int]] - window: int = 7 - - @Flow.call - def __call__(self, context): - return GenericResult(value=resolve(self.data).value) - - @Flow.deps - def __deps__(self, context): - shifted = context.model_copy(update={"value": context.value + self.window}) - return [(self.data, [shifted])] -``` +It is not the only execution path. Because the generated object is still a +standard `CallableModel`, calling `model(context)` remains fully supported. -### Limitation: `context_args` Type Matching Is Best-Effort +## Lazy Inputs -When you use `context_args=[...]`, the framework validates those fields via a -runtime `TypedDict` schema. It only maps to a concrete built-in context type in -special cases such as `DateRangeContext`. Otherwise the generated model's -`context_type` is `FlowContext`, a universal frozen carrier for the validated -context values. - -## Complete Example: Multi-Stage Pipeline +`Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an +upstream model, the generated model passes a zero-argument thunk. The thunk +caches its first result. ```python -from datetime import date, timedelta -from typing import Annotated - -from ccflow import ( - CallableModel, DateRangeContext, Dep, DepOf, - Flow, GenericResult -) -from ccflow.callable import resolve +from ccflow import Flow, Lazy -# Stage 1: Data loader (simple, use @Flow.model) @Flow.model -def load_events(context: DateRangeContext, source: str) -> GenericResult[list]: - print(f"Loading from {source} for {context.start_date} to {context.end_date}") - return GenericResult(value=[ - {"date": str(context.start_date), "count": 100 + i} - for i in range(5) - ]) +def source(value: int) -> int: + return value * 10 -# Stage 2: Transform with fixed lookback (use @Flow.model with Dep transform) @Flow.model -def compute_daily_totals( - context: DateRangeContext, - events: Annotated[GenericResult[list], Dep( - transform=lambda ctx: ctx.model_copy( - update={"start_date": ctx.start_date - timedelta(days=1)} - ) - )], -) -> GenericResult[float]: - values = [e["count"] for e in events.value] - total = sum(values) / len(values) if values else 0 - return GenericResult(value=total) - - -# Stage 3: Configurable window (use class-based) -class ComputeRollingSummary(CallableModel): - """Summary with configurable lookback window.""" - - totals: DepOf[..., GenericResult[float]] - window: int = 20 - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - totals = resolve(self.totals) - # Scale by window size - summary = totals.value * (self.window ** 0.5) - return GenericResult(value=summary) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - lookback = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.totals, [lookback])] - - -# Build pipeline -events = load_events(source="main_db") -totals = compute_daily_totals(events=events) -summary_20 = ComputeRollingSummary(totals=totals, window=20) -summary_60 = ComputeRollingSummary(totals=totals, window=60) - -# Execute -ctx = DateRangeContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) -print(f"20-day summary: {summary_20(ctx).value}") -print(f"60-day summary: {summary_60(ctx).value}") +def maybe_use_source(value: int, data: Lazy[int]) -> int: + if value > 10: + return value + return data() ``` -## API Reference - -### `@Flow.model` - -```python -@Flow.model( - context_args: list[str] = None, # Unpack context fields as function args - cacheable: bool = False, - volatile: bool = False, - log_level: int = logging.DEBUG, - validate_result: bool = True, - verbose: bool = True, - evaluator: EvaluatorBase = None, -) -def my_function(context: ContextType, ...) -> ResultType: - ... -``` - -If the function is annotated with a plain value type instead of a `ResultBase` -subclass, the generated model will wrap the returned value in `GenericResult` -at runtime. - -### `DepOf[ContextType, ResultType]` +## FlowContext -```python -# Inherit context from parent -field: DepOf[..., GenericResult[int]] +`FlowContext` is the universal frozen carrier for generated contexts that do +not map to a dedicated built-in context type. -# Explicit context type -field: DepOf[DateRangeContext, GenericResult[int]] -``` +The implementation stays intentionally small: -### `Dep(transform=..., context_type=...)` +- context validation is driven by `TypedDict` + `TypeAdapter`, +- runtime execution uses one reusable `FlowContext` type, +- public pydantic iteration (`dict(context)`) is used instead of pydantic + internals. -```python -field: Annotated[GenericResult[int], Dep( - transform=my_transform_func, # Optional: (context) -> transformed_context - context_type=DateRangeContext, # Optional: Expected context type -)] -``` +## BoundModel -### `resolve(dep)` +`.flow.with_inputs(...)` returns a `BoundModel`, which is just a thin wrapper +around: -```python -from ccflow.callable import resolve +- the original model, and +- a mapping of input transforms. -# Inside __call__ of class-based CallableModel: -resolved_value = resolve(self.dep_field) +At call time it: -# Type signature: -def resolve(dep: Union[T, CallableModel]) -> T: ... -``` - -## File Structure +1. converts the incoming context into a plain dictionary, +1. applies the configured transforms, +1. rebuilds a `FlowContext`, +1. delegates to the wrapped model. -``` -ccflow/ -├── callable.py # CallableModel, Flow, resolve(), _resolve_deps_and_call() -├── dep.py # Dep, DepOf, extract_dep() -├── flow_model.py # @Flow.model implementation -└── tests/ - └── test_flow_model.py # Comprehensive tests -``` +That keeps transformed dependency wiring explicit without adding special +annotation machinery to the core API. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index f73ac6b..5f1b502 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -24,7 +24,19 @@ As an example, you may have a `SQLReader` callable model that when called with a ### Flow.model Decorator -The `@Flow.model` decorator provides a simpler way to define `CallableModel`s using plain Python functions instead of classes. It automatically generates a `CallableModel` class with proper `__call__` and `__deps__` methods. +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s +using plain Python functions instead of classes. It automatically generates a +standard `CallableModel` class with proper `__call__` and `__deps__` methods, +so it still uses the normal ccflow framework for evaluation, caching, +serialization, and registry loading. + +You can execute a generated model in two equivalent ways: + +- call it directly with a context object: `model(ctx)` +- use `.flow.compute(...)` to supply runtime inputs as keyword arguments + +`.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred +execution point. **Basic Example:** @@ -45,79 +57,130 @@ ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -**Composing Dependencies with `Dep` and `DepOf`:** +**Default `@Flow.model` Style:** + +Use this when you want the simplest API and do not need to declare a formal +context shape up front. + +```python +from ccflow import Flow + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + +model = add(x=10) + +# `x` is bound when the model is created. +# `y` is supplied later at execution time. +assert model.flow.compute(y=5) == 15 + +# `.flow.with_inputs(...)` rewrites runtime inputs for this call path. +doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert doubled_y.flow.compute(y=5) == 20 +``` -Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. +In this mode: -For `@Flow.model`, regular parameters can also accept a `CallableModel` value at -construction time. This lets you either inject a literal value or splice in an -upstream computation for the same parameter. Use `Dep`/`DepOf` when you need -context transforms or explicit dependency metadata. +- bound parameters are model configuration +- unbound parameters are runtime inputs for that model instance -> **Rule of thumb:** `@Flow.model` works best when the dependency wiring is declarative and local to the signature. If the main point of the node is custom graph logic or transforms that depend on instance fields, use a class-based `CallableModel` instead. +**Composing Dependencies with Normal Parameters:** + +Any non-context parameter can be bound either to a literal value or to another +`CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with +the current context and passes the resolved value into your function. ```python from datetime import date, timedelta -from typing import Annotated -from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf - -def previous_window(ctx: DateRangeContext) -> DateRangeContext: - window = ctx.end_date - ctx.start_date - return ctx.model_copy( - update={ - "start_date": ctx.start_date - window - timedelta(days=1), - "end_date": ctx.start_date - timedelta(days=1), - } - ) +from ccflow import DateRangeContext, Flow -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: - # Pretend this queries a warehouse - return GenericResult(value=125.0) +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 -@Flow.model +@Flow.model(context_args=["start_date", "end_date"]) def revenue_growth( - context: DateRangeContext, - current: DepOf[..., GenericResult[float]], - previous: Annotated[GenericResult[float], Dep(transform=previous_window)], -) -> GenericResult[dict]: - growth = (current.value - previous.value) / previous.value - return GenericResult(value={"as_of": context.end_date, "growth": growth}) - -# Build the pipeline. The same loader is reused with two contexts: -# - current window: original context -# - previous window: transformed via Dep(transform=...) -revenue = load_revenue(region="us") -growth = revenue_growth(current=revenue, previous=revenue) - -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = growth(ctx) + start_date: date, + end_date: date, + current: float, + previous: float, +) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } + +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) +growth = revenue_growth(current=current, previous=previous) + +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +# Standard ccflow execution +direct = growth(ctx).value + +# Equivalent explicit deferred entry point +computed = growth.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +assert direct == computed ``` -`DepOf` is also useful when you want the same parameter to accept either an -upstream model or a precomputed value: +This pattern is the main story for transformed dependencies. `@Flow.model` +still produces an ordinary `CallableModel`; `.flow.compute(...)` is just a +clearer way to say "supply the runtime inputs here." + +**Why `context_args` Exists:** + +Without `context_args`, runtime inputs are inferred from whichever parameters +are still unbound on a particular model instance. That is flexible and +ergonomic. + +Use `context_args` when some parameters are semantically "the execution +context" and you want that split to stay stable and explicit: + +- the runtime context should be stable across instances +- the split between config and runtime inputs matters semantically +- the model is naturally "run over a context" such as date windows, + partitions, or scenarios +- you want the generated model to match a built-in context type like + `DateRangeContext` when possible + +**Deferred Execution Helpers:** ```python -from ccflow import DateRangeContext, DepOf, Flow, GenericResult +from ccflow import Flow @Flow.model -def load_signal(context: DateRangeContext, source: str) -> GenericResult[float]: - return GenericResult(value=0.87) +def add(x: int, y: int) -> int: + return x + y -@Flow.model -def publish_signal( - context: DateRangeContext, - signal: DepOf[..., GenericResult[float]], - threshold: float = 0.8, -) -> GenericResult[dict]: - return GenericResult(value={ - "as_of": context.end_date, - "signal": signal.value, - "go_live": signal.value >= threshold, - }) - -live = publish_signal(signal=load_signal(source="prod")) -override = publish_signal(signal=GenericResult(value=0.95), threshold=0.9) +model = add(x=10) +assert model.flow.compute(y=5) == 15 + +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 +``` + +If you already have a real context object, you can call the model directly +instead: + +```python +from ccflow import FlowContext + +ctx = FlowContext(y=5) +assert model(ctx).value == 15 +assert shifted(ctx).value == 20 ``` **Hydra/YAML Configuration:** @@ -166,27 +229,6 @@ result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. -**Deferred Execution Helpers:** - -Generated models also expose a `.flow` helper namespace: - -```python -from ccflow import Flow, GenericResult - -@Flow.model -def add(x: int, y: int) -> GenericResult[int]: - return GenericResult(value=x + y) - -model = add(x=10) - -# Validate and execute by passing context fields as kwargs -assert model.flow.compute(y=5) == 15 - -# Derive a new model by transforming context inputs -shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 -``` - If a `@Flow.model` function returns a plain value instead of a `ResultBase` subclass, the generated model automatically wraps that value in `GenericResult` at runtime so it still behaves like a normal `CallableModel`. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index e93d452..f2616dc 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -1,220 +1,64 @@ #!/usr/bin/env python -"""Example demonstrating Flow.model decorator and class-based CallableModel. - -This example shows: -- Flow.model for simple functions with minimal boilerplate -- Context transforms with Dep annotations -- Class-based CallableModel for complex cases needing instance field access -""" +"""Example demonstrating the core Flow.model workflow.""" from datetime import date, timedelta -from typing import Annotated - -from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult -from ccflow.callable import resolve - - -# ============================================================================= -# Example 1: Basic Flow.model - No more boilerplate classes! -# ============================================================================= - -@Flow.model -def load_records(context: DateRangeContext, source: str, limit: int = 100) -> GenericResult[list]: - """Load records from a data source for the given date range.""" - print(f" Loading from '{source}' for {context.start_date} to {context.end_date} (limit={limit})") - return GenericResult(value=[ - {"id": i, "date": str(context.start_date), "value": i * 10} - for i in range(min(limit, 5)) - ]) - - -# ============================================================================= -# Example 2: Dependencies with DepOf - Automatic dependency resolution -# ============================================================================= - -@Flow.model -def compute_totals( - _: DateRangeContext, # Context passed to dependency, not used directly here - records: DepOf[..., GenericResult[list]], -) -> GenericResult[dict]: - """Compute totals from loaded records.""" - total = sum(r["value"] for r in records.value) - count = len(records.value) - print(f" Computing totals: {count} records, total={total}") - return GenericResult(value={"total": total, "count": count}) - - -# ============================================================================= -# Example 3: Simple Transform with Flow.model -# When the transform is a fixed function, Flow.model works great -# ============================================================================= - -def lookback_7_days(ctx: DateRangeContext) -> DateRangeContext: - """Fixed transform that extends the date range back by 7 days.""" - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=7)}) - - -@Flow.model -def compute_weekly_average( - _: DateRangeContext, - records: Annotated[GenericResult[list], Dep(transform=lookback_7_days)], -) -> GenericResult[float]: - """Compute average using fixed 7-day lookback.""" - values = [r["value"] for r in records.value] - avg = sum(values) / len(values) if values else 0 - print(f" Computing weekly average: {avg:.2f} (from {len(values)} records)") - return GenericResult(value=avg) - - -# ============================================================================= -# Example 4: Class-based CallableModel with Configurable Transform -# When the transform needs access to instance fields (like window size), -# use a class-based approach with auto-resolution -# ============================================================================= - -class ComputeMovingAverage(CallableModel): - """Compute moving average with configurable lookback window. - - This demonstrates: - - Field uses DepOf annotation: accepts either result or CallableModel - - Instance field (window) accessible in __deps__ for custom transforms - - resolve() to access resolved dependency values during __call__ - """ - - records: DepOf[..., GenericResult[list]] - window: int = 7 # Configurable lookback window - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - """Compute the moving average - use resolve() to get resolved value.""" - records = resolve(self.records) # Get the resolved GenericResult - values = [r["value"] for r in records.value] - avg = sum(values) / len(values) if values else 0 - print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") - return GenericResult(value=avg) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - """Define dependencies with transform that uses self.window.""" - # This is where we can access instance fields! - lookback_ctx = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.records, [lookback_ctx])] - - -# ============================================================================= -# Example 5: Multi-stage pipeline - Composing models together -# ============================================================================= - -@Flow.model -def generate_report( - context: DateRangeContext, - totals: DepOf[..., GenericResult[dict]], - moving_avg: DepOf[..., GenericResult[float]], - report_name: str = "Daily Report", -) -> GenericResult[str]: - """Generate a report combining multiple data sources.""" - report = f""" -{report_name} -{'=' * len(report_name)} -Date Range: {context.start_date} to {context.end_date} -Total Value: {totals.value['total']} -Record Count: {totals.value['count']} -Moving Avg: {moving_avg.value:.2f} -""" - return GenericResult(value=report.strip()) - - -# ============================================================================= -# Example 6: Using context_args for cleaner signatures -# ============================================================================= + +from ccflow import DateRangeContext, Flow + @Flow.model(context_args=["start_date", "end_date"]) -def fetch_metadata(start_date: date, end_date: date, category: str) -> GenericResult[dict]: - """Fetch metadata - note how start_date/end_date are direct parameters.""" - print(f" Fetching metadata for '{category}' from {start_date} to {end_date}") - return GenericResult(value={ - "category": category, - "days": (end_date - start_date).days, - "generated_at": str(date.today()), - }) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Pretend to load revenue for a date window.""" + days = (end_date - start_date).days + 1 + baseline = 1000.0 if region == "us" else 800.0 + return baseline + days * 10.0 + +@Flow.model(context_args=["start_date", "end_date"]) +def summarize_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: + """Compare the current and previous windows.""" + growth_pct = round((current - previous) / previous * 100, 2) + return { + "start_date": start_date, + "end_date": end_date, + "current": current, + "previous": previous, + "growth_pct": growth_pct, + } -# ============================================================================= -# Main: Build and execute the pipeline -# ============================================================================= def main(): print("=" * 60) - print("Flow.model Example - Simplified CallableModel Creation") + print("Flow.model Example") print("=" * 60) + current_window = load_revenue(region="us") + previous_window = current_window.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), + ) + + growth = summarize_growth(current=current_window, previous=previous_window) + ctx = DateRangeContext( - start_date=date(2024, 1, 15), - end_date=date(2024, 1, 31) + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), ) - # --- Example 1: Basic model --- - print("\n[1] Basic Flow.model:") - loader = load_records(source="main_db", limit=5) - result = loader(ctx) - print(f" Result: {result.value}") - - # --- Example 2: Simple dependency chain --- - print("\n[2] Dependency chain (loader -> totals):") - loader = load_records(source="main_db") - totals = compute_totals(records=loader) - result = totals(ctx) - print(f" Result: {result.value}") - - # --- Example 3: Fixed transform with Flow.model --- - print("\n[3] Fixed transform (7-day lookback with Flow.model):") - loader = load_records(source="main_db") - weekly_avg = compute_weekly_average(records=loader) - result = weekly_avg(ctx) - print(f" Result: {result.value}") - - # --- Example 4: Configurable transform with class-based model --- - print("\n[4] Configurable transform (class-based with auto-resolution):") - loader = load_records(source="main_db") - - # 14-day window - moving_avg_14 = ComputeMovingAverage(records=loader, window=14) - result = moving_avg_14(ctx) - print(f" 14-day result: {result.value}") - - # 30-day window - same loader, different window - moving_avg_30 = ComputeMovingAverage(records=loader, window=30) - result = moving_avg_30(ctx) - print(f" 30-day result: {result.value}") - - # --- Example 5: Full pipeline --- - print("\n[5] Full pipeline (mixing Flow.model and class-based):") - loader = load_records(source="analytics_db") - totals = compute_totals(records=loader) - moving_avg = ComputeMovingAverage(records=loader, window=7) - report = generate_report( - totals=totals, - moving_avg=moving_avg, - report_name="Analytics Summary" + print("\n[1] Execute as a normal CallableModel:") + print(growth(ctx).value) + + print("\n[2] Execute via .flow.compute(...):") + print( + growth.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), + ) ) - result = report(ctx) - print(result.value) - - # --- Example 6: context_args --- - print("\n[6] Using context_args (auto-unpacked context):") - metadata = fetch_metadata(category="sales") - result = metadata(ctx) - print(f" Result: {result.value}") - - # --- Bonus: Inspecting models --- - print("\n[Bonus] Inspecting models:") - print(f" load_records.context_type = {loader.context_type.__name__}") - print(f" ComputeMovingAverage uses __deps__ for custom transforms") - deps = moving_avg.__deps__(ctx) - for dep_model, dep_contexts in deps: - print(f" - Dependency context start: {dep_contexts[0].start_date} (lookback applied)") + + print("\n[3] Inspect bound and unbound inputs:") + print(" bound_inputs:", growth.flow.bound_inputs) + print(" unbound_inputs:", growth.flow.unbound_inputs) if __name__ == "__main__": From 0c274a1f2d564ba856fcb9981a9c99c75649262f Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 17:02:40 -0400 Subject: [PATCH 11/17] Clean up more, make repr nicer for BoundModel Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 101 ++++++++---- ccflow/flow_model.py | 89 ++++++++--- ccflow/tests/test_flow_context.py | 59 ++++++- docs/design/flow_model_design.md | 60 ++++++- docs/wiki/Key-Features.md | 253 ++++++++++++++++++++++-------- examples/flow_model_example.py | 100 ++++++++---- 6 files changed, 516 insertions(+), 146 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index fd849c5..185c229 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -16,7 +16,7 @@ import logging from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -31,6 +31,9 @@ from .local_persistence import create_ccflow_model from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -62,6 +65,25 @@ def _cached_signature(fn): return signature(fn) +def _callable_qualname(fn: Callable[..., Any]) -> str: + return getattr(fn, "__qualname__", type(fn).__qualname__) + + +def _declared_type_matches(actual: Any, expected: Any) -> bool: + if isinstance(expected, TypeVar): + return True + if get_origin(expected) is Union: + expected_args = tuple(arg for arg in get_args(expected) if isinstance(arg, type)) + if not expected_args: + return False + if get_origin(actual) is Union: + actual_args = tuple(arg for arg in get_args(actual) if isinstance(arg, type)) + return set(actual_args) == set(expected_args) + return isinstance(actual, type) and any(issubclass(actual, arg) for arg in expected_args) + + return isinstance(actual, type) and isinstance(expected, type) and issubclass(actual, expected) + + class MetaData(BaseModel): """Class to represent metadata for all callable models""" @@ -329,15 +351,16 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = return result wrap = wraps(fn)(wrapper) - wrap.get_evaluator = self.get_evaluator - wrap.get_options = self.get_options - wrap.get_evaluation_context = get_evaluation_context + wrap_any = cast(Any, wrap) + wrap_any.get_evaluator = self.get_evaluator + wrap_any.get_options = self.get_options + wrap_any.get_evaluation_context = get_evaluation_context # Preserve auto context attributes for introspection if hasattr(fn, "__auto_context__"): - wrap.__auto_context__ = fn.__auto_context__ + wrap_any.__auto_context__ = fn.__auto_context__ if hasattr(fn, "__result_type__"): - wrap.__result_type__ = fn.__result_type__ + wrap_any.__result_type__ = fn.__result_type__ return wrap @@ -480,7 +503,7 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # Note that the code below is executed only once if auto_context_enabled: # Return a decorator that first applies auto_context, then FlowOptions - def auto_context_decorator(fn): + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: wrapped = _apply_auto_context(fn, parent=context_parent) # FlowOptions.__call__ already applies wraps, so we just return its result return FlowOptions(**kwargs)(wrapped) @@ -592,13 +615,15 @@ class ModelEvaluationContext( # TODO: Make the instance check compatible with the generic types instead of the base type @model_validator(mode="wrap") - def _context_validator(cls, values, handler, info): + @classmethod + def _context_validator(cls, values: Any, handler: Any, info: Any): """Override _context_validator from parent""" # Validate the context with the model, if possible - model = values.get("model") - if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): - values["context"] = model.context_type.model_validate(values.get("context")) + if isinstance(values, dict): + model = values.get("model") + if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): + values["context"] = model.context_type.model_validate(values.get("context")) # Apply standard pydantic validation context = handler(values) @@ -626,9 +651,9 @@ def __call__(self) -> ResultType: raise TypeError(f"Model result_type {result_type} is not a subclass of ResultBase") result = result_type.model_validate(result) - return result + return cast(ResultType, result) else: - return fn(self.context) + return cast(ResultType, fn(self.context)) class EvaluatorBase(_CallableModel, abc.ABC): @@ -716,7 +741,7 @@ def context_type(self) -> Type[ContextType]: if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase): raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.") - return typ + return cast(Type[ContextType], typ) @property def result_type(self) -> Type[ResultType]: @@ -759,7 +784,7 @@ def result_type(self) -> Type[ResultType]: # Ensure subclass of ResultBase if not isclass(typ) or not issubclass(typ, ResultBase): raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.") - return typ + return cast(Type[ResultType], typ) @Flow.deps def __deps__( @@ -775,6 +800,13 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -787,12 +819,12 @@ class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): @property def context_type(self) -> Type[ContextType]: """Return the context type of the underlying model.""" - return self.model.context_type + return cast(CallableModel, self.model).context_type @property def result_type(self) -> Type[ResultType]: """Return the result type of the underlying model.""" - return self.model.result_type + return cast(CallableModel, self.model).result_type class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]): @@ -864,32 +896,36 @@ def _determine_context_result(cls): if new_context_type is not None: # Set on class - cls._context_generic_type = new_context_type + setattr(cls, "_context_generic_type", new_context_type) if new_result_type is not None: # Set on class - cls._result_generic_type = new_result_type + setattr(cls, "_result_generic_type", new_result_type) @model_validator(mode="wrap") - def _validate_callable_model_generic_type(cls, m, handler, info): + @classmethod + def _validate_callable_model_generic_type(cls, m: Any, handler: Any, info: Any): from ccflow.base import resolve_str if isinstance(m, str): m = resolve_str(m) - if isinstance(m, dict): - m = handler(m) - elif isinstance(m, cls): - m = handler(m) + validated_cls = cast(Any, cls) + if isinstance(m, (dict, CallableModel)): + if isinstance(m, dict): + m = handler(m) + elif isinstance(m, validated_cls): + m = handler(m) # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/ if not isinstance(m, CallableModel): raise ValueError(f"{m} is not a CallableModel: {type(m)}") subtypes = cls.__pydantic_generic_metadata__["args"] - if subtypes: - TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type) - TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type) + if len(subtypes) >= 1 and not _declared_type_matches(m.context_type, subtypes[0]): + raise ValueError(f"{m} context_type {m.context_type} does not match {subtypes[0]}") + if len(subtypes) >= 2 and not _declared_type_matches(m.result_type, subtypes[1]): + raise ValueError(f"{m} result_type {m.result_type} does not match {subtypes[1]}") return m @@ -902,7 +938,7 @@ def _validate_callable_model_generic_type(cls, m, handler, info): # ***************************************************************************** -def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable: +def _apply_auto_context(func: Callable[..., Any], *, parent: Optional[Type[ContextBase]] = None) -> Callable[..., Any]: """Internal function that creates an auto context class from function parameters. This function extracts the parameters from a function signature and creates @@ -941,7 +977,7 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: fields[name] = (param.annotation, default) # Create auto context class - auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields) + auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields) @wraps(func) def wrapper(self, context): @@ -949,13 +985,14 @@ def wrapper(self, context): return func(self, **fn_kwargs) # Must set __signature__ so CallableModel validation sees 'context' parameter - wrapper.__signature__ = inspect.Signature( + wrapper_any = cast(Any, wrapper) + wrapper_any.__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), ], return_annotation=sig.return_annotation, ) - wrapper.__auto_context__ = auto_context_class - wrapper.__result_type__ = sig.return_annotation + wrapper_any.__auto_context__ = auto_context_class + wrapper_any.__result_type__ = sig.return_annotation return wrapper diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 44d9cfa..e2496ea 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,7 +12,7 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, cast, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict @@ -87,6 +87,17 @@ def _context_values(context: ContextBase) -> Dict[str, Any]: return dict(context) +def _transform_repr(transform: Any) -> str: + """Render an input transform without noisy object addresses.""" + + if callable(transform): + name = _callable_name(transform) + if name.startswith("<") and name.endswith(">"): + return name + return f"<{name}>" + return repr(transform) + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -95,6 +106,22 @@ def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter return TypeAdapter(TypedDict(name, schema)) +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + """Extract a concrete ContextBase subclass from a context annotation.""" + + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in (Optional, Union): + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: """Precompute validators for constructor fields.""" @@ -141,9 +168,20 @@ class FlowAPI: Accessed via model.flow property. """ - def __init__(self, model: "_GeneratedFlowModelBase"): + def __init__(self, model: CallableModel): self._model = model + def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: + """Construct a runtime context for either generated or hand-written models.""" + get_validator = getattr(self._model, "_get_context_validator", None) + if get_validator is not None: + validator = get_validator() + validated = validator.validate_python(kwargs) + return FlowContext(**validated) + + validator = TypeAdapter(self._model.context_type) + return validator.validate_python(kwargs) + def compute(self, **kwargs) -> Any: """Execute the model with the provided context arguments. @@ -156,14 +194,7 @@ def compute(self, **kwargs) -> Any: Returns: The model's result, unwrapped from GenericResult if applicable. """ - # Get validator from model (lazily created if needed after unpickling) - validator = self._model._get_context_validator() - - # Validate and coerce kwargs via TypeAdapter - validated = validator.validate_python(kwargs) - - # Wrap in FlowContext (single class, always) - ctx = FlowContext(**validated) + ctx = self._build_context(kwargs) # Call the model result = self._model(ctx) @@ -181,23 +212,41 @@ def unbound_inputs(self) -> Dict[str, Type]: """ all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) bound_fields = getattr(self._model, "_bound_fields", set()) + model_cls = self._model.__class__ # If explicit context_args was provided, use _context_schema - explicit_args = getattr(self._model.__class__, "__flow_model_explicit_context_args__", None) + explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) if explicit_args is not None: - return self._model._context_schema.copy() + context_schema = getattr(model_cls, "_context_schema", None) + return context_schema.copy() if context_schema is not None else {} - # Otherwise, unbound = all params - bound - return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + # Dynamic @Flow.model: unbound = all params - bound + if all_param_types: + return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + + # Generic CallableModel: runtime inputs are the context schema. + context_cls = _concrete_context_type(self._model.context_type) + if context_cls is None or not hasattr(context_cls, "model_fields"): + return {} + return {name: info.annotation for name, info in context_cls.model_fields.items()} @property def bound_inputs(self) -> Dict[str, Any]: """Return the config values bound at construction time.""" bound_fields = getattr(self._model, "_bound_fields", set()) - result = {} + result: Dict[str, Any] = {} for name in bound_fields: if hasattr(self._model, name): result[name] = getattr(self._model, name) + if result: + return result + + # Generic CallableModel: configured model fields are the bound inputs. + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) return result def with_inputs(self, **transforms) -> "BoundModel": @@ -235,7 +284,7 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: "_GeneratedFlowModelBase", input_transforms: Dict[str, Any]): + def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms @@ -253,6 +302,10 @@ def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) + def __repr__(self) -> str: + transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) + return f"{self._model!r}.flow.with_inputs({transforms})" + @property def flow(self) -> "FlowAPI": """Access the flow API.""" @@ -267,9 +320,7 @@ def __init__(self, bound_model: BoundModel): super().__init__(bound_model._model) def compute(self, **kwargs) -> Any: - validator = self._model._get_context_validator() - validated = validator.validate_python(kwargs) - ctx = FlowContext(**validated) + ctx = self._build_context(kwargs) result = self._bound(ctx) # Call through BoundModel, not _model if isinstance(result, GenericResult): return result.value diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 61869f9..bd526b1 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -12,10 +12,22 @@ import cloudpickle import pytest -from ccflow import Flow, FlowAPI, FlowContext, GenericResult +from ccflow import CallableModel, ContextBase, Flow, FlowAPI, FlowContext, GenericResult from ccflow.context import DateRangeContext +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + class TestFlowContext: """Tests for the FlowContext universal carrier.""" @@ -152,6 +164,30 @@ def load_data(start_date: date, end_date: date, source: str = "db") -> GenericRe assert "start_date" not in bound assert "end_date" not in bound + def test_flow_compute_regular_callable_model(self): + """Regular CallableModels also expose .flow.compute().""" + + model = OffsetModel(offset=10) + result = model.flow.compute(x=5) + + assert result == 15 + + def test_flow_unbound_inputs_regular_callable_model(self): + """Regular CallableModels expose their context schema as unbound inputs.""" + + model = OffsetModel(offset=10) + unbound = model.flow.unbound_inputs + + assert unbound == {"x": int} + + def test_flow_bound_inputs_regular_callable_model(self): + """Regular CallableModels expose their configured fields as bound inputs.""" + + model = OffsetModel(offset=10) + bound = model.flow.bound_inputs + + assert bound["offset"] == 10 + class TestBoundModel: """Tests for BoundModel (created via .flow.with_inputs()).""" @@ -217,6 +253,27 @@ def compute(x: int) -> GenericResult[int]: bound = model.flow.with_inputs(x=42) assert isinstance(bound.flow, FlowAPI) + def test_bound_model_repr_looks_like_with_inputs_call(self): + """BoundModel repr should mirror the API users wrote.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) + + assert repr(bound) == f"{model!r}.flow.with_inputs(x=)" + + def test_with_inputs_regular_callable_model(self): + """Regular CallableModels support .flow.with_inputs().""" + + model = OffsetModel(offset=1) + shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) + + result = shifted(NumberContext(x=5)) + assert result.value == 11 + class TestTypedDictValidation: """Tests for TypedDict-based context validation.""" diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index dca3fbe..76adbea 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -192,11 +192,45 @@ schema, creates a `FlowContext`, executes the model, and unwraps It is not the only execution path. Because the generated object is still a standard `CallableModel`, calling `model(context)` remains fully supported. +## FieldExtractor + +Accessing an unknown public attribute on a `@Flow.model` instance returns a +`FieldExtractor`. It is itself a `CallableModel` that runs the source model, +then extracts the named field from the result (via `getattr` or dict key +access). + +```python +from ccflow import ContextBase, Flow, GenericResult + + +class TrainingContext(ContextBase): + seed: int + + +@Flow.model +def prepare(context: TrainingContext) -> GenericResult[dict]: + s = context.seed + return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) + + +@Flow.model +def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: + return GenericResult(value=sum(X) + sum(y)) + + +prepared = prepare() +model = train(X=prepared.X_train, y=prepared.y_train) +``` + +Multiple extractors from the same source share the source model instance. If +caching is enabled the source is evaluated only once. + ## Lazy Inputs `Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an upstream model, the generated model passes a zero-argument thunk. The thunk -caches its first result. +caches its first result. Lazy dependencies are excluded from the `__deps__` +graph, so they are not pre-evaluated by the evaluator infrastructure. ```python from ccflow import Flow, Lazy @@ -243,3 +277,27 @@ At call time it: That keeps transformed dependency wiring explicit without adding special annotation machinery to the core API. + +## Flow.call with `auto_context` + +Separately from `@Flow.model`, `Flow.call(auto_context=...)` provides a similar +convenience for class-based `CallableModel`s. Instead of defining a separate +`ContextBase` subclass, the decorator generates one from the function's +keyword-only parameters. + +```python +from ccflow import CallableModel, Flow, GenericResult + + +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") +``` + +Passing a `ContextBase` subclass (e.g., `auto_context=DateContext`) makes the +generated context inherit from that class, so it remains compatible with +infrastructure that expects the parent type. + +The generated class is registered via `create_ccflow_model` for serialization +support. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 5f1b502..4e34fc3 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -30,6 +30,10 @@ standard `CallableModel` class with proper `__call__` and `__deps__` methods, so it still uses the normal ccflow framework for evaluation, caching, serialization, and registry loading. +If a `@Flow.model` function returns a plain value instead of a `ResultBase` +subclass, the generated model automatically wraps it in `GenericResult` at +runtime so it still behaves like a normal `CallableModel`. + You can execute a generated model in two equivalent ways: - call it directly with a context object: `model(ctx)` @@ -38,7 +42,16 @@ You can execute a generated model in two equivalent ways: `.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred execution point. -**Basic Example:** +#### Context Modes + +There are three ways to define how a `@Flow.model` function receives its +runtime context. + +**Mode 1 — Explicit context parameter:** + +The function takes a `context` parameter (or `_` if unused) annotated with a +`ContextBase` subclass. This is the most direct mode and behaves like a +traditional `CallableModel.__call__`. ```python from datetime import date @@ -46,21 +59,56 @@ from ccflow import Flow, GenericResult, DateRangeContext @Flow.model def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: - # Your data loading logic here return GenericResult(value=query_db(source, context.start_date, context.end_date)) -# Create model instance loader = load_data(source="my_database") -# Execute with context ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -**Default `@Flow.model` Style:** +**Mode 2 — Unpacked context with `context_args`:** -Use this when you want the simplest API and do not need to declare a formal -context shape up front. +Instead of receiving a context object, you list which parameters should come +from the context at runtime. The remaining parameters are model configuration. + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model(context_args=["start_date", "end_date"]) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +loader = load_data(source="my_database") + +# For well-known field sets the decorator matches a built-in context type +assert loader.context_type == DateRangeContext + +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +For well-known shapes such as `start_date` / `end_date` with `date` +annotations, the generated model uses a concrete built-in context type like +`DateRangeContext`. Otherwise it falls back to `FlowContext`, a universal +frozen carrier for the validated fields. + +Use `context_args` when some parameters are semantically "the execution +context" and you want that split to stay stable and explicit: + +- the runtime context should be stable across instances +- the split between config and runtime inputs matters semantically +- the model is naturally "run over a context" such as date windows, + partitions, or scenarios +- you want the generated model to match a built-in context type like + `DateRangeContext` when possible + +**Mode 3 — Default deferred style (no explicit context):** + +When there is no `context` parameter and no `context_args`, all parameters are +potential configuration or runtime inputs. Parameters provided at construction +are bound (configuration); everything else comes from the context at runtime. ```python from ccflow import Flow @@ -80,12 +128,7 @@ doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) assert doubled_y.flow.compute(y=5) == 20 ``` -In this mode: - -- bound parameters are model configuration -- unbound parameters are runtime inputs for that model instance - -**Composing Dependencies with Normal Parameters:** +#### Composing Dependencies Any non-context parameter can be bound either to a literal value or to another `CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with @@ -136,30 +179,19 @@ computed = growth.flow.compute( assert direct == computed ``` -This pattern is the main story for transformed dependencies. `@Flow.model` -still produces an ordinary `CallableModel`; `.flow.compute(...)` is just a -clearer way to say "supply the runtime inputs here." - -**Why `context_args` Exists:** +#### Deferred Execution Helpers -Without `context_args`, runtime inputs are inferred from whichever parameters -are still unbound on a particular model instance. That is flexible and -ergonomic. +**`.flow.compute(**kwargs)`** validates the keyword arguments against the +generated context schema, wraps them in a `FlowContext`, calls the model, and +unwraps `GenericResult.value` if present. -Use `context_args` when some parameters are semantically "the execution -context" and you want that split to stay stable and explicit: - -- the runtime context should be stable across instances -- the split between config and runtime inputs matters semantically -- the model is naturally "run over a context" such as date windows, - partitions, or scenarios -- you want the generated model to match a built-in context type like - `DateRangeContext` when possible - -**Deferred Execution Helpers:** +**`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies +context transforms before delegating to the underlying model. Each transform +is either a static value or a `(ctx) -> value` callable. Transforms are local +to the wrapped model — upstream models never see them. ```python -from ccflow import Flow +from ccflow import Flow, FlowContext @Flow.model def add(x: int, y: int) -> int: @@ -170,22 +202,105 @@ assert model.flow.compute(y=5) == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) assert shifted.flow.compute(y=5) == 20 + +# You can also call with a context object directly +ctx = FlowContext(y=5) +assert model(ctx).value == 15 +assert shifted(ctx).value == 20 ``` -If you already have a real context object, you can call the model directly -instead: +#### Field Extraction + +Accessing an unknown attribute on a `@Flow.model` instance returns a +`FieldExtractor` — a `CallableModel` that runs the source model and extracts +the named field from its result. This makes it easy to wire individual output +fields into downstream models. ```python -from ccflow import FlowContext +from ccflow import ContextBase, Flow, GenericResult -ctx = FlowContext(y=5) -assert model(ctx).value == 15 -assert shifted(ctx).value == 20 +class TrainingContext(ContextBase): + seed: int + +@Flow.model +def prepare(context: TrainingContext) -> GenericResult[dict]: + s = context.seed + return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) + +@Flow.model +def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: + return GenericResult(value=sum(X) + sum(y)) + +prepared = prepare() +model = train(X=prepared.X_train, y=prepared.y_train) +result = model(TrainingContext(seed=5)) +# X_train = [5, 10], y_train = [50] -> 15 + 50 = 65 +assert result.value == 65 +``` + +Multiple extractors from the same source share the source model instance, so +with caching enabled the source is only evaluated once. + +#### Lazy Dependencies with `Lazy[T]` + +Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly +resolving the upstream model, the generated model passes a zero-argument thunk +that evaluates on first call and caches the result. The thunk unwraps +`GenericResult` automatically, so `T` should be the inner value type. + +```python +from ccflow import ContextBase, Flow, GenericResult, Lazy + +class SimpleContext(ContextBase): + value: int + +@Flow.model +def fast_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + +@Flow.model +def slow_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 100) + +@Flow.model +def smart_selector( + context: SimpleContext, + fast: int, # Eagerly resolved and unwrapped + slow: Lazy[int], # Deferred — receives a thunk returning unwrapped int + threshold: int = 10, +) -> GenericResult[int]: + if fast > threshold: + return GenericResult(value=fast) + return GenericResult(value=slow()) # Evaluated only when called + +model = smart_selector( + fast=fast_path(), + slow=slow_path(), + threshold=10, +) ``` -**Hydra/YAML Configuration:** +`Lazy` dependencies are excluded from the model's `__deps__` graph, so they +are not pre-evaluated by the evaluator infrastructure. -`Flow.model` decorated functions work seamlessly with Hydra configuration and the `ModelRegistry`: +#### Decorator Options + +`@Flow.model(...)` accepts the same options as `Flow.call` to control execution +behavior: + +- `cacheable` — enable caching of results +- `volatile` — mark as volatile (always re-execute) +- `log_level` — logging verbosity +- `validate_result` — validate return type +- `verbose` — verbose logging output +- `evaluator` — custom evaluator + +When not explicitly set, these inherit from any active `FlowOptionsOverride`. + +#### Hydra / YAML Configuration + +`@Flow.model` decorated functions work seamlessly with Hydra configuration and +the `ModelRegistry`: ```yaml # config.yaml @@ -202,36 +317,50 @@ aggregated: transformed: transformed # Reference by registry name ``` -When loaded via `ModelRegistry.load_config()`, references by name ensure the same object instance is shared across all consumers. +```python +from ccflow import ModelRegistry -**Auto-Unpacked Context with `context_args`:** +registry = ModelRegistry.root() +registry.load_config_from_path("config.yaml") -Instead of taking an explicit `context` parameter, you can use `context_args` to automatically unpack context fields as function parameters. This is useful when you want cleaner function signatures: +# References by name ensure the same object instance is shared +model = registry["aggregated"] +``` -```python -from datetime import date -from ccflow import Flow, GenericResult, DateRangeContext +### Flow.call with `auto_context` -# Instead of: def load_data(context: DateRangeContext, source: str) -# Use context_args to unpack the context fields directly: -@Flow.model(context_args=["start_date", "end_date"]) -def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: - return GenericResult(value=f"{source}:{start_date} to {end_date}") +For class-based `CallableModel`s, `Flow.call(auto_context=...)` provides a +similar convenience. Instead of defining a separate `ContextBase` subclass, the +decorator generates one from the function's keyword-only parameters. -# The decorator matches common built-in context types when possible -loader = load_data(source="my_database") -assert loader.context_type == DateRangeContext +```python +from ccflow import CallableModel, Flow, GenericResult -# Execute with context as usual -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + +model = MyModel() +result = model(x=42, y="hello") +assert result.value == "42-hello" ``` -The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. +You can also pass a parent context class so the generated context inherits +from it: -If a `@Flow.model` function returns a plain value instead of a `ResultBase` -subclass, the generated model automatically wraps that value in `GenericResult` -at runtime so it still behaves like a normal `CallableModel`. +```python +from datetime import date +from ccflow import CallableModel, DateContext, Flow, GenericResult + +class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> GenericResult: + return GenericResult(value=date.day + extra) +``` + +The generated context class is a proper `ContextBase` subclass, so it works +with all existing evaluator and registry infrastructure. ## Model Registry diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index f2616dc..27e31bb 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -1,5 +1,16 @@ #!/usr/bin/env python -"""Example demonstrating the core Flow.model workflow.""" +"""Canonical Flow.model example. + +This is the main `@Flow.model` story: + +1. define workflow steps as plain Python functions, +2. wire them together by passing upstream models as normal arguments, +3. use a small Python builder for reusable composition, +4. execute either as a normal CallableModel or via `.flow.compute(...)`. + +Run with: + python examples/flow_model_example.py +""" from datetime import date, timedelta @@ -8,57 +19,84 @@ @Flow.model(context_args=["start_date", "end_date"]) def load_revenue(start_date: date, end_date: date, region: str) -> float: - """Pretend to load revenue for a date window.""" + """Return synthetic revenue for one reporting window.""" days = (end_date - start_date).days + 1 - baseline = 1000.0 if region == "us" else 800.0 - return baseline + days * 10.0 + region_base = {"us": 1000.0, "eu": 850.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) @Flow.model(context_args=["start_date", "end_date"]) -def summarize_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: - """Compare the current and previous windows.""" +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + label: str, + days_back: int, +) -> dict: + """Compare the current window against a shifted previous window.""" + previous_start = start_date - timedelta(days=days_back) + previous_end = end_date - timedelta(days=days_back) growth_pct = round((current - previous) / previous * 100, 2) return { - "start_date": start_date, - "end_date": end_date, + "comparison": label, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", "current": current, "previous": previous, "growth_pct": growth_pct, } -def main(): - print("=" * 60) - print("Flow.model Example") - print("=" * 60) +def shifted_window(model, *, days_back: int): + """Reuse one upstream model with a shifted runtime window.""" + return model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=days_back), + end_date=lambda ctx: ctx.end_date - timedelta(days=days_back), + ) + - current_window = load_revenue(region="us") - previous_window = current_window.flow.with_inputs( - start_date=lambda ctx: ctx.start_date - timedelta(days=30), - end_date=lambda ctx: ctx.end_date - timedelta(days=30), +def build_week_over_week_pipeline(region: str): + """Build one reusable pipeline from plain Flow.model functions.""" + current = load_revenue(region=region) + previous = shifted_window(current, days_back=7) + return revenue_change( + current=current, + previous=previous, + label="week_over_week", + days_back=7, ) - growth = summarize_growth(current=current_window, previous=previous_window) +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_week_over_week_pipeline(region="us") ctx = DateRangeContext( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), ) - print("\n[1] Execute as a normal CallableModel:") - print(growth(ctx).value) - - print("\n[2] Execute via .flow.compute(...):") - print( - growth.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), - ) + direct = pipeline(ctx).value + computed = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, ) - print("\n[3] Inspect bound and unbound inputs:") - print(" bound_inputs:", growth.flow.bound_inputs) - print(" unbound_inputs:", growth.flow.unbound_inputs) + print("\nPipeline wired from plain functions:") + print(" current input:", pipeline.current) + print(" previous input:", pipeline.previous) + + print("\nDirect call and .flow.compute(...) are equivalent:") + print(f" direct == computed: {direct == computed}") + + print("\nResult:") + for key, value in computed.items(): + print(f" {key}: {value}") if __name__ == "__main__": From 9c7bc7debb94728ca80c8b5cebcbf99645b32e89 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 17:30:01 -0400 Subject: [PATCH 12/17] Small bug fixes for @Flow.model Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 62 +++++++++++++++++++++----- ccflow/tests/test_flow_model.py | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 12 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index e2496ea..9d426fa 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -98,6 +98,36 @@ def _transform_repr(transform: Any) -> str: return repr(transform) +def _is_model_dependency(value: Any) -> bool: + return isinstance(value, (CallableModel, BoundModel)) + + +def _resolve_registry_candidate(value: str) -> Any: + from .base import BaseModel as _BM + + try: + candidate = _BM.model_validate(value) + except Exception: + return None + return candidate if isinstance(candidate, _BM) else None + + +def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + try: + TypeAdapter(expected_type).validate_python(candidate) + except Exception: + return False + return True + + +def _make_field_extractor(source: Any, name: str) -> "FieldExtractor": + if name.startswith("_"): + raise AttributeError(f"'{type(source).__name__}' has no attribute '{name}'") + return FieldExtractor(source=source, field_name=name) + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -144,16 +174,18 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, return from .base import ModelRegistry as _MR - from .callable import CallableModel as _CM for field_name, validator in validators.items(): if field_name not in kwargs: continue value = kwargs[field_name] - if value is None or isinstance(value, (_CM, BoundModel)): + if value is None or _is_model_dependency(value): continue if isinstance(value, str) and value in _MR.root(): - continue + candidate = _resolve_registry_candidate(value) + expected_type = validatable_types[field_name] + if candidate is not None and _registry_candidate_allowed(expected_type, candidate): + continue try: validator.validate_python(value) except Exception: @@ -302,6 +334,9 @@ def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) + def __getattr__(self, name): + return _make_field_extractor(self, name) + def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) return f"{self._model!r}.flow.with_inputs({transforms})" @@ -311,6 +346,10 @@ def flow(self) -> "FlowAPI": """Access the flow API.""" return _BoundFlowAPI(self) + @property + def context_type(self) -> Type[ContextBase]: + return self._model.context_type + class _BoundFlowAPI(FlowAPI): """FlowAPI that delegates to a BoundModel, honoring transforms.""" @@ -349,9 +388,7 @@ def __getattr__(self, name): raise AttributeError(name) return super_getattr(name) except AttributeError: - if name.startswith("_"): - raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None - return FieldExtractor(source=self, field_name=name) + return _make_field_extractor(self, name) class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): @@ -374,8 +411,6 @@ def _resolve_registry_refs(cls, values, info): if not isinstance(values, dict): return values - from .base import BaseModel as _BM - param_types = getattr(cls, "__flow_model_all_param_types__", {}) resolved = dict(values) for field_name, expected_type in param_types.items(): @@ -386,11 +421,10 @@ def _resolve_registry_refs(cls, values, info): continue if expected_type is str: continue - try: - candidate = _BM.model_validate(value) - except Exception: + candidate = _resolve_registry_candidate(value) + if candidate is None: continue - if isinstance(candidate, _BM): + if _registry_candidate_allowed(expected_type, candidate): resolved[field_name] = candidate return resolved @@ -936,6 +970,8 @@ class FieldExtractor(_FieldExtractorMixin, CallableModel): @property def context_type(self): + if isinstance(self.source, BoundModel): + return self.source.context_type if isinstance(self.source, _CallableModel): return self.source.context_type return ContextBase @@ -956,6 +992,8 @@ def __call__(self, context: ContextBase) -> GenericResult: @Flow.deps def __deps__(self, context: ContextBase) -> GraphDepList: + if isinstance(self.source, BoundModel): + return [(self.source._model, [self.source._transform_context(context)])] if isinstance(self.source, _CallableModel): return [(self.source, [context])] return [] diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 458569d..018052a 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -6,6 +6,7 @@ from ray.cloudpickle import dumps as rcpdumps, loads as rcploads from ccflow import ( + BaseModel, CallableModel, ContextBase, DateRangeContext, @@ -709,6 +710,51 @@ def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> Generi result = model(SimpleContext(value=1)) self.assertEqual(result.value, "test:42") + def test_config_validation_rejects_registry_alias_for_incompatible_type(self): + """Registry aliases should not silently bypass scalar type validation.""" + + class DummyConfig(BaseModel): + x: int = 1 + + registry = ModelRegistry.root() + registry.clear() + try: + registry.add("dummy_config", DummyConfig()) + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + with self.assertRaises(TypeError) as cm: + typed_config(n="dummy_config") + + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + + def test_config_validation_accepts_registry_alias_for_callable_dependency(self): + """Registry aliases still work for CallableModel dependencies.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data + 1) + + registry.add("source_model", source()) + model = consumer(data="source_model") + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 11) + finally: + registry.clear() + # ============================================================================= # BoundModel Tests @@ -1689,6 +1735,38 @@ def prepare(context: SimpleContext) -> GenericResult[dict]: self.assertIs(deps[0][0], model) self.assertEqual(deps[0][1], [ctx]) + def test_field_extraction_from_bound_model(self): + """Field extraction should still work after .flow.with_inputs().""" + + @Flow.model + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2}) + + bound = prepare().flow.with_inputs(x=lambda ctx: ctx.x + 1) + extractor = bound.doubled + + result = extractor.flow.compute(x=5) + self.assertEqual(result, 12) + + def test_field_extraction_deps_from_bound_model(self): + """Bound-model extractors should preserve transformed dependency contexts.""" + from ccflow import FlowContext + + @Flow.model + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2}) + + model = prepare() + bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) + extractor = bound.doubled + + ctx = FlowContext(x=5) + deps = extractor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], model) + self.assertEqual(deps[0][1][0].x, 6) + if __name__ == "__main__": import unittest From 5569ed09d63f2036994ab895a39679413a7399cb Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 18 Mar 2026 14:20:35 -0400 Subject: [PATCH 13/17] Temp progress for cleaning up Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 16 +- ccflow/context.py | 29 +- ccflow/flow_model.py | 480 ++++++++++++++++---------- ccflow/tests/test_callable.py | 34 +- ccflow/tests/test_flow_context.py | 43 ++- ccflow/tests/test_flow_model.py | 538 ++++++++++++++++-------------- docs/design/flow_model_design.md | 57 +--- docs/wiki/Key-Features.md | 65 +--- examples/flow_model_example.py | 8 +- 9 files changed, 720 insertions(+), 550 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 185c229..54f4a9d 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -536,6 +536,7 @@ def model(*args, **kwargs): Args: context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with context_args mode cacheable: Enable caching of results (default: False) volatile: Mark as volatile (default: False) log_level: Logging verbosity (default: logging.DEBUG) @@ -555,7 +556,7 @@ def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.Data Mode 2 - Unpacked context_args: Context fields are unpacked into function parameters. - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: return GenericResult(value=query_db(source, start_date, end_date)) @@ -807,6 +808,12 @@ def flow(self) -> "FlowAPI": return FlowAPI(self) + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this model into a downstream generated ``@Flow.model`` stage.""" + from .flow_model import pipe_model + + return pipe_model(self, stage, param=param, **bindings) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -960,6 +967,9 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: sig = signature(func) base_class = parent or ContextBase + if sig.return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True") + # Validate parent fields are in function signature if parent is not None: parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) @@ -973,6 +983,10 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: for name, param in sig.parameters.items(): if name == "self": continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True") + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True") default = ... if param.default is inspect.Parameter.empty else param.default fields[name] = (param.annotation, default) diff --git a/ccflow/context.py b/ccflow/context.py index 0d00d2e..ae69e22 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,7 +1,8 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated from pydantic import ConfigDict, field_validator, model_validator @@ -106,6 +107,32 @@ class FlowContext(ContextBase): model_config = ConfigDict(extra="allow", frozen=True) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FlowContext): + return False + return self.model_dump(mode="python") == other.model_dump(mode="python") + + def __hash__(self) -> int: + return hash(_freeze_for_hash(self.model_dump(mode="python"))) + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted((key, _freeze_for_hash(item)) for key, item in value.items())) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + if hasattr(value, "__dict__"): + return (type(value), _freeze_for_hash(vars(value))) + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}") from exc + return value + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 9d426fa..da9d1bb 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,22 +12,32 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict from .base import ContextBase, ResultBase -from .callable import CallableModel, Flow, GraphDepList, _CallableModel +from .callable import CallableModel, Flow, GraphDepList from .context import FlowContext from .local_persistence import register_ccflow_import_path from .result import GenericResult -__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") +__all__ = ("FlowAPI", "BoundModel", "Lazy") _AnyCallable = Callable[..., Any] +class _DeferredInput: + """Sentinel for dynamic @Flow.model inputs left for runtime context.""" + + def __repr__(self) -> str: + return "" + + +_DEFERRED_INPUT = _DeferredInput() + + def _callable_name(func: _AnyCallable) -> str: return getattr(func, "__name__", type(func).__name__) @@ -102,6 +112,34 @@ def _is_model_dependency(value: Any) -> bool: return isinstance(value, (CallableModel, BoundModel)) +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set(getattr(model, "_bound_fields", set())) + + +def _has_deferred_input(value: Any) -> bool: + return isinstance(value, _DeferredInput) + + +def _deferred_input_factory() -> _DeferredInput: + return _DEFERRED_INPUT + + +def _effective_bound_field_names(model: Any) -> set[str]: + fields = _bound_field_names(model) + defaults = getattr(model.__class__, "__flow_model_default_param_names__", set()) + return fields | set(defaults) + + +def _runtime_input_names(model: Any) -> set[str]: + all_param_names = set(getattr(model.__class__, "__flow_model_all_param_types__", {})) + if not all_param_names: + return set() + return all_param_names - _effective_bound_field_names(model) + + def _resolve_registry_candidate(value: str) -> Any: from .base import BaseModel as _BM @@ -122,18 +160,12 @@ def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: return True -def _make_field_extractor(source: Any, name: str) -> "FieldExtractor": - if name.startswith("_"): - raise AttributeError(f"'{type(source).__name__}' has no attribute '{name}'") - return FieldExtractor(source=source, field_name=name) - - -def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: +def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" if not schema: return TypeAdapter(dict) - return TypeAdapter(TypedDict(name, schema)) + return TypeAdapter(TypedDict(name, schema, total=total)) def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: @@ -193,6 +225,129 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + if isinstance(stage, BoundModel): + model = stage._model + else: + model = stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _generated_model_class(stage: Any) -> Optional[type["_GeneratedFlowModelBase"]]: + model = _generated_model_instance(stage) + if model is not None: + return type(model) + + generated_model = getattr(stage, "_generated_model", None) + if isinstance(generated_model, type) and issubclass(generated_model, _GeneratedFlowModelBase): + return generated_model + return None + + +def _describe_pipe_stage(stage: Any) -> str: + if isinstance(stage, BoundModel): + return repr(stage) + if isinstance(stage, _GeneratedFlowModelBase): + return repr(stage) + if callable(stage): + return _callable_name(stage) + return repr(stage) + + +def _generated_model_explicit_kwargs(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: + return cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) + + +def _infer_pipe_param( + stage_name: str, + param_names: List[str], + default_param_names: set[str], + occupied_names: set[str], +) -> str: + required_candidates = [name for name in param_names if name not in occupied_names and name not in default_param_names] + if len(required_candidates) == 1: + return required_candidates[0] + if len(required_candidates) > 1: + candidates = ", ".join(required_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + fallback_candidates = [name for name in param_names if name not in occupied_names] + if len(fallback_candidates) == 1: + return fallback_candidates[0] + if len(fallback_candidates) > 1: + candidates = ", ".join(fallback_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + raise TypeError(f"pipe() could not find an available target parameter for {stage_name}.") + + +def _resolve_pipe_param(source: Any, stage: Any, param: Optional[str], bindings: Dict[str, Any]) -> Tuple[str, type["_GeneratedFlowModelBase"]]: + del source # Source only matters when binding, not during target resolution. + + generated_model_cls = _generated_model_class(stage) + if generated_model_cls is None: + raise TypeError("pipe() only supports downstream stages created by @Flow.model or bound versions of those stages.") + + stage_name = _describe_pipe_stage(stage) + all_param_types = getattr(generated_model_cls, "__flow_model_all_param_types__", {}) + if not all_param_types: + raise TypeError(f"pipe() could not determine bindable parameters for {stage_name}.") + + param_names = list(all_param_types.keys()) + default_param_names = set(getattr(generated_model_cls, "__flow_model_default_param_names__", set())) + + generated_model = _generated_model_instance(stage) + occupied_names = set(bindings) + if generated_model is not None: + occupied_names |= _bound_field_names(generated_model) + if isinstance(stage, BoundModel): + occupied_names |= set(stage._input_transforms) + + if param is not None: + if param not in all_param_types: + valid = ", ".join(param_names) + raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available parameters: {valid}.") + if param in occupied_names: + raise TypeError(f"pipe() target parameter '{param}' is already bound for {stage_name}.") + return param, generated_model_cls + + return _infer_pipe_param(stage_name, param_names, default_param_names, occupied_names), generated_model_cls + + +def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" + + if not _is_model_dependency(source): + raise TypeError(f"pipe() source must be a CallableModel or BoundModel, got {type(source).__name__}.") + + target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) + build_kwargs = dict(bindings) + build_kwargs[target_param] = source + + if isinstance(stage, BoundModel): + generated_model = _generated_model_instance(stage) + if generated_model is None: + raise TypeError("pipe() only supports downstream BoundModel stages created from @Flow.model.") + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + rebound_model = generated_model_cls(**explicit_kwargs) + return BoundModel(model=rebound_model, input_transforms=dict(stage._input_transforms)) + + generated_model = _generated_model_instance(stage) + if generated_model is not None: + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + return generated_model_cls(**explicit_kwargs) + + return stage(**build_kwargs) + + class FlowAPI: """API namespace for deferred computation operations. @@ -224,26 +379,18 @@ def compute(self, **kwargs) -> Any: **kwargs: Context arguments (e.g., start_date, end_date) Returns: - The model's result, unwrapped from GenericResult if applicable. + The model's result, using the same return contract as ``model(context)``. """ ctx = self._build_context(kwargs) - - # Call the model - result = self._model(ctx) - - # Unwrap GenericResult if present - if isinstance(result, GenericResult): - return result.value - return result + return self._model(ctx) @property def unbound_inputs(self) -> Dict[str, Type]: """Return the context schema (field name -> type). - In deferred mode, this is everything NOT provided at construction. + In deferred mode, this is everything that must still come from runtime context. """ all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self._model, "_bound_fields", set()) model_cls = self._model.__class__ # If explicit context_args was provided, use _context_schema @@ -252,9 +399,10 @@ def unbound_inputs(self) -> Dict[str, Type]: context_schema = getattr(model_cls, "_context_schema", None) return context_schema.copy() if context_schema is not None else {} - # Dynamic @Flow.model: unbound = all params - bound + # Dynamic @Flow.model: unbound = params with no explicit value and no declared default if all_param_types: - return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + runtime_inputs = _runtime_input_names(self._model) + return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} # Generic CallableModel: runtime inputs are the context schema. context_cls = _concrete_context_type(self._model.context_type) @@ -264,13 +412,15 @@ def unbound_inputs(self) -> Dict[str, Type]: @property def bound_inputs(self) -> Dict[str, Any]: - """Return the config values bound at construction time.""" - bound_fields = getattr(self._model, "_bound_fields", set()) + """Return the effective config values for this model.""" result: Dict[str, Any] = {} - for name in bound_fields: - if hasattr(self._model, name): - result[name] = getattr(self._model, name) - if result: + flow_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + if flow_param_types: + for name in flow_param_types: + value = getattr(self._model, name, _DEFERRED_INPUT) + if _has_deferred_input(value): + continue + result[name] = value return result # Generic CallableModel: configured model fields are the bound inputs. @@ -320,22 +470,26 @@ def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms - def _transform_context(self, context: ContextBase) -> FlowContext: - """Return a FlowContext with this model's input transforms applied.""" + def _transform_context(self, context: ContextBase) -> ContextBase: + """Return this model's preferred context type with input transforms applied.""" ctx_dict = _context_values(context) for name, transform in self._input_transforms.items(): if callable(transform): ctx_dict[name] = transform(context) else: ctx_dict[name] = transform + context_type = _concrete_context_type(self._model.context_type) + if context_type is not None and context_type is not FlowContext: + return context_type.model_validate(ctx_dict) return FlowContext(**ctx_dict) def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) - def __getattr__(self, name): - return _make_field_extractor(self, name) + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" + return pipe_model(self, stage, param=param, **bindings) def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) @@ -360,10 +514,7 @@ def __init__(self, bound_model: BoundModel): def compute(self, **kwargs) -> Any: ctx = self._build_context(kwargs) - result = self._bound(ctx) # Call through BoundModel, not _model - if isinstance(result, GenericResult): - return result.value - return result + return self._bound(ctx) # Call through BoundModel, not _model def with_inputs(self, **transforms) -> "BoundModel": """Chain transforms: merge new transforms with existing ones. @@ -374,24 +525,7 @@ def with_inputs(self, **transforms) -> "BoundModel": return BoundModel(model=self._bound._model, input_transforms=merged) -class _FieldExtractorMixin: - """Turn unknown public attributes into FieldExtractors. - - Real model attributes are still resolved by the normal pydantic/base-model - attribute path via ``super().__getattr__``. - """ - - def __getattr__(self, name): - try: - super_getattr = getattr(super(), "__getattr__", None) - if super_getattr is None: - raise AttributeError(name) - return super_getattr(name) - except AttributeError: - return _make_field_extractor(self, name) - - -class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): +class _GeneratedFlowModelBase(CallableModel): """Shared behavior for models generated by ``@Flow.model``.""" __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext @@ -400,10 +534,10 @@ class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): __flow_model_use_context_args__: ClassVar[bool] = True __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_default_param_names__: ClassVar[set[str]] = set() __flow_model_auto_wrap__: ClassVar[bool] = False _context_schema: ClassVar[Dict[str, Type]] = {} _context_td: ClassVar[Any | None] = None - _matched_context_type: ClassVar[Optional[Type[ContextBase]]] = None _cached_context_validator: ClassVar[TypeAdapter | None] = None @model_validator(mode="before") @@ -458,10 +592,14 @@ def _get_context_validator(self) -> TypeAdapter: if not hasattr(self, "_instance_context_validator"): all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self, "_bound_fields", set()) - unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} - object.__setattr__(self, "_instance_context_validator", _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema)) - return self._instance_context_validator + runtime_inputs = _runtime_input_names(self) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} + object.__setattr__( + self, + "_instance_context_validator", + _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema, total=False), + ) + return cast(TypeAdapter, getattr(self, "_instance_context_validator")) class Lazy: @@ -544,8 +682,8 @@ def model(self) -> "CallableModel": # noqa: F821 def _build_context_schema( - context_args: List[str], func: _AnyCallable, sig: inspect.Signature -) -> Tuple[Dict[str, Type], Any, Optional[Type[ContextBase]]]: + context_args: List[str], func: _AnyCallable, sig: inspect.Signature, resolved_hints: Dict[str, Any] +) -> Tuple[Dict[str, Type], Any]: """Build context schema from context_args parameter names. Instead of creating a dynamic ContextBase subclass, this builds: @@ -559,7 +697,7 @@ def _build_context_schema( sig: The function signature Returns: - Tuple of (schema_dict, TypedDict type, optional matched ContextBase type) + Tuple of (schema_dict, TypedDict type) """ # Build schema dict from parameter annotations schema = {} @@ -567,28 +705,54 @@ def _build_context_schema( if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") param = sig.parameters[name] - if param.annotation is inspect.Parameter.empty: + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") - schema[name] = param.annotation + schema[name] = annotation - # Try to match common context types for compatibility - matched_context_type = None - from .context import DateRangeContext + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) - if set(context_args) == {"start_date", "end_date"}: - from datetime import date + return schema, context_td - if all( - sig.parameters[name].annotation in (date, "date") - or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) - for name in context_args - ): - matched_context_type = DateRangeContext - # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) +def _validate_context_type_override(context_type: Any, context_args: List[str], func_schema: Dict[str, Type]) -> Type[ContextBase]: + """Validate an explicit ``context_type`` override for ``context_args`` mode.""" - return schema, context_td, matched_context_type + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + missing = sorted(name for name in context_args if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for context_args: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in context_args and info.is_required() + ) + if required_extra_fields: + raise TypeError(f"context_type {context_type.__name__} has required fields not listed in context_args: {', '.join(required_extra_fields)}") + + # Warn when the function's annotation for a context_arg doesn't match the + # context_type's field annotation. A mismatch means the function declares + # one type but will silently receive whatever Pydantic coerces to. + for name in context_args: + func_ann = func_schema.get(name) + ctx_field = context_fields.get(name) + if func_ann is None or ctx_field is None: + continue + ctx_ann = ctx_field.annotation + if func_ann is ctx_ann: + continue + # Both are concrete types — check subclass relationship + if isinstance(func_ann, type) and isinstance(ctx_ann, type): + if not (issubclass(func_ann, ctx_ann) or issubclass(ctx_ann, func_ann)): + raise TypeError( + f"context_arg '{name}': function annotates {func_ann.__name__} " + f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" + ) + + return context_type _UNSET = object() @@ -599,6 +763,7 @@ def flow_model( *, # Context handling context_args: Optional[List[str]] = None, + context_type: Optional[Type[ContextBase]] = None, # Flow.call options (passed to generated __call__) # Default to _UNSET so FlowOptionsOverride can control these globally. # Only explicitly user-provided values are passed to Flow.call. @@ -618,6 +783,7 @@ def flow_model( Args: func: The function to decorate context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with ``context_args`` mode. cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) @@ -635,7 +801,7 @@ def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.Data 2. Unpacked context_args: Context fields are unpacked into function parameters. - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: ... @@ -644,15 +810,13 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ """ def decorator(fn: _AnyCallable) -> _AnyCallable: - import typing as _typing - sig = inspect.signature(fn) params = sig.parameters # Resolve string annotations (PEP 563 / from __future__ import annotations) # into real type objects. include_extras=True preserves Annotated metadata. try: - _resolved_hints = _typing.get_type_hints(fn, include_extras=True) + _resolved_hints = get_type_hints(fn, include_extras=True) except Exception: _resolved_hints = {} @@ -670,15 +834,19 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: internal_return_type = return_type # Determine context mode + context_schema_early: Dict[str, Type] = {} + context_td_early = None if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) if context_annotation is inspect.Parameter.empty: raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") - context_type = context_annotation - if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): + resolved_context_type = context_annotation + if not (isinstance(resolved_context_type, type) and issubclass(resolved_context_type, ContextBase)): raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False @@ -686,20 +854,22 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: elif context_args is not None: # Mode 2: Explicit context_args - specified params come from context context_param_name = "context" - # Build context schema early to determine matched_context_type - context_schema_early, _, matched_type = _build_context_schema(context_args, fn, sig) - # Use matched type if available (e.g., DateRangeContext), else FlowContext - context_type = matched_type if matched_type is not None else FlowContext + context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) + explicit_context_type = ( + _validate_context_type_override(context_type, context_args, context_schema_early) if context_type is not None else None + ) + resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext # Exclude context_args from model fields model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} use_context_args = True explicit_context_args = context_args else: - # Mode 3: Dynamic deferred mode - ALL params are potential context or config - # What's provided at construction = config/deps - # What's NOT provided = comes from context at runtime + # Mode 3: Dynamic deferred mode - every param can be configured on the model, + # but only params without Python defaults remain runtime inputs when omitted. + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") context_param_name = "context" - context_type = FlowContext + resolved_context_type = FlowContext model_field_params = {name: param for name, param in params.items() if name != "self"} use_context_args = True explicit_context_args = None # Dynamic - determined at construction @@ -707,9 +877,10 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: # Analyze parameters to find lazy fields and regular fields. model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] + default_param_names: set[str] = set() - # In dynamic deferred mode (no explicit context_args), all fields are optional - # because values not provided at construction come from context at runtime + # In dynamic deferred mode (no explicit context_args), fields without Python defaults + # are internally represented by a deferred sentinel until runtime context supplies them. dynamic_deferred_mode = use_context_args and explicit_context_args is None for name, param in model_field_params.items(): @@ -724,10 +895,11 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: lazy_fields.add(name) if param.default is not inspect.Parameter.empty: + default_param_names.add(name) default = param.default elif dynamic_deferred_mode: - # In dynamic mode, params without defaults are optional (come from context) - default = None + # In dynamic mode, params without defaults remain deferred to runtime context. + default = Field(default_factory=_deferred_input_factory, exclude_if=_has_deferred_input) else: # In explicit mode, params without defaults are required default = ... @@ -786,17 +958,32 @@ def _resolve_field(name, value): value = getattr(self, name) fn_kwargs[name] = _resolve_field(name, value) else: - # Mode 3: Dynamic deferred mode - unbound from context, bound from self - bound_fields = getattr(self, "_bound_fields", set()) + # Mode 3: Dynamic deferred mode - explicit values or Python defaults from self, + # otherwise values come from runtime context. + explicit_fields = _bound_field_names(self) + missing_fields = [] for name in all_param_names: - if name in bound_fields: - # Bound at construction - get from self + value = getattr(self, name, _DEFERRED_INPUT) + if name in explicit_fields or name in default_param_names: + # Explicitly provided or implicitly bound via Python default. value = getattr(self, name) fn_kwargs[name] = _resolve_field(name, value) - else: - # Unbound - get from context - fn_kwargs[name] = getattr(context, name) + continue + + if _has_deferred_input(value): + value = getattr(context, name, _UNSET) + if value is _UNSET: + missing_fields.append(name) + continue + fn_kwargs[name] = value + + if missing_fields: + missing = ", ".join(sorted(missing_fields)) + raise TypeError( + f"Missing runtime input(s) for {_callable_name(fn)}: {missing}. " + "Provide them in the call context or bind them at construction time." + ) raw_result = fn(**fn_kwargs) if auto_wrap_result: @@ -807,7 +994,7 @@ def _resolve_field(name, value): cast(Any, __call__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), ], return_annotation=internal_return_type, ) @@ -849,7 +1036,7 @@ def __deps__(self, context) -> GraphDepList: cast(Any, __deps__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), ], return_annotation=GraphDepList, ) @@ -884,34 +1071,32 @@ def __deps__(self, context) -> GraphDepList: GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) # Set class-level attributes after class creation (to avoid pydantic processing) - GeneratedModel.__flow_model_context_type__ = context_type + GeneratedModel.__flow_model_context_type__ = resolved_context_type GeneratedModel.__flow_model_return_type__ = internal_return_type setattr(GeneratedModel, "__flow_model_func__", fn) GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + GeneratedModel.__flow_model_default_param_names__ = default_param_names GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result - # Build context_schema and matched_context_type + # Build context_schema context_schema: Dict[str, Type] = {} context_td = None - matched_context_type: Optional[Type[ContextBase]] = None if explicit_context_args is not None: # Explicit context_args provided - use early-computed schema - # (matched_context_type was already used to set context_type above) - context_schema, context_td, matched_context_type = _build_context_schema(explicit_context_args, fn, sig) + context_schema, context_td = context_schema_early, context_td_early elif not use_context_args: # Explicit context mode - schema comes from the context type's fields - if hasattr(context_type, "model_fields"): - context_schema = {name: info.annotation for name, info in context_type.model_fields.items()} + if hasattr(resolved_context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in resolved_context_type.model_fields.items()} # For dynamic mode (is_dynamic_mode), _context_schema remains empty - # and schema is built dynamically from _bound_fields at runtime + # and schema is built dynamically from the instance's unresolved runtime inputs. # Store context schema for TypedDict-based validation (picklable!) GeneratedModel._context_schema = context_schema GeneratedModel._context_td = context_td - GeneratedModel._matched_context_type = matched_context_type # Validator is created lazily to survive pickling GeneratedModel._cached_context_validator = None @@ -927,12 +1112,7 @@ def __deps__(self, context) -> GraphDepList: @wraps(fn) def factory(**kwargs) -> _GeneratedFlowModelBase: _validate_config_kwargs(kwargs, _validatable_types, _config_validators) - - instance = GeneratedModel(**kwargs) - # Track which fields were explicitly provided at construction - # These are "bound" - everything else comes from context at runtime - object.__setattr__(instance, "_bound_fields", set(kwargs.keys())) - return instance + return GeneratedModel(**kwargs) # Preserve useful attributes on factory cast(Any, factory)._generated_model = GeneratedModel @@ -944,59 +1124,3 @@ def factory(**kwargs) -> _GeneratedFlowModelBase: if func is not None: return decorator(func) return decorator - - -# ============================================================================= -# FieldExtractor — structured output field access -# ============================================================================= - - -class FieldExtractor(_FieldExtractorMixin, CallableModel): - """Extracts a named field from a source model's result. - - Created automatically by accessing an unknown attribute on a @Flow.model - instance (e.g., ``prepared.X_train``). The extractor is itself a - CallableModel, so it can be wired as a dependency to downstream models. - - When evaluated, it runs the source model and returns - ``GenericResult(value=getattr(source_result, field_name))``. - - Multiple extractors from the same source share the source model instance. - If caching is enabled on the evaluator, the source is evaluated only once. - """ - - source: Any # The source CallableModel - field_name: str # The attribute to extract - - @property - def context_type(self): - if isinstance(self.source, BoundModel): - return self.source.context_type - if isinstance(self.source, _CallableModel): - return self.source.context_type - return ContextBase - - @property - def result_type(self): - return GenericResult - - @Flow.call - def __call__(self, context: ContextBase) -> GenericResult: - result = self.source(context) - if isinstance(result, GenericResult): - result = result.value - # Support both attribute access and dict key access - if isinstance(result, dict): - return GenericResult(value=result[self.field_name]) - return GenericResult(value=getattr(result, self.field_name)) - - @Flow.deps - def __deps__(self, context: ContextBase) -> GraphDepList: - if isinstance(self.source, BoundModel): - return [(self.source._model, [self.source._transform_context(context)])] - if isinstance(self.source, _CallableModel): - return [(self.source, [context])] - return [] - - -register_ccflow_import_path(FieldExtractor) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 29f4524..6d8f53e 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -1024,4 +1024,36 @@ def bad_func(self, *, x: int) -> GenericResult: error_msg = str(cm.exception) self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) - self.assertIn("invalid", error_msg) + + def test_auto_context_rejects_var_args(self): + """auto_context should reject *args early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + self.assertIn("variadic positional", str(cm.exception)) + + def test_auto_context_rejects_var_kwargs(self): + """auto_context should reject **kwargs early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + self.assertIn("variadic keyword", str(cm.exception)) + + def test_auto_context_requires_return_annotation(self): + """auto_context should reject missing return annotations immediately.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + self.assertIn("must have a return type annotation", str(cm.exception)) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index bd526b1..718f8de 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -72,6 +72,20 @@ def test_flow_context_model_dump(self): assert dumped["start_date"] == date(2024, 1, 1) assert dumped["value"] == 42 + def test_flow_context_value_semantics_include_extra_fields(self): + """Equality should reflect the actual extra payload.""" + assert FlowContext(x=1) == FlowContext(x=1) + assert FlowContext(x=1) != FlowContext(x=2) + assert FlowContext(x=1) != FlowContext(y=1) + + def test_flow_context_hash_uses_extra_fields(self): + """Distinct extra payloads should remain distinct in hashed collections.""" + first = FlowContext(values=[1, 2], label="a") + second = FlowContext(values=[1, 3], label="a") + third = FlowContext(values=[1, 2], label="b") + + assert len({first, second, third}) == 3 + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) @@ -102,9 +116,9 @@ def load_data(start_date: date, end_date: date, source: str = "db") -> GenericRe model = load_data(source="api") result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) - assert result["source"] == "api" + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + assert result.value["source"] == "api" def test_flow_compute_type_coercion(self): """FlowAPI.compute() coerces types via TypeAdapter.""" @@ -117,8 +131,8 @@ def load_data(start_date: date, end_date: date) -> GenericResult[dict]: # Pass strings - should be coerced to dates result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) def test_flow_compute_validation_error(self): """FlowAPI.compute() raises on missing required args.""" @@ -170,7 +184,7 @@ def test_flow_compute_regular_callable_model(self): model = OffsetModel(offset=10) result = model.flow.compute(x=5) - assert result == 15 + assert result.value == 15 def test_flow_unbound_inputs_regular_callable_model(self): """Regular CallableModels expose their context schema as unbound inputs.""" @@ -305,15 +319,14 @@ def compute(x: int) -> GenericResult[int]: assert validator is not None assert model.__class__._cached_context_validator is validator - def test_matched_context_type(self): - """DateRangeContext pattern is matched for compatibility.""" + def test_explicit_context_type_override(self): + """context_type can opt into an existing ContextBase subclass.""" - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_data(start_date: date, end_date: date) -> GenericResult[dict]: return GenericResult(value={}) model = load_data() - # Should match DateRangeContext assert model.context_type == DateRangeContext @@ -340,7 +353,7 @@ def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: # Should work after unpickling result = unpickled.flow.compute(x=1, y=2) - assert result == 9 # (1 + 2) * 3 + assert result.value == 9 # (1 + 2) * 3 def test_model_cloudpickle_simple(self): """Simple model cloudpickle test.""" @@ -355,7 +368,7 @@ def double(value: int) -> GenericResult[int]: unpickled = cloudpickle.loads(pickled) result = unpickled.flow.compute(value=21) - assert result == 42 + assert result.value == 42 def test_validator_recreated_after_cloudpickle(self): """TypeAdapter validator is recreated after cloudpickling.""" @@ -375,7 +388,7 @@ def compute(x: int) -> GenericResult[int]: # Validator should still work (may be lazily recreated) result = unpickled.flow.compute(x=42) - assert result == 42 + assert result.value == 42 def test_flow_context_pickle_standard(self): """FlowContext works with standard pickle.""" @@ -431,8 +444,8 @@ def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[di model = load_data(source="api") result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) class TestLazy: diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 018052a..a2e788e 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -11,6 +11,7 @@ ContextBase, DateRangeContext, Flow, + FlowContext, FlowOptionsOverride, GenericResult, Lazy, @@ -160,13 +161,13 @@ class TestFlowModelContextArgs(TestCase): def test_context_args_basic(self): """Test basic context_args usage.""" - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") loader = date_range_loader(source="db") - # Should use DateRangeContext + # Explicit context_type keeps compatibility with existing contexts. self.assertEqual(loader.context_type, DateRangeContext) ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) @@ -182,6 +183,9 @@ def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: model = unpacked_model(multiplier=2) + # Default context_args mode uses FlowContext unless overridden explicitly. + self.assertEqual(model.context_type, FlowContext) + # Create context with generated type ctx_type = model.context_type ctx = ctx_type(x=5, y="test") @@ -490,6 +494,40 @@ def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult self.assertEqual(dumped["value"], 100) self.assertIn("type_", dumped) + def test_serialization_roundtrip_preserves_bound_inputs(self): + """Round-tripping should preserve which inputs were bound at construction.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["x"], 10) + self.assertNotIn("y", dumped) + self.assertEqual(restored.flow.bound_inputs, {"x": 10}) + self.assertEqual(restored.flow.unbound_inputs, {"y": int}) + self.assertEqual(restored.flow.compute(y=5).value, 15) + + def test_serialization_roundtrip_preserves_defaults_and_deferred_inputs(self): + """Default-valued params should serialize normally without binding runtime-only inputs.""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + model = load() + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["source"], "warehouse") + self.assertNotIn("start_date", dumped) + self.assertEqual(restored.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(restored.flow.unbound_inputs, {"start_date": str}) + self.assertEqual(restored.flow.compute(start_date="2024-01-01").value, "warehouse:2024-01-01") + def test_pickle_roundtrip(self): """Test cloudpickle serialization of generated models.""" @@ -568,7 +606,7 @@ def test_auto_wrap_unwrap_as_dependency(self): Auto-wrapped models have result_type=GenericResult (unparameterized). When used as an auto-detected dep, the framework resolves - the GenericResult and unwraps .value for the downstream function. + the GenericResult to its inner value for the downstream function. """ @Flow.model @@ -622,24 +660,36 @@ def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: result = model(ctx) self.assertEqual(result.value, 30) # 10 * 3 + def test_dynamic_deferred_mode_missing_runtime_inputs_is_clear(self): + """Missing deferred inputs should fail at the framework boundary.""" + + @Flow.model + def dynamic_model(value: int, multiplier: int) -> int: + return value * multiplier + + model = dynamic_model() + + with self.assertRaises(TypeError) as cm: + model.flow.compute() + + self.assertIn("Missing runtime input(s) for dynamic_model: multiplier, value", str(cm.exception)) + def test_all_defaults_is_valid(self): - """Test that all-defaults function is valid (everything can be pre-bound).""" + """All-default functions should treat those defaults as bound config.""" from ccflow import FlowContext @Flow.model def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: return GenericResult(value=f"{value}-{other}") - # No args provided -> everything comes from defaults or context model = all_defaults() - # All params are unbound (not provided at construction) - self.assertEqual(model.flow.unbound_inputs, {"value": int, "other": str}) + self.assertEqual(model.flow.bound_inputs, {"value": 1, "other": "x"}) + self.assertEqual(model.flow.unbound_inputs, {}) - # Call with context - context values override defaults ctx = FlowContext(value=5, other="y") result = model(ctx) - self.assertEqual(result.value, "5-y") + self.assertEqual(result.value, "1-x") def test_invalid_context_arg(self): """Test error when context_args refers to non-existent parameter.""" @@ -661,6 +711,30 @@ def untyped_context_arg(x) -> GenericResult[int]: self.assertIn("type annotation", str(cm.exception)) + def test_context_type_requires_context_args_mode(self): + """context_type is only valid alongside context_args.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_type=DateRangeContext) + def dynamic_model(value: int) -> GenericResult[int]: + return GenericResult(value=value) + + self.assertIn("context_args", str(cm.exception)) + + def test_context_type_must_cover_context_args(self): + """context_type must expose all named context_args fields.""" + + class StartOnlyContext(ContextBase): + start_date: date + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["start_date", "end_date"], context_type=StartOnlyContext) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + self.assertIn("end_date", str(cm.exception)) + # ============================================================================= # Validation Tests @@ -755,6 +829,32 @@ def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: finally: registry.clear() + def test_context_type_annotation_mismatch_raises(self): + """context_type validation should reject incompatible field annotations.""" + + class StringIdContext(ContextBase): + item_id: str + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["item_id"], context_type=StringIdContext) + def load(item_id: int) -> int: + return item_id + + self.assertIn("item_id", str(cm.exception)) + self.assertIn("int", str(cm.exception)) + self.assertIn("str", str(cm.exception)) + + def test_context_type_compatible_annotations_accepted(self): + """context_type validation should accept matching or subclass annotations.""" + + # Exact match should work + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_exact(start_date: date, end_date: date) -> str: + return f"{start_date}" + + self.assertIsNotNone(load_exact) + # ============================================================================= # BoundModel Tests @@ -780,7 +880,7 @@ def my_model(x: int, y: int) -> GenericResult[int]: result = bound.flow.compute(y=5) # y transform: 5 * 2 = 10, x is bound to 10 # model: 10 + 10 = 20 - self.assertEqual(result, 20) + self.assertEqual(result.value, 20) def test_bound_model_flow_compute_static_transform(self): """Test BoundModel.flow.compute() with static value transform.""" @@ -795,7 +895,19 @@ def my_model(x: int, y: int) -> GenericResult[int]: result = bound.flow.compute(y=999) # y should be overridden by transform # y is statically bound to 3, x=7 # 7 * 3 = 21 - self.assertEqual(result, 21) + self.assertEqual(result.value, 21) + + def test_bound_model_cloudpickle_with_lambda_transform(self): + """BoundModel with lambda transforms should survive cloudpickle round-trip.""" + + @Flow.model + def my_model(x: int, y: int) -> int: + return x + y + + bound = my_model(x=10).flow.with_inputs(y=lambda ctx: ctx.y * 2) + restored = rcploads(rcpdumps(bound, protocol=5)) + + self.assertEqual(restored.flow.compute(y=6).value, 22) def test_bound_model_as_dependency(self): """Test that BoundModel can be passed as a dependency to another model.""" @@ -817,7 +929,21 @@ def consumer(data: GenericResult[int]) -> GenericResult[int]: # x transform: 5 * 2 = 10 # source: 10 * 10 = 100 # consumer: 100 + 1 = 101 - self.assertEqual(result, 101) + self.assertEqual(result.value, 101) + + def test_flow_compute_with_upstream_callable_model_dependency(self): + """flow.compute() should resolve upstream generated-model dependencies.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int], offset: int = 1) -> int: + return data + offset + + model = consumer(data=source(), offset=3) + self.assertEqual(model.flow.compute(x=5).value, 53) def test_bound_model_chained_with_inputs(self): """Test that chaining with_inputs merges transforms correctly.""" @@ -836,7 +962,7 @@ def my_model(x: int, y: int, z: int) -> int: # y transform: 10 * 3 = 30 # z from context: 1 # 10 + 30 + 1 = 41 - self.assertEqual(result, 41) + self.assertEqual(result.value, 41) def test_bound_model_chained_with_inputs_override(self): """Test that chaining with_inputs allows overriding transforms.""" @@ -851,7 +977,7 @@ def my_model(x: int) -> int: # Second transform should override the first for 'x' result = bound2.flow.compute(x=5) - self.assertEqual(result, 50) # 5 * 10, not 5 * 2 + self.assertEqual(result.value, 50) # 5 * 10, not 5 * 2 def test_bound_model_with_default_args(self): """with_inputs works when the model has parameters with default values.""" @@ -867,24 +993,24 @@ def load(start_date: str, end_date: str, source: str = "warehouse") -> str: lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") - self.assertEqual(result, "prod_db:shifted_2024-01-01-2024-06-30") + self.assertEqual(result.value, "prod_db:shifted_2024-01-01-2024-06-30") - def test_bound_model_with_default_arg_unbound(self): - """with_inputs works when defaulted parameter is left unbound (comes from context).""" + def test_bound_model_with_default_arg_uses_default(self): + """with_inputs should preserve omitted Python defaults as bound config.""" @Flow.model def load(start_date: str, source: str = "warehouse") -> str: return f"{source}:{start_date}" - # Don't bind 'source' — it keeps its default in the model, - # but in dynamic deferred mode, unbound params come from context model = load() - # Transform start_date; source comes from context (overriding the default) bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) - result = bound.flow.compute(start_date="2024-01-01", source="s3_bucket") - self.assertEqual(result, "s3_bucket:shifted_2024-01-01") + self.assertEqual(model.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(model.flow.unbound_inputs, {"start_date": str}) + + result = bound.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "warehouse:shifted_2024-01-01") def test_bound_model_default_arg_as_dependency(self): """BoundModel with default args works correctly as a dependency.""" @@ -905,7 +1031,7 @@ def consumer(data: int) -> int: # x transform: 3 * 10 = 30 # source: 30 * 5 (multiplier) = 150 # consumer: 150 + 1 = 151 - self.assertEqual(result, 151) + self.assertEqual(result.value, 151) def test_bound_model_as_lazy_dependency(self): """Test that BoundModel works as a Lazy dependency.""" @@ -927,7 +1053,7 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: model = consumer(data=5, slow=bound_src) result = model.flow.compute(x=7) # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 - self.assertEqual(result, 51) + self.assertEqual(result.value, 51) def test_bound_and_unbound_models_share_memory_cache(self): """Shifted and unshifted models should share one evaluator cache. @@ -959,6 +1085,135 @@ def source(context: SimpleContext) -> GenericResult[int]: self.assertEqual(call_counts["source"], 2) self.assertEqual(len(evaluator.cache), 2) + def test_transform_error_propagates(self): + """A buggy transform should raise, not silently fall back to FlowContext.""" + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform has a typo — ctx.sart_date instead of ctx.start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.sart_date) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + with self.assertRaises(AttributeError): + bound(ctx) + + def test_transform_validation_error_propagates(self): + """If transforms produce invalid context data, the error should surface.""" + from pydantic import ValidationError + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform returns a string where a date is expected + bound = model.flow.with_inputs(start_date="not-a-date") + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + # Pydantic validation should raise, not silently fall back to FlowContext + with self.assertRaises(ValidationError): + bound(ctx) + + +class TestFlowModelPipe(TestCase): + """Tests for the ``.pipe(..., param=...)`` convenience API.""" + + def test_pipe_infers_single_required_parameter(self): + """pipe() should infer the only required downstream parameter.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: int, offset: int = 1) -> int: + return data + offset + + pipeline = source().pipe(consumer, offset=3) + self.assertEqual(pipeline.flow.compute(x=5).value, 53) + + def test_pipe_infers_single_defaulted_parameter(self): + """pipe() should fall back to a single defaulted downstream parameter.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int = 0) -> int: + return data + 1 + + pipeline = source().pipe(consumer) + self.assertEqual(pipeline.flow.compute(x=5).value, 51) + + def test_pipe_param_disambiguates_multiple_parameters(self): + """param= should identify the downstream argument to bind.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + pipeline = source().pipe(combine, param="right", left=7) + self.assertEqual(pipeline.flow.compute(x=5).value, 57) + + def test_pipe_rejects_ambiguous_downstream_stage(self): + """pipe() should require param= when multiple targets are available.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + with self.assertRaisesRegex( + TypeError, + r"pipe\(\) could not infer a target parameter for combine; unbound candidates are: left, right", + ): + source().pipe(combine) + + def test_manual_callable_model_can_pipe_into_generated_stage(self): + """Hand-written CallableModels should be usable as pipe sources.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def consumer(data: int, multiplier: int) -> int: + return data * multiplier + + pipeline = ManualModel(offset=5).pipe(consumer, multiplier=2) + self.assertEqual(pipeline.flow.compute(value=10).value, 30) + + def test_bound_model_pipe_preserves_downstream_transforms(self): + """pipe() should keep downstream with_inputs transforms intact.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, scale: int) -> int: + return data + scale + + shifted_source = source().flow.with_inputs(x=lambda ctx: ctx.scale + 1) + scaled_consumer = consumer().flow.with_inputs(scale=lambda ctx: ctx.scale * 3) + + pipeline = shifted_source.pipe(scaled_consumer) + self.assertEqual(pipeline.flow.compute(scale=2).value, 76) + # ============================================================================= # PEP 563 (from __future__ import annotations) Compatibility Tests @@ -1007,7 +1262,7 @@ def plain_model(value: int) -> int: model = plain_model() result = model.flow.compute(value=5) - self.assertEqual(result, 10) + self.assertEqual(result.value, 10) self.assertEqual(model.result_type, GenericResult) @@ -1159,7 +1414,7 @@ def hydra_consumer_model( # --- context_args fixtures for Hydra testing --- -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: """Loader using context_args with DateRangeContext.""" return GenericResult( @@ -1171,7 +1426,7 @@ def context_args_loader(start_date: date, end_date: date, source: str) -> Generi ) -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def context_args_processor( start_date: date, end_date: date, @@ -1539,235 +1794,6 @@ def consumer( self.assertEqual(result.value, 51) # 50 + 1 -# ============================================================================= -# FieldExtractor Tests (Structured Output Field Access) -# ============================================================================= - - -class TestFieldExtractor(TestCase): - """Tests for structured output field access (prepared.X_train pattern).""" - - def test_field_extraction_basic(self): - """Accessing unknown attr on @Flow.model instance returns FieldExtractor.""" - from ccflow.flow_model import FieldExtractor - - @Flow.model - def prepare(context: SimpleContext, factor: int = 2) -> GenericResult[dict]: - return GenericResult(value={"X_train": context.value * factor, "X_test": context.value}) - - model = prepare(factor=3) - extractor = model.X_train - - self.assertIsInstance(extractor, FieldExtractor) - self.assertIs(extractor.source, model) - self.assertEqual(extractor.field_name, "X_train") - - def test_field_extraction_evaluates_correctly(self): - """FieldExtractor runs source and extracts the named field.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"X_train": [1, 2, 3], "y_train": [4, 5, 6]}) - - model = prepare() - x_train = model.X_train - - result = x_train(SimpleContext(value=0)) - self.assertEqual(result.value, [1, 2, 3]) - - def test_field_extraction_as_dependency(self): - """FieldExtractor wired as a dep to a downstream model. - - Note: FieldExtractors are CallableModels, so they're auto-detected as deps - and auto-unwrapped (GenericResult.value). The downstream function receives - the raw extracted value, not a GenericResult wrapper. - """ - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - v = context.value - return GenericResult(value={"X_train": [v, v * 2], "y_train": [v * 10]}) - - @Flow.model - def train(context: SimpleContext, X: list, y: list) -> GenericResult[int]: - # X and y are auto-unwrapped to the raw list values - return GenericResult(value=sum(X) + sum(y)) - - prepared = prepare() - model = train(X=prepared.X_train, y=prepared.y_train) - - result = model(SimpleContext(value=5)) - # X_train = [5, 10], y_train = [50] - # sum(X) + sum(y) = 15 + 50 = 65 - self.assertEqual(result.value, 65) - - def test_field_extraction_multiple_from_same_source(self): - """Multiple extractors from same source share the source instance.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"a": 1, "b": 2, "c": 3}) - - model = prepare() - ext_a = model.a - ext_b = model.b - ext_c = model.c - - # All should reference the same source - self.assertIs(ext_a.source, model) - self.assertIs(ext_b.source, model) - self.assertIs(ext_c.source, model) - - # All should evaluate correctly - ctx = SimpleContext(value=0) - self.assertEqual(ext_a(ctx).value, 1) - self.assertEqual(ext_b(ctx).value, 2) - self.assertEqual(ext_c(ctx).value, 3) - - def test_field_extraction_nested(self): - """Chained extraction (result.a.b) creates nested FieldExtractors.""" - from ccflow.flow_model import FieldExtractor - - class Nested: - def __init__(self): - self.inner_val = 42 - - @Flow.model - def produce(context: SimpleContext) -> GenericResult: - return GenericResult(value={"nested": Nested()}) - - model = produce() - nested_extractor = model.nested - inner_extractor = nested_extractor.inner_val - - self.assertIsInstance(nested_extractor, FieldExtractor) - self.assertIsInstance(inner_extractor, FieldExtractor) - - result = inner_extractor(SimpleContext(value=0)) - self.assertEqual(result.value, 42) - - def test_field_extraction_context_type_inherited(self): - """FieldExtractor inherits context_type from source.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - self.assertEqual(extractor.context_type, SimpleContext) - - def test_field_extraction_nonexistent_field_runtime_error(self): - """Non-existent field raises error at evaluation time, not construction. - - For dict results, raises KeyError. For object results, raises AttributeError. - """ - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.nonexistent # No error at construction - - # Error at evaluation time (KeyError for dicts, AttributeError for objects) - with self.assertRaises((KeyError, AttributeError)): - extractor(SimpleContext(value=0)) - - def test_field_extraction_pydantic_fields_not_intercepted(self): - """Accessing real pydantic fields returns the field value, NOT an extractor.""" - from ccflow.flow_model import FieldExtractor - - @Flow.model - def model_with_fields(context: SimpleContext, multiplier: int = 5) -> GenericResult[int]: - return GenericResult(value=context.value * multiplier) - - model = model_with_fields(multiplier=10) - - # 'multiplier' is a real pydantic field — should return the value, not a FieldExtractor - self.assertEqual(model.multiplier, 10) - self.assertNotIsInstance(model.multiplier, FieldExtractor) - - # 'meta' is inherited from CallableModel — should also not be intercepted - self.assertNotIsInstance(model.meta, FieldExtractor) - - def test_field_extraction_with_context_args(self): - """FieldExtractor works with context_args mode models.""" - from ccflow import FlowContext - - @Flow.model(context_args=["x"]) - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2, "tripled": x * 3}) - - model = prepare() - doubled = model.doubled - - result = doubled(FlowContext(x=5)) - self.assertEqual(result.value, 10) - - def test_field_extraction_has_flow_property(self): - """FieldExtractor has .flow property (inherits from CallableModel).""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - self.assertTrue(hasattr(extractor, "flow")) - - def test_field_extraction_deps(self): - """FieldExtractor.__deps__ returns the source as a dependency.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - ctx = SimpleContext(value=0) - deps = extractor.__deps__(ctx) - - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], model) - self.assertEqual(deps[0][1], [ctx]) - - def test_field_extraction_from_bound_model(self): - """Field extraction should still work after .flow.with_inputs().""" - - @Flow.model - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2}) - - bound = prepare().flow.with_inputs(x=lambda ctx: ctx.x + 1) - extractor = bound.doubled - - result = extractor.flow.compute(x=5) - self.assertEqual(result, 12) - - def test_field_extraction_deps_from_bound_model(self): - """Bound-model extractors should preserve transformed dependency contexts.""" - from ccflow import FlowContext - - @Flow.model - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2}) - - model = prepare() - bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) - extractor = bound.doubled - - ctx = FlowContext(x=5) - deps = extractor.__deps__(ctx) - - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], model) - self.assertEqual(deps[0][1][0].x, 6) - - if __name__ == "__main__": import unittest diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 76adbea..7b6ac9f 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -36,13 +36,13 @@ def add(x: int, y: int) -> int: model = add(x=10) # Explicit deferred entry point -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 # Standard CallableModel call path assert model(FlowContext(y=5)).value == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 +assert shifted.flow.compute(y=5).value == 20 ``` In this mode: @@ -73,7 +73,7 @@ from datetime import date from ccflow import Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: return 125.0 ``` @@ -85,9 +85,8 @@ Use `context_args` when certain parameters are semantically the execution context and you want that split to be explicit and stable across model instances. -When the requested shape matches a built-in context like -`DateRangeContext(start_date, end_date)`, the generated model uses that type. -Otherwise it falls back to `FlowContext`. +By default, `context_args` models use `FlowContext`. If you want compatibility +with an existing context class, pass `context_type=...` explicitly. ### Upstream Models as Normal Arguments @@ -130,13 +129,13 @@ from datetime import date, timedelta from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: return { "window_end": end_date, @@ -156,7 +155,7 @@ ctx = DateRangeContext( end_date=date(2024, 1, 31), ) -direct = model(ctx).value +direct = model(ctx) computed = model.flow.compute( start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), @@ -182,49 +181,17 @@ def add(x: int, y: int) -> int: model = add(x=10) -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 ``` It validates the supplied keyword arguments against the generated context -schema, creates a `FlowContext`, executes the model, and unwraps -`GenericResult.value` if needed. +schema, creates a `FlowContext`, and executes the model. + +It returns the same result object you would get from calling `model(context)`. It is not the only execution path. Because the generated object is still a standard `CallableModel`, calling `model(context)` remains fully supported. -## FieldExtractor - -Accessing an unknown public attribute on a `@Flow.model` instance returns a -`FieldExtractor`. It is itself a `CallableModel` that runs the source model, -then extracts the named field from the result (via `getattr` or dict key -access). - -```python -from ccflow import ContextBase, Flow, GenericResult - - -class TrainingContext(ContextBase): - seed: int - - -@Flow.model -def prepare(context: TrainingContext) -> GenericResult[dict]: - s = context.seed - return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) - - -@Flow.model -def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: - return GenericResult(value=sum(X) + sum(y)) - - -prepared = prepare() -model = train(X=prepared.X_train, y=prepared.y_train) -``` - -Multiple extractors from the same source share the source model instance. If -caching is enabled the source is evaluated only once. - ## Lazy Inputs `Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 4e34fc3..e0aac45 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -76,23 +76,22 @@ from the context at runtime. The remaining parameters are model configuration. from datetime import date from ccflow import Flow, GenericResult, DateRangeContext -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") loader = load_data(source="my_database") -# For well-known field sets the decorator matches a built-in context type +# Opt in explicitly when you want compatibility with an existing context type assert loader.context_type == DateRangeContext ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -For well-known shapes such as `start_date` / `end_date` with `date` -annotations, the generated model uses a concrete built-in context type like -`DateRangeContext`. Otherwise it falls back to `FlowContext`, a universal -frozen carrier for the validated fields. +By default, `context_args` models use `FlowContext`, a universal frozen carrier +for the validated fields. If you want the generated model to advertise and +accept an existing `ContextBase` subclass, pass `context_type=...` explicitly. Use `context_args` when some parameters are semantically "the execution context" and you want that split to stay stable and explicit: @@ -101,8 +100,8 @@ context" and you want that split to stay stable and explicit: - the split between config and runtime inputs matters semantically - the model is naturally "run over a context" such as date windows, partitions, or scenarios -- you want the generated model to match a built-in context type like - `DateRangeContext` when possible +- you want the generated model to accept a specific existing context type + such as `DateRangeContext` **Mode 3 — Default deferred style (no explicit context):** @@ -121,11 +120,11 @@ model = add(x=10) # `x` is bound when the model is created. # `y` is supplied later at execution time. -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 # `.flow.with_inputs(...)` rewrites runtime inputs for this call path. doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert doubled_y.flow.compute(y=5) == 20 +assert doubled_y.flow.compute(y=5).value == 20 ``` #### Composing Dependencies @@ -138,12 +137,12 @@ the current context and passes the resolved value into your function. from datetime import date, timedelta from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_growth( start_date: date, end_date: date, @@ -168,7 +167,7 @@ ctx = DateRangeContext( ) # Standard ccflow execution -direct = growth(ctx).value +direct = growth(ctx) # Equivalent explicit deferred entry point computed = growth.flow.compute( @@ -182,8 +181,8 @@ assert direct == computed #### Deferred Execution Helpers **`.flow.compute(**kwargs)`** validates the keyword arguments against the -generated context schema, wraps them in a `FlowContext`, calls the model, and -unwraps `GenericResult.value` if present. +generated context schema, wraps them in a `FlowContext`, and calls the model. +It returns the same result object you would get from `model(context)`. **`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies context transforms before delegating to the underlying model. Each transform @@ -198,10 +197,10 @@ def add(x: int, y: int) -> int: return x + y model = add(x=10) -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 +assert shifted.flow.compute(y=5).value == 20 # You can also call with a context object directly ctx = FlowContext(y=5) @@ -209,38 +208,6 @@ assert model(ctx).value == 15 assert shifted(ctx).value == 20 ``` -#### Field Extraction - -Accessing an unknown attribute on a `@Flow.model` instance returns a -`FieldExtractor` — a `CallableModel` that runs the source model and extracts -the named field from its result. This makes it easy to wire individual output -fields into downstream models. - -```python -from ccflow import ContextBase, Flow, GenericResult - -class TrainingContext(ContextBase): - seed: int - -@Flow.model -def prepare(context: TrainingContext) -> GenericResult[dict]: - s = context.seed - return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) - -@Flow.model -def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: - return GenericResult(value=sum(X) + sum(y)) - -prepared = prepare() -model = train(X=prepared.X_train, y=prepared.y_train) -result = model(TrainingContext(seed=5)) -# X_train = [5, 10], y_train = [50] -> 15 + 50 = 65 -assert result.value == 65 -``` - -Multiple extractors from the same source share the source model instance, so -with caching enabled the source is only evaluated once. - #### Lazy Dependencies with `Lazy[T]` Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index 27e31bb..27d5d0e 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -17,7 +17,7 @@ from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: """Return synthetic revenue for one reporting window.""" days = (end_date - start_date).days + 1 @@ -27,7 +27,7 @@ def load_revenue(start_date: date, end_date: date, region: str) -> float: return round(region_base + days * 8.0 + trend, 2) -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_change( start_date: date, end_date: date, @@ -81,7 +81,7 @@ def main() -> None: end_date=date(2024, 3, 31), ) - direct = pipeline(ctx).value + direct = pipeline(ctx) computed = pipeline.flow.compute( start_date=ctx.start_date, end_date=ctx.end_date, @@ -95,7 +95,7 @@ def main() -> None: print(f" direct == computed: {direct == computed}") print("\nResult:") - for key, value in computed.items(): + for key, value in computed.value.items(): print(f" {key}: {value}") From 20612e27ed5ca40a12907cf15b1d7c77567ccd2d Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 14:19:14 -0400 Subject: [PATCH 14/17] Further clean-up Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 259 ++++++++++++++++++--- ccflow/tests/test_flow_model.py | 389 +++++++++++++++++++++++++++++++- 2 files changed, 611 insertions(+), 37 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index da9d1bb..ab0b3f4 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -11,14 +11,15 @@ import inspect import logging +import threading from functools import wraps from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints -from pydantic import Field, TypeAdapter, model_validator -from typing_extensions import TypedDict +from pydantic import Field, PrivateAttr, TypeAdapter, model_serializer, model_validator +from typing_extensions import NotRequired, TypedDict from .base import ContextBase, ResultBase -from .callable import CallableModel, Flow, GraphDepList +from .callable import CallableModel, Flow, GraphDepList, WrapperModel from .context import FlowContext from .local_persistence import register_ccflow_import_path from .result import GenericResult @@ -109,7 +110,7 @@ def _transform_repr(transform: Any) -> str: def _is_model_dependency(value: Any) -> bool: - return isinstance(value, (CallableModel, BoundModel)) + return isinstance(value, CallableModel) def _bound_field_names(model: Any) -> set[str]: @@ -160,6 +161,22 @@ def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: return True +def _type_accepts_str(annotation) -> bool: + """Return True when ``str`` is a valid type for *annotation*. + + Handles ``str``, ``Union[str, ...]``, ``Optional[str]``, and + ``Annotated[str, ...]``. + """ + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin is Union: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -199,6 +216,17 @@ def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str return validatable_types, validators +def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapter], validatable_types: Dict[str, Type]) -> Any: + """Validate/coerce a single context-sourced value. Returns coerced value or raises TypeError.""" + if name not in validators: + return value + try: + return validators[name].validate_python(value) + except Exception: + expected = validatable_types.get(name, "unknown") + raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") + + def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: """Validate plain config inputs while still allowing dependency objects.""" @@ -214,8 +242,10 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, if value is None or _is_model_dependency(value): continue if isinstance(value, str) and value in _MR.root(): - candidate = _resolve_registry_candidate(value) expected_type = validatable_types[field_name] + if _type_accepts_str(expected_type): + continue + candidate = _resolve_registry_candidate(value) if candidate is not None and _registry_candidate_allowed(expected_type, candidate): continue try: @@ -227,7 +257,7 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: if isinstance(stage, BoundModel): - model = stage._model + model = stage.model else: model = stage if isinstance(model, _GeneratedFlowModelBase): @@ -324,7 +354,7 @@ def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bin """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" if not _is_model_dependency(source): - raise TypeError(f"pipe() source must be a CallableModel or BoundModel, got {type(source).__name__}.") + raise TypeError(f"pipe() source must be a CallableModel, got {type(source).__name__}.") target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) build_kwargs = dict(bindings) @@ -364,6 +394,8 @@ def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: if get_validator is not None: validator = get_validator() validated = validator.validate_python(kwargs) + if isinstance(validated, ContextBase): + return validated return FlowContext(**validated) validator = TypeAdapter(self._model.context_type) @@ -393,22 +425,27 @@ def unbound_inputs(self) -> Dict[str, Type]: all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) model_cls = self._model.__class__ - # If explicit context_args was provided, use _context_schema + # If explicit context_args was provided, use _context_schema minus + # fields that have function defaults (they aren't truly required). explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) if explicit_args is not None: context_schema = getattr(model_cls, "_context_schema", None) - return context_schema.copy() if context_schema is not None else {} + if context_schema is None: + return {} + ctx_arg_defaults = getattr(model_cls, "__flow_model_context_arg_defaults__", {}) + return {name: typ for name, typ in context_schema.items() if name not in ctx_arg_defaults} # Dynamic @Flow.model: unbound = params with no explicit value and no declared default if all_param_types: runtime_inputs = _runtime_input_names(self._model) return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} - # Generic CallableModel: runtime inputs are the context schema. + # Generic CallableModel / Mode 1: runtime inputs are the required + # context fields (fields with defaults are not required). context_cls = _concrete_context_type(self._model.context_type) if context_cls is None or not hasattr(context_cls, "model_fields"): return {} - return {name: info.annotation for name, info in context_cls.model_fields.items()} + return {name: info.annotation for name, info in context_cls.model_fields.items() if info.is_required()} @property def bound_inputs(self) -> Dict[str, Any]: @@ -445,7 +482,25 @@ def with_inputs(self, **transforms) -> "BoundModel": return BoundModel(model=self._model, input_transforms=transforms) -class BoundModel: +_bound_model_restore = threading.local() + + +def _fingerprint_transforms(transforms: Dict[str, Any]) -> Dict[str, str]: + """Create a stable, hashable fingerprint of input transforms for cache key differentiation. + + Callable transforms are identified by their id() (unique per object), which is + stable within a process lifetime. Static values are repr'd directly. + """ + result = {} + for name, transform in sorted(transforms.items()): + if callable(transform): + result[name] = f"callable:{id(transform)}" + else: + result[name] = repr(transform) + return result + + +class BoundModel(WrapperModel): """A model with context transforms applied. Created by model.flow.with_inputs(). Applies transforms to context @@ -466,9 +521,44 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): - self._model = model - self._input_transforms = input_transforms + _input_transforms: Dict[str, Any] = PrivateAttr(default_factory=dict) + + @model_validator(mode="wrap") + @classmethod + def _restore_serialized_transforms(cls, values, handler): + """Strip serialization-injected keys, restore static transforms, guarantee cleanup. + + Uses thread-local storage to pass static transforms to __init__ because + pydantic rejects unknown keys in the input dict. The wrap validator's + try/finally ensures the thread-local is always cleaned up, even if + validation fails before __init__ runs. + """ + if isinstance(values, dict): + values = dict(values) # Don't mutate the caller's dict + values.pop("_input_transforms_token", None) + static = values.pop("_static_transforms", None) + else: + static = None + + if static is not None: + _bound_model_restore.pending = static + try: + return handler(values) + except Exception: + _bound_model_restore.pending = None + raise + + def __init__(self, *, model: CallableModel, input_transforms: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__(model=model, **kwargs) + restore = getattr(_bound_model_restore, "pending", None) + if restore is not None: + _bound_model_restore.pending = None + if input_transforms is not None: + self._input_transforms = input_transforms + elif restore is not None: + self._input_transforms = restore + else: + self._input_transforms = {} def _transform_context(self, context: ContextBase) -> ContextBase: """Return this model's preferred context type with input transforms applied.""" @@ -478,14 +568,35 @@ def _transform_context(self, context: ContextBase) -> ContextBase: ctx_dict[name] = transform(context) else: ctx_dict[name] = transform - context_type = _concrete_context_type(self._model.context_type) + context_type = _concrete_context_type(self.model.context_type) if context_type is not None and context_type is not FlowContext: return context_type.model_validate(ctx_dict) return FlowContext(**ctx_dict) - def __call__(self, context: ContextBase) -> Any: + @Flow.call + def __call__(self, context: ContextBase) -> ResultBase: """Call the model with transformed context.""" - return self._model(self._transform_context(context)) + return self.model(self._transform_context(context)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + """Declare the wrapped model as an upstream dependency with transformed context.""" + return [(self.model, [self._transform_context(context)])] + + @model_serializer(mode="wrap") + def _serialize_with_transforms(self, handler): + """Include transforms in serialization for cache keys and faithful roundtrips. + + Static (non-callable) transforms are serialized in _static_transforms for + faithful restoration. A fingerprint token covers all transforms (including + callables) for cache key differentiation. + """ + data = handler(self) + static = {k: v for k, v in self._input_transforms.items() if not callable(v)} + if static: + data["_static_transforms"] = static + data["_input_transforms_token"] = _fingerprint_transforms(self._input_transforms) + return data def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" @@ -493,28 +604,24 @@ def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) - def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) - return f"{self._model!r}.flow.with_inputs({transforms})" + return f"{self.model!r}.flow.with_inputs({transforms})" @property def flow(self) -> "FlowAPI": """Access the flow API.""" return _BoundFlowAPI(self) - @property - def context_type(self) -> Type[ContextBase]: - return self._model.context_type - class _BoundFlowAPI(FlowAPI): """FlowAPI that delegates to a BoundModel, honoring transforms.""" def __init__(self, bound_model: BoundModel): self._bound = bound_model - super().__init__(bound_model._model) + super().__init__(bound_model.model) def compute(self, **kwargs) -> Any: ctx = self._build_context(kwargs) - return self._bound(ctx) # Call through BoundModel, not _model + return self._bound(ctx) # Call through BoundModel, not inner model def with_inputs(self, **transforms) -> "BoundModel": """Chain transforms: merge new transforms with existing ones. @@ -522,7 +629,7 @@ def with_inputs(self, **transforms) -> "BoundModel": New transforms override existing ones for the same key. """ merged = {**self._bound._input_transforms, **transforms} - return BoundModel(model=self._bound._model, input_transforms=merged) + return BoundModel(model=self._bound.model, input_transforms=merged) class _GeneratedFlowModelBase(CallableModel): @@ -535,7 +642,10 @@ class _GeneratedFlowModelBase(CallableModel): __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} __flow_model_default_param_names__: ClassVar[set[str]] = set() + __flow_model_context_arg_defaults__: ClassVar[Dict[str, Any]] = {} __flow_model_auto_wrap__: ClassVar[bool] = False + __flow_model_validatable_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_config_validators__: ClassVar[Dict[str, TypeAdapter]] = {} _context_schema: ClassVar[Dict[str, Type]] = {} _context_td: ClassVar[Any | None] = None _cached_context_validator: ClassVar[TypeAdapter | None] = None @@ -553,7 +663,7 @@ def _resolve_registry_refs(cls, values, info): value = resolved[field_name] if not isinstance(value, str): continue - if expected_type is str: + if _type_accepts_str(expected_type): continue candidate = _resolve_registry_candidate(value) if candidate is None: @@ -562,6 +672,30 @@ def _resolve_registry_refs(cls, values, info): resolved[field_name] = candidate return resolved + @model_validator(mode="after") + def _validate_field_types(self): + """Validate field values against their declared types. + + This catches type mismatches in the model_validate/deserialization path, + where fields are typed as Any and pydantic won't reject wrong types. + """ + cls = self.__class__ + config_validators = getattr(cls, "__flow_model_config_validators__", {}) + validatable_types = getattr(cls, "__flow_model_validatable_types__", {}) + if not config_validators: + return self + + for field_name, validator in config_validators.items(): + value = getattr(self, field_name, _DEFERRED_INPUT) + if _has_deferred_input(value) or value is None or _is_model_dependency(value): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + return self + @property def context_type(self) -> Type[ContextBase]: return self.__class__.__flow_model_context_type__ @@ -582,7 +716,13 @@ def _get_context_validator(self) -> TypeAdapter: if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): if cls._cached_context_validator is None: - if cls._context_td is not None: + use_ctx_args = getattr(cls, "__flow_model_use_context_args__", True) + ctx_type = cls.__flow_model_context_type__ + if not use_ctx_args and isinstance(ctx_type, type) and issubclass(ctx_type, ContextBase) and ctx_type is not FlowContext: + # Mode 1 with concrete context type — use TypeAdapter(context_type) + # directly so defaults on the context type are respected. + cls._cached_context_validator = TypeAdapter(ctx_type) + elif cls._context_td is not None: cls._cached_context_validator = TypeAdapter(cls._context_td) elif cls._context_schema: cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) @@ -701,6 +841,7 @@ def _build_context_schema( """ # Build schema dict from parameter annotations schema = {} + td_schema = {} for name in context_args: if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") @@ -709,14 +850,25 @@ def _build_context_schema( if annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") schema[name] = annotation + # Use NotRequired in the TypedDict for params that have a default in the + # function signature, so compute() doesn't require them. + if param.default is not inspect.Parameter.empty: + td_schema[name] = NotRequired[annotation] + else: + td_schema[name] = annotation # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) + context_td = TypedDict(f"{_callable_name(func)}Inputs", td_schema) return schema, context_td -def _validate_context_type_override(context_type: Any, context_args: List[str], func_schema: Dict[str, Type]) -> Type[ContextBase]: +def _validate_context_type_override( + context_type: Any, + context_args: List[str], + func_schema: Dict[str, Type], + func_defaults: set[str] = frozenset(), +) -> Type[ContextBase]: """Validate an explicit ``context_type`` override for ``context_args`` mode.""" if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): @@ -752,6 +904,14 @@ def _validate_context_type_override(context_type: Any, context_args: List[str], f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" ) + # Reject if the function has a default for a context_arg but the + # context_type declares that field as required — this is contradictory. + for name in context_args: + if name in func_defaults: + ctx_field = context_fields.get(name) + if ctx_field is not None and ctx_field.is_required(): + raise TypeError(f"context_arg '{name}': function has a default but context_type {context_type.__name__} requires this field") + return context_type @@ -855,8 +1015,11 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: # Mode 2: Explicit context_args - specified params come from context context_param_name = "context" context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) + _func_defaults_set = {name for name in context_args if sig.parameters[name].default is not inspect.Parameter.empty} explicit_context_type = ( - _validate_context_type_override(context_type, context_args, context_schema_early) if context_type is not None else None + _validate_context_type_override(context_type, context_args, context_schema_early, _func_defaults_set) + if context_type is not None + else None ) resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext # Exclude context_args from model fields @@ -914,6 +1077,17 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: ctx_args_for_closure = context_args if context_args is not None else [] is_dynamic_mode = use_context_args and explicit_context_args is None + # Compute context_arg defaults and validators for Mode 2 (context_args) + context_arg_defaults: Dict[str, Any] = {} + _ctx_validatable_types: Dict[str, Type] = {} + _ctx_validators: Dict[str, TypeAdapter] = {} + if context_args is not None: + for name in context_args: + p = sig.parameters[name] + if p.default is not inspect.Parameter.empty: + context_arg_defaults[name] = p.default + _ctx_validatable_types, _ctx_validators = _build_config_validators(context_schema_early) + # Create the __call__ method def make_call_impl(): def __call__(self, context): @@ -929,7 +1103,7 @@ def resolve_callable_model(value): def _resolve_field(name, value): """Resolve a single field value, handling lazy wrapping.""" - is_dep = isinstance(value, (CallableModel, BoundModel)) + is_dep = isinstance(value, CallableModel) if name in lazy_fields: # Lazy field: wrap in a thunk regardless of type if is_dep: @@ -952,7 +1126,14 @@ def _resolve_field(name, value): elif not is_dynamic_mode: # Mode 2: Explicit context_args - get those from context, rest from self for name in ctx_args_for_closure: - fn_kwargs[name] = getattr(context, name) + value = getattr(context, name, _UNSET) + if value is _UNSET: + if name in context_arg_defaults: + fn_kwargs[name] = context_arg_defaults[name] + else: + raise TypeError(f"Missing context field '{name}'") + else: + fn_kwargs[name] = _coerce_context_value(name, value, _ctx_validators, _ctx_validatable_types) # Add model fields for name in all_param_names: value = getattr(self, name) @@ -976,7 +1157,10 @@ def _resolve_field(name, value): if value is _UNSET: missing_fields.append(name) continue - fn_kwargs[name] = value + # Validate/coerce context-sourced value, skip CallableModel deps + if not _is_model_dependency(value): + value = _coerce_context_value(name, value, _config_validators, _validatable_types) + fn_kwargs[name] = _resolve_field(name, value) if missing_fields: missing = ", ".join(sorted(missing_fields)) @@ -1021,13 +1205,13 @@ def _resolve_field(name, value): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - # Check ALL fields for CallableModels/BoundModels (auto-detection) + # Check ALL fields for CallableModel dependencies (auto-detection) for name in model_fields: if name in lazy_fields: continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) if isinstance(value, BoundModel): - deps.append((value._model, [value._transform_context(context)])) + deps.append((value.model, [value._transform_context(context)])) elif isinstance(value, CallableModel): deps.append((value, [context])) return deps @@ -1078,7 +1262,10 @@ def __deps__(self, context) -> GraphDepList: GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type GeneratedModel.__flow_model_default_param_names__ = default_param_names + GeneratedModel.__flow_model_context_arg_defaults__ = context_arg_defaults GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result + GeneratedModel.__flow_model_validatable_types__ = _validatable_types + GeneratedModel.__flow_model_config_validators__ = _config_validators # Build context_schema context_schema: Dict[str, Type] = {} diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index a2e788e..ad30824 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -845,6 +845,49 @@ def load(item_id: int) -> int: self.assertIn("int", str(cm.exception)) self.assertIn("str", str(cm.exception)) + def test_model_validate_rejects_bad_scalar_type(self): + """model_validate should reject wrong scalar types, not silently accept them.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + with self.assertRaises(TypeError) as cm: + cls.model_validate({"x": "abc"}) + + self.assertIn("x", str(cm.exception)) + + def test_model_validate_accepts_correct_type(self): + """model_validate should accept correct types.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + restored = cls.model_validate({"x": 42}) + self.assertEqual(restored(SimpleContext(value=0)).value, 42) + + def test_model_validate_rejects_bad_registry_alias(self): + """Typoed registry aliases should not silently pass through model_validate.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def consumer(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + cls = type(consumer(n=1)) + # "not_in_registry" is not a valid int and not a valid registry key + with self.assertRaises(TypeError) as cm: + cls.model_validate({"n": "not_in_registry"}) + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + def test_context_type_compatible_annotations_accepted(self): """context_type validation should accept matching or subclass annotations.""" @@ -864,6 +907,16 @@ def load_exact(start_date: date, end_date: date) -> str: class TestBoundModel(TestCase): """Tests for BoundModel and BoundModel.flow.""" + def test_bound_model_is_callable_model(self): + """BoundModel should be a proper CallableModel subclass.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + bound = source().flow.with_inputs(x=lambda ctx: ctx.x * 2) + self.assertIsInstance(bound, CallableModel) + def test_bound_model_flow_compute(self): """Test that bound.flow.compute() honors transforms.""" @@ -897,6 +950,66 @@ def my_model(x: int, y: int) -> GenericResult[int]: # 7 * 3 = 21 self.assertEqual(result.value, 21) + def test_bound_model_dump_validate_roundtrip_static(self): + """Static transforms survive model_dump → model_validate roundtrip.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + restored = type(bound).model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(bound(ctx).value, 420) + self.assertEqual(restored(ctx).value, 420) + + def test_bound_model_validate_same_payload_twice(self): + """Validating the same serialized BoundModel payload twice should work both times.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + + r1 = BoundModel.model_validate(dump) + r2 = BoundModel.model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(r1(ctx).value, 420) + self.assertEqual(r2(ctx).value, 420) + + def test_bound_model_failed_validate_does_not_poison_next_construction(self): + """A failed model_validate must not leak static transforms to subsequent constructions.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + base = source() + + # Attempt a model_validate that will fail (invalid model field) + try: + BoundModel.model_validate( + { + "model": "not-a-real-model", + "_static_transforms": {"value": 42}, + "_input_transforms_token": {"value": "42"}, + } + ) + except Exception: + pass # Expected to fail + + # Now construct a fresh BoundModel normally — must NOT inherit stale transforms + clean = BoundModel(model=base, input_transforms={}) + ctx = SimpleContext(value=1) + self.assertEqual(clean(ctx).value, 10) # 1 * 10, no transform applied + def test_bound_model_cloudpickle_with_lambda_transform(self): """BoundModel with lambda transforms should survive cloudpickle round-trip.""" @@ -1055,6 +1168,33 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 self.assertEqual(result.value, 51) + def test_differently_transformed_bound_models_have_distinct_cache_keys(self): + """Two BoundModels with different transforms must not collide under caching.""" + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + b1 = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + b2 = base.flow.with_inputs(value=lambda ctx: ctx.value + 2) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + r1 = b1(ctx) + r2 = b2(ctx) + + # b1 transforms value to 6, source: 6*10=60 + # b2 transforms value to 7, source: 7*10=70 + self.assertEqual(r1.value, 60) + self.assertEqual(r2.value, 70) + # Source called twice (once per distinct transformed context) + self.assertEqual(call_counts["source"], 2) + def test_bound_and_unbound_models_share_memory_cache(self): """Shifted and unshifted models should share one evaluator cache. @@ -1083,7 +1223,9 @@ def source(context: SimpleContext) -> GenericResult[int]: # One execution for the unshifted context and one for the shifted context. self.assertEqual(call_counts["source"], 2) - self.assertEqual(len(evaluator.cache), 2) + # Cache has 3 entries: base(ctx), BoundModel(ctx), and base(shifted_ctx). + # BoundModel is a proper CallableModel now, so it gets its own cache entry. + self.assertEqual(len(evaluator.cache), 3) def test_transform_error_propagates(self): """A buggy transform should raise, not silently fall back to FlowContext.""" @@ -1794,6 +1936,251 @@ def consumer( self.assertEqual(result.value, 51) # 50 + 1 +# ============================================================================= +# Bug Fix Regression Tests +# ============================================================================= + + +class TestFlowModelBugFixes(TestCase): + """Regression tests for four bugs identified during code review.""" + + # ----- Issue 1: .flow.compute() drops context defaults ----- + + def test_compute_respects_explicit_context_defaults(self): + """Mode 1: compute(x=1) should use ExtendedContext's default y='default'.""" + + @Flow.model + def model_fn(context: ExtendedContext, factor: int = 1) -> str: + return f"{context.x}-{context.y}-{factor}" + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, "1-default-1") + + def test_compute_respects_context_args_defaults(self): + """Mode 2: compute(x=1) should use function default y=42.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, 43) + + def test_unbound_inputs_excludes_context_args_with_defaults(self): + """Mode 2: unbound_inputs should not include context_args that have function defaults.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_unbound_inputs_excludes_context_type_defaults(self): + """Mode 1: unbound_inputs should not include context fields that have defaults.""" + + @Flow.model + def model_fn(context: ExtendedContext) -> str: + return f"{context.x}-{context.y}" + + model = model_fn() + # ExtendedContext has x: int (required) and y: str = "default" + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_context_type_rejects_required_field_with_function_default(self): + """Decoration should fail when function has default but context_type requires the field.""" + + class StrictContext(ContextBase): + x: int # required + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["x"], context_type=StrictContext) + def model_fn(x: int = 5) -> int: + return x + + self.assertIn("x", str(cm.exception)) + self.assertIn("requires", str(cm.exception)) + + def test_context_type_accepts_optional_field_with_function_default(self): + """Both context_type and function have defaults — should work.""" + + class OptionalContext(ContextBase): + x: int = 10 + + @Flow.model(context_args=["x"], context_type=OptionalContext) + def model_fn(x: int = 5) -> int: + return x + + model = model_fn() + result = model(OptionalContext()) + self.assertEqual(result.value, 10) # context default wins + + # ----- Issue 2: Lazy[...] broken in dynamic deferred mode ----- + + def test_lazy_from_runtime_context_in_dynamic_mode(self): + """Lazy[int] provided via FlowContext should be wrapped in a thunk.""" + + @Flow.model + def model_fn(x: int, y: Lazy[int]) -> int: + return x + y() + + model = model_fn(x=10) + result = model(FlowContext(y=32)) + self.assertEqual(result.value, 42) + + def test_callable_model_from_runtime_context_in_dynamic_mode(self): + """CallableModel provided in FlowContext should be resolved.""" + + @Flow.model + def source(value: int) -> int: + return value * 10 + + @Flow.model + def consumer(x: int, data: int) -> int: + return x + data + + model = consumer(x=1) + src = source() + result = model(FlowContext(data=src, value=5)) + # source resolves with value=5 → 50, consumer: 1 + 50 = 51 + self.assertEqual(result.value, 51) + + # ----- Issue 3: FlowContext-backed models skip schema validation ----- + + def test_direct_call_validates_flowcontext_dynamic_mode(self): + """Dynamic mode: FlowContext(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x=1, y="hello")) + + self.assertIn("y", str(cm.exception)) + + def test_direct_call_validates_flowcontext_context_args_mode(self): + """context_args mode: FlowContext(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x="hello")) + + self.assertIn("x", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_dynamic(self): + """Dynamic mode: with_inputs(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn(x=1) + bound = model.flow.with_inputs(y="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("y", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_context_args(self): + """context_args mode: with_inputs(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + bound = model.flow.with_inputs(x="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("x", str(cm.exception)) + + # ----- Issue 4: Registry-name resolution too aggressive for union strings ----- + + def test_registry_resolution_skips_union_str_annotation(self): + """Union[str, int] field with a registry key string should keep the string.""" + from typing import Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[str, int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_optional_str_annotation(self): + """Optional[str] field with a registry key string should keep the string.""" + from typing import Optional + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, label: Optional[str] = None) -> str: + return f"label={label}" + + model = consumer(label="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "label=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_union_annotated_str(self): + """Union[Annotated[str, ...], int] field with a registry key should keep the string.""" + from typing import Annotated, Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + if __name__ == "__main__": import unittest From 587b26f68eb9af7ad5e042400957b168b6646fee Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 14:57:25 -0400 Subject: [PATCH 15/17] Update docs and small flow_model changes Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 27 ++++++++++++++++++++++++--- ccflow/tests/test_flow_context.py | 22 ++++++++++++++++++++++ docs/wiki/Key-Features.md | 2 +- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index ab0b3f4..da05d8e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -222,9 +222,9 @@ def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapt return value try: return validators[name].validate_python(value) - except Exception: + except Exception as exc: expected = validatable_types.get(name, "unknown") - raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") + raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: @@ -767,6 +767,13 @@ def smart_training( ``with_inputs()`` for deferred execution:: lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + **Which to use:** + + - Use ``Lazy[T]`` in a ``@Flow.model`` signature when you want conditional/ + on-demand evaluation of an expensive upstream dependency. + - Use ``Lazy(model)(...)`` when you need to rewire context fields before + passing them to an existing model (e.g., shifting a date window). """ def __class_getitem__(cls, item): @@ -993,7 +1000,21 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: else: internal_return_type = return_type - # Determine context mode + # ── Context mode selection ── + # The decorator supports three mutually exclusive context modes: + # + # Mode 1 (explicit context): Function has a 'context' (or '_') parameter + # annotated with a ContextBase subclass. Behaves like a traditional + # CallableModel.__call__. Other params become model fields. + # + # Mode 2 (context_args): Decorator specifies context_args=[...] listing + # which params come from the context at runtime. Remaining params become + # model fields. Uses FlowContext unless context_type= overrides it. + # + # Mode 3 (dynamic deferred): No 'context' param and no context_args. + # Every param is a potential model field. Params bound at construction + # are config; unbound params become runtime inputs from FlowContext. + # context_schema_early: Dict[str, Type] = {} context_td_early = None if "context" in params or "_" in params: diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 718f8de..c9a9811 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -86,6 +86,28 @@ def test_flow_context_hash_uses_extra_fields(self): assert len({first, second, third}) == 3 + def test_flow_context_hash_raises_for_unhashable_values(self): + """FlowContext with truly unhashable values (no __dict__) should raise TypeError.""" + + class Unhashable: + __hash__ = None # type: ignore[assignment] + + def __init__(self): + pass + + # Deliberately no __dict__ suppression — but __hash__ is None, + # so the fallback path in _freeze_for_hash should use __dict__. + # To trigger the actual TypeError path, we need an object with + # no __dict__ and no __hash__. + + class UnhashableSlots: + __slots__ = () + __hash__ = None # type: ignore[assignment] + + ctx = FlowContext(val=UnhashableSlots()) + with pytest.raises(TypeError, match="unhashable value"): + hash(ctx) + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index e0aac45..5fb27de 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -103,7 +103,7 @@ context" and you want that split to stay stable and explicit: - you want the generated model to accept a specific existing context type such as `DateRangeContext` -**Mode 3 — Default deferred style (no explicit context):** +**Mode 3 — Dynamic deferred style (no explicit context):** When there is no `context` parameter and no `context_args`, all parameters are potential configuration or runtime inputs. Parameters provided at construction From 517f45ec39cc9eae76a1d79ef2fc702bdd76eb5f Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 15:03:58 -0400 Subject: [PATCH 16/17] Add more examples Signed-off-by: Nijat Khanbabayev --- .../config/flow_model_hydra_builder_demo.yaml | 24 ++ examples/evaluator_demo.py | 186 ++++++++++ examples/flow_model_hydra_builder_demo.py | 160 ++++++++ examples/ml_pipeline_demo.py | 351 ++++++++++++++++++ 4 files changed, 721 insertions(+) create mode 100644 examples/config/flow_model_hydra_builder_demo.yaml create mode 100644 examples/evaluator_demo.py create mode 100644 examples/flow_model_hydra_builder_demo.py create mode 100644 examples/ml_pipeline_demo.py diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/examples/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..5579a5c --- /dev/null +++ b/examples/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,24 @@ +# Hydra config for examples/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +current_revenue: + _target_: examples.flow_model_hydra_builder_demo.load_revenue + region: us + +week_over_week: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: week_over_week + +month_over_month: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: month_over_month diff --git a/examples/evaluator_demo.py b/examples/evaluator_demo.py new file mode 100644 index 0000000..a85b087 --- /dev/null +++ b/examples/evaluator_demo.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +""" +Evaluator Demo: Caching & Execution Strategies +=============================================== + +Shows how to change execution behavior (caching, graph evaluation, logging) +WITHOUT changing user code. The same @Flow.model functions work with any +evaluator stack — you just configure it at the top level. + +Key insight: "default lazy" is an evaluator concern, not a wiring concern. +Users write plain functions and wire them by passing outputs as inputs. +The evaluator layer controls how they execute. + +Demonstrates: + 1. Default execution (eager, no caching) — diamond dep calls load twice + 2. MemoryCacheEvaluator — deduplicates shared deps in a diamond + 3. GraphEvaluator + Cache — topological evaluation + deduplication + 4. LoggingEvaluator — adds tracing around every model call + 5. Per-model opt-out — @Flow.model(cacheable=False) overrides global + +Run with: python examples/evaluator_demo.py +""" + +from __future__ import annotations + +import logging +import sys + +# Suppress default debug logging from ccflow evaluators for clean demo output +logging.disable(logging.DEBUG) + +from ccflow import Flow, FlowOptionsOverride # noqa: E402 +from ccflow.evaluators.common import ( # noqa: E402 + GraphEvaluator, + LoggingEvaluator, + MemoryCacheEvaluator, + MultiEvaluator, +) + +# ============================================================================= +# Plain @Flow.model functions — no evaluator concerns in the code +# ============================================================================= + +call_counts: dict[str, int] = {} + + +def _track(name: str) -> None: + call_counts[name] = call_counts.get(name, 0) + 1 + + +@Flow.model +def load_data(x: int, source: str = "warehouse") -> list: + """Load raw data. Expensive — we want to avoid calling this twice.""" + _track("load_data") + return [x, x * 2, x * 3] + + +@Flow.model +def compute_sum(data: list) -> int: + """Branch A: sum the data.""" + _track("compute_sum") + return sum(data) + + +@Flow.model +def compute_max(data: list) -> int: + """Branch B: max of the data.""" + _track("compute_max") + return max(data) + + +@Flow.model +def combine(sum_result: int, max_result: int) -> dict: + """Combine results from both branches.""" + _track("combine") + return {"sum": sum_result, "max": max_result, "total": sum_result + max_result} + + +@Flow.model(cacheable=False) +def volatile_timestamp(seed: int) -> str: + """Explicitly non-cacheable — always re-executes even with global caching.""" + _track("volatile_timestamp") + from datetime import datetime + + return datetime.now().isoformat() + + +# ============================================================================= +# Wire the pipeline — diamond dependency on load_data +# +# load_data ──┬── compute_sum ──┐ +# └── compute_max ──┴── combine +# ============================================================================= + +shared = load_data(source="prod") +branch_a = compute_sum(data=shared) +branch_b = compute_max(data=shared) +pipeline = combine(sum_result=branch_a, max_result=branch_b) + + +def run() -> dict: + call_counts.clear() + result = pipeline.flow.compute(x=5) + loads = call_counts.get("load_data", 0) + print(f" Result: {result.value}") + print(f" load_data called: {loads}x | total model calls: {sum(call_counts.values())}") + return result.value + + +# ============================================================================= +# Demo 1: Default — no evaluator +# ============================================================================= + +print("=" * 70) +print("1. Default (eager, no caching)") +print(" load_data is called TWICE — once per branch") +print("=" * 70) +run() + +# ============================================================================= +# Demo 2: MemoryCacheEvaluator — deduplicates shared deps +# ============================================================================= + +print() +print("=" * 70) +print("2. MemoryCacheEvaluator (global override)") +print(" load_data is called ONCE — second branch hits cache") +print("=" * 70) +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + run() + +# ============================================================================= +# Demo 3: Cache + GraphEvaluator — topological order + deduplication +# ============================================================================= + +print() +print("=" * 70) +print("3. GraphEvaluator + MemoryCacheEvaluator") +print(" Evaluates in dependency order: load_data → branches → combine") +print("=" * 70) +evaluator = MultiEvaluator(evaluators=[MemoryCacheEvaluator(), GraphEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# ============================================================================= +# Demo 4: Logging — trace every model call +# ============================================================================= + +print() +print("=" * 70) +print("4. LoggingEvaluator + MemoryCacheEvaluator") +print(" Adds timing/tracing around every evaluation") +print("=" * 70) + +# Re-enable logging for this demo (use stdout so log lines interleave with print correctly) +logging.disable(logging.NOTSET) +logging.basicConfig(level=logging.INFO, format=" LOG: %(message)s", stream=sys.stdout) + +evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(log_level=logging.INFO), MemoryCacheEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# Suppress again for clean output +logging.disable(logging.DEBUG) +logging.getLogger().handlers.clear() + +# ============================================================================= +# Demo 5: Per-model opt-out — cacheable=False overrides global +# ============================================================================= + +print() +print("=" * 70) +print("5. Per-model opt-out: @Flow.model(cacheable=False)") +print(" volatile_timestamp always re-executes despite global cacheable=True") +print("=" * 70) + +ts = volatile_timestamp(seed=0) + +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + call_counts.clear() + r1 = ts.flow.compute(seed=0) + r2 = ts.flow.compute(seed=0) + print(f" Call 1: {r1.value}") + print(f" Call 2: {r2.value}") + print(f" volatile_timestamp called: {call_counts.get('volatile_timestamp', 0)}x") + print(f" (Same result? {r1.value == r2.value} — called twice, timestamps may differ)") diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..00c9571 --- /dev/null +++ b/examples/flow_model_hydra_builder_demo.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python examples/flow_model_hydra_builder_demo.py +""" + +from calendar import monthrange +from datetime import date, timedelta +from pathlib import Path +from typing import Literal, Protocol, cast + +from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FlowAPI, GenericResult, ModelRegistry +from typing_extensions import TypedDict + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" +ComparisonName = Literal["week_over_week", "month_over_month"] + + +class RevenueChangeResult(TypedDict): + comparison: ComparisonName + current_window: str + previous_window: str + current: float + previous: float + delta: float + growth_pct: float + + +class RevenueChangeModel(Protocol): + flow: FlowAPI + + def __call__(self, context: DateRangeContext) -> GenericResult[RevenueChangeResult]: ... + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Return synthetic revenue for a date window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0, "apac": 920.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + comparison: ComparisonName, +) -> RevenueChangeResult: + """Compare the current window against a shifted previous window.""" + growth = (current - previous) / previous + previous_start, previous_end = comparison_window(start_date, end_date, comparison) + return { + "comparison": comparison, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "delta": round(current - previous, 2), + "growth_pct": round(growth * 100, 2), + } + + +def comparison_window(start_date: date, end_date: date, comparison: ComparisonName) -> tuple[date, date]: + """Return the previous window for a named comparison policy.""" + if comparison == "week_over_week": + return start_date - timedelta(days=7), end_date - timedelta(days=7) + + if start_date.day != 1: + raise ValueError("month_over_month requires start_date to be the first day of a month") + if start_date.year != end_date.year or start_date.month != end_date.month: + raise ValueError("month_over_month requires the current window to stay within one calendar month") + expected_end = date(end_date.year, end_date.month, monthrange(end_date.year, end_date.month)[1]) + if end_date != expected_end: + raise ValueError("month_over_month requires end_date to be the last day of that month") + + previous_year = start_date.year if start_date.month > 1 else start_date.year - 1 + previous_month = start_date.month - 1 if start_date.month > 1 else 12 + previous_start = date(previous_year, previous_month, 1) + previous_end = date(previous_year, previous_month, monthrange(previous_year, previous_month)[1]) + return previous_start, previous_end + + +def comparison_input(model: CallableModel, comparison: ComparisonName) -> BoundModel: + """Apply a named comparison policy to one dependency.""" + return model.flow.with_inputs( + start_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[0], + end_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[1], + ) + + +def build_comparison(current: CallableModel, *, comparison: ComparisonName) -> RevenueChangeModel: + """Hydra-friendly builder that returns a configured comparison model.""" + previous = comparison_input(current, comparison) + return revenue_change( + current=current, + previous=previous, + comparison=comparison, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + week_over_week = cast(RevenueChangeModel, registry["week_over_week"]) + month_over_month = cast(RevenueChangeModel, registry["month_over_month"]) + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + print(" current_revenue:", registry["current_revenue"]) + print(" week_over_week:", week_over_week) + print(" month_over_month:", month_over_month) + + week_over_week_result = cast( + RevenueChangeResult, + week_over_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value, + ) + month_over_month_result = month_over_month(ctx).value + + print("\nWeek-over-week:") + for key, value in week_over_week_result.items(): + print(f" {key}: {value}") + + print("\nMonth-over-month:") + for key, value in month_over_month_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main() diff --git a/examples/ml_pipeline_demo.py b/examples/ml_pipeline_demo.py new file mode 100644 index 0000000..0c8920f --- /dev/null +++ b/examples/ml_pipeline_demo.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +""" +ML Pipeline Demo: Smart Model Selection +======================================== + +This is the example from the original design conversation — a realistic ML +pipeline that demonstrates how Flow.model lets you write plain functions, +wire them by passing outputs as inputs, and execute with .flow.compute(). + +Features demonstrated: + 1. @Flow.model with auto-wrap (plain return types, no GenericResult needed) + 2. Lazy[T] for conditional evaluation (skip slow model if fast is good enough) + 3. .flow.compute() for execution with automatic context propagation + 4. .flow.with_inputs() for context transforms (lookback windows) + 5. Factored wiring — build_pipeline() shows how to reuse the same graph + structure with different data sources + +The pipeline: + + load_dataset ──> prepare_features ──> train_linear ──> evaluate ──> fast_metrics ──┐ + └──> train_forest ──> evaluate ──> slow_metrics ──┴──> smart_training + +Run with: python examples/ml_pipeline_demo.py +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, timedelta +from math import sin + +from ccflow import Flow, Lazy + + +# ============================================================================= +# Domain types (stand-ins for real ML objects) +# ============================================================================= + + +@dataclass +class PreparedData: + """Container for train/test split data.""" + + X_train: list # list of feature vectors + X_test: list + y_train: list # list of target values + y_test: list + + +@dataclass +class TrainedModel: + """A fitted model (placeholder).""" + + name: str + coefficients: list + intercept: float + augment: bool # Whether to add sin feature during prediction + + +@dataclass +class Metrics: + """Evaluation metrics.""" + + r2: float + mse: float + model_name: str + + +# ============================================================================= +# Data Loading +# ============================================================================= + + +@Flow.model +def load_dataset(start_date: date, end_date: date, source: str = "warehouse") -> list: + """Load raw dataset for a date range. + + Returns a list of dicts (standing in for a DataFrame). + Auto-wrapped: returns plain list, framework wraps in GenericResult. + """ + n_days = (end_date - start_date).days + 1 + print(f" [load_dataset] Loading {n_days} days from '{source}' ({start_date} to {end_date})") + # True relationship: target = 2.0 * x + 10.0 + 15.0 * sin(x * 0.2) + # Linear model captures the trend (R^2 ~0.93), forest also captures the sin wave (~0.99) + return [ + { + "date": str(start_date + timedelta(days=i)), + "x": float(i), + "target": 2.0 * i + 10.0 + 15.0 * sin(i * 0.2), + } + for i in range(n_days) + ] + + +# ============================================================================= +# Feature Engineering +# ============================================================================= + + +@Flow.model +def prepare_features(raw_data: list) -> PreparedData: + """Split data into train/test. + + Returns a PreparedData dataclass — the framework auto-wraps it in GenericResult. + Downstream models can request individual fields via prepared["X_train"] etc. + """ + n = len(raw_data) + split = int(n * 0.8) + print(f" [prepare_features] {n} rows, split at {split}") + + X = [[r["x"]] for r in raw_data] + y = [r["target"] for r in raw_data] + + return PreparedData( + X_train=X[:split], + X_test=X[split:], + y_train=y[:split], + y_test=y[split:], + ) + + +# ============================================================================= +# Model Training +# ============================================================================= + + +def _ols_fit(X, y): + """Simple OLS: compute coefficients and intercept.""" + n = len(X) + n_feat = len(X[0]) + y_mean = sum(y) / n + x_means = [sum(row[j] for row in X) / n for j in range(n_feat)] + + coefficients = [] + for j in range(n_feat): + cov = sum((X[i][j] - x_means[j]) * (y[i] - y_mean) for i in range(n)) / n + var = sum((X[i][j] - x_means[j]) ** 2 for i in range(n)) / n + coefficients.append(cov / var if var > 1e-10 else 0.0) + + intercept = y_mean - sum(c * m for c, m in zip(coefficients, x_means)) + return coefficients, intercept + + +def _augment(X): + """Add sin(x*0.2) feature to capture non-linearity.""" + return [row + [sin(row[0] * 0.2)] for row in X] + + +@Flow.model +def train_linear(prepared: PreparedData) -> TrainedModel: + """Train a fast linear model (linear features only).""" + print(f" [train_linear] Fitting on {len(prepared.X_train)} samples") + coefficients, intercept = _ols_fit(prepared.X_train, prepared.y_train) + return TrainedModel(name="LinearRegression", coefficients=coefficients, intercept=intercept, augment=False) + + +@Flow.model +def train_forest(prepared: PreparedData, n_estimators: int = 100) -> TrainedModel: + """Train a model that also captures non-linear patterns (simulated).""" + print(f" [train_forest] Fitting {n_estimators} trees on {len(prepared.X_train)} samples") + # Augment with sin feature to capture non-linearity + X_aug = _augment(prepared.X_train) + coefficients, intercept = _ols_fit(X_aug, prepared.y_train) + return TrainedModel( + name=f"RandomForest(n={n_estimators})", + coefficients=coefficients, + intercept=intercept, + augment=True, + ) + + +# ============================================================================= +# Model Evaluation +# ============================================================================= + + +@Flow.model +def evaluate_model(model: TrainedModel, prepared: PreparedData) -> Metrics: + """Evaluate a trained model on test data.""" + X_test = prepared.X_test + y_test = prepared.y_test + X_eval = _augment(X_test) if model.augment else X_test + + y_pred = [ + model.intercept + sum(c * x for c, x in zip(model.coefficients, row)) + for row in X_eval + ] + + y_mean = sum(y_test) / len(y_test) if y_test else 0 + ss_tot = sum((y - y_mean) ** 2 for y in y_test) or 1 + ss_res = sum((yt - yp) ** 2 for yt, yp in zip(y_test, y_pred)) + r2 = 1.0 - ss_res / ss_tot + mse = ss_res / len(y_test) if y_test else 0 + + print(f" [evaluate_model] {model.name}: R^2={r2:.4f}, MSE={mse:.2f}") + return Metrics(r2=r2, mse=mse, model_name=model.name) + + +# ============================================================================= +# Smart Pipeline with Conditional Execution +# ============================================================================= + + +@Flow.model +def smart_training( + # data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # Only evaluated if fast isn't good enough + threshold: float = 0.9, +) -> Metrics: + """Use fast model if good enough, else fall back to slow. + + The slow_metrics parameter is Lazy — it receives a zero-arg thunk. + If the fast model exceeds the threshold, the slow model is never + trained or evaluated at all. + """ + print(f" [smart_training] Fast R^2={fast_metrics.r2:.4f}, threshold={threshold}") + if fast_metrics.r2 >= threshold: + print(" [smart_training] Fast model is good enough! Skipping slow model.") + return fast_metrics + else: + print(" [smart_training] Fast model below threshold, evaluating slow model...") + return slow_metrics() + + +# ============================================================================= +# Pipeline Wiring Helper +# ============================================================================= + + +def build_pipeline(raw, *, n_estimators=200, threshold=0.95): + """Wire a complete train/evaluate/select pipeline from a data source. + + This function shows the flexibility of the approach: the same wiring + logic can be applied to different data sources (raw, lookback_raw, etc.) + without duplicating code. Everything here is just wiring — no computation + happens until .flow.compute() is called. + + Args: + raw: A CallableModel or BoundModel that produces raw data (list of dicts) + n_estimators: Number of trees for the forest model + threshold: R^2 threshold for the fast/slow model selection + + Returns: + A smart_training model instance ready for .flow.compute() + """ + # Feature engineering — returns a PreparedData with X_train, X_test, etc. + prepared = prepare_features(raw_data=raw) + + # Train both models — each receives the whole PreparedData and extracts + # the fields it needs internally. + linear = train_linear(prepared=prepared) + forest = train_forest(prepared=prepared, n_estimators=n_estimators) + + # Evaluate both + linear_metrics = evaluate_model(model=linear, prepared=prepared) + forest_metrics = evaluate_model(model=forest, prepared=prepared) + + # Smart selection with Lazy — forest is only evaluated if linear isn't good enough + return smart_training( + fast_metrics=linear_metrics, + slow_metrics=forest_metrics, + threshold=threshold, + ) + + +# ============================================================================= +# Main: Wire and execute the pipeline +# ============================================================================= + + +def main(): + print("=" * 70) + print("ML Pipeline Demo: Smart Model Selection with Flow.model") + print("=" * 70) + + # ------------------------------------------------------------------ + # Step 1: Wire the pipeline (no computation happens here) + # ------------------------------------------------------------------ + print("\n--- Wiring the pipeline (lazy, no computation yet) ---\n") + + raw = load_dataset(source="prod_warehouse") + + # build_pipeline factors out the repeated wiring logic. + # Linear R^2 ≈ 0.93. Threshold is 0.95 → falls through to forest. + pipeline = build_pipeline(raw, n_estimators=200, threshold=0.95) + + print("Pipeline wired. No functions have been called yet.") + + # ------------------------------------------------------------------ + # Step 2: Execute — linear not good enough, falls back to forest + # ------------------------------------------------------------------ + print("\n--- Executing pipeline (Jan-Jun 2024) ---\n") + result = pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + + print(f"\n Best model: {result.value.model_name}") + print(f" R^2: {result.value.r2:.4f}") + print(f" MSE: {result.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 3: Context transforms (lookback) — reuse build_pipeline + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("With Lookback: Same pipeline structure, extra history for loading") + print("=" * 70) + + # flow.with_inputs() creates a BoundModel that transforms the context + # before calling the underlying model. start_date is shifted 30 days earlier. + lookback_raw = raw.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30) + ) + + # Same wiring logic, different data source — no duplication. + lookback_pipeline = build_pipeline(lookback_raw, n_estimators=200, threshold=0.95) + + print("\n--- Executing lookback pipeline ---\n") + result2 = lookback_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + # Notice: load_dataset gets start_date=2023-12-02 (30 days earlier) + + print(f"\n Best model: {result2.value.model_name}") + print(f" R^2: {result2.value.r2:.4f}") + print(f" MSE: {result2.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 4: Lower threshold — linear is good enough, skip forest + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("Lazy Evaluation: Lower threshold so fast model is good enough") + print("=" * 70) + + # With threshold=0.80, the linear model's R^2 (~0.93) passes. + # The forest is NEVER trained or evaluated — Lazy skips it entirely. + fast_pipeline = build_pipeline(raw, n_estimators=200, threshold=0.80) + + print("\n--- Executing (slow model should NOT be trained) ---\n") + result3 = fast_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + print(f"\n Selected: {result3.value.model_name} (R^2={result3.value.r2:.4f})") + print(" (Notice: train_forest and its evaluate_model were never called)") + + +if __name__ == "__main__": + main() From 324fc4ec93a3058a8289a21b612b85c590eacc4c Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 20 Mar 2026 13:53:06 -0400 Subject: [PATCH 17/17] More test coverage Signed-off-by: Nijat Khanbabayev --- ccflow/tests/test_callable.py | 97 ++++ ccflow/tests/test_flow_context.py | 40 ++ ccflow/tests/test_flow_model.py | 799 ++++++++++++++++++++++++++++++ 3 files changed, 936 insertions(+) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 6d8f53e..9b51592 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -1057,3 +1057,100 @@ def __call__(self, *, value: int): return GenericResult(value=value) self.assertIn("must have a return type annotation", str(cm.exception)) + + def test_auto_context_rejects_missing_annotation(self): + """auto_context should reject params without type annotations.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + self.assertIn("must have a type annotation", str(cm.exception)) + + +class TestDeclaredTypeMatches(TestCase): + """Tests for _declared_type_matches helper in callable.py.""" + + def test_typevar_always_matches(self): + from ccflow.callable import _declared_type_matches + + T = TypeVar("T") + self.assertTrue(_declared_type_matches(int, T)) + + def test_union_expected_no_type_args(self): + """Union with no concrete type args should return False.""" + from ccflow.callable import _declared_type_matches + + # Union[None] after filtering out NoneType has no concrete args + self.assertFalse(_declared_type_matches(int, Union[None])) + + def test_union_expected_with_actual_type(self): + """Concrete type matching Union expected.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(int, Union[int, str])) + self.assertFalse(_declared_type_matches(float, Union[int, str])) + + def test_union_both_sides(self): + """Both actual and expected are Unions.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(Union[int, str], Union[int, str])) + self.assertTrue(_declared_type_matches(Union[str, int], Union[int, str])) # order independent + self.assertFalse(_declared_type_matches(Union[int, float], Union[int, str])) + + def test_non_type_actual(self): + """Non-type actual should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches("not_a_type", int)) + + def test_non_type_expected(self): + """Non-type expected should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches(int, "not_a_type")) + + +class TestCallableModelGenericValidation(TestCase): + """Tests for CallableModelGeneric type validation paths.""" + + def test_context_type_mismatch_raises(self): + """Generic type validation should reject context type mismatch.""" + + class ContextA(ContextBase): + a: int + + class ContextB(ContextBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: ContextA) -> GenericResult[int]: + return GenericResult(value=context.a) + + with self.assertRaises(ValidationError): + # Expect ContextB but model has ContextA + CallableModelGenericType[ContextB, GenericResult[int]].model_validate(ModelA()) + + def test_result_type_mismatch_raises(self): + """Generic type validation should reject result type mismatch.""" + + class MyContext(ContextBase): + x: int + + class ResultA(ResultBase): + a: int + + class ResultB(ResultBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> ResultA: + return ResultA(a=context.x) + + with self.assertRaises(ValidationError): + CallableModelGenericType[MyContext, ResultB].model_validate(ModelA()) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index c9a9811..970cc08 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -108,6 +108,46 @@ class UnhashableSlots: with pytest.raises(TypeError, match="unhashable value"): hash(ctx) + def test_flow_context_eq_non_flow_context(self): + """FlowContext.__eq__ returns False for non-FlowContext objects.""" + ctx = FlowContext(x=1) + assert ctx != 42 + assert ctx != "hello" + assert ctx != None # noqa: E711 + assert ctx != NumberContext(x=1) + + def test_flow_context_hash_with_set_value(self): + """FlowContext with set values should hash correctly via frozenset.""" + ctx = FlowContext(tags=frozenset({"a", "b"})) + # Should not raise + h = hash(ctx) + assert isinstance(h, int) + + def test_flow_context_hash_with_model_dump_object(self): + """_freeze_for_hash should handle objects with model_dump attribute.""" + from ccflow.context import _freeze_for_hash + + # Directly test _freeze_for_hash with an object that has model_dump + # (FlowContext.__hash__ goes through model_dump first which serializes + # nested models, so we test the helper directly) + inner = NumberContext(x=42) + result = _freeze_for_hash(inner) + assert isinstance(result, tuple) + assert result[0] is NumberContext + + def test_flow_context_hash_unhashable_with_dict_fallback(self): + """Objects with __dict__ but no __hash__ should use __dict__ fallback.""" + + class UnhashableWithDict: + __hash__ = None # type: ignore[assignment] + + def __init__(self, val): + self.val = val + + ctx = FlowContext(obj=UnhashableWithDict(42)) + h = hash(ctx) + assert isinstance(h, int) + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index ad30824..57a54df 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -2181,6 +2181,805 @@ def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = registry.clear() +# ============================================================================= +# Coverage Gap Tests +# ============================================================================= + + +class TestExtractLazyLoopBody(TestCase): + """Group 1: _extract_lazy loop body with non-LazyMarker metadata.""" + + def test_annotated_with_extra_metadata_before_lazy_marker(self): + """Annotated type where _LazyMarker is NOT the first metadata element.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy, _LazyMarker + + # _LazyMarker is the second metadata element — loop must iterate past "other" + ann = Annotated[int, "other_metadata", _LazyMarker()] + base_type, is_lazy = _extract_lazy(ann) + self.assertTrue(is_lazy) + self.assertIs(base_type, int) + + def test_annotated_without_lazy_marker(self): + """Annotated type with no _LazyMarker returns is_lazy=False.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy + + ann = Annotated[int, "just_metadata"] + base_type, is_lazy = _extract_lazy(ann) + self.assertFalse(is_lazy) + + def test_lazy_type_annotation_with_extra_annotated(self): + """End-to-end: Lazy wrapping of an Annotated type.""" + + @Flow.model + def model_with_lazy( + x: int, + dep: Lazy[int], + ) -> int: + return x + dep() + + @Flow.model + def upstream(x: int) -> int: + return x * 10 + + model = model_with_lazy(x=1, dep=upstream()) + result = model.flow.compute(x=1) + self.assertEqual(result.value, 11) + + def test_lazy_dep_returning_custom_result(self): + """Lazy dep returning custom ResultBase (not GenericResult) should return raw result.""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"v={context.value}") + + @Flow.model + def consumer(context: SimpleContext, dep: Lazy[MyResult]) -> GenericResult[str]: + result = dep() + return GenericResult(value=result.data) + + model = consumer(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, "v=42") + + +class TestTransformReprNamedCallable(TestCase): + """Group 2: _transform_repr with a named callable.""" + + def test_named_function_transform_in_repr(self): + """Named functions should appear in BoundModel repr wrapped in angle brackets.""" + from ccflow.flow_model import _transform_repr + + def my_custom_transform(ctx): + return ctx.value + 1 + + result = _transform_repr(my_custom_transform) + self.assertIn("my_custom_transform", result) + self.assertTrue(result.startswith("<")) + self.assertTrue(result.endswith(">")) + + def test_static_value_repr(self): + """Static (non-callable) values should use repr().""" + from ccflow.flow_model import _transform_repr + + self.assertEqual(_transform_repr(42), "42") + self.assertEqual(_transform_repr("hello"), "'hello'") + + +class TestBoundFieldNamesFallback(TestCase): + """Group 3: _bound_field_names fallback for objects without model_fields_set.""" + + def test_fallback_to_bound_fields_attr(self): + from ccflow.flow_model import _bound_field_names + + class FakeModel: + _bound_fields = {"x", "y"} + + result = _bound_field_names(FakeModel()) + self.assertEqual(result, {"x", "y"}) + + def test_fallback_no_attrs(self): + from ccflow.flow_model import _bound_field_names + + class Empty: + pass + + result = _bound_field_names(Empty()) + self.assertEqual(result, set()) + + +class TestRuntimeInputNamesEmpty(TestCase): + """Group 4: _runtime_input_names when all_param_names is empty.""" + + def test_non_flow_model_returns_empty(self): + from ccflow.flow_model import _runtime_input_names + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + self.assertEqual(_runtime_input_names(model), set()) + + +class TestRegistryCandidateAllowed(TestCase): + """Group 5: _registry_candidate_allowed TypeAdapter success path.""" + + def test_non_callable_model_passes_type_check(self): + """Registry value that is not a CallableModel but passes TypeAdapter validation.""" + from ccflow.flow_model import _registry_candidate_allowed + + # int value passes TypeAdapter(int).validate_python + self.assertTrue(_registry_candidate_allowed(int, 42)) + + def test_non_callable_model_fails_type_check(self): + from ccflow.flow_model import _registry_candidate_allowed + + self.assertFalse(_registry_candidate_allowed(int, "not_an_int")) + + +class TestConcreteContextTypeOptional(TestCase): + """Group 6: _concrete_context_type with Optional/Union types.""" + + def test_optional_context_type(self): + """Optional[T] has NoneType that should be skipped to find T.""" + from typing import Optional + + from ccflow.flow_model import _concrete_context_type + + # Optional[SimpleContext] = Union[SimpleContext, None] + # The NoneType arg must be skipped (line 196-197) + result = _concrete_context_type(Optional[SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_with_none_first(self): + """Union[None, T] should skip NoneType and find T.""" + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + # NoneType comes first, must be skipped + result = _concrete_context_type(Union[None, SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_context_type(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[SimpleContext, None]) + self.assertIs(result, SimpleContext) + + def test_union_no_context_base(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[int, str]) + self.assertIsNone(result) + + def test_returns_none_for_non_type(self): + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type("not_a_type") + self.assertIsNone(result) + + +class TestBuildConfigValidatorsException(TestCase): + """Group 7: _build_config_validators when TypeAdapter fails.""" + + def test_unadaptable_type_skipped(self): + """Types that TypeAdapter can't handle should be silently skipped.""" + from ccflow.flow_model import _build_config_validators + + # type(...) (EllipsisType) makes TypeAdapter fail + validatable, validators = _build_config_validators({"x": int, "y": type(...)}) + self.assertIn("x", validatable) + self.assertNotIn("y", validatable) + self.assertIn("x", validators) + self.assertNotIn("y", validators) + + +class TestCoerceContextValueNoValidator(TestCase): + """Group 8: _coerce_context_value early return for fields without validators.""" + + def test_field_without_validator_passes_through(self): + from ccflow.flow_model import _coerce_context_value + + # When name is not in validators, value should pass through unchanged + result = _coerce_context_value("unknown_field", 42, {}, {}) + self.assertEqual(result, 42) + + +class TestGeneratedModelClassFactoryPath(TestCase): + """Group 9: _generated_model_class when stage has no generated model.""" + + def test_returns_none_for_plain_callable(self): + from ccflow.flow_model import _generated_model_class + + def plain_func(): + pass + + self.assertIsNone(_generated_model_class(plain_func)) + + +class TestDescribePipeStagePaths(TestCase): + """Group 10: _describe_pipe_stage for different stage types.""" + + def test_generated_model_instance(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def my_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(my_stage()) + self.assertIn("my_stage", desc) + + def test_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def factory_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(factory_stage) + self.assertIn("factory_stage", desc) + + def test_non_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + desc = _describe_pipe_stage(42) + self.assertEqual(desc, "42") + + +class TestInferPipeParamAmbiguousDefaults(TestCase): + """Cover _infer_pipe_param fallback path with multiple defaulted candidates.""" + + def test_ambiguous_defaulted_candidates(self): + """When all candidates have defaults but multiple are unoccupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(a: int = 1, b: int = 2) -> int: + return a + b + + # Both a and b have defaults, both are unoccupied -> ambiguous + with self.assertRaisesRegex(TypeError, "could not infer a target parameter"): + source().pipe(consumer) + + +class TestPipeErrorPaths(TestCase): + """Group 11: pipe() error paths not covered by existing tests.""" + + def test_pipe_non_callable_model_source(self): + """pipe() should reject non-CallableModel source.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) source must be a CallableModel"): + pipe_model("not_a_model", consumer) + + def test_pipe_non_flow_model_target(self): + """pipe() should reject non-@Flow.model target.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualTarget(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream stages"): + pipe_model(source(), ManualTarget()) + + def test_pipe_invalid_param_name(self): + """pipe() should reject invalid target parameter names.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "is not valid for"): + source().pipe(consumer, param="nonexistent") + + def test_pipe_already_bound_param(self): + """pipe() should reject already-bound parameters.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "is already bound"): + source().pipe(model, param="data") + + def test_pipe_no_available_target_parameter(self): + """pipe() should error when all downstream params are occupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "could not find an available target parameter"): + source().pipe(model) + + def test_pipe_into_generated_instance_rebuilds(self): + """pipe() into an existing generated model instance should rebuild.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, extra: int = 1) -> int: + return data + extra + + instance = consumer(extra=5) + pipeline = source().pipe(instance) + result = pipeline.flow.compute(x=3) + self.assertEqual(result.value, 35) # 3*10 + 5 + + def test_pipe_bound_model_wrapping_non_generated_rejects(self): + """pipe() into BoundModel wrapping a non-generated model should fail.""" + from ccflow.flow_model import BoundModel, pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualModel(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=ManualModel(), input_transforms={"value": 42}) + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream"): + pipe_model(source(), bound) + + +class TestFlowAPIBuildContextFallback(TestCase): + """Group 12: FlowAPI._build_context when _context_schema is None/unset.""" + + def test_unbound_inputs_on_manual_callable_model(self): + """Manual CallableModel with context should show required fields.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + unbound = model.flow.unbound_inputs + self.assertIn("value", unbound) + + +class TestBoundModelRestoreNonDict(TestCase): + """Group 13: BoundModel._restore_serialized_transforms non-dict path.""" + + def test_restore_from_model_instance(self): + """model_validate from an existing BoundModel instance (non-dict).""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + # Pass existing instance through model_validate (non-dict path) + restored = BoundModel.model_validate(bound) + ctx = SimpleContext(value=1) + self.assertEqual(restored(ctx).value, 420) + + +class TestBoundModelInitEmptyTransforms(TestCase): + """Group 14: BoundModel.__init__ with no transforms.""" + + def test_init_without_transforms(self): + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=source()) + self.assertEqual(bound._input_transforms, {}) + result = bound(SimpleContext(value=5)) + self.assertEqual(result.value, 5) + + +class TestBoundModelDeps(TestCase): + """Group 15: BoundModel.__deps__.""" + + def test_deps_returns_wrapped_model(self): + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + deps = bound.__deps__(SimpleContext(value=1)) + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], bound.model) + + +class TestValidateFieldTypesAfterValidator(TestCase): + """Group 16: _validate_field_types in the model_validate path.""" + + def test_model_validate_rejects_wrong_type(self): + """model_validate should reject wrong scalar types.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + cls = type(source(x=5)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "not_an_int"}) + + +class TestGetContextValidatorPaths(TestCase): + """Group 17: _get_context_validator fallback paths.""" + + def test_mode2_context_validator_from_schema(self): + """Mode 2 model should build validator from _context_schema.""" + + @Flow.model(context_args=["start_date"]) + def loader(start_date: str, source: str = "db") -> str: + return f"{source}:{start_date}" + + model = loader() + # Trigger validator creation by calling flow.compute + result = model.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "db:2024-01-01") + + def test_mode1_context_validator_uses_context_type_directly(self): + """Mode 1 should use TypeAdapter(context_type) directly.""" + + @Flow.model + def model_fn(context: SimpleContext, offset: int = 0) -> GenericResult[int]: + return GenericResult(value=context.value + offset) + + model = model_fn() + # compute with SimpleContext fields + result = model.flow.compute(value=5) + self.assertEqual(result.value, 5) + + +class TestValidateContextTypeOverrideErrors(TestCase): + """Group 18: _validate_context_type_override error paths.""" + + def test_non_context_base_raises(self): + with self.assertRaisesRegex(TypeError, "context_type must be a ContextBase subclass"): + + @Flow.model(context_args=["x"], context_type=int) + def bad_model(x: int) -> int: + return x + + def test_context_type_missing_context_args_fields(self): + """context_type missing required context_args fields.""" + + class TinyContext(ContextBase): + a: int + + with self.assertRaisesRegex(TypeError, "must define fields for context_args"): + + @Flow.model(context_args=["a", "b"], context_type=TinyContext) + def bad_model(a: int, b: int) -> int: + return a + b + + def test_context_type_extra_required_fields(self): + """context_type has required fields not listed in context_args.""" + + class BigContext(ContextBase): + a: int + b: int + extra: str + + with self.assertRaisesRegex(TypeError, "has required fields not listed in context_args"): + + @Flow.model(context_args=["a"], context_type=BigContext) + def bad_model(a: int) -> int: + return a + + def test_annotation_type_mismatch(self): + """Function and context_type disagree on annotation type.""" + + class TypedContext(ContextBase): + x: str + + with self.assertRaisesRegex(TypeError, "context_arg 'x'"): + + @Flow.model(context_args=["x"], context_type=TypedContext) + def bad_model(x: int) -> int: + return x + + def test_annotation_skip_when_func_ann_is_none(self): + """Annotation check should skip when function annotation is absent from schema.""" + from ccflow.flow_model import _validate_context_type_override + + class CompatContext(ContextBase): + a: int + + # context_args has 'a', schema has 'a': int. Compatible, no error. + result = _validate_context_type_override(CompatContext, ["a"], {"a": int}) + self.assertIs(result, CompatContext) + + def test_subclass_annotations_allowed(self): + """context_type with subclass-compatible annotations should pass.""" + from ccflow.flow_model import _validate_context_type_override + + class ContextWithBase(ContextBase): + ctx: ContextBase + + # Function declares SimpleContext which is a subclass of ContextBase — should pass + result = _validate_context_type_override(ContextWithBase, ["ctx"], {"ctx": SimpleContext}) + self.assertIs(result, ContextWithBase) + + def test_default_vs_required_field_conflict(self): + """Function has default for context_arg but context_type requires it.""" + + class StrictContext(ContextBase): + x: int + + with self.assertRaisesRegex(TypeError, "function has a default but context_type"): + + @Flow.model(context_args=["x"], context_type=StrictContext) + def bad_model(x: int = 5) -> int: + return x + + +class TestDecoratorErrorPaths(TestCase): + """Group 19: Decorator error paths.""" + + def test_context_type_with_explicit_context_param(self): + """context_type= with explicit context param should raise.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + def test_context_type_without_context_args(self): + """context_type= without context_args should raise in dynamic mode.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(x: int) -> int: + return x + + def test_missing_context_annotation(self): + """Missing type annotation on context param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context) -> int: + return 0 + + def test_missing_param_annotation(self): + """Missing type annotation on a model field param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context: SimpleContext, untyped_param) -> int: + return 0 + + def test_context_param_not_context_base(self): + """context param annotated with non-ContextBase type should raise.""" + with self.assertRaisesRegex(TypeError, "must be annotated with a ContextBase subclass"): + + @Flow.model + def bad_model(context: int) -> int: + return 0 + + def test_pep563_fallback_on_failed_get_type_hints(self): + """When get_type_hints fails, falls back to raw annotations.""" + + # This is hard to trigger directly, but we can test that string annotations work + @Flow.model + def model_with_string_return(x: int) -> "int": + return x * 2 + + result = model_with_string_return().flow.compute(x=5) + self.assertEqual(result.value, 10) + + +class TestMode1CallPath(TestCase): + """Group 20: Mode 1 explicit context pass-through in __call__.""" + + def test_mode1_resolve_callable_model_returns_non_generic_result(self): + """Mode 1 should handle deps that return raw ResultBase (not GenericResult).""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[str]: + # dep is resolved to MyResult since it's not GenericResult + return GenericResult(value=f"got:{dep}") + + model = downstream(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertIn("value=42", result.value) + + +class TestDynamicModeContextLookup(TestCase): + """Group 21: Dynamic mode context lookup for deferred values.""" + + def test_deferred_value_from_context(self): + """Dynamic mode should pull deferred values from context.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + # y is deferred — pulled from context + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_missing_deferred_value_raises(self): + """Dynamic mode should raise for missing deferred values.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + with self.assertRaisesRegex(TypeError, "Missing runtime input"): + model.flow.compute() # y not provided + + def test_context_sourced_value_coercion(self): + """Dynamic mode should coerce context-sourced values through validators.""" + + @Flow.model + def typed_model(x: int, y: int) -> int: + return x + y + + model = typed_model(x=10) + # y provided as a value that can be coerced to int + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_deferred_value_from_context_object(self): + """Dynamic mode should look up deferred values from context attributes.""" + + @Flow.model + def multiply(x: int, y: int) -> int: + return x * y + + model = multiply(x=3) + # Call directly with a FlowContext — y must come from context + result = model(FlowContext(y=7)) + self.assertEqual(result.value, 21) + + +class TestGetContextValidatorFallbacks(TestCase): + """Group 17 additional: _get_context_validator edge cases.""" + + def test_mode2_with_context_type_override(self): + """Mode 2 with explicit context_type should use that type's validator.""" + + @Flow.model(context_args=["value"], context_type=SimpleContext) + def typed_model(value: int) -> int: + return value * 2 + + model = typed_model() + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 10) + + def test_dynamic_mode_instance_validator(self): + """Dynamic mode should create instance-specific validator.""" + + @Flow.model + def add(x: int, y: int, z: int = 0) -> int: + return x + y + z + + m1 = add(x=1) + m2 = add(x=1, y=2) + # Different bound fields => different runtime inputs + self.assertIn("y", m1.flow.unbound_inputs) + self.assertNotIn("y", m2.flow.unbound_inputs) + + +class TestRegistryResolutionInValidateFieldTypes(TestCase): + """Group 16: _resolve_registry_refs and _validate_field_types paths.""" + + def test_registry_string_not_resolving_passes_through(self): + """String value that doesn't resolve from registry should fail type validation.""" + + @Flow.model + def model_fn(x: int) -> int: + return x + + cls = type(model_fn(x=1)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "nonexistent_registry_key"}) + + def test_registry_ref_resolves_to_callable_model(self): + """String value resolving to a CallableModel should be substituted.""" + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def upstream(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[int]: + return GenericResult(value=0) + + registry.add("my_upstream", upstream()) + cls = type(downstream(dep=upstream())) + restored = cls.model_validate({"dep": "my_upstream"}) + self.assertIsNotNone(restored) + finally: + registry.clear() + + +class TestMode2MissingContextField(TestCase): + """Line 1155: Mode 2 missing context field error.""" + + def test_mode2_missing_required_context_field(self): + """Mode 2 model called with context missing a required field should raise.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def loader(start_date: str, end_date: str, source: str = "db") -> str: + return f"{source}:{start_date}-{end_date}" + + model = loader() + # Call with a FlowContext missing end_date + with self.assertRaisesRegex(TypeError, "Missing context field"): + model(FlowContext(start_date="2024-01-01")) + + +class TestDynamicModeContextObjectLookup(TestCase): + """Line 1155/1176: Dynamic mode pulling deferred values from context object.""" + + def test_deferred_value_coercion_through_context(self): + """Dynamic mode should coerce values from FlowContext through validators.""" + + @Flow.model + def typed_add(x: int, y: int) -> int: + return x + y + + model = typed_add(x=10) + # Calling with a FlowContext — y pulled from context and coerced + result = model(FlowContext(y=5)) + self.assertEqual(result.value, 15) + + if __name__ == "__main__": import unittest