Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Tuple

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next import common, named_collections
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction


Expand All @@ -21,8 +21,8 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi
@WhereBuiltinFunction
def concat_where(
cond: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
true_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
false_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
/,
) -> common.Field | Tuple:
"""
Expand Down
13 changes: 9 additions & 4 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next import common, named_collections
from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS
from gt4py.next.iterator import runtime
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -78,6 +78,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
types = [_type_conversion_helper(e) for e in t.__args__] # type: ignore[attr-defined]
assert all(type(t) is type and issubclass(t, ts.TypeSpec) for t in types)
return cast(tuple[type[ts.TypeSpec], ...], tuple(types)) # `cast` to break the recursion
elif t in named_collections.CUSTOM_NAMED_COLLECTION_TYPES:
return ts.NamedCollectionType
else:
raise AssertionError("Illegal type encountered.")

Expand Down Expand Up @@ -138,7 +140,10 @@ def __gt_type__(self) -> ts.FunctionType:


CondT = TypeVar("CondT", bound=Union[common.Field, common.Domain])
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])
FieldT = TypeVar(
"FieldT",
bound=Union[common.Field, core_defs.Scalar, Tuple, named_collections.CustomNamedCollection],
)


class WhereBuiltinFunction(
Expand Down Expand Up @@ -188,8 +193,8 @@ def broadcast(
@WhereBuiltinFunction
def where(
mask: common.Field,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
true_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
false_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
/,
) -> common.Field | Tuple:
raise NotImplementedError()
Expand Down
205 changes: 68 additions & 137 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Optional, TypeAlias, TypeVar, cast
import textwrap
from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast

import gt4py.next.ffront.field_operator_ast as foast
from gt4py import eve
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors, utils
from gt4py.next.common import DimensionKind, promote_dims
from gt4py.next.ffront import ( # noqa
from gt4py.next import errors
from gt4py.next.common import Dimension, DimensionKind, promote_dims
from gt4py.next.ffront import (
dialect_ast_enums,
experimental,
fbuiltins,
Expand All @@ -35,7 +37,6 @@ def with_altered_scalar_kind(

Examples:
---------
>>> from gt4py.next import Dimension
>>> scalar_t = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> print(with_altered_scalar_kind(scalar_t, ts.ScalarKind.BOOL))
bool
Expand All @@ -57,80 +58,6 @@ def with_altered_scalar_kind(
raise ValueError(f"Expected field or scalar type, got '{type_spec}'.")


def construct_tuple_type(
true_branch_types: list, false_branch_types: list, mask_type: ts.FieldType
) -> list:
"""
Recursively construct the return types for the tuple return branch.

Examples:
---------
>>> from gt4py.next import Dimension
>>> mask_type = ts.FieldType(
... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)
... )
>>> true_branch_types = [
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ]
>>> false_branch_types = [
... ts.FieldType(
... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
... ),
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ]
>>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type))
[FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None)), FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))]
"""
element_types_new = true_branch_types
for i, element in enumerate(true_branch_types):
if isinstance(element, ts.TupleType):
element_types_new[i] = ts.TupleType(
types=construct_tuple_type(element.types, false_branch_types[i].types, mask_type)
)
else:
element_types_new[i] = promote_to_mask_type(
mask_type, type_info.promote(element_types_new[i], false_branch_types[i])
)
return element_types_new


def promote_to_mask_type(
mask_type: ts.FieldType, input_type: ts.FieldType | ts.ScalarType
) -> ts.FieldType:
"""
Promote mask type with the input type.

The input type being the result of promoting the left and right types in a conditional clause.

If the input type is a scalar, the return type takes the dimensions of the mask_type, while retaining the dtype of
the input type. The behavior is similar when the input type is a field type with fewer dimensions than the mask_type.
In all other cases, the return type takes the dimensions and dtype of the input type.

>>> from gt4py.next import Dimension
>>> I, J = (Dimension(value=dim) for dim in ["I", "J"])
>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype)
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype)
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 11>, shape=None))
"""
if isinstance(input_type, ts.ScalarType) or not all(
item in input_type.dims for item in mask_type.dims
):
return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type
return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore
else:
return input_type


def deduce_stmt_return_type(
node: foast.BlockStmt, *, requires_unconditional_return: bool = True
) -> Optional[ts.TypeSpec]:
Expand Down Expand Up @@ -227,7 +154,7 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla
---------
>>> import ast
>>> import typing
>>> from gt4py.next import Field, Dimension
>>> from gt4py.next import Field
>>> from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function
>>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser
>>> IDim = Dimension("IDim")
Expand Down Expand Up @@ -983,46 +910,68 @@ def _visit_as_offset(self, node: foast.Call, **kwargs: Any) -> foast.Call:
func=node.func, args=node.args, kwargs=node.kwargs, type=arg_0, location=node.location
)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
mask_type = cast(ts.FieldType, node.args[0].type)
true_branch_type = node.args[1].type
false_branch_type = node.args[2].type
return_type: ts.TupleType | ts.FieldType
if not type_info.is_logical(mask_type):
def _deduce_where_return_type(
self,
func_name: str,
cond_dims: Sequence[Dimension],
true_branch: ts.FieldType | ts.TupleType | ts.NamedCollectionType,
false_branch: ts.FieldType | ts.TupleType | ts.NamedCollectionType,
location: eve.SourceLocation,
) -> ts.FieldType | ts.TupleType | ts.NamedCollectionType:
assert all(
isinstance(el, (ts.FieldType, ts.ScalarType))
for arg in (true_branch, false_branch)
for el in type_info.primitive_constituents(arg)
)

# replace all primitive constituents by the same type, `ts.DeferredType()` for convenience,
# to capture the structure of the two branches
extract_structure = ti_ffront.tree_map_type(lambda x: ts.DeferredType(constraint=None))
tb_structure = extract_structure(true_branch)
fb_structure = extract_structure(false_branch)

if tb_structure != fb_structure:
raise errors.DSLError(
node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"a field with dtype 'bool', got '{mask_type}'.",
location,
f"Second and third argument to '{func_name}' must have the same tuple/collection "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would

Suggested change
f"Second and third argument to '{func_name}' must have the same tuple/collection "
f"Second and third argument to '{func_name}' must have the same collection "

Let's include @egparedes. It's not about this line, but how we want to use collection/named collection etc when we talk and in documentation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer collection as a generic term but I don't thin it's specific enough for our use case so we might need to find a better name.... Maybe static collection, because the attributes need to be known at compile time. Other random ideas: shaped collection, structured collection, ...
For now, as I said, collection might be just be good enough

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we agree in some term (I still leaning towards StaticCollection) we should create a TypeAlias in the same spirit of NumericValue or PrimitiveValue and use it in the type annotations (e.g. NumericValue | StaticCollection[NumericValue]). This can be also done in a different PR, since it likely needs a bit of discussion...

f"structure.\n"
+ textwrap.indent(
f"true branch: '{true_branch}'\nfalse branch: '{false_branch}'", " "
),
)

try:
# TODO(tehrengruber): the construct_tuple_type function doesn't look correct
if isinstance(true_branch_type, ts.TupleType) and isinstance(
false_branch_type, ts.TupleType
):
return_type = ts.TupleType(
types=construct_tuple_type(
true_branch_type.types, false_branch_type.types, mask_type
)
)
elif isinstance(true_branch_type, ts.TupleType) or isinstance(
false_branch_type, ts.TupleType
):
@ti_ffront.tree_map_type
def deduce_return_type(
tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType
) -> ts.FieldType:
if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)):
raise errors.DSLError(
node.location,
f"Return arguments need to be of same type in '{node.func!s}', got "
f"'{node.args[1].type}' and '{node.args[2].type}'.",
location,
f"Field arguments to '{func_name}' must be of same dtype, got '{t_dtype}' != "
f"'{f_dtype}'.",
)
else:
true_branch_fieldtype = cast(ts.FieldType, true_branch_type)
false_branch_fieldtype = cast(ts.FieldType, false_branch_type)
promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype)
return_type = promote_to_mask_type(mask_type, promoted_type)
return_dims = promote_dims(cond_dims, type_info.extract_dims(type_info.promote(tb, fb)))
return_type = ts.FieldType(dims=return_dims, dtype=t_dtype)
return return_type

except ValueError as ex:
return deduce_return_type(true_branch, false_branch) # type: ignore[return-value]

def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)
assert isinstance(mask_type, ts.FieldType)
if not type_info.is_logical(mask_type):
raise errors.DSLError(
node.location, f"Incompatible argument in call to '{node.func!s}'."
) from ex
node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"a field with dtype 'bool', got '{mask_type}'.",
)
return_type = self._deduce_where_return_type(
"where",
mask_type.dims,
true_branch_type, # type: ignore[arg-type]
false_branch_type, # type: ignore[arg-type]
node.location,
)

return foast.Call(
func=node.func,
Expand All @@ -1036,32 +985,14 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
cond_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)

assert isinstance(cond_type, ts.DomainType)
assert all(
isinstance(el, (ts.FieldType, ts.ScalarType))
for arg in (true_branch_type, false_branch_type)
for el in type_info.primitive_constituents(arg)
return_type = self._deduce_where_return_type(
"concat_where",
cond_type.dims,
true_branch_type, # type: ignore[arg-type]
false_branch_type, # type: ignore[arg-type]
node.location,
)

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=list(elts)),
)
def deduce_return_type(
tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType
) -> ts.FieldType:
if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)):
raise errors.DSLError(
node.location,
f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.",
)
return_dims = promote_dims(
cond_type.dims, type_info.extract_dims(type_info.promote(tb, fb))
)
return_type = ts.FieldType(dims=return_dims, dtype=t_dtype)
return return_type

return_type = deduce_return_type(true_branch_type, false_branch_type)

return foast.Call(
func=node.func,
args=node.args,
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def process_elements(
arg_types: Optional[Iterable[ts.TypeSpec]] = None,
) -> itir.FunCall:
"""
Recursively applies a processing function to all primitive constituents of a tuple.
Recursively applies a processing function to all primitive constituents of a tuple or
named collection.

Arguments:
process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`.
Expand Down Expand Up @@ -60,7 +61,7 @@ def _process_elements_impl(
current_el_type: ts.TypeSpec,
arg_types: Optional[Iterable[ts.TypeSpec]],
) -> itir.Expr:
if isinstance(current_el_type, ts.TupleType):
if isinstance(current_el_type, (ts.TupleType, ts.NamedCollectionType)):
result = im.make_tuple(
Comment thread
tehrengruber marked this conversation as resolved.
*(
_process_elements_impl(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S
tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb])

assert tb_dtype == fb_dtype, (
f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'."
f"Field arguments to 'concat_where' must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'."
)
dtype = tb_dtype

Expand Down
Loading