Skip to content

Commit f781fac

Browse files
mohanchenabacus_fixer
andauthored
refactor esolver_ks_pw (deepmodeling#7008)
* refactor(esolver): extract update_cell_pw function from esolver_ks_pw - Create new files update_cell_pw.h and update_cell_pw.cpp in source_pw/module_pwdft - Extract cell parameter update logic from ESolver_KS_PW::before_scf() - The new function handles: 1. Rescaling non-local pseudopotential (ppcell.rescale_vnl) 2. Reinitializing plane wave basis grids (pw_wfc->initgrids/initparameters/collect_local_pw) - Keep psi initialization (p_psi_init->prepare_init) in esolver to avoid template dependency - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving PW-specific cell update logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. * refactor(esolver): extract EXX initialization into Exx_Helper::init - Add init() function to Exx_Helper class for EXX initialization - The init function handles: 1. Check if calculation type is scf/relax/cell-relax/md 2. Check if cal_exx is enabled 3. Set XC first loop if separate_loop is true 4. Set wg pointer for EXX calculation - Simplify ESolver_KS_PW::before_all_runners() by calling exx_helper.init() - Move EXX-specific logic out of esolver, improving code organization This refactoring makes the esolver code cleaner and more focused on high-level workflow control. * refactor(esolver): extract DFT+U initialization into pw::iter_init_dftu_pw - Create new files dftu_pw.h and dftu_pw.cpp in source_pw/module_pwdft - Extract DFT+U occupation update logic from ESolver_KS_PW::iter_init() - The new function handles: 1. Check if DFT+U is enabled 2. Check iteration and step conditions 3. Call cal_occ_pw for occupation calculation 4. Output DFT+U results - Use void* for psi parameter to avoid template dependency - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving DFT+U specific logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. * refactor(esolver): extract DeltaSpin lambda loop into pw::run_deltaspin_lambda_loop - Create new files deltaspin_pw.h and deltaspin_pw.cpp in source_pw/module_pwdft - Extract DeltaSpin lambda loop logic from ESolver_KS_PW::hamilt2rho_single() - The new function handles: 1. Check if DeltaSpin (sc_mag_switch) is enabled 2. Get SpinConstrain singleton instance 3. Run lambda loop to constrain atomic magnetic moments 4. Return skip_solve flag to control solver execution - Add Doxygen-style comments in English - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving DeltaSpin-specific logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. * refactor(esolver): extract DeltaSpin oscillation check into pw::check_deltaspin_oscillation - Add check_deltaspin_oscillation() function to deltaspin_pw.h/cpp - Extract DeltaSpin SCF oscillation check logic from ESolver_KS_PW::iter_finish() - The new function handles: 1. Check if DeltaSpin (sc_mag_switch) is enabled 2. Get SpinConstrain singleton instance 3. Detect SCF oscillation using if_scf_oscillate() 4. Set mixing_restart_step if oscillation detected - Add Doxygen-style comments in English This refactoring consolidates all DeltaSpin-related functions in one place, making the code more modular and easier to maintain. * refactor(esolver): extract EXX before_scf setup into Exx_Helper::before_scf - Add before_scf() function to Exx_Helper class - Extract EXX setup logic from ESolver_KS_PW::before_scf() - The new function handles: 1. Check if calculation type is valid (scf/relax/cell-relax/md) 2. Check if EXX is enabled and basis type is PW 3. Set EXX helper to Hamiltonian 4. Set psi for EXX calculation - Use void* for p_hamilt parameter to avoid circular dependency - Add Doxygen-style comments in English This refactoring consolidates EXX-related setup logic in the Exx_Helper class, making the code more modular and easier to maintain. * refactor(esolver): extract EXX iter_finish logic into Exx_Helper::iter_finish - Add iter_finish() function to Exx_Helper class - Extract EXX convergence handling logic from ESolver_KS_PW::iter_finish() - The new function handles: 1. Check if EXX is enabled 2. Handle separate_loop mode for EXX convergence 3. Calculate EXX energy difference for energy threshold 4. Update potential if SCF not converged 5. Increment EXX iteration counter - Use Charge* and void* parameters to avoid circular dependency - Add Doxygen-style comments in English This refactoring consolidates all EXX-related functions in the Exx_Helper class, making the code more modular and easier to maintain. --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent 6de8e35 commit f781fac

11 files changed

Lines changed: 394 additions & 121 deletions

File tree

source/Makefile.Objects

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,9 @@ OBJS_SRCPW=H_Ewald_pw.o\
705705
setup_pot.o\
706706
setup_pwrho.o\
707707
setup_pwwfc.o\
708+
update_cell_pw.o\
709+
dftu_pw.o\
710+
deltaspin_pw.o\
708711
forces.o\
709712
forces_us.o\
710713
forces_nl.o\

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 13 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#include "source_io/module_ctrl/ctrl_output_pw.h" // mohan add 20250927
2929
#include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006
3030
#include "source_estate/update_pot.h" // mohan add 20251016
31+
#include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309
32+
#include "source_pw/module_pwdft/dftu_pw.h" // mohan add 20250309
33+
#include "source_pw/module_pwdft/deltaspin_pw.h" // mohan add 20250309
3134

3235
#include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info
3336

@@ -93,20 +96,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
9396
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");
9497

9598
//! Initialize exx pw
96-
if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax"
97-
|| inp.calculation == "md")
98-
{
99-
if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true)
100-
{
101-
XC_Functional::set_xc_first_loop(ucell);
102-
exx_helper.set_firstiter();
103-
}
104-
105-
if (GlobalC::exx_info.info_global.cal_exx)
106-
{
107-
exx_helper.set_wg(&this->pelec->wg);
108-
}
109-
}
99+
this->exx_helper.init(ucell, inp, this->pelec->wg);
110100
}
111101

112102
template <typename T, typename Device>
@@ -119,17 +109,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
119109
ESolver_KS<T, Device>::before_scf(ucell, istep);
120110

121111
//! Init variables (once the cell has changed)
112+
pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp);
113+
122114
if (ucell.cell_parameter_updated)
123115
{
124-
this->ppcell.rescale_vnl(ucell.omega);
125-
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");
126-
127-
this->pw_wfc->initgrids(ucell.lat0, ucell.latvec, this->pw_wfc->nx, this->pw_wfc->ny, this->pw_wfc->nz);
128-
129-
this->pw_wfc->initparameters(false, PARAM.inp.ecutwfc, this->kv.get_nks(), this->kv.kvec_d.data());
130-
131-
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
132-
133116
this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
134117
}
135118

@@ -151,17 +134,8 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
151134
// setup psi (electronic wave functions)
152135
this->stp.init(this->p_hamilt);
153136

154-
//! Exx calculations
155-
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
156-
|| PARAM.inp.calculation == "cell-relax" || PARAM.inp.calculation == "md")
157-
{
158-
if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw")
159-
{
160-
auto hamilt_pw = reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
161-
hamilt_pw->set_exx_helper(exx_helper);
162-
exx_helper.set_psi(this->stp.psi_t);
163-
}
164-
}
137+
//! Setup EXX helper for Hamiltonian and psi
138+
exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp);
165139

166140
ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");
167141
}
@@ -181,16 +155,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
181155

182156
// 4) update local occupations for DFT+U
183157
// should before lambda loop in DeltaSpin
184-
if (PARAM.inp.dft_plus_u && (iter != 1 || istep != 0))
185-
{
186-
// only old DFT+U method should calculate energy correction in esolver,
187-
// new DFT+U method will calculate energy when evaluating the Hamiltonian
188-
if (this->dftu.omc != 2)
189-
{
190-
this->dftu.cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
191-
}
192-
this->dftu.output(ucell);
193-
}
158+
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp);
194159
}
195160

196161
// Temporary, it should be replaced by hsolver later.
@@ -218,26 +183,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
218183
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
219184

220185
// run the inner lambda loop to contrain atomic moments with the DeltaSpin method
221-
bool skip_solve = false;
222-
223-
if (PARAM.inp.sc_mag_switch)
224-
{
225-
spinconstrain::SpinConstrain<std::complex<double>>& sc
226-
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
227-
if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr)
228-
{
229-
// optimize lambda to get target magnetic moments, but the lambda is not near target
230-
sc.run_lambda_loop(iter - 1);
231-
sc.set_mag_converged(true);
232-
skip_solve = true;
233-
}
234-
else if (sc.mag_converged())
235-
{
236-
// optimize lambda to get target magnetic moments, but the lambda is not near target
237-
sc.run_lambda_loop(iter - 1);
238-
skip_solve = true;
239-
}
240-
}
186+
bool skip_solve = pw::run_deltaspin_lambda_loop(iter - 1, this->drho, PARAM.inp);
241187

242188
if (!skip_solve)
243189
{
@@ -293,65 +239,11 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
293239
this->ppcell.cal_effective_D(veff, this->pw_rhod, ucell);
294240
}
295241

296-
// Related to EXX
297-
if (GlobalC::exx_info.info_global.cal_exx)
298-
{
299-
if (GlobalC::exx_info.info_global.separate_loop)
300-
{
301-
if (conv_esolver)
302-
{
303-
auto start = std::chrono::high_resolution_clock::now();
304-
exx_helper.set_firstiter(false);
305-
exx_helper.op_exx->first_iter = false;
306-
double dexx = 0.0;
307-
if (PARAM.inp.exx_thr_type == "energy")
308-
{
309-
dexx = exx_helper.cal_exx_energy(this->stp.psi_t);
310-
exx_helper.set_psi(this->stp.psi_t);
311-
dexx -= exx_helper.cal_exx_energy(this->stp.psi_t);
312-
// std::cout << "dexx = " << dexx << std::endl;
313-
}
314-
bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr;
315-
316-
conv_esolver = exx_helper.exx_after_converge(iter, conv_ene);
317-
if (!conv_esolver)
318-
{
319-
if (PARAM.inp.exx_thr_type != "energy")
320-
{
321-
exx_helper.set_psi(this->stp.psi_t);
322-
}
323-
auto duration = std::chrono::high_resolution_clock::now() - start;
324-
std::cout << " Setting Psi for EXX PW Inner Loop took "
325-
<< std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() / 1000.0 << "s"
326-
<< std::endl;
327-
exx_helper.op_exx->first_iter = false;
328-
XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func);
329-
elecstate::update_pot(ucell, this->pelec, this->chr, conv_esolver);
330-
exx_helper.iter_inc();
331-
}
332-
}
333-
}
334-
else
335-
{
336-
exx_helper.set_psi(this->stp.psi_t);
337-
}
338-
}
242+
// Handle EXX-related operations after SCF iteration
243+
exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter);
339244

340245
// check if oscillate for delta_spin method
341-
if (PARAM.inp.sc_mag_switch)
342-
{
343-
spinconstrain::SpinConstrain<std::complex<double>>& sc
344-
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
345-
if (!sc.higher_mag_prec)
346-
{
347-
sc.higher_mag_prec = this->p_chgmix->if_scf_oscillate(iter,
348-
this->drho, PARAM.inp.sc_os_ndim, PARAM.inp.scf_os_thr);
349-
if (sc.higher_mag_prec)
350-
{ // if oscillate, increase the precision of magnetization and do mixing_restart in next iteration
351-
this->p_chgmix->mixing_restart_step = iter + 1;
352-
}
353-
}
354-
}
246+
pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp);
355247

356248
// the output quantities
357249
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu,

source/source_pw/module_pwdft/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ list(APPEND objects
1515
setup_pot.cpp
1616
setup_pwrho.cpp
1717
setup_pwwfc.cpp
18+
update_cell_pw.cpp
19+
dftu_pw.cpp
20+
deltaspin_pw.cpp
1821
forces_nl.cpp
1922
forces_cc.cpp
2023
forces_scc.cpp
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "source_pw/module_pwdft/deltaspin_pw.h"
2+
#include "source_lcao/module_deltaspin/spin_constrain.h"
3+
#include "source_estate/module_charge/charge_mixing.h"
4+
5+
namespace pw
6+
{
7+
8+
bool run_deltaspin_lambda_loop(const int iter,
9+
const double drho,
10+
const Input_para& inp)
11+
{
12+
/// Return false if DeltaSpin is not enabled
13+
if (!inp.sc_mag_switch)
14+
{
15+
return false;
16+
}
17+
18+
/// Get the singleton instance of SpinConstrain
19+
spinconstrain::SpinConstrain<std::complex<double>>& sc
20+
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
21+
22+
/// Case 1: Magnetic moments not yet converged and SCF is close to convergence.
23+
/// This is the first time we enter the lambda loop after SCF is nearly converged.
24+
if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr)
25+
{
26+
/// Optimize lambda to get target magnetic moments
27+
sc.run_lambda_loop(iter);
28+
sc.set_mag_converged(true);
29+
return true;
30+
}
31+
/// Case 2: Magnetic moments already converged in previous iteration.
32+
/// Continue to refine lambda in subsequent SCF iterations.
33+
else if (sc.mag_converged())
34+
{
35+
sc.run_lambda_loop(iter);
36+
return true;
37+
}
38+
39+
/// Default: run the normal solver
40+
return false;
41+
}
42+
43+
void check_deltaspin_oscillation(const int iter,
44+
const double drho,
45+
Charge_Mixing* p_chgmix,
46+
const Input_para& inp)
47+
{
48+
/// Return if DeltaSpin is not enabled
49+
if (!inp.sc_mag_switch)
50+
{
51+
return;
52+
}
53+
54+
/// Get the singleton instance of SpinConstrain
55+
spinconstrain::SpinConstrain<std::complex<double>>& sc
56+
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
57+
58+
/// Check if higher magnetization precision is needed
59+
if (!sc.higher_mag_prec)
60+
{
61+
/// Detect SCF oscillation
62+
sc.higher_mag_prec = p_chgmix->if_scf_oscillate(iter, drho, inp.sc_os_ndim, inp.scf_os_thr);
63+
64+
/// If oscillation detected, set mixing restart step for next iteration
65+
if (sc.higher_mag_prec)
66+
{
67+
p_chgmix->mixing_restart_step = iter + 1;
68+
}
69+
}
70+
}
71+
72+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef DELTASPIN_PW_H
2+
#define DELTASPIN_PW_H
3+
4+
#include "source_io/module_parameter/parameter.h"
5+
6+
class Charge_Mixing;
7+
8+
namespace pw
9+
{
10+
11+
/**
12+
* @brief Run the inner lambda loop for DeltaSpin method to constrain atomic magnetic moments.
13+
*
14+
* This function is used in the PW basis SCF iteration to optimize lambda parameters
15+
* for constraining atomic magnetic moments to target values using the DeltaSpin method.
16+
*
17+
* @param iter The current iteration number (0-indexed).
18+
* @param drho The current charge density difference.
19+
* @param inp The input parameters.
20+
* @return true if the solver should be skipped (lambda loop was executed),
21+
* false otherwise.
22+
*/
23+
bool run_deltaspin_lambda_loop(const int iter,
24+
const double drho,
25+
const Input_para& inp);
26+
27+
/**
28+
* @brief Check if SCF oscillation occurs for DeltaSpin method.
29+
*
30+
* This function checks if the SCF iteration is oscillating and sets the
31+
* mixing restart step if oscillation is detected. This is used to increase
32+
* the precision of magnetization calculation.
33+
*
34+
* @param iter The current iteration number (1-indexed).
35+
* @param drho The current charge density difference.
36+
* @param p_chgmix Pointer to the Charge_Mixing object.
37+
* @param inp The input parameters.
38+
*/
39+
void check_deltaspin_oscillation(const int iter,
40+
const double drho,
41+
Charge_Mixing* p_chgmix,
42+
const Input_para& inp);
43+
44+
}
45+
46+
#endif
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "source_pw/module_pwdft/dftu_pw.h"
2+
#include "source_lcao/module_dftu/dftu.h"
3+
4+
namespace pw
5+
{
6+
7+
void iter_init_dftu_pw(const int iter,
8+
const int istep,
9+
Plus_U& dftu,
10+
const void* psi,
11+
const ModuleBase::matrix& wg,
12+
const UnitCell& ucell,
13+
const Input_para& inp)
14+
{
15+
if (!inp.dft_plus_u)
16+
{
17+
return;
18+
}
19+
20+
if (iter == 1 && istep == 0)
21+
{
22+
return;
23+
}
24+
25+
if (dftu.omc != 2)
26+
{
27+
dftu.cal_occ_pw(iter, psi, wg, ucell, inp.mixing_beta);
28+
}
29+
dftu.output(ucell);
30+
}
31+
32+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef DFTU_PW_H
2+
#define DFTU_PW_H
3+
4+
#include "source_io/module_parameter/parameter.h"
5+
#include "source_cell/unitcell.h"
6+
#include "source_base/matrix.h"
7+
8+
class Plus_U;
9+
10+
namespace pw
11+
{
12+
13+
void iter_init_dftu_pw(const int iter,
14+
const int istep,
15+
Plus_U& dftu,
16+
const void* psi,
17+
const ModuleBase::matrix& wg,
18+
const UnitCell& ucell,
19+
const Input_para& inp);
20+
21+
}
22+
23+
#endif

0 commit comments

Comments
 (0)