diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 62f3054a..b1766bbd 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -22,8 +22,20 @@ namespace infini_train::nn::init { namespace { -static std::random_device rd; -static std::mt19937 gen(rd()); +constexpr int kRandomSeed = 42; + +// FIXME: RNG design is incomplete. +// +// Current implementation lacks: +// - unified Generator abstraction +// - global default generator and seed control +// - reproducible / clonable RNG state +// +// TODO: +// - introduce Generator interface and backend impl +// - add default generator management (per device) +// - refactor random ops to consume Generator +static std::mt19937 gen(kRandomSeed); } // namespace std::shared_ptr Normal(const std::shared_ptr &tensor, float mean, float std, @@ -34,7 +46,7 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean #ifdef USE_OMP #pragma omp parallel { - std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num()); + std::mt19937 local_gen(kRandomSeed + omp_get_thread_num()); std::normal_distribution local_dis(mean, std); #pragma omp for for (int i = 0; i < buffer.size(); ++i) { @@ -126,7 +138,7 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, #ifdef USE_OMP #pragma omp parallel { - std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num()); + std::mt19937 local_gen(kRandomSeed + omp_get_thread_num()); std::uniform_real_distribution local_dis(a, b); #pragma omp for for (int i = 0; i < buffer.size(); ++i) {