diff --git a/source/source_pw/module_stodft/sto_wf.cpp b/source/source_pw/module_stodft/sto_wf.cpp index 2de8a8c28c..ec8d8ea022 100644 --- a/source/source_pw/module_stodft/sto_wf.cpp +++ b/source/source_pw/module_stodft/sto_wf.cpp @@ -7,6 +7,8 @@ #include #include +#include + #include "source_base/global_function.h" template @@ -19,7 +21,7 @@ Stochastic_WF::~Stochastic_WF() { delete chi0_cpu; Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { delete chi0; } @@ -60,18 +62,22 @@ void Stochastic_WF::clean_chiallorder() delete[] chiallorder; chiallorder = nullptr; } + template void Stochastic_WF::init_sto_orbitals(const int seed_in) { - const unsigned int rank_seed_offset = 10000; + unsigned int final_seed; if (seed_in == 0 || seed_in == -1) { - srand(static_cast(time(nullptr)) + GlobalV::MY_RANK * rank_seed_offset); // GlobalV global variables are reserved + final_seed = (unsigned)time(nullptr) + GlobalV::MY_RANK * 10000; } else { - srand(static_cast(std::abs(seed_in)) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * rank_seed_offset); + final_seed = (unsigned)std::abs(seed_in) + (GlobalV::MY_BNDGROUP * GlobalV::NPROC_IN_BNDGROUP + GlobalV::RANK_IN_BPGROUP) * 10000; } + + // initialize the random number generator with the final seed + this->rng.seed(final_seed); this->allocate_chi0(); this->update_sto_orbitals(seed_in); @@ -119,7 +125,7 @@ void Stochastic_WF::allocate_chi0() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -134,11 +140,17 @@ void Stochastic_WF::update_sto_orbitals(const int seed_in) { const int nchi = PARAM.inp.nbands_sto; this->chi0_cpu->fix_k(0); + + // Uniform distribution to generate random phases between 0 and 2*pi + std::uniform_real_distribution dist_phi(0.0, 2.0 * ModuleBase::PI); + // Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability + std::bernoulli_distribution dist_coin(0.5); + if (seed_in >= 0) { for (int i = 0; i < this->chi0_cpu->size(); ++i) { - const double phi = 2 * ModuleBase::PI * rand() / double(RAND_MAX); + const double phi = dist_phi(this->rng); this->chi0_cpu->get_pointer()[i] = std::complex(cos(phi), sin(phi)) / sqrt(double(nchi)); } } @@ -146,7 +158,8 @@ void Stochastic_WF::update_sto_orbitals(const int seed_in) { for (int i = 0; i < this->chi0_cpu->size(); ++i) { - if (rand() / double(RAND_MAX) < 0.5) + // use Bernoulli distribution to generate +1/sqrt(nchi) or -1/sqrt(nchi) with equal probability + if (dist_coin(this->rng)) { this->chi0_cpu->get_pointer()[i] = -1.0 / sqrt(double(nchi)); } diff --git a/source/source_pw/module_stodft/sto_wf.h b/source/source_pw/module_stodft/sto_wf.h index 5146ff7f56..c008ed91e8 100644 --- a/source/source_pw/module_stodft/sto_wf.h +++ b/source/source_pw/module_stodft/sto_wf.h @@ -59,6 +59,10 @@ class Stochastic_WF void init_com_orbitals(); // sync chi0 from CPU to GPU void sync_chi0(); + + private: + // random number generator + std::mt19937 rng; protected: using setmem_complex_op = base_device::memory::set_memory_op;