diff --git a/examples/elm-pb/elm_pb.cxx b/examples/elm-pb/elm_pb.cxx index 2cb373c52b..dc9a6bb394 100644 --- a/examples/elm-pb/elm_pb.cxx +++ b/examples/elm-pb/elm_pb.cxx @@ -1356,7 +1356,8 @@ class ELMpb : public PhysicsModel { // Only update if simulation time has advanced // Uses an exponential decay of the weighting of the value in the boundary // so that the solution is well behaved for arbitrary steps - BoutReal const weight = exp(-(t - phi_boundary_last_update) / phi_boundary_timescale); + const BoutReal weight = + exp(-(t - phi_boundary_last_update) / phi_boundary_timescale); phi_boundary_last_update = t; if (mesh->firstX()) { @@ -1385,11 +1386,11 @@ class ELMpb : public PhysicsModel { } // Old value of phi at boundary. Note: this is constant in Z - BoutReal const oldvalue = + const BoutReal oldvalue = 0.5 * (phi(mesh->xstart - 1, j, 0) + phi(mesh->xstart, j, 0)); // New value of phi at boundary, relaxing towards phivalue - BoutReal const newvalue = weight * oldvalue + (1. - weight) * phivalue; + const BoutReal newvalue = weight * oldvalue + (1. - weight) * phivalue; // Set phi at the boundary to this value for (int k = mesh->zstart; k <= mesh->zend; k++) { @@ -1412,7 +1413,7 @@ class ELMpb : public PhysicsModel { 0.5 * (phi(mesh->xend + 1, j, 0) + phi(mesh->xend, j, 0)); // New value of phi at boundary, relaxing towards phivalue - BoutReal const newvalue = weight * oldvalue + (1. - weight) * phivalue; + const BoutReal newvalue = weight * oldvalue + (1. - weight) * phivalue; // Set phi at the boundary to this value for (int k = mesh->zstart; k <= mesh->zend; k++) { @@ -1625,7 +1626,7 @@ class ELMpb : public PhysicsModel { for (int jz = 0; jz < mesh->LocalNz; jz++) { // Zero-gradient potential - BoutReal const phisheath = phi_fa(r.ind, mesh->ystart, jz); + const BoutReal phisheath = phi_fa(r.ind, mesh->ystart, jz); BoutReal jsheath = -(sqrt(mi_me) / (2. * sqrt(PI))) * phisheath; @@ -1646,7 +1647,7 @@ class ELMpb : public PhysicsModel { for (int jz = 0; jz < mesh->LocalNz; jz++) { // Zero-gradient potential - BoutReal const phisheath = phi_fa(r.ind, mesh->yend, jz); + const BoutReal phisheath = phi_fa(r.ind, mesh->yend, jz); BoutReal jsheath = (sqrt(mi_me) / (2. * sqrt(PI))) * phisheath; @@ -2068,7 +2069,8 @@ class ELMpb : public PhysicsModel { ddt(P).applyBoundary("neumann"); Field3D U1 = ddt(U); - U1 += (gamma * B0 * B0) * Grad_par(Jrhs, CELL_CENTRE) + (gamma * b0xcv) * Grad(ddt(P)); + U1 += + (gamma * B0 * B0) * Grad_par(Jrhs, CELL_CENTRE) + (gamma * b0xcv) * Grad(ddt(P)); // Second matrix, solving Alfven wave dynamics static std::unique_ptr invU{nullptr}; diff --git a/include/bout/field2d.hxx b/include/bout/field2d.hxx index 2af5c519d2..418a1dbbde 100644 --- a/include/bout/field2d.hxx +++ b/include/bout/field2d.hxx @@ -4,9 +4,9 @@ * \brief Definition of 2D scalar field class * ************************************************************************** - * Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu + * Copyright 2010 - 2026 BOUT++ developers * - * Contact: Ben Dudson, bd512@york.ac.uk + * Contact: Ben Dudson, dudson2@llnl.gov * * This file is part of BOUT++. * @@ -33,13 +33,16 @@ class Field2D; #define BOUT_FIELD2D_H #include "bout/array.hxx" -#include "bout/build_config.hxx" +#include "bout/bout_types.hxx" +#include "bout/build_defines.hxx" #include "bout/field.hxx" #include "bout/field_data.hxx" #include "bout/fieldperp.hxx" #include "bout/region.hxx" #include +#include +#include #if BOUT_HAS_RAJA #include "RAJA/RAJA.hpp" // using RAJA lib @@ -206,7 +209,7 @@ public: } #endif - return data[jx * ny + jy]; + return data[(jx * ny) + jy]; } inline const BoutReal& operator()(int jx, int jy) const { #if CHECK > 2 && !BOUT_HAS_CUDA @@ -220,7 +223,7 @@ public: } #endif - return data[jx * ny + jy]; + return data[(jx * ny) + jy]; } /*! diff --git a/include/bout/field3d.hxx b/include/bout/field3d.hxx index 7ba6ae9575..ad3249270a 100644 --- a/include/bout/field3d.hxx +++ b/include/bout/field3d.hxx @@ -1,7 +1,7 @@ /************************************************************************** - * Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu + * Copyright 2010 - 2026 BOUT++ contributors * - * Contact: Ben Dudson, bd512@york.ac.uk + * Contact: Ben Dudson, dudson2@llnl.gov * * This file is part of BOUT++. * @@ -33,6 +33,7 @@ class Field3D; #include "bout/bout_types.hxx" #include "bout/field.hxx" #include "bout/field2d.hxx" +#include "bout/field_data.hxx" #include "bout/fieldperp.hxx" #include "bout/region.hxx" #include "bout/traits.hxx" diff --git a/include/bout/invertable_operator.hxx b/include/bout/invertable_operator.hxx index d61c258654..6dcd92897a 100644 --- a/include/bout/invertable_operator.hxx +++ b/include/bout/invertable_operator.hxx @@ -135,9 +135,9 @@ public: : operatorFunction(func), preconditionerFunction(func), opt(optIn == nullptr ? Options::getRoot()->getSection("invertableOperator") : optIn), - localmesh(localmeshIn == nullptr ? bout::globals::mesh : localmeshIn), lib(opt){ + localmesh(localmeshIn == nullptr ? bout::globals::mesh : localmeshIn), lib(opt) { - }; + }; /// Destructor just has to cleanup the PETSc owned objects. ~InvertableOperator() { diff --git a/include/bout/mesh.hxx b/include/bout/mesh.hxx index 25e66fe5d2..ca03548a31 100644 --- a/include/bout/mesh.hxx +++ b/include/bout/mesh.hxx @@ -597,7 +597,7 @@ public: virtual int getLocalZIndexNoBoundaries(int zglobal) const = 0; /// Size of the mesh on this processor including guard/boundary cells - int LocalNx, LocalNy, LocalNz; + int LocalNx{0}, LocalNy{0}, LocalNz{0}; /// Local ranges of data (inclusive), excluding guard cells int xstart, xend, ystart, yend, zstart, zend; diff --git a/include/bout/options.hxx b/include/bout/options.hxx index 03e95488a8..48840a7b6f 100644 --- a/include/bout/options.hxx +++ b/include/bout/options.hxx @@ -74,19 +74,19 @@ class Options; * which can be used as a map. * * Options options; - * + * * // Set values * options["key"] = 1.0; * * // Get values. Throws BoutException if not found - * int val = options["key"]; // Sets val to 1 + * int val = options["key"]; // Sets val to 1 * * // Return as specified type. Throws BoutException if not found * BoutReal var = options["key"].as(); * * // A default value can be used if key is not found * BoutReal value = options["pi"].withDefault(3.14); - * + * * // Assign value with source label. Throws if already has a value from same source * options["newkey"].assign(1.0, "some source"); * @@ -94,7 +94,7 @@ class Options; * options["newkey"].force(2.0, "some source"); * * A legacy interface is also supported: - * + * * options.set("key", 1.0, "code"); // Sets a key from source "code" * * int val; @@ -119,9 +119,9 @@ class Options; * * Each Options object can also contain any number of sections, which are * themselves Options objects. - * + * * Options §ion = options["section"]; - * + * * which can be nested: * * options["section"]["subsection"]["value"] = 3; @@ -134,13 +134,13 @@ class Options; * * e.g. * options->getSection("section")->getSection("subsection")->set("value", 3); - * + * * Options also know about their parents: * * Options &parent = section.parent(); - * + * * or - * + * * Options *parent = section->getParent(); * * Root options object @@ -150,8 +150,8 @@ class Options; * there is a global singleton Options object which can be accessed with a static function * * Options &root = Options::root(); - * - * or + * + * or * * Options *root = Options::getRoot(); * @@ -193,7 +193,7 @@ public: /// @param[in] parent Parent object /// @param[in] sectionName Name of the section, including path from the root Options(Options* parent_instance, std::string full_name) - : parent_instance(parent_instance), full_name(std::move(full_name)){}; + : parent_instance(parent_instance), full_name(std::move(full_name)) {}; /// Initialise with a value /// These enable Options to be constructed using initializer lists diff --git a/include/bout/output.hxx b/include/bout/output.hxx index 9416ab411b..43fc6c59bf 100644 --- a/include/bout/output.hxx +++ b/include/bout/output.hxx @@ -166,13 +166,13 @@ class ConditionalOutput : public Output { public: /// @param[in] base The Output object which will be written to if enabled /// @param[in] enabled Should this be enabled by default? - ConditionalOutput(Output* base, bool enabled = true) : base(base), enabled(enabled){}; + ConditionalOutput(Output* base, bool enabled = true) : base(base), enabled(enabled) {}; /// Constuctor taking ConditionalOutput. This allows several layers of conditions /// /// @param[in] base A ConditionalOutput which will be written to if enabled /// - ConditionalOutput(ConditionalOutput* base) : base(base), enabled(base->enabled){}; + ConditionalOutput(ConditionalOutput* base) : base(base), enabled(base->enabled) {}; /// If enabled, writes a string using fmt formatting /// by calling base->write @@ -237,7 +237,7 @@ private: /// output_debug << "debug message"; /// compile but have no effect if BOUT_USE_OUTPUT_DEBUG is false template -DummyOutput& operator<<(DummyOutput& out, T const& UNUSED(t)) { +DummyOutput& operator<<(DummyOutput& out, const T& UNUSED(t)) { return out; } @@ -261,7 +261,7 @@ inline ConditionalOutput& operator<<(ConditionalOutput& out, stream_manipulator } template -ConditionalOutput& operator<<(ConditionalOutput& out, T const& t) { +ConditionalOutput& operator<<(ConditionalOutput& out, const T& t) { if (out.isEnabled()) { *out.getBase() << t; } diff --git a/include/bout/physicsmodel.hxx b/include/bout/physicsmodel.hxx index d653d96d68..8e8e3814cd 100644 --- a/include/bout/physicsmodel.hxx +++ b/include/bout/physicsmodel.hxx @@ -1,20 +1,20 @@ /*!************************************************************************ * \file physicsmodel.hxx - * + * * @brief Base class for Physics Models - * - * + * + * * * Changelog: - * + * * 2013-08 Ben Dudson * * Initial version - * + * ************************************************************************** * Copyright 2013 B.D.Dudson * * Contact: Ben Dudson, bd512@york.ac.uk - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -167,9 +167,9 @@ public: * * Output * ------ - * + * * The time derivatives will be put in the ddt() variables - * + * * Returns a flag: 0 indicates success, non-zero an error flag */ int runRHS(BoutReal time, bool linear = false); @@ -203,7 +203,7 @@ public: bool hasPrecon(); /*! - * Run the preconditioner. The system state should be in the + * Run the preconditioner. The system state should be in the * evolving variables, and the vector to be solved in the ddt() variables. * The result will be put in the ddt() variables. * @@ -219,7 +219,7 @@ public: /*! * Run the Jacobian-vector multiplication function - * + * * Note: this is usually only called by the Solver */ int runJacobian(BoutReal t); @@ -241,10 +241,10 @@ protected: // The init and rhs functions are implemented by user code to specify problem /*! * @brief This function is called once by the solver at the start of a simulation. - * + * * A valid PhysicsModel must implement this function - * - * Variables should be read from the inputs, and the variables to + * + * Variables should be read from the inputs, and the variables to * be evolved should be specified. */ virtual int init(bool restarting) = 0; @@ -258,7 +258,7 @@ protected: /*! * @brief This function is called by the time integration solver * at least once per time step - * + * * Variables being evolved will be set by the solver * before the call, and this function must calculate * and set the time-derivatives. @@ -278,10 +278,10 @@ protected: /// Add additional variables other than the evolving variables to the restart files virtual void restartVars(Options& options); - /* + /* If split operator is set to true, then convective() and diffusive() are called instead of rhs() - + For implicit-explicit schemes, convective() will typically be treated explicitly, whilst diffusive() will be treated implicitly. For unsplit methods, both convective and diffusive will be called @@ -334,7 +334,7 @@ protected: * * @param[in] var The variable to evolve * @param[in] name The name to use for variable initialisation and output - * + * * Note that the variable must not be destroyed (e.g. go out of scope) * after this call, since a pointer to \p var is stored in the solver. * @@ -358,11 +358,11 @@ protected: * Specify a constrained variable \p var, which will be * adjusted to make \p F_var equal to zero. * If the solver does not support constraints then this will throw an exception - * + * * @param[in] var The variable the solver should modify * @param[in] F_var The control variable, which the user will set * @param[in] name The name to use for initialisation and output - * + * */ bool bout_constrain(Field3D& var, Field3D& F_var, const char* name); @@ -491,8 +491,7 @@ private: /// Add fields to the solver. /// This should accept up to ten arguments -#define SOLVE_FOR(...) \ - { MACRO_FOR_EACH(SOLVE_FOR1, __VA_ARGS__) } +#define SOLVE_FOR(...) {MACRO_FOR_EACH(SOLVE_FOR1, __VA_ARGS__)} /// Write this variable once to the grid file #define SAVE_ONCE1(var) dump.addOnce(var, #var); @@ -532,8 +531,7 @@ private: dump.addOnce(var6, #var6); \ } -#define SAVE_ONCE(...) \ - { MACRO_FOR_EACH(SAVE_ONCE1, __VA_ARGS__) } +#define SAVE_ONCE(...) {MACRO_FOR_EACH(SAVE_ONCE1, __VA_ARGS__)} /// Write this variable every timestep #define SAVE_REPEAT1(var) dump.addRepeat(var, #var); @@ -573,7 +571,6 @@ private: dump.addRepeat(var6, #var6); \ } -#define SAVE_REPEAT(...) \ - { MACRO_FOR_EACH(SAVE_REPEAT1, __VA_ARGS__) } +#define SAVE_REPEAT(...) {MACRO_FOR_EACH(SAVE_REPEAT1, __VA_ARGS__)} #endif // BOUT_PHYSICS_MODEL_H diff --git a/include/bout/sys/generator_context.hxx b/include/bout/sys/generator_context.hxx index bbe60fab65..e39eea1835 100644 --- a/include/bout/sys/generator_context.hxx +++ b/include/bout/sys/generator_context.hxx @@ -24,12 +24,8 @@ public: : Context(i.x(), i.y(), 0, (loc == CELL_ZLOW) ? CELL_CENTRE : loc, msh, t) {} /// Specify a cell index, together with the cell location, mesh and time - /// Context(int ix, int iy, int iz, CELL_LOC loc, Mesh* msh, BoutReal t); - /// Specify the values directly - Context(BoutReal x, BoutReal y, BoutReal z, Mesh* msh, BoutReal t); - /// If constructed without parameters, contains no values (null). /// Requesting x,y,z or t throws an exception Context() = default; @@ -44,6 +40,11 @@ public: BoutReal z() const { return get("z"); } BoutReal t() const { return get("t"); } + /// Cell indices + int ix() const { return ix_; } + int jy() const { return jy_; } + int kz() const { return kz_; } + /// Set the value of a parameter with given name Context& set(const std::string& name, BoutReal value) { parameters[name] = value; @@ -76,6 +77,10 @@ public: } private: + int ix_{0}; + int jy_{0}; + int kz_{0}; + Mesh* localmesh{nullptr}; ///< The mesh on which the position is defined /// Contains user-set values which can be set and retrieved diff --git a/include/bout/utils.hxx b/include/bout/utils.hxx index 5088a025c1..eae3b27730 100644 --- a/include/bout/utils.hxx +++ b/include/bout/utils.hxx @@ -8,7 +8,7 @@ * Copyright 2010 - 2026 BOUT++ contributors * * Contact: Ben Dudson, dudson2@llnl.gov - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -484,7 +484,7 @@ inline bool is_pow2(int x) { return x && !((x - 1) & x); } /*! * Return the sign of a number \p a - * by testing if a > 0 + * by testing if a > 0 */ template T SIGN(T a) { // Return +1 or -1 (0 -> +1) @@ -511,7 +511,7 @@ inline void checkData(BoutReal f) { } #else /// Ignored with disabled CHECK; Throw an exception if \p f is not finite -inline void checkData(BoutReal UNUSED(f)){}; +inline void checkData(BoutReal UNUSED(f)) {}; #endif /*! @@ -587,7 +587,7 @@ BoutReal stringToReal(const std::string& s); /*! * Convert a string to an int - * + * * Throws BoutException if can't be done */ int stringToInt(const std::string& s); @@ -604,7 +604,7 @@ std::list& strsplit(const std::string& s, char delim, /*! * Split a string on a given delimiter - * + * * @param[in] s The string to split (not modified by call) * @param[in] delim The delimiter to split on (single char) */ @@ -612,7 +612,7 @@ std::list strsplit(const std::string& s, char delim); /*! * Strips leading and trailing spaces from a string - * + * * @param[in] s The string to trim (not modified) * @param[in] c Collection of characters to remove */ @@ -620,7 +620,7 @@ std::string trim(const std::string& s, const std::string& c = " \t\r"); /*! * Strips leading spaces from a string - * + * * @param[in] s The string to trim (not modified) * @param[in] c Collection of characters to remove */ @@ -628,7 +628,7 @@ std::string trimLeft(const std::string& s, const std::string& c = " \t"); /*! * Strips leading spaces from a string - * + * * @param[in] s The string to trim (not modified) * @param[in] c Collection of characters to remove */ @@ -636,7 +636,7 @@ std::string trimRight(const std::string& s, const std::string& c = " \t\r"); /*! * Strips the comments from a string - * + * * @param[in] s The string to trim (not modified) * @param[in] c Collection of characters to remove */ diff --git a/manual/sphinx/user_docs/input_grids.rst b/manual/sphinx/user_docs/input_grids.rst index 3d6be2cf77..983346f14e 100644 --- a/manual/sphinx/user_docs/input_grids.rst +++ b/manual/sphinx/user_docs/input_grids.rst @@ -155,6 +155,21 @@ The only quantities which are required are the sizes of the grid. If these are the only quantities specified, then the coordinates revert to Cartesian. +You can read additional quantities from the grid and make them available in +expressions in the input file by listing them in the ``input:grid_variables`` +section, with the key being the name in the grid file (``mesh:file``) and the +value being the type (one of ``field3d``, ``field2d``, ``boutreal``): + +.. code-block:: cfg + + [input:grid_variables] + rho = field2d + theta = field2d + scale = boutreal + + [mesh] + B = (scale / rho) * cos(theta) + This section describes how to generate inputs for tokamak equilibria. If you’re not interested in tokamaks then you can skip to the next section. diff --git a/src/field/field2d.cxx b/src/field/field2d.cxx index c2d35162b3..a269306c3a 100644 --- a/src/field/field2d.cxx +++ b/src/field/field2d.cxx @@ -4,7 +4,7 @@ * Class for 2D X-Y profiles * ************************************************************************** - * Copyright 2010 - 2025 BOUT++ developers + * Copyright 2010 - 2026 BOUT++ developers * * Contact: Ben Dudson, dudson2@llnl.gov * @@ -29,6 +29,7 @@ #include "bout/build_defines.hxx" #include "bout/unused.hxx" +#include #include #include #include @@ -46,27 +47,25 @@ Field2D::Field2D(Mesh* localmesh, CELL_LOC location_in, DirectionTypes directions_in, std::optional UNUSED(regionID)) : Field(localmesh, location_in, directions_in) { - - if (fieldmesh) { + if (fieldmesh != nullptr) { + // Note: Even if fieldmesh is not null, LocalNx and LocalNy may + // not be initialised. nx = fieldmesh->LocalNx; ny = fieldmesh->LocalNy; } - #if BOUT_USE_TRACK name = ""; #endif } Field2D::Field2D(const Field2D& f) : Field(f), data(f.data) { - -#if BOUT_USE_TRACK - name = f.name; -#endif - - if (fieldmesh) { + if (fieldmesh != nullptr) { nx = fieldmesh->LocalNx; ny = fieldmesh->LocalNy; } +#if BOUT_USE_TRACK + name = f.name; +#endif } Field2D::Field2D(BoutReal val, Mesh* localmesh) : Field2D(localmesh) { *this = val; } @@ -89,13 +88,16 @@ Field2D::~Field2D() { delete deriv; } Field2D& Field2D::allocate() { if (data.empty()) { - if (!fieldmesh) { + if (fieldmesh == nullptr) { // fieldmesh was not initialized when this field was initialized, so use - // the global mesh and set some members to default values + // the global mesh fieldmesh = bout::globals::mesh; - nx = fieldmesh->LocalNx; - ny = fieldmesh->LocalNy; } + // Get size from the mesh. + nx = fieldmesh->LocalNx; + ny = fieldmesh->LocalNy; + ASSERT1(nx > 0); + ASSERT1(ny > 0); data.reallocate(nx * ny); #if CHECK > 2 invalidateGuards(*this); diff --git a/src/field/field3d.cxx b/src/field/field3d.cxx index 9154f78a4f..7cbf135e25 100644 --- a/src/field/field3d.cxx +++ b/src/field/field3d.cxx @@ -4,7 +4,7 @@ * Class for 3D X-Y-Z scalar fields * ************************************************************************** - * Copyright 2010 - 2025 BOUT++ developers + * Copyright 2010 - 2026 BOUT++ developers * * Contact: Ben Dudson, dudson2@llnl.gov * @@ -25,8 +25,10 @@ * **************************************************************************/ +#include "bout/array.hxx" #include "bout/bout_types.hxx" #include "bout/build_defines.hxx" +#include "bout/field2d.hxx" #include #include @@ -62,7 +64,7 @@ Field3D::Field3D(Mesh* localmesh, CELL_LOC location_in, DirectionTypes direction name = ""; #endif - if (fieldmesh) { + if (fieldmesh != nullptr) { nx = fieldmesh->LocalNx; ny = fieldmesh->LocalNy; nz = fieldmesh->LocalNz; @@ -74,35 +76,27 @@ Field3D::Field3D(Mesh* localmesh, CELL_LOC location_in, DirectionTypes direction Field3D::Field3D(const Field3D& f) : Field(f), data(f.data), yup_fields(f.yup_fields), ydown_fields(f.ydown_fields), regionID(f.regionID) { - - if (fieldmesh) { + if (fieldmesh != nullptr) { nx = fieldmesh->LocalNx; ny = fieldmesh->LocalNy; nz = fieldmesh->LocalNz; } } -Field3D::Field3D(const Field2D& f) : Field(f) { - - nx = fieldmesh->LocalNx; - ny = fieldmesh->LocalNy; - nz = fieldmesh->LocalNz; +Field3D::Field3D(const Field2D& f) + : Field(f), nx(fieldmesh->LocalNx), ny(fieldmesh->LocalNy), nz(fieldmesh->LocalNz) { *this = f; } Field3D::Field3D(const BoutReal val, Mesh* localmesh) : Field3D(localmesh) { - *this = val; } Field3D::Field3D(Array data_in, Mesh* localmesh, CELL_LOC datalocation, DirectionTypes directions_in) - : Field(localmesh, datalocation, directions_in), data(std::move(data_in)) { - - nx = fieldmesh->LocalNx; - ny = fieldmesh->LocalNy; - nz = fieldmesh->LocalNz; + : Field(localmesh, datalocation, directions_in), nx(fieldmesh->LocalNx), + ny(fieldmesh->LocalNy), nz(fieldmesh->LocalNz), data(std::move(data_in)) { ASSERT1(data.size() == nx * ny * nz); } @@ -111,14 +105,17 @@ Field3D::~Field3D() { delete deriv; } Field3D& Field3D::allocate() { if (data.empty()) { - if (!fieldmesh) { + if (fieldmesh == nullptr) { // fieldmesh was not initialized when this field was initialized, so use // the global mesh and set some members to default values fieldmesh = bout::globals::mesh; - nx = fieldmesh->LocalNx; - ny = fieldmesh->LocalNy; - nz = fieldmesh->LocalNz; } + nx = fieldmesh->LocalNx; + ny = fieldmesh->LocalNy; + nz = fieldmesh->LocalNz; + ASSERT1(nx > 0); + ASSERT1(ny > 0); + ASSERT1(nz > 0); data.reallocate(nx * ny * nz); #if CHECK > 2 invalidateGuards(*this); diff --git a/src/field/field_data.cxx b/src/field/field_data.cxx index b925cb1426..8f6a808482 100644 --- a/src/field/field_data.cxx +++ b/src/field/field_data.cxx @@ -53,11 +53,6 @@ FieldData::FieldData(Mesh* localmesh, CELL_LOC location_in) location_in, fieldmesh)) { // Need to check for nullptr again, because the // fieldmesh might still be // nullptr if the global mesh hasn't been initialized yet - if (fieldmesh != nullptr) { - // sets fieldCoordinates by getting Coordinates for our location from - // fieldmesh - getCoordinates(); - } } FieldData::FieldData(const FieldData& other) { diff --git a/src/field/field_factory.cxx b/src/field/field_factory.cxx index 188a64cf0c..793ea9f278 100644 --- a/src/field/field_factory.cxx +++ b/src/field/field_factory.cxx @@ -1,5 +1,5 @@ /************************************************************************** - * Copyright 2010-2025 BOUT++ contributors + * Copyright 2010 - 2026 BOUT++ contributors * * Contact: Ben Dudson, dudson2@llnl.gov * @@ -19,20 +19,28 @@ * along with BOUT++. If not, see . * **************************************************************************/ -#include #include -#include - +#include +#include +#include #include +#include +#include +#include +#include #include +#include +#include #include -#include "bout/constants.hxx" - #include "fieldgenerators.hxx" +#include +#include +#include + using bout::generator::Context; /// Helper function to create a FieldValue generator from a BoutReal @@ -45,6 +53,8 @@ FieldGeneratorPtr generator(BoutReal* ptr) { return std::make_shared(ptr); } +BOUT_ENUM_CLASS(GridVariableFunction, field3d, field2d, boutreal); + namespace { /// Provides a placeholder whose target can be changed after creation. /// This enables recursive FieldGenerator expressions to be generated @@ -80,6 +90,46 @@ class FieldIndirect : public FieldGenerator { FieldGeneratorPtr target; }; + +// Read variables from the grid file and make them available in expressions +template +auto add_grid_variable(FieldFactory& factory, Mesh& mesh, const std::string& name) { + factory.addGenerator(name, std::make_shared>(&mesh, name)); +} + +auto read_grid_variables(FieldFactory& factory, Mesh& mesh, Options& options) { + auto& field_variables = options["input"]["grid_variables"].doc( + "Variables to read from the grid file and make available in expressions"); + + for (const auto& [name, value] : field_variables) { + if (not mesh.isDataSourceGridFile()) { + throw BoutException( + "A grid file ('mesh:file') is required for `input:grid_variables`"); + } + + if (not mesh.sourceHasVar(name)) { + const auto filename = Options::root()["mesh"]["file"].as(); + throw BoutException( + "Grid file '{}' missing `{}` specified in `input:grid_variables`", filename, + name); + } + + const auto func = value.as(); + switch (func) { + case GridVariableFunction::field3d: + add_grid_variable(factory, mesh, name); + break; + case GridVariableFunction::field2d: + add_grid_variable(factory, mesh, name); + break; + case GridVariableFunction::boutreal: + BoutReal var{}; + mesh.get(var, name); + factory.addGenerator(name, std::make_shared(var)); + break; + } + } +} } // namespace ////////////////////////////////////////////////////////// @@ -179,6 +229,9 @@ FieldFactory::FieldFactory(Mesh* localmesh, Options* opt) // Periodic in the Y direction? addGenerator("is_periodic_y", std::make_shared()); + + // Variables from the grid file + read_grid_variables(*this, *fieldmesh, nonconst_options); } Field2D FieldFactory::create2D(const std::string& value, const Options* opt, diff --git a/src/field/fieldgenerators.hxx b/src/field/fieldgenerators.hxx index 050e335448..6866808df9 100644 --- a/src/field/fieldgenerators.hxx +++ b/src/field/fieldgenerators.hxx @@ -7,11 +7,20 @@ #ifndef BOUT_FIELDGENERATORS_H #define BOUT_FIELDGENERATORS_H +#include #include #include +#include +#include +#include #include +#include #include +#include +#include +#include +#include ////////////////////////////////////////////////////////// // Generators from values @@ -39,7 +48,7 @@ template class FieldGenOneArg : public FieldGenerator { ///< Template for single-argument function public: FieldGenOneArg(FieldGeneratorPtr g, const std::string& name = "function") - : gen(g), name(name) {} + : gen(std::move(g)), name(name) {} FieldGeneratorPtr clone(const std::list args) override { if (args.size() != 1) { throw ParseException("Incorrect number of arguments to {:s}. Expecting 1, got {:d}", @@ -66,7 +75,7 @@ class FieldGenTwoArg : public FieldGenerator { ///< Template for two-argument fu public: FieldGenTwoArg(FieldGeneratorPtr a, FieldGeneratorPtr b, const std::string& name = "function") - : A(a), B(b), name(name) {} + : A(std::move(a)), B(std::move(b)), name(name) {} FieldGeneratorPtr clone(const std::list args) override { if (args.size() != 2) { throw ParseException("Incorrect number of arguments to {:s}. Expecting 2, got {:d}", @@ -89,11 +98,13 @@ private: /// Arc (Inverse) tangent. Either one or two argument versions class FieldATan : public FieldGenerator { public: - FieldATan(FieldGeneratorPtr a, FieldGeneratorPtr b = nullptr) : A(a), B(b) {} + FieldATan(FieldGeneratorPtr a, FieldGeneratorPtr b = nullptr) + : A(std::move(a)), B(std::move(b)) {} FieldGeneratorPtr clone(const std::list args) override { if (args.size() == 1) { return std::make_shared(args.front()); - } else if (args.size() == 2) { + } + if (args.size() == 2) { return std::make_shared(args.front(), args.back()); } throw ParseException( @@ -144,7 +155,7 @@ public: FieldMin() = default; FieldMin(const std::list args) : input(args) {} FieldGeneratorPtr clone(const std::list args) override { - if (args.size() == 0) { + if (args.empty()) { throw ParseException("min function must have some inputs"); } return std::make_shared(args); @@ -153,10 +164,7 @@ public: auto it = input.begin(); BoutReal result = (*it)->generate(pos); for (; it != input.end(); it++) { - BoutReal val = (*it)->generate(pos); - if (val < result) { - result = val; - } + result = std::min(result, (*it)->generate(pos)); } return result; } @@ -171,7 +179,7 @@ public: FieldMax() = default; FieldMax(const std::list args) : input(args) {} FieldGeneratorPtr clone(const std::list args) override { - if (args.size() == 0) { + if (args.empty()) { throw ParseException("max function must have some inputs"); } return std::make_shared(args); @@ -180,10 +188,7 @@ public: auto it = input.begin(); BoutReal result = (*it)->generate(pos); for (; it != input.end(); it++) { - BoutReal val = (*it)->generate(pos); - if (val > result) { - result = val; - } + result = std::max(result, (*it)->generate(pos)); } return result; } @@ -202,7 +207,7 @@ class FieldClamp : public FieldGenerator { public: FieldClamp() = default; FieldClamp(FieldGeneratorPtr value, FieldGeneratorPtr low, FieldGeneratorPtr high) - : value(value), low(low), high(high) {} + : value(std::move(value)), low(std::move(low)), high(std::move(high)) {} FieldGeneratorPtr clone(const std::list args) override { if (args.size() != 3) { throw ParseException( @@ -238,7 +243,7 @@ private: /// Generator to round to the nearest integer class FieldRound : public FieldGenerator { public: - FieldRound(FieldGeneratorPtr g) : gen(g) {} + FieldRound(FieldGeneratorPtr g) : gen(std::move(g)) {} FieldGeneratorPtr clone(const std::list args) override { if (args.size() != 1) { @@ -307,7 +312,7 @@ public: // Constructor FieldTanhHat(FieldGeneratorPtr xin, FieldGeneratorPtr widthin, FieldGeneratorPtr centerin, FieldGeneratorPtr steepnessin) - : X(xin), width(widthin), center(centerin), steepness(steepnessin){}; + : X(xin), width(widthin), center(centerin), steepness(steepnessin) {}; // Clone containing the list of arguments FieldGeneratorPtr clone(const std::list args) override; BoutReal generate(const bout::generator::Context& pos) override; @@ -322,7 +327,7 @@ private: class FieldWhere : public FieldGenerator { public: FieldWhere(FieldGeneratorPtr test, FieldGeneratorPtr gt0, FieldGeneratorPtr lt0) - : test(test), gt0(gt0), lt0(lt0){}; + : test(test), gt0(gt0), lt0(lt0) {}; FieldGeneratorPtr clone(const std::list args) override { if (args.size() != 3) { @@ -352,6 +357,54 @@ private: FieldGeneratorPtr test, gt0, lt0; }; +/// A `Field3D` or `Field2D` that can be used in expressions +/// +/// The variable is read from Mesh when first used and shared between +/// clones. This is to avoid circular dependencies in construction. +template > +class GridVariable : public FieldGenerator { +private: + struct LazyLoaded { + LazyLoaded(Mesh* mesh, std::string name) : mesh(mesh), name(std::move(name)) {} + + Field3D get() { + if (!var.isAllocated()) { + // Read variable from mesh + if (this->mesh->get(this->var, this->name) != 0) { + throw BoutException("Couldn't read GridVariable '{}'", this->name); + } + } + return this->var; + } + + Mesh* mesh; + std::string name; + T var{}; + }; + + std::shared_ptr variable; + +public: + GridVariable(Mesh* mesh, std::string name) + : variable(std::make_shared(mesh, std::move(name))) {} + + GridVariable(std::shared_ptr variable) : variable(std::move(variable)) {} + + double generate(const bout::generator::Context& ctx) override { + return variable->get()(ctx.ix(), ctx.jy(), ctx.kz()); + } + + FieldGeneratorPtr clone(const std::list args) override { + if (!args.empty()) { + throw ParseException("Variable '{}' takes no arguments but got {:d}", + variable->name, args.size()); + } + return std::make_shared>(variable); + } + + std::string str() const override { return variable->name; } +}; + /// Function that evaluates to 1 when Y is periodic (i.e. in the core), 0 otherwise /// Note: Assumes symmetricGlobalX class FieldPeriodicY : public FieldGenerator { diff --git a/src/invert/laplace/common_transform.cxx b/src/invert/laplace/common_transform.cxx index 847add4fea..98571624a1 100644 --- a/src/invert/laplace/common_transform.cxx +++ b/src/invert/laplace/common_transform.cxx @@ -28,8 +28,8 @@ FFTTransform::FFTTransform(const Mesh& mesh, int nmode, int xs, int xe, int ys, auto FFTTransform::forward(const Laplacian& laplacian, const Field3D& rhs, const Field3D& x0, const Field2D& Acoef, const Field2D& C1coef, - const Field2D& C2coef, - const Field2D& Dcoef) const -> Matrices { + const Field2D& C2coef, const Field2D& Dcoef) const + -> Matrices { Matrices result(nsys, nx); @@ -79,8 +79,8 @@ auto FFTTransform::forward(const Laplacian& laplacian, const Field3D& rhs, return result; } -auto FFTTransform::backward(const Field3D& rhs, - const Matrix& xcmplx3D) const -> Field3D { +auto FFTTransform::backward(const Field3D& rhs, const Matrix& xcmplx3D) const + -> Field3D { Field3D x{emptyFrom(rhs)}; // FFT back to real space @@ -162,8 +162,8 @@ auto FFTTransform::forward(const Laplacian& laplacian, const FieldPerp& rhs, return result; } -auto FFTTransform::backward(const FieldPerp& rhs, - const Matrix& xcmplx) const -> FieldPerp { +auto FFTTransform::backward(const FieldPerp& rhs, const Matrix& xcmplx) const + -> FieldPerp { FieldPerp x{emptyFrom(rhs)}; // FFT back to real space @@ -209,8 +209,8 @@ DSTTransform::DSTTransform(const Mesh& mesh, int nmode, int xs, int xe, int ys, auto DSTTransform::forward(const Laplacian& laplacian, const Field3D& rhs, const Field3D& x0, const Field2D& Acoef, const Field2D& C1coef, - const Field2D& C2coef, - const Field2D& Dcoef) const -> Matrices { + const Field2D& C2coef, const Field2D& Dcoef) const + -> Matrices { Matrices result(nsys, nx); @@ -260,8 +260,8 @@ auto DSTTransform::forward(const Laplacian& laplacian, const Field3D& rhs, return result; } -auto DSTTransform::backward(const Field3D& rhs, - const Matrix& xcmplx3D) const -> Field3D { +auto DSTTransform::backward(const Field3D& rhs, const Matrix& xcmplx3D) const + -> Field3D { Field3D x{emptyFrom(rhs)}; // DST back to real space @@ -346,8 +346,8 @@ auto DSTTransform::forward(const Laplacian& laplacian, const FieldPerp& rhs, return result; } -auto DSTTransform::backward(const FieldPerp& rhs, - const Matrix& xcmplx) const -> FieldPerp { +auto DSTTransform::backward(const FieldPerp& rhs, const Matrix& xcmplx) const + -> FieldPerp { FieldPerp x{emptyFrom(rhs)}; // DST back to real space diff --git a/src/invert/laplace/impls/multigrid/multigrid_laplace.cxx b/src/invert/laplace/impls/multigrid/multigrid_laplace.cxx index 4be5bfbad3..c06a83585b 100644 --- a/src/invert/laplace/impls/multigrid/multigrid_laplace.cxx +++ b/src/invert/laplace/impls/multigrid/multigrid_laplace.cxx @@ -1,5 +1,5 @@ /************************************************************************** - * Perpendicular Laplacian inversion. + * Perpendicular Laplacian inversion. * Using Geometrical Multigrid Solver * * Equation solved is: @@ -9,7 +9,7 @@ * Copyright 2016 K.S. Kang * * Contact: Ben Dudson, bd512@york.ac.uk - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -32,8 +32,11 @@ #if not BOUT_USE_METRIC_3D +#include +#include #include #include +#include #if BOUT_USE_OPENMP #include @@ -107,11 +110,7 @@ LaplaceMultigrid::LaplaceMultigrid(Options* opt, const CELL_LOC loc, Mesh* mesh_ } Nz_global = localmesh->GlobalNz; Nz_local = Nz_global; // No parallelization in z-direction (for now) - // - //else { - // Nz_local = localmesh->zend - localmesh->zstart + 1; // excluding guard cells - // Nz_global = localmesh->GlobalNz - 2*localmesh->zstart; // excluding guard cells - // } + if (mgcount == 0) { output << "Nz=" << Nz_global << "(" << Nz_local << ")" << endl; } @@ -214,7 +213,9 @@ LaplaceMultigrid::LaplaceMultigrid(Options* opt, const CELL_LOC loc, Mesh* mesh_ } BOUT_OMP_SAFE(parallel) BOUT_OMP_SAFE(master) - { output << "Num threads = " << omp_get_num_threads() << endl; } + { + output << "Num threads = " << omp_get_num_threads() << endl; + } } } @@ -384,11 +385,9 @@ FieldPerp LaplaceMultigrid::solve(const FieldPerp& b_in, const FieldPerp& x0) { } if ((pcheck == 3) && (mgcount == 0)) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "test_matF_%d.mat", kMG->rProcI); + std::string outfile = fmt::format("test_matF_{:d}.mat", kMG->rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); int dim = (lxx + 2) * (lzz + 2); fprintf(outf, "dim = %d (%d, %d)\n", dim, lxx, lzz); @@ -411,11 +410,9 @@ FieldPerp LaplaceMultigrid::solve(const FieldPerp& b_in, const FieldPerp& x0) { output << i << "dimension= " << kMG->lnx[i - 1] << "(" << kMG->gnx[i - 1] << ")," << kMG->lnz[i - 1] << endl; - FILE* outf; - char outfile[256]; - sprintf(outfile, "test_matC%1d_%d.mat", i, kMG->rProcI); + std::string outfile = fmt::format("test_matC{:1d}_{:d}.mat", i, kMG->rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); int dim = (kMG->lnx[i - 1] + 2) * (kMG->lnz[i - 1] + 2); fprintf(outf, "dim = %d (%d,%d)\n", dim, kMG->lnx[i - 1], kMG->lnz[i - 1]); diff --git a/src/invert/laplace/impls/multigrid/multigrid_solver.cxx b/src/invert/laplace/impls/multigrid/multigrid_solver.cxx index 5ff616aa54..074e482320 100644 --- a/src/invert/laplace/impls/multigrid/multigrid_solver.cxx +++ b/src/invert/laplace/impls/multigrid/multigrid_solver.cxx @@ -1,5 +1,5 @@ /************************************************************************** - * Perpendicular Laplacian inversion. + * Perpendicular Laplacian inversion. * Using Geometrical Multigrid Solver * * Equation solved is: @@ -9,7 +9,7 @@ * Copyright 2015 K.S. Kang * * Contact: Ben Dudson, bd512@york.ac.uk - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -35,6 +35,9 @@ #include "bout/unused.hxx" #include +#include +#include + Multigrid1DP::Multigrid1DP(int level, int lx, int lz, int gx, int dl, int merge, MPI_Comm comm, int check) : MultigridAlg(level, lx, lz, gx, lz, comm, check) { @@ -195,11 +198,9 @@ void Multigrid1DP::setMultigridC(int UNUSED(plag)) { if (pcheck == 2) { for (int i = level; i >= 0; i--) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "2DP_matC%1d_%d.mat", i, rMG->rProcI); + std::string outfile = fmt::format("2DP_matC{:1d}_{:d}.mat", i, rMG->rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); int dim = (rMG->lnx[i] + 2) * (rMG->lnz[i] + 2); fprintf(outf, "dim = %d (%d, %d)\n", dim, rMG->lnx[i], rMG->lnz[i]); @@ -222,11 +223,9 @@ void Multigrid1DP::setMultigridC(int UNUSED(plag)) { } if (pcheck == 3) { for (int i = level; i >= 0; i--) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "S1D_matC%1d_%d.mat", i, sMG->rProcI); + std::string outfile = fmt::format("S1D_matC{:1d}_{:d}.mat", i, sMG->rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); int dim = (sMG->lnx[i] + 2) * (sMG->lnz[i] + 2); fprintf(outf, "dim = %d\n", dim); @@ -456,11 +455,9 @@ void Multigrid1DP::convertMatrixF2D(int level) { } } if (pcheck == 3) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "2DP_CP_%d.mat", rProcI); + std::string outfile = fmt::format("2DP_CP_{}.mat", rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); fprintf(outf, "dim = (%d, %d)\n", ggx, gnz[0]); for (int ii = 0; ii < dim; ii++) { @@ -476,11 +473,10 @@ void Multigrid1DP::convertMatrixF2D(int level) { MPI_SUM, comm2D); if (pcheck == 3) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "2DP_Conv_%d.mat", rProcI); + + std::string outfile = fmt::format("2DP_Conv_{}.mat", rProcI); output << "Out file= " << outfile << endl; - outf = fopen(outfile, "w"); + FILE* outf = fopen(outfile.c_str(), "w"); fprintf(outf, "dim = (%d, %d)\n", ggx, gnz[0]); for (int ii = 0; ii < dim; ii++) { @@ -625,10 +621,8 @@ void Multigrid2DPf1D::setMultigridC(int UNUSED(plag)) { } if (pcheck == 2) { for (int i = level; i >= 0; i--) { - FILE* outf; - char outfile[256]; - sprintf(outfile, "S2D_matC%1d_%d.mat", i, sMG->rProcI); - outf = fopen(outfile, "w"); + std::string outfile = fmt::format("S2D_matC{:1d}_{:d}.mat", i, sMG->rProcI); + FILE* outf = fopen(outfile.c_str(), "w"); output << "Out file= " << outfile << endl; int dim = (sMG->lnx[i] + 2) * (sMG->lnz[i] + 2); fprintf(outf, "dim = %d\n", dim); diff --git a/src/invert/laplace/impls/pcr/pcr.cxx b/src/invert/laplace/impls/pcr/pcr.cxx index eb786afa2c..325242a3b2 100644 --- a/src/invert/laplace/impls/pcr/pcr.cxx +++ b/src/invert/laplace/impls/pcr/pcr.cxx @@ -45,11 +45,14 @@ #include "../../common_transform.hxx" #include +#include +#include #include #include #include #include #include +#include #include #include #include @@ -124,8 +127,8 @@ LaplacePCR::LaplacePCR(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED // (unless periodic in x) xe = localmesh->LocalNx - 1; } - int n = xe - xs + 1; // Number of X points on this processor, - // including boundaries but not guard cells + const int n = xe - xs + 1; // Number of X points on this processor, + // including boundaries but not guard cells a.reallocate(nmode, n); b.reallocate(nmode, n); @@ -148,7 +151,7 @@ FieldPerp LaplacePCR::solve(const FieldPerp& rhs, const FieldPerp& x0) { FieldPerp x{emptyFrom(rhs)}; // Result - int jy = rhs.getIndex(); // Get the Y index + const int jy = rhs.getIndex(); // Get the Y index x.setIndex(jy); // Get the width of the boundary @@ -215,7 +218,7 @@ Field3D LaplacePCR::solve(const Field3D& rhs, const Field3D& x0) { outbndry = 1; } - int nx = xe - xs + 1; // Number of X points on this processor + const int nx = xe - xs + 1; // Number of X points on this processor // Get range of Y indices int ys = localmesh->ystart; @@ -313,7 +316,7 @@ void LaplacePCR ::cr_pcr_solver(Matrix& a_mpi, Matrix& b_mpi // don't want to copy them. // xs = xstart if a proc has no boundary points // xs = 0 if a proc has boundary points - int offset = localmesh->xstart - xs; + const int offset = localmesh->xstart - xs; aa(kz, ix + 1) = a_mpi(kz, ix + offset); bb(kz, ix + 1) = b_mpi(kz, ix + offset); cc(kz, ix + 1) = c_mpi(kz, ix + offset); @@ -400,7 +403,7 @@ void LaplacePCR ::apply_boundary_conditions(const Matrix& a, } } if (localmesh->lastX()) { - int n = xe - xs + 1; // actual length of array + const int n = xe - xs + 1; // actual length of array for (int kz = 0; kz < nsys; kz++) { for (int ix = n - localmesh->xstart; ix < n; ix++) { x(kz, ix) = (r(kz, ix) - a(kz, ix) * x(kz, ix - 1)) / b(kz, ix); @@ -419,7 +422,7 @@ void LaplacePCR ::cr_forward_multiple_row(Matrix& a, Matrix& Matrix& r) const { const int nsys = std::get<0>(a.shape()); - MPI_Comm comm = BoutComm::get(); + const MPI_Comm comm = BoutComm::get(); Array alpha(nsys); Array gamma(nsys); Array sbuf(4 * nsys); @@ -494,7 +497,7 @@ void LaplacePCR ::cr_backward_multiple_row(Matrix& a, Matrix Matrix& x) const { const int nsys = std::get<0>(a.shape()); - MPI_Comm comm = BoutComm::get(); + const MPI_Comm comm = BoutComm::get(); MPI_Status status; MPI_Request request[2]; @@ -533,7 +536,7 @@ void LaplacePCR ::cr_backward_multiple_row(Matrix& a, Matrix dist_row = dist_row / 2; } if (xproc < nprocs - 1) { - MPI_Wait(request + 1, &status); + MPI_Wait(&request[1], &status); } } @@ -555,12 +558,11 @@ void LaplacePCR ::pcr_forward_single_row(Matrix& a, Matrix& MPI_Status status; Array request(4); - MPI_Comm comm = BoutComm::get(); + const MPI_Comm comm = BoutComm::get(); const int nlevel = log2(nprocs); const int nhprocs = nprocs / 2; int dist_rank = 1; - int dist2_rank = 2; /// Parallel cyclic reduction continues until 2x2 matrix are made between a pair of /// rank, (myrank, myrank+nhprocs). @@ -678,7 +680,6 @@ void LaplacePCR ::pcr_forward_single_row(Matrix& a, Matrix& } dist_rank *= 2; - dist2_rank *= 2; } /// Solving 2x2 matrix. All pair of ranks, myrank and myrank+nhprocs, solves it diff --git a/src/invert/laplace/impls/serial_tri/serial_tri.hxx b/src/invert/laplace/impls/serial_tri/serial_tri.hxx index 4aed777b7c..332c272f3f 100644 --- a/src/invert/laplace/impls/serial_tri/serial_tri.hxx +++ b/src/invert/laplace/impls/serial_tri/serial_tri.hxx @@ -52,7 +52,7 @@ class LaplaceSerialTri : public Laplacian { public: LaplaceSerialTri(Options* opt = nullptr, const CELL_LOC loc = CELL_CENTRE, Mesh* mesh_in = nullptr, Solver* solver = nullptr); - ~LaplaceSerialTri(){}; + ~LaplaceSerialTri() {}; using Laplacian::setCoefA; void setCoefA(const Field2D& val) override { diff --git a/src/invert/laplace/impls/spt/spt.hxx b/src/invert/laplace/impls/spt/spt.hxx index a3b1acb7c3..e951f431e7 100644 --- a/src/invert/laplace/impls/spt/spt.hxx +++ b/src/invert/laplace/impls/spt/spt.hxx @@ -1,16 +1,16 @@ /************************************************************************** - * Perpendicular Laplacian inversion. + * Perpendicular Laplacian inversion. * PARALLEL CODE - SIMPLE ALGORITHM - * + * * I'm just calling this Simple Parallel Tridag. Naive parallelisation of * the serial code. For use as a reference case. - * + * * Overlap calculation / communication of poloidal slices to achieve some * parallelism. * * Changelog * --------- - * + * * 2014-06 Ben Dudson * * Removed static variables in functions, changing to class members. * @@ -18,7 +18,7 @@ * Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu * * Contact: Ben Dudson, bd512@york.ac.uk - * + * * This file is part of BOUT++. * * BOUT++ is free software: you can redistribute it and/or modify @@ -142,8 +142,8 @@ private: Array buffer; }; - int ys, ye; // Range of Y indices - SPT_data slicedata; // Used to solve for a single FieldPerp + int ys, ye; // Range of Y indices + SPT_data slicedata; // Used to solve for a single FieldPerp Array alldata; // Used to solve a Field3D Array dc1d; ///< 1D in Z for taking FFTs diff --git a/src/invert/laplacexy/impls/hypre/laplacexy-hypre.cxx b/src/invert/laplacexy/impls/hypre/laplacexy-hypre.cxx index 439d07a4e5..61632a332d 100644 --- a/src/invert/laplacexy/impls/hypre/laplacexy-hypre.cxx +++ b/src/invert/laplacexy/impls/hypre/laplacexy-hypre.cxx @@ -23,8 +23,10 @@ #include #if BOUT_HAS_CUDA && defined(__CUDACC__) -#define gpuErrchk(ans) \ - { gpuAssert((ans), __FILE__, __LINE__); } +#define gpuErrchk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); diff --git a/src/mesh/coordinates_accessor.cxx b/src/mesh/coordinates_accessor.cxx index 0ce4b664b5..c874ea3ac2 100644 --- a/src/mesh/coordinates_accessor.cxx +++ b/src/mesh/coordinates_accessor.cxx @@ -45,8 +45,10 @@ CoordinatesAccessor::CoordinatesAccessor(const Coordinates* coords) { data[stripe_size * ind.ind + static_cast(Offset::symbol)] = coords->symbol[ind]; // Implement copy for each argument -#define COPY_STRIPE(...) \ - { MACRO_FOR_EACH(COPY_STRIPE1, __VA_ARGS__) } +#define COPY_STRIPE(...) \ + { \ + MACRO_FOR_EACH(COPY_STRIPE1, __VA_ARGS__) \ + } // Iterate over all points in the field // Note this could be 2D or 3D, depending on FieldMetric type diff --git a/src/mesh/data/gridfromfile.cxx b/src/mesh/data/gridfromfile.cxx index 10f6f6926e..ee09d0055f 100644 --- a/src/mesh/data/gridfromfile.cxx +++ b/src/mesh/data/gridfromfile.cxx @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -120,7 +121,7 @@ bool GridFile::get(Mesh* UNUSED(m), BoutReal& rval, const std::string& name, /*! * Reads a 2D, 3D or FieldPerp field variable from a file - * + * * Successfully reads Field2D or FieldPerp if the variable in the file is 0-D or 2-D. * Successfully reads Field3D if the variable in the file is 0-D, 2-D or 3-D. */ @@ -197,7 +198,10 @@ bool GridFile::getField(Mesh* m, T& var, const std::string& name, BoutReal def, ///Ghost region widths. const int mxg = (m->LocalNx - (m->xend - m->xstart + 1)) / 2; const int myg = (m->LocalNy - (m->yend - m->ystart + 1)) / 2; - ///Check that ghost region widths are in fact integers + // Check grid has cells + ASSERT1(m->LocalNx > 0); + ASSERT1(m->LocalNy > 0); + // Check that ghost region widths are in fact integers ASSERT1((m->LocalNx - (m->xend - m->xstart + 1)) % 2 == 0); ASSERT1((m->LocalNy - (m->yend - m->ystart + 1)) % 2 == 0); @@ -378,7 +382,7 @@ void GridFile::readField(Mesh* m, const std::string& name, int ys, int yd, int n for (int x = xs; x < xs + nx_to_read; ++x) { for (int y = ys; y < ys + ny_to_read; ++y) { - BoutReal const value = full_var(x, y); + const BoutReal value = full_var(x, y); for (int z = 0; z < var.getNz(); z++) { var(x - xs + xd, y - ys + yd, z) = value; } diff --git a/src/mesh/impls/bout/boutmesh.cxx b/src/mesh/impls/bout/boutmesh.cxx index e1e34d2f4f..83cd16d76c 100644 --- a/src/mesh/impls/bout/boutmesh.cxx +++ b/src/mesh/impls/bout/boutmesh.cxx @@ -2566,9 +2566,8 @@ bool BoutMesh::periodicY(int jx, BoutReal& ts) const { int BoutMesh::numberOfYBoundaries() const { if (jyseps2_1 != jyseps1_2) { return 2; - } else { - return 1; } + return 1; } std::pair BoutMesh::hasBranchCutLower(int jx) const { diff --git a/src/mesh/parallel/fci_comm.hxx b/src/mesh/parallel/fci_comm.hxx index 324cae8a22..34512b18c4 100644 --- a/src/mesh/parallel/fci_comm.hxx +++ b/src/mesh/parallel/fci_comm.hxx @@ -67,7 +67,7 @@ struct ProcLocal { struct GlobalToLocal1D { GlobalToLocal1D(int mg, int npe, int localwith, bool periodic) : mg(mg), npe(npe), localwith(localwith), local(localwith - (2 * mg)), - global(local * npe), globalwith(global + (2 * mg)), periodic(periodic){}; + global(local * npe), globalwith(global + (2 * mg)), periodic(periodic) {}; ProcLocal convert(int id) const; int getLocalWith() const { return localwith; } int getGlobalWith() const { return globalwith; } @@ -104,7 +104,7 @@ public: const BoutReal& operator[](IndG3D ind) const; GlobalField3DAccessInstance(const GlobalField3DAccess* gfa, std::vector&& data) - : gfa(gfa), data(std::move(data)){}; + : gfa(gfa), data(std::move(data)) {}; private: const GlobalField3DAccess* gfa; diff --git a/src/mesh/parallel/shiftedmetric.cxx b/src/mesh/parallel/shiftedmetric.cxx index 64c6d9a2ce..f167824541 100644 --- a/src/mesh/parallel/shiftedmetric.cxx +++ b/src/mesh/parallel/shiftedmetric.cxx @@ -39,8 +39,8 @@ void ShiftedMetric::checkInputGrid() { "Should be 'shiftedmetric'."); } } // else: parallel_transform variable not found in grid input, indicates older input - // file or grid from options so must rely on the user having ensured the type is - // correct + // file or grid from options so must rely on the user having ensured the type is + // correct } void ShiftedMetric::outputVars(Options& output_options) { diff --git a/src/mesh/parallel/shiftedmetricinterp.cxx b/src/mesh/parallel/shiftedmetricinterp.cxx index c71618ab19..3e187f29bf 100644 --- a/src/mesh/parallel/shiftedmetricinterp.cxx +++ b/src/mesh/parallel/shiftedmetricinterp.cxx @@ -215,7 +215,7 @@ void ShiftedMetricInterp::checkInputGrid() { "Should be 'orthogonal'."); } } // else: coordinate_system variable not found in grid input, indicates older input - // file so must rely on the user having ensured the type is correct + // file so must rely on the user having ensured the type is correct } /*! diff --git a/src/solver/impls/petsc/petsc.cxx b/src/solver/impls/petsc/petsc.cxx index 1e1e05596b..cc884c240c 100644 --- a/src/solver/impls/petsc/petsc.cxx +++ b/src/solver/impls/petsc/petsc.cxx @@ -64,10 +64,10 @@ class ColoringStencil { private: - bool static isInSquare(int const i, int const j, int const n_square) { + bool static isInSquare(const int i, const int j, const int n_square) { return std::abs(i) <= n_square && std::abs(j) <= n_square; } - bool static isInCross(int const i, int const j, int const n_cross) { + bool static isInCross(const int i, const int j, const int n_cross) { if (i == 0) { return std::abs(j) <= n_cross; } @@ -76,7 +76,7 @@ class ColoringStencil { } return false; } - bool static isInTaxi(int const i, int const j, int const n_taxi) { + bool static isInTaxi(const int i, const int j, const int n_taxi) { return std::abs(i) + std::abs(j) <= n_taxi; } @@ -614,7 +614,7 @@ int PetscSolver::init() { n_taxi = 2; } - auto const xy_offsets = ColoringStencil::getOffsets(n_square, n_taxi, n_cross); + const auto xy_offsets = ColoringStencil::getOffsets(n_square, n_taxi, n_cross); { // This is ugly but can't think of a better and robust way to // count the non-zeros for some arbitrary stencil @@ -734,7 +734,7 @@ int PetscSolver::init() { // Mark non-zero entries output_progress.write("Marking non-zero Jacobian entries\n"); - PetscScalar const val = 1.0; + const PetscScalar val = 1.0; for (int x = mesh->xstart; x <= mesh->xend; x++) { for (int y = mesh->ystart; y <= mesh->yend; y++) { @@ -753,21 +753,21 @@ int PetscSolver::init() { continue; } - int const ind2 = ROUND(index(xi, yi, 0)); + const int ind2 = ROUND(index(xi, yi, 0)); if (ind2 < 0) { continue; // A boundary point } // Depends on all variables on this cell for (int j = 0; j < n2d; j++) { - PetscInt const col = ind2 + j; + const PetscInt col = ind2 + j; PetscCall(MatSetValues(Jfd, 1, &row, 1, &col, &val, INSERT_VALUES)); } } } // 3D fields for (int z = mesh->zstart; z <= mesh->zend; z++) { - int const ind = ROUND(index(x, y, z)); + const int ind = ROUND(index(x, y, z)); for (int i = 0; i < n3d; i++) { PetscInt row = ind + i; @@ -777,7 +777,7 @@ int PetscSolver::init() { // Depends on 2D fields for (int j = 0; j < n2d; j++) { - PetscInt const col = ind0 + j; + const PetscInt col = ind0 + j; PetscCall(MatSetValues(Jfd, 1, &row, 1, &col, &val, INSERT_VALUES)); } @@ -802,7 +802,7 @@ int PetscSolver::init() { // 3D fields on this cell for (int j = 0; j < n3d; j++) { - PetscInt const col = ind2 + j; + const PetscInt col = ind2 + j; int ierr = MatSetValues(Jfd, 1, &row, 1, &col, &val, INSERT_VALUES); if (ierr != 0) { diff --git a/src/solver/impls/snes/snes.hxx b/src/solver/impls/snes/snes.hxx index 40412f83b7..10ac4783e7 100644 --- a/src/solver/impls/snes/snes.hxx +++ b/src/solver/impls/snes/snes.hxx @@ -201,8 +201,8 @@ private: BoutReal kI; ///< (0.2 - 0.4) Integral parameter (smooths history of changes) BoutReal kD; ///< (0.1 - 0.3) Derivative (dampens oscillation - optional) bool pid_consider_failures; ///< Reduce timestep increases if recent solves have failed - BoutReal recent_failure_rate; ///< Rolling average of recent failure rate - BoutReal last_failure_weight; ///< 1 / number of recent solves + BoutReal recent_failure_rate; ///< Rolling average of recent failure rate + BoutReal last_failure_weight; ///< 1 / number of recent solves BoutReal nl_its_prev; BoutReal nl_its_prev2; @@ -244,7 +244,7 @@ private: bool matrix_free_operator; ///< Use matrix free Jacobian in the operator? int lag_jacobian; ///< Re-use Jacobian bool jacobian_persists; ///< Re-use Jacobian and preconditioner across nonlinear solves - bool use_coloring; ///< Use matrix coloring + bool use_coloring; ///< Use matrix coloring bool jacobian_recalculated; ///< Flag set when Jacobian is recalculated bool prune_jacobian; ///< Remove small elements in the Jacobian? diff --git a/src/sys/generator_context.cxx b/src/sys/generator_context.cxx index 31a5662378..01090daa51 100644 --- a/src/sys/generator_context.cxx +++ b/src/sys/generator_context.cxx @@ -9,7 +9,7 @@ namespace bout { namespace generator { Context::Context(int ix, int iy, int iz, CELL_LOC loc, Mesh* msh, BoutReal t) - : localmesh(msh) { + : ix_(ix), jy_(iy), kz_(iz), localmesh(msh) { parameters["x"] = (loc == CELL_XLOW) ? 0.5 * (msh->GlobalX(ix) + msh->GlobalX(ix - 1)) : msh->GlobalX(ix); @@ -24,20 +24,17 @@ Context::Context(int ix, int iy, int iz, CELL_LOC loc, Mesh* msh, BoutReal t) } Context::Context(const BoundaryRegion* bndry, int iz, CELL_LOC loc, BoutReal t, Mesh* msh) - : localmesh(msh) { - - // Add one to X index if boundary is in -x direction, so that XLOW is on the boundary - const int ix = (bndry->bx < 0) ? bndry->x + 1 : bndry->x; + : // Add one to X index if boundary is in -x direction, so that XLOW is on the boundary + ix_((bndry->bx < 0) ? bndry->x + 1 : bndry->x), + jy_((bndry->by < 0) ? bndry->y + 1 : bndry->y), kz_(iz), localmesh(msh) { parameters["x"] = ((loc == CELL_XLOW) || (bndry->bx != 0)) - ? 0.5 * (msh->GlobalX(ix) + msh->GlobalX(ix - 1)) - : msh->GlobalX(ix); - - const int iy = (bndry->by < 0) ? bndry->y + 1 : bndry->y; + ? 0.5 * (msh->GlobalX(ix_) + msh->GlobalX(ix_ - 1)) + : msh->GlobalX(ix_); parameters["y"] = ((loc == CELL_YLOW) || (bndry->by != 0)) - ? PI * (msh->GlobalY(iy) + msh->GlobalY(iy - 1)) - : TWOPI * msh->GlobalY(iy); + ? PI * (msh->GlobalY(jy_) + msh->GlobalY(jy_ - 1)) + : TWOPI * msh->GlobalY(jy_); parameters["z"] = (loc == CELL_ZLOW) ? PI * (msh->GlobalZ(iz) + msh->GlobalZ(iz - 1)) : TWOPI * msh->GlobalZ(iz); @@ -45,8 +42,5 @@ Context::Context(const BoundaryRegion* bndry, int iz, CELL_LOC loc, BoutReal t, parameters["t"] = t; } -Context::Context(BoutReal x, BoutReal y, BoutReal z, Mesh* msh, BoutReal t) - : localmesh(msh), parameters{{"x", x}, {"y", y}, {"z", z}, {"t", t}} {} - } // namespace generator } // namespace bout diff --git a/src/sys/msg_stack.cxx b/src/sys/msg_stack.cxx index 3dbd7c2797..bd14a766ed 100644 --- a/src/sys/msg_stack.cxx +++ b/src/sys/msg_stack.cxx @@ -59,7 +59,9 @@ void MsgStack::pop() { return; } BOUT_OMP_SAFE(single) - { --position; } + { + --position; + } } void MsgStack::pop(int id) { @@ -87,7 +89,9 @@ void MsgStack::clear() { void MsgStack::dump() { BOUT_OMP_SAFE(single) - { output << this->getDump(); } + { + output << this->getDump(); + } } std::string MsgStack::getDump() { diff --git a/src/sys/options/options_netcdf.cxx b/src/sys/options/options_netcdf.cxx index f21f461c09..9170f7c284 100644 --- a/src/sys/options/options_netcdf.cxx +++ b/src/sys/options/options_netcdf.cxx @@ -132,25 +132,27 @@ void readGroup(const std::string& filename, const NcGroup& group, Options& resul {s2i(dims[0].getSize()), s2i(dims[1].getSize()), s2i(dims[2].getSize())}); // We need to explicitly copy file, so that there is a pointer to the file, and // the file does not get closed, which would prevent us from reading. - result[var_name].setLazyLoad(std::make_unique( - int, int, int, int, int, int)>>( - [file, var](int xstart, int xend, int ystart, int yend, int zstart, - int zend) { - const auto i2s = [](int i) { - if (i < 0) { - throw BoutException("BadCast {} < 0", i); - } - return static_cast(i); - }; - Tensor value(xend - xstart + 1, yend - ystart + 1, - zend - zstart + 1); - const std::vector index{i2s(xstart), i2s(ystart), i2s(zstart)}; - const std::vector count{i2s(xend - xstart + 1), - i2s(yend - ystart + 1), - i2s(zend - zstart + 1)}; - var.getVar(index, count, value.begin()); - return value; - })); + result[var_name].setLazyLoad( + std::make_unique< + std::function(int, int, int, int, int, int)>>( + [file, var](int xstart, int xend, int ystart, int yend, int zstart, + int zend) { + const auto i2s = [](int i) { + if (i < 0) { + throw BoutException("BadCast {} < 0", i); + } + return static_cast(i); + }; + Tensor value(xend - xstart + 1, yend - ystart + 1, + zend - zstart + 1); + const std::vector index{i2s(xstart), i2s(ystart), + i2s(zstart)}; + const std::vector count{i2s(xend - xstart + 1), + i2s(yend - ystart + 1), + i2s(zend - zstart + 1)}; + var.getVar(index, count, value.begin()); + return value; + })); } else { Tensor value(static_cast(dims[0].getSize()), static_cast(dims[1].getSize()), diff --git a/tests/unit/fake_mesh.hxx b/tests/unit/fake_mesh.hxx index 6dbbd6200b..bf749a80be 100644 --- a/tests/unit/fake_mesh.hxx +++ b/tests/unit/fake_mesh.hxx @@ -149,6 +149,7 @@ public: return jx < ix_separatrix; } int numberOfYBoundaries() const override { return 1; } + std::pair hasBranchCutLower(int UNUSED(jx)) const override { return std::make_pair(false, 0.); } diff --git a/tests/unit/field/test_field_factory.cxx b/tests/unit/field/test_field_factory.cxx index b45206b979..dfe4c5858a 100644 --- a/tests/unit/field/test_field_factory.cxx +++ b/tests/unit/field/test_field_factory.cxx @@ -1,17 +1,29 @@ #include "gtest/gtest.h" +#include "fake_mesh.hxx" #include "test_extras.hxx" +#include "bout/bout_types.hxx" #include "bout/boutexception.hxx" #include "bout/constants.hxx" +#include "bout/coordinates.hxx" #include "bout/field2d.hxx" #include "bout/field3d.hxx" #include "bout/field_factory.hxx" +#include "bout/globals.hxx" #include "bout/mesh.hxx" +#include "bout/options_io.hxx" #include "bout/output.hxx" #include "bout/paralleltransform.hxx" +#include "bout/sys/expressionparser.hxx" +#include "bout/sys/generator_context.hxx" #include "bout/traits.hxx" +#include "bout/utils.hxx" #include "fake_mesh_fixture.hxx" +#include "test_tmpfiles.hxx" + +#include +#include // The unit tests use the global mesh using namespace bout::globals; @@ -770,7 +782,7 @@ TEST_F(FieldFactoryTest, Recursion) { opt["input"]["max_recursion_depth"] = 4; // Should be sufficient for n=6 // Create a factory with a max_recursion_depth != 0 - FieldFactory factory_rec(nullptr, &opt); + const FieldFactory factory_rec(nullptr, &opt); // Fibonacci sequence: 1 1 2 3 5 8 opt["fib"] = "where({n} - 2.5, [n={n}-1](fib) + [n={n}-2](fib), 1)"; @@ -802,7 +814,7 @@ TEST_F(FieldFactoryTest, ResolveLocalOptions) { options["f"] = "2 + 2"; options["g"] = "f * f"; - FieldFactoryExposer factory_local(mesh, &options); + const FieldFactoryExposer factory_local(mesh, &options); auto g = factory_local.resolve("g"); EXPECT_EQ(g->generate({}), 16); @@ -881,7 +893,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create2D) { mesh->getCoordinates()->setParallelTransform( bout::utils::make_unique(*mesh, true)); - FieldFactory factory; + const FieldFactory factory; auto output = factory.create2D("x"); @@ -896,7 +908,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create3D) { mesh->getCoordinates()->setParallelTransform( bout::utils::make_unique(*mesh, true)); - FieldFactory factory; + const FieldFactory factory; auto output = factory.create3D("x"); @@ -912,7 +924,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create2DNoTransform) { Options options; options["input"]["transform_from_field_aligned"] = false; - FieldFactory factory{mesh, &options}; + const FieldFactory factory{mesh, &options}; auto output = factory.create2D("x"); @@ -929,7 +941,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create3DNoTransform) { Options options; options["input"]["transform_from_field_aligned"] = false; - FieldFactory factory{mesh, &options}; + const FieldFactory factory{mesh, &options}; auto output = factory.create3D("x"); @@ -943,7 +955,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create2DCantTransform) { mesh->getCoordinates()->setParallelTransform( bout::utils::make_unique(*mesh, false)); - FieldFactory factory{mesh}; + const FieldFactory factory{mesh}; auto output = factory.create2D("x"); @@ -958,7 +970,7 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create3DCantTransform) { mesh->getCoordinates()->setParallelTransform( bout::utils::make_unique(*mesh, false)); - FieldFactory factory{mesh}; + const FieldFactory factory{mesh}; auto output = factory.create3D("x"); @@ -968,6 +980,121 @@ TEST_F(FieldFactoryCreateAndTransformTest, Create3DCantTransform) { EXPECT_TRUE(IsFieldEqual(output, expected)); } +struct FieldFactoryFieldVariableTest : public FakeMeshFixture { + WithQuietOutput quiet{output_info}; +}; + +TEST_F(FieldFactoryFieldVariableTest, CreateField3D) { + const bout::testing::TempFile filename; + + { + // Write some fields to a grid file + const FieldFactory factory{mesh}; + const auto rho = factory.create3D("sqrt(x^2 + y^2)"); + const auto theta = factory.create3D("atan(y, x)"); + const Options grid{{"rho", rho}, {"theta", theta}, + {"nx", mesh->LocalNx}, {"ny", mesh->LocalNy - 2}, + {"nz", mesh->LocalNz}, {"y_boundary_guards", 1}}; + bout::OptionsIO::create(filename)->write(grid); + } + + { + Options options{ + {"mesh", {{"file", filename.string()}}}, + {"input", {{"grid_variables", {{"rho", "field3d"}, {"theta", "field3d"}}}}}}; + + dynamic_cast(mesh)->setGridDataSource(new GridFile{filename}); + auto factory = FieldFactory{mesh, &options}; + + const auto output = factory.create3D("rho * cos(theta)"); + const auto x = factory.create3D("x"); + EXPECT_TRUE(IsFieldEqual(output, x, "RGN_NOBNDRY", 1e-14)); + } +} + +TEST_F(FieldFactoryFieldVariableTest, CreateField2D) { + const bout::testing::TempFile filename; + + { + // Write some fields to a grid file + const FieldFactory factory{mesh}; + const auto rho = factory.create2D("sqrt(x^2 + y^2)"); + const auto theta = factory.create2D("atan(y, x)"); + const Options grid{{"rho", rho}, {"theta", theta}, + {"nx", mesh->LocalNx}, {"ny", mesh->LocalNy - 2}, + {"nz", mesh->LocalNz}, {"y_boundary_guards", 1}}; + bout::OptionsIO::create(filename)->write(grid); + } + + { + Options options{ + {"mesh", {{"file", filename.string()}}}, + {"input", {{"grid_variables", {{"rho", "field2d"}, {"theta", "field2d"}}}}}}; + + dynamic_cast(mesh)->setGridDataSource(new GridFile{filename}); + auto factory = FieldFactory{mesh, &options}; + + const auto output = factory.create2D("rho * cos(theta)"); + const auto x = factory.create2D("x"); + EXPECT_TRUE(IsFieldEqual(output, x, "RGN_ALL", 1e-14)); + } +} + +TEST_F(FieldFactoryFieldVariableTest, ReadBoutReal) { + const bout::testing::TempFile filename; + + { + const Options grid{{"rho", 4}, + {"theta", 5}, + {"nx", mesh->LocalNx}, + {"ny", mesh->LocalNy}, + {"nz", mesh->LocalNz}}; + bout::OptionsIO::create(filename)->write(grid); + } + + { + Options options{ + {"mesh", {{"file", filename.string()}}}, + {"input", {{"grid_variables", {{"rho", "boutreal"}, {"theta", "boutreal"}}}}}}; + + dynamic_cast(mesh)->setGridDataSource(new GridFile{filename}); + auto factory = FieldFactory{mesh, &options}; + + const auto output = factory.create3D("rho * theta"); + EXPECT_TRUE(IsFieldEqual(output, 4 * 5)); + } +} + +TEST_F(FieldFactoryFieldVariableTest, NoMeshFile) { + Options options{{"input", {{"grid_variables", {{"rho", "field3d"}}}}}}; + + EXPECT_THROW((FieldFactory(mesh, &options)), BoutException); +} + +TEST_F(FieldFactoryFieldVariableTest, MissingVariable) { + const bout::testing::TempFile filename; + + { + // Write some fields to a grid file + const FieldFactory factory{mesh}; + const auto rho = factory.create3D("sqrt(x^2 + y^2)"); + const Options grid{{"rho", rho}, + {"nx", mesh->LocalNx}, + {"ny", mesh->LocalNy}, + {"nz", mesh->LocalNz}}; + bout::OptionsIO::create(filename)->write(grid); + } + + { + Options options{ + {"mesh", {{"file", filename.string()}}}, + {"input", {{"grid_variables", {{"rho", "field3d"}, {"theta", "field3d"}}}}}}; + + dynamic_cast(mesh)->setGridDataSource(new GridFile{filename}); + EXPECT_THROW((FieldFactory{mesh, &options}), BoutException); + } +} + TYPED_TEST(FieldFactoryCreationTest, CreatePeriodicY) { auto output = this->create("is_periodic_y"); diff --git a/tests/unit/test_tmpfiles.hxx b/tests/unit/test_tmpfiles.hxx index 6c37e7c792..db7387bb45 100644 --- a/tests/unit/test_tmpfiles.hxx +++ b/tests/unit/test_tmpfiles.hxx @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -30,7 +31,11 @@ public: TempFile& operator=(const TempFile&) = delete; TempFile& operator=(TempFile&&) = delete; - ~TempFile() { std::filesystem::remove_all(filename); } + ~TempFile() { + if (std::uncaught_exceptions() <= 0) { + std::filesystem::remove_all(filename); + } + } // Enable conversions to std::string / const char* operator std::string() const { return filename.string(); }