Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 8 additions & 72 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -12121,22 +12121,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 24,
"endColumn": 27,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 41,
"endColumn": 76,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -12209,38 +12193,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 35,
"endColumn": 45,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 35,
"endColumn": 45,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 34,
"endColumn": 43,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 34,
"endColumn": 43,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down Expand Up @@ -16885,14 +16837,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 41,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand Down Expand Up @@ -17157,6 +17101,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 8,
"endColumn": 19,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down Expand Up @@ -17189,14 +17141,6 @@
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 58,
"endColumn": 74,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand All @@ -17213,14 +17157,6 @@
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 36,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
59 changes: 41 additions & 18 deletions pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast

Expand All @@ -33,7 +33,12 @@
from pymbolic import ArithmeticExpression
from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT, from_numpy

from pytential.symbolic.mappers import CachedIdentityMapper, DependencyMapper
from pytential.symbolic.mappers import (
CachedIdentityMapper,
DependencyMapper,
PrettyStringifyMapper,
StringifyMapper,
)
from pytential.symbolic.primitives import (
DOFDescriptor,
IntG,
Expand All @@ -44,6 +49,7 @@

if TYPE_CHECKING:
from collections.abc import (
Callable,
Collection,
Hashable,
Iterator,
Expand Down Expand Up @@ -76,7 +82,7 @@
# {{{ statements

@dataclass(frozen=True, eq=False)
class Statement:
class Statement(ABC):
"""
.. autoattribute:: names
.. autoattribute:: exprs
Expand All @@ -93,23 +99,25 @@ class Statement:
priority: int
"""The priority of the statement."""

@abstractmethod
def get_assignees(self) -> set[str]:
"""
:returns: names of variables that are assigned to in this statement.
"""
raise NotImplementedError(
f"get_assignees for '{self.__class__.__name__}'")

@abstractmethod
def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
"""
:returns: variables that are dependencies of the assignees.
"""
raise NotImplementedError(
f"get_dependencies for '{self.__class__.__name__}'")

@abstractmethod
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
...

@override
def __str__(self) -> str:
raise NotImplementedError
def __str__(self):
return self.stringify(StringifyMapper())


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -152,14 +160,17 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
return deps

@override
def __str__(self) -> str:
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
comment = self.comment

if len(self.names) == 1:
if comment:
comment = f"/* {comment} */ "

return "{} <- {}{}".format(self.names[0], comment, self.exprs[0])
return "{} <- {}{}".format(
self.names[0],
comment,
expr_mapper(self.exprs[0]))
else:
do_not_return = self.do_not_return
if do_not_return is None:
Expand All @@ -176,7 +187,7 @@ def __str__(self) -> str:
else:
dnr_indicator = ""

lines.append(f" {n} <{dnr_indicator}- {e}")
lines.append(f" {n} <{dnr_indicator}- {expr_mapper(e)}")
lines.append("}")

return "\n".join(lines)
Expand Down Expand Up @@ -266,14 +277,12 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
return result

@override
def __str__(self) -> str:
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
args = [f"source={self.source}"]
for i, density in enumerate(self.densities):
args.append(f"density{i}={density}")

from pytential.symbolic.mappers import StringifyMapper, stringify_where
strify = StringifyMapper()

from pytential.symbolic.mappers import stringify_where
lines: list[str] = []
for o in self.outputs:
if o.target_name != self.source:
Expand Down Expand Up @@ -308,7 +317,7 @@ def __str__(self) -> str:
lines.append(line)

for arg_name, arg_expr in self.kernel_arguments.items():
arg_expr_lines = strify(arg_expr).split("\n")
arg_expr_lines = expr_mapper(arg_expr).split("\n")
lines.append(" {} = {}".format(arg_name, arg_expr_lines[0]))
lines.extend(" " + s for s in arg_expr_lines[1:])

Expand Down Expand Up @@ -417,9 +426,23 @@ def statements(self) -> list[Statement]:

@override
def __str__(self) -> str:
strify_mapper = PrettyStringifyMapper()
lines: list[str] = []
for insn in self.statements:
lines.extend(str(insn).split("\n"))
lines.extend(insn.stringify(strify_mapper).split("\n"))

if strify_mapper.cse_name_list:
# FIXME: There's potential here for name clashes between the 'code'
# and 'discretization CSE' parts. It's just presentation, so if it's
# bothersome, near here is the place to fix it.
lines = [
"DISCRETIZATION-LEVEL COMMON SUBEXPRESSIONS:",
*[
f"{name} <- {cse_expr_str}"
for name, cse_expr_str in strify_mapper.cse_name_list],
"-"*75,
*lines]

lines.append(f"RESULT: {self.result}")

return "\n".join(lines)
Expand Down
Loading
Loading