Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing_extensions import Self, Any

from vtkmodules.vtkCommonDataModel import vtkMultiBlockDataSet, vtkDataSet
from vtkmodules.vtkCommonCore import vtkIdTypeArray
from vtkmodules.vtkCommonDataModel import vtkSelection, vtkSelectionNode
from vtkmodules.vtkFiltersExtraction import vtkExtractSelection

from geos.mesh.utils.arrayModifiers import createAttribute
from geos.mesh.utils.arrayHelpers import ( getAttributeSet, getNumberOfComponents, getArrayInObject )
Expand Down Expand Up @@ -53,7 +56,7 @@

# Set the attributes to compare:
dictAttributesToCompare: dict[ Piece, set[ str ] ]
attributesDiffFilter.setDicAttributesToCompare( dicAttributesToCompare )
attributesDiffFilter.setDictAttributesToCompare( dictAttributesToCompare )

# Set the inf norm computation (if wanted):
computeInfNorm: bool
Expand All @@ -75,6 +78,8 @@ class AttributesDiff:
def __init__(
self: Self,
speHandler: bool = False,
computePoints: bool = True,
computeCells: bool = True,
) -> None:
"""Compute differences (L1 and inf norm) between two identical meshes attributes.

Expand All @@ -83,6 +88,8 @@ def __init__(
Args:
speHandler (bool, optional): True to use a specific handler, False to use the internal handler.
Defaults to False.
computePoints (bool, optional): True to compute attributes differences on points, False otherwise. Defaults to True.
computeCells (bool, optional): True to compute attributes differences on cells, False otherwise.
"""
self.listMeshes: list[ vtkMultiBlockDataSet | vtkDataSet ] = []
self.dictNbElements: dict[ Piece, int ] = {}
Expand All @@ -96,6 +103,9 @@ def __init__(

self.outputMesh: vtkMultiBlockDataSet | vtkDataSet = vtkMultiBlockDataSet()

self.computeCells: bool = computeCells
self.computePoints: bool = computePoints

# Logger.
self.logger: Logger
if not speHandler:
Expand All @@ -119,6 +129,53 @@ def setLoggerHandler( self: Self, handler: logging.Handler ) -> None:
self.logger.warning( "The logger already has an handler, to use yours set the argument 'speHandler' to True"
" during the filter initialization." )

@staticmethod
def _filterVolumeCells( mesh: vtkDataSet ) -> vtkDataSet:
"""Keep only 3D volume cells; optionally save 2D cells to VTU.

Args:
mesh (vtkDataSet): input mesh to filter
"""
volumeIds = vtkIdTypeArray()
surfaceIds = vtkIdTypeArray()
nVolume = nSurface = nOther = 0

for i in range( mesh.GetNumberOfCells() ):
dim = mesh.GetCell( i ).GetCellDimension()
if dim == 3:
volumeIds.InsertNextValue( i )
nVolume += 1
elif dim == 2:
surfaceIds.InsertNextValue( i )
nSurface += 1
else:
nOther += 1

getLogger( loggerTitle, True ).info( f" Cell types: {nVolume} volume (3D) | "
f"{nSurface} surface (2D) | {nOther} other" )

if nSurface == 0 and nOther == 0:
getLogger( loggerTitle, True ).info( "No filtering needed (all cells are 3D)" )
return vtkDataSet.SafeDownCast( mesh )

sn = vtkSelectionNode()
sn.SetFieldType( vtkSelectionNode.CELL )
sn.SetContentType( vtkSelectionNode.INDICES )
sn.SetSelectionList( volumeIds )

sel = vtkSelection()
sel.AddNode( sn )

ext = vtkExtractSelection()
ext.SetInputData( 0, mesh )
ext.SetInputData( 1, sel )
ext.Update()

getLogger( loggerTitle, True ).info( f"Filtered → {nVolume} cells "
f"(removed {nSurface + nOther})" )

return vtkDataSet.SafeDownCast( ext.GetOutput() )

def setMeshes(
self: Self,
listMeshes: list[ vtkMultiBlockDataSet | vtkDataSet ],
Expand All @@ -138,12 +195,16 @@ def setMeshes(
raise ValueError( "The list of meshes must contain two meshes." )

if listMeshes[ 0 ].GetClassName() != listMeshes[ 1 ].GetClassName():
raise TypeError( "The meshes must have the same type." )
raise TypeError(
f"The meshes must have the same type. {listMeshes[0].GetClassName()} and {listMeshes[1].GetClassName()}"
)

dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = {}
if self.computeCells:
dictMeshesMaxElementId.update( { Piece.CELLS: [ 0, 0 ] } )
if self.computePoints:
dictMeshesMaxElementId.update( { Piece.POINTS: [ 0, 0 ] } )

dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = {
Piece.CELLS: [ 0, 0 ],
Piece.POINTS: [ 0, 0 ],
}
if isinstance( listMeshes[ 0 ], vtkDataSet ):
for meshId, mesh in enumerate( listMeshes ):
for piece in dictMeshesMaxElementId:
Expand Down Expand Up @@ -171,8 +232,10 @@ def setMeshes(
raise ValueError( f"The total number of { piece.value } in the meshes must be the same." )

self.listMeshes = listMeshes
self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1
self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1
if self.computeCells:
self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1
if self.computePoints:
self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1
self.outputMesh = listMeshes[ 0 ].NewInstance()
self.outputMesh.ShallowCopy( listMeshes[ 0 ] )
self._computeDictSharedAttributes()
Expand Down Expand Up @@ -219,6 +282,12 @@ def setDictAttributesToCompare( self: Self, dictAttributesToCompare: dict[ Piece
Raises:
ValueError: At least one attribute to compare is not a shared attribute.
"""
if not ( ( Piece.CELLS in dictAttributesToCompare ) == self.computeCells and
( Piece.POINTS in dictAttributesToCompare ) == self.computePoints ):
raise LookupError(
"While instructed on cell/point the attribute to diff is either absent or not on the right support \n. "
)

for piece, setSharedAttributesToCompare in dictAttributesToCompare.items():
if not setSharedAttributesToCompare.issubset( self.dictSharedAttributes[ piece ] ):
wrongAttributes: set[ str ] = setSharedAttributesToCompare.difference(
Expand Down Expand Up @@ -298,6 +367,9 @@ def _computeDictAttributesArray( self: Self ) -> None:
nbAttributeComponents: int
for meshId, mesh in enumerate( self.listMeshes ):
if isinstance( mesh, vtkDataSet ):
mesh = AttributesDiff._filterVolumeCells( mesh )
if mesh.GetNumberOfCells() == 0 and piece == Piece.CELLS:
continue
arrayAttributeData = getArrayInObject( mesh, attributeName, piece )
nbAttributeComponents = getNumberOfComponents( mesh, attributeName, piece )
self.dictAttributesArray[ piece ][ :, idComponents:idComponents + nbAttributeComponents,
Expand All @@ -306,7 +378,10 @@ def _computeDictAttributesArray( self: Self ) -> None:
else:
listMeshBlockId: list[ int ] = getBlockElementIndexesFlatten( mesh )
for meshBlockId in listMeshBlockId:
dataset: vtkDataSet = vtkDataSet.SafeDownCast( mesh.GetDataSet( meshBlockId ) )
dataset: vtkDataSet = AttributesDiff._filterVolumeCells(
vtkDataSet.SafeDownCast( mesh.GetDataSet( meshBlockId ) ) )
if dataset.GetNumberOfCells() == 0 and piece == Piece.CELLS:
continue
arrayAttributeData = getArrayInObject( dataset, attributeName, piece )
nbAttributeComponents = getNumberOfComponents( dataset, attributeName, piece )
lToG: npt.NDArray[ Any ] = getArrayInObject( dataset, "localToGlobalMap", piece )
Expand Down
2 changes: 1 addition & 1 deletion geos-processing/tests/test_AttributesDiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_AttributesDiff( dataSetTest: vtkMultiBlockDataSet, ) -> None:
mesh1: vtkMultiBlockDataSet = dataSetTest( "2Ranks" )
mesh2: vtkMultiBlockDataSet = dataSetTest( "4Ranks" )

attributesDiffFilter: AttributesDiff = AttributesDiff()
attributesDiffFilter: AttributesDiff = AttributesDiff( computePoints=True, computeCells=True )
attributesDiffFilter.setMeshes( [ mesh1, mesh2 ] )
attributesDiffFilter.logSharedAttributeInfo()
dictAttributesToCompare: dict[ Piece, set[ str ] ] = {
Expand Down
Loading