diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index ef065d8a3c42..2b745babad02 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -796,12 +796,72 @@ def to_row(element): def expand_composite_transform(spec, scope): spec = normalize_inputs_outputs(normalize_source_sink(spec)) + original_transforms = spec['transforms'] + # Check if any transform has a NON-EMPTY explicit input or output. + # Note: {} (empty dict) means "no explicit input specified" and should + # NOT count as having explicit io. + # However, if the composite has no input, we can't do implicit chaining. + has_explicit_io = any( + io is not None and not is_empty(t.get(io, {})) + for t in original_transforms for io in ('input', 'output')) + + # If the composite has no input, we can't do implicit chaining + composite_has_input = not is_empty(spec.get('input', {})) + + # Only do implicit chaining if: + # 1. No transform has explicit io, AND + # 2. The composite has an input to chain from + if not has_explicit_io and composite_has_input: + new_transforms = [] + for ix, transform in enumerate(original_transforms): + transform = dict(transform) + if ix == 0: + composite_input = spec.get('input', {}) + if is_explicitly_empty(composite_input): + transform['input'] = composite_input + elif is_empty(composite_input): + # No explicit input - the composite input IS the pipeline input. + # Reference the 'input' key from the Scope's inputs. + transform['input'] = 'input' + else: + transform['input'] = {key: key for key in composite_input.keys()} + else: + transform['input'] = new_transforms[-1]['__uuid__'] + new_transforms.append(transform) + + if new_transforms: + spec = dict(spec, transforms=new_transforms) + # Check if output is empty, not just present (normalization sets it to {}) + if is_empty(spec.get('output', {})): + spec['output'] = { + '__implicit_outputs__': new_transforms[-1]['__uuid__'] + } + + # Compute the inputs for the inner scope. + # If the composite has an empty input dict ({}), it means the composite + # should use the parent scope's inputs directly. + composite_input = spec.get('input', {}) + + if is_empty(composite_input): + # No explicit input - use the parent scope's inputs directly + inner_scope_inputs = dict(scope._inputs) + else: + # The composite has explicit input references + # They can reference either: + # 1. A parent scope input (e.g., 'input' key in scope._inputs) + # 2. A transform output (e.g., 'uuid' -> the output of a transform) + inner_scope_inputs = {} + for key, value in composite_input.items(): + if isinstance(value, str) and value in scope._inputs: + # Reference to a parent scope input + inner_scope_inputs[key] = scope._inputs[value] + else: + # Reference to a transform output + inner_scope_inputs[key] = scope.get_pcollection(value) + inner_scope = Scope( scope.root, - { - key: scope.get_pcollection(value) - for (key, value) in empty_if_explicitly_empty(spec['input']).items() - }, + inner_scope_inputs, spec['transforms'], # TODO(robertwb): Are scoped providers ever used? Worth supporting? yaml_provider.merge_providers( @@ -814,7 +874,8 @@ class CompositePTransform(beam.PTransform): def expand(inputs): inner_scope.compute_all() if '__implicit_outputs__' in spec['output']: - return inner_scope.get_outputs(spec['output']['__implicit_outputs__']) + result = inner_scope.get_outputs(spec['output']['__implicit_outputs__']) + return result else: return { key: inner_scope.get_pcollection(value) @@ -826,16 +887,25 @@ def expand(inputs): transform = transform.with_resource_hints( **SafeLineLoader.strip_metadata(spec['resource_hints'])) + # Always set a name for the composite to ensure proper return value if 'name' not in spec: spec['name'] = 'Composite' if spec['name'] is None: # top-level pipeline, don't nest return transform.expand(None) else: _LOGGER.info("Expanding %s ", identify_object(spec)) - return ({ - key: scope.get_pcollection(value) - for (key, value) in empty_if_explicitly_empty(spec['input']).items() - } or scope.root) | scope.unique_name(spec, None) >> transform + # When the input references a scope input (not a transform output), + # we need to use the scope's inputs directly + input_dict = {} + for key, value in empty_if_explicitly_empty(spec['input']).items(): + if isinstance(value, str) and value in scope._inputs: + # Reference to a scope input + input_dict[key] = scope._inputs[value] + else: + # Reference to a transform output + input_dict[key] = scope.get_pcollection(value) + return (input_dict or + scope.root) | scope.unique_name(spec, None) >> transform def expand_chain_transform(spec, scope): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 2afb5e7d8e33..a4da97f7f50e 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -122,6 +122,26 @@ def test_composite(self): providers=TEST_PROVIDERS) assert_that(result, equal_to([1, 4, 9, 1, 8, 27])) + def test_composite_implicit_input_chaining(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([1, 2, 3]) + result = elements | YamlTransform( + ''' + type: composite + transforms: + - type: PyMap + name: Square + config: + fn: "lambda x: x * x" + - type: PyMap + name: Increment + config: + fn: "lambda x: x + 1" + ''', + providers=TEST_PROVIDERS) + assert_that(result, equal_to([2, 5, 10])) + def test_chain_with_input(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: