Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions src/pcms/adapter/omega_h/omega_h_field2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,49 +158,102 @@ struct OmegaHField2LocalizationHint
{
OmegaHField2LocalizationHint(
Omega_h::Mesh& mesh,
Kokkos::View<GridPointSearch::Result*, HostMemorySpace> search_results)
: offsets_("", mesh.nelems() + 1),
coordinates_("", search_results.size(), mesh.dim() + 1),
indices_("", search_results.size())
Kokkos::View<GridPointSearch::Result*, HostMemorySpace> search_results,
OutOfBoundsMode mode)
: mode_(mode), num_valid_(0), num_missing_(0)
{
Kokkos::View<LO*, HostMemorySpace> elem_counts("", mesh.nelems());
// First pass: count valid and invalid points
std::vector<size_t> valid_point_indices;
std::vector<size_t> missing_point_indices;

if (mode_ == OutOfBoundsMode::ERROR) {
// Error mode - throw error immediately if any point is out of bounds
for (size_t i = 0; i < search_results.size(); ++i) {
auto [dim, elem_idx, coord] = search_results(i);
bool is_missing =
(static_cast<int>(dim) != mesh.dim()) || (elem_idx < 0);
PCMS_ALWAYS_ASSERT(!is_missing && "Points found outside mesh domain");
valid_point_indices.push_back(i);
}
} else {
// Other modes - collect valid and missing points separately
for (size_t i = 0; i < search_results.size(); ++i) {
auto [dim, elem_idx, coord] = search_results(i);
bool is_missing =
(static_cast<int>(dim) != mesh.dim()) || (elem_idx < 0);
if (is_missing) {
missing_point_indices.push_back(i);
} else {
valid_point_indices.push_back(i);
}
}
}

num_valid_ = valid_point_indices.size();
num_missing_ = missing_point_indices.size();

// Handle missing points based on mode
if (num_missing_ > 0 && mode_ == OutOfBoundsMode::NEAREST_BOUNDARY) {
PCMS_ALWAYS_ASSERT(false && "NEAREST_BOUNDARY mode not implemented yet");
}

// Allocate arrays for valid points only
offsets_ = Kokkos::View<LO*, HostMemorySpace>("offsets", mesh.nelems() + 1);
coordinates_ = Kokkos::View<Real**, HostMemorySpace>(
"coordinates", num_valid_, mesh.dim() + 1);
indices_ = Kokkos::View<LO*, HostMemorySpace>("indices", num_valid_);

// Store missing point indices
if (num_missing_ > 0) {
missing_indices_ =
Kokkos::View<LO*, HostMemorySpace>("missing_indices", num_missing_);
for (size_t i = 0; i < num_missing_; ++i) {
missing_indices_(i) = static_cast<LO>(missing_point_indices[i]);
}
}

for (size_t i = 0; i < search_results.size(); ++i) {
auto [dim, elem_idx, coord] = search_results(i);
// Count points per element (valid points only)
Kokkos::View<LO*, HostMemorySpace> elem_counts("elem_counts",
mesh.nelems());
for (size_t i = 0; i < num_valid_; ++i) {
auto [dim, elem_idx, coord] = search_results(valid_point_indices[i]);
elem_counts[elem_idx] += 1;
}

// Compute offsets
LO total;

ComputeOffsetsFunctor functor(offsets_, elem_counts);
Kokkos::parallel_scan(
"ComputeOffsets",
Kokkos::RangePolicy<HostMemorySpace::execution_space>(0, mesh.nelems()),
functor, total);
offsets_(mesh.nelems()) = total;

for (size_t i = 0; i < search_results.size(); ++i) {
auto [dim, elem_idx, coord] = search_results(i);
// currently don't handle case where point is on a boundary
PCMS_ALWAYS_ASSERT(static_cast<int>(dim) == mesh.dim());
// element should be inside the domain (positive)
PCMS_ALWAYS_ASSERT(elem_idx >= 0 && elem_idx < mesh.nelems());
// Fill coordinates and indices for valid points
for (size_t i = 0; i < num_valid_; ++i) {
size_t orig_idx = valid_point_indices[i];
auto [dim, elem_idx, coord] = search_results(orig_idx);
elem_counts(elem_idx) -= 1;
LO index = offsets_(elem_idx) + elem_counts(elem_idx);
for (int j = 0; j < (mesh.dim() + 1); ++j) {
coordinates_(index, j) = coord[j];
}
// coordinates_(index, mesh.dim()) = coord[0];
indices_(index) = i;
indices_(index) = static_cast<LO>(orig_idx);
}
}

OutOfBoundsMode mode_;
size_t num_valid_;
size_t num_missing_;

// offsets is the number of points in each element
Kokkos::View<LO*, HostMemorySpace> offsets_;
// coordinates are the parametric coordinates of each point
Kokkos::View<Real**, HostMemorySpace> coordinates_;
// indices are the index of the original point
// indices are the index of the original point (for valid points)
Kokkos::View<LO*, HostMemorySpace> indices_;
// indices of points not found in mesh
Kokkos::View<LO*, HostMemorySpace> missing_indices_;
};

/*
Expand All @@ -210,7 +263,7 @@ OmegaHField2::OmegaHField2(const OmegaHFieldLayout& layout)
: layout_(layout),
mesh_(layout.GetMesh()),
search_(mesh_, 10, 10),
dof_holder_data_("", static_cast<size_t>(layout.OwnedSize()))
dof_holder_data_("dof_holder_data", static_cast<size_t>(layout.OwnedSize()))
{
auto nodes_per_dim = layout.GetNodesPerDim();
if (nodes_per_dim[2] == 0 && nodes_per_dim[3] == 0) {
Expand Down Expand Up @@ -302,7 +355,8 @@ LocalizationHint OmegaHField2::GetLocalizationHint(
Kokkos::View<GridPointSearch::Result*, HostMemorySpace> results_h(
"results_h", results.size());
Kokkos::deep_copy(results_h, results);
auto hint = std::make_shared<OmegaHField2LocalizationHint>(mesh_, results_h);
auto hint = std::make_shared<OmegaHField2LocalizationHint>(
mesh_, results_h, out_of_bounds_mode_);

return LocalizationHint{hint};
}
Expand Down Expand Up @@ -332,11 +386,23 @@ void OmegaHField2::Evaluate(LocalizationHint location,
"eval_results_h", eval_results.extent(0), eval_results.extent(1));
deep_copy_mismatch_layouts(eval_results_h, eval_results);
Rank1View<Real, HostMemorySpace> values = results.GetValues();

// Copy results for valid points
Kokkos::parallel_for(
"CopyEvalResultsToValues",
Kokkos::RangePolicy<HostMemorySpace::execution_space>(
0, eval_results_h.extent(0)),
KOKKOS_LAMBDA(LO i) { values[hint.indices_(i)] = eval_results_h(i, 0); });

// Handle missing points based on mode
if (hint.num_missing_ > 0 && hint.mode_ == OutOfBoundsMode::FILL) {
auto fill_val = fill_value_;
Kokkos::parallel_for(
"FillMissingValues",
Kokkos::RangePolicy<HostMemorySpace::execution_space>(0,
hint.num_missing_),
KOKKOS_LAMBDA(LO i) { values[hint.missing_indices_(i)] = fill_val; });
}
}

void OmegaHField2::EvaluateGradient(
Expand Down Expand Up @@ -388,4 +454,5 @@ void OmegaHField2::Deserialize(

SetDOFHolderData(pcms::make_const_array_view(sorted_buffer));
}

} // namespace pcms
21 changes: 21 additions & 0 deletions src/pcms/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ struct LocalizationHint
std::shared_ptr<void> data = nullptr;
};

enum class OutOfBoundsMode
{
ERROR, // Throw error when points are out of bounds
FILL, // Fill out-of-bounds points with a fill value
NEAREST_BOUNDARY // Map to nearest boundary cell (extrapolate)
};

class FieldLayout;

/*
Expand Down Expand Up @@ -127,7 +134,21 @@ class FieldT
Rank1View<const T, pcms::HostMemorySpace> buffer,
Rank1View<const pcms::LO, pcms::HostMemorySpace> permutation) = 0;

// Out-of-bounds handling
void SetOutOfBoundsMode(OutOfBoundsMode mode, Real fill_value = 0.0)
{
out_of_bounds_mode_ = mode;
fill_value_ = fill_value;
}

OutOfBoundsMode GetOutOfBoundsMode() const { return out_of_bounds_mode_; }
Real GetFillValue() const { return fill_value_; }

virtual ~FieldT() noexcept = default;

protected:
OutOfBoundsMode out_of_bounds_mode_ = OutOfBoundsMode::ERROR;
Real fill_value_ = 0.0;
};
// Should statically instantiate types
using FieldPtr =
Expand Down
3 changes: 2 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ if(Catch2_FOUND)
test_normalisation.cpp
test_spline_interpolator.cpp
test_svd_serial.cpp
test_interpolation_class.cpp)
test_interpolation_class.cpp
test_omega_h_field2_outofbounds.cpp)
endif()
add_executable(unit_tests ${PCMS_UNIT_TEST_SOURCES})
target_link_libraries(unit_tests PUBLIC Catch2::Catch2 pcms::core
Expand Down
82 changes: 82 additions & 0 deletions test/test_omega_h_field2_outofbounds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <Omega_h_mesh.hpp>
#include <Omega_h_build.hpp>
#include <Omega_h_for.hpp>
#include "pcms/adapter/omega_h/omega_h_field2.h"
#include "pcms/create_field.h"
#include <Kokkos_Core.hpp>
#include <vector>

using pcms::Real;

TEST_CASE("omega_h_field2 out of bounds FILL mode")
{
auto lib = Omega_h::Library{};
auto world = lib.world();
// Create a 1x1 box mesh (coords from 0 to 1)
auto mesh =
Omega_h::build_box(world, OMEGA_H_SIMPLEX, 1, 1, 0, 10, 10, 0, false);
auto layout =
pcms::CreateLagrangeLayout(mesh, 1, 1, pcms::CoordinateSystem::Cartesian);
const auto nverts = mesh.nents(0);
auto mesh_coords = mesh.coords();

// Set up a simple linear field
auto f = KOKKOS_LAMBDA(Real x, Real y)
{
return x + y;
};
Omega_h::Write<Real> test_f(nverts);
Omega_h::parallel_for(
nverts, OMEGA_H_LAMBDA(int i) {
Real x = mesh_coords[2 * i + 0];
Real y = mesh_coords[2 * i + 1];
test_f[i] = f(x, y);
});
Omega_h::HostWrite<Real> test_f_host(test_f);
auto field = layout->CreateField();
field->SetDOFHolderData(pcms::make_const_array_view(test_f_host));

// Set FILL mode with fill value of -999.0
Real fill_value = -999.0;
field->SetOutOfBoundsMode(pcms::OutOfBoundsMode::FILL, fill_value);

// Test points - mix of inside and outside
std::vector<Real> coords = {
0.5, 0.5, // inside - should evaluate normally
1.5, 0.5, // outside (x > 1) - should return fill_value
0.5, -0.1, // outside (y < 0) - should return fill_value
0.3, 0.7, // inside - should evaluate normally
-0.1, 0.5, // outside (x < 0) - should return fill_value
};

std::vector<Real> evaluation(coords.size() / 2);
pcms::Rank1View<Real, pcms::HostMemorySpace> eval_view{evaluation.data(),
evaluation.size()};
pcms::Rank2View<const Real, pcms::HostMemorySpace> coords_view(
coords.data(), coords.size() / 2, 2);
pcms::FieldDataView<Real, pcms::HostMemorySpace> data_view(
eval_view, field->GetCoordinateSystem());
pcms::CoordinateView<pcms::HostMemorySpace> coordinate_view{
field->GetCoordinateSystem(), coords_view};

auto locale = field->GetLocalizationHint(coordinate_view);
field->Evaluate(locale, data_view);

// Check results
// Point 0: (0.5, 0.5) - inside, should be close to f(0.5, 0.5) = 1.0
REQUIRE(std::abs(evaluation[0] - 1.0) < 0.1);

// Point 1: (1.5, 0.5) - outside, should be fill_value
REQUIRE(evaluation[1] == fill_value);

// Point 2: (0.5, -0.1) - outside, should be fill_value
REQUIRE(evaluation[2] == fill_value);

// Point 3: (0.3, 0.7) - inside, should be close to f(0.3, 0.7) = 1.0
REQUIRE(std::abs(evaluation[3] - 1.0) < 0.1);

// Point 4: (-0.1, 0.5) - outside, should be fill_value
REQUIRE(evaluation[4] == fill_value);
}