Skip to content

Commit 5489eac

Browse files
committed
Implement combination of MPI parallelization and unit cell threading
1 parent fe9d514 commit 5489eac

7 files changed

Lines changed: 51 additions & 34 deletions

File tree

examples/J1J2_mpi.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using MPIHelper
88
using MPI
99
MPSKit.Defaults.set_scheduler!(:serial)
1010

11-
MPI.Init()
11+
MPI.Init(;threadlevel=:multiple)
1212
mpi_rank() = MPI.Comm_rank(MPI.COMM_WORLD)
1313
mpi_size() = MPI.Comm_size(MPI.COMM_WORLD)
1414

@@ -35,13 +35,20 @@ else
3535
end
3636
H_mpi = MPIOperator(H_mpi)
3737

38-
3938
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes.")
4039

4140
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, verbosity=1); ## This tests VUMPS and GradientGrassmann
4241

4342
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
4443

44+
MPSKit.Defaults.set_scheduler!(:dynamic)
45+
46+
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, verbosity=1); ## This tests VUMPS and GradientGrassmann with unit cell parallelization
47+
48+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
49+
50+
51+
4552
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(state, H_mpi, IDMRG2(; maxiter = 20, tol = 1.0e-12, verbosity=1, trscheme=truncrank(50)));
4653

4754
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")

src/MPIOperator/derivatives.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
# This does not work, as this is not a more specific type than the generic constructors in MPSKit, so it will never be executed:
2-
32
# @forward_3_1_astype MPIOperator.parent MPSKit.C_hamiltonian, MPSKit.AC_hamiltonian, MPSKit.AC2_hamiltonian, MPSKit.C_projection, MPSKit.AC_projection, MPSKit.AC2_projection
43

5-
function MPSKit.C_hamiltonian(site::Int, below, operator::MPIOperator, above, envs)
6-
return MPIOperator(MPSKit.C_hamiltonian(site, below, parent(operator), above, envs))
4+
# In case the scheduler is parallel, we create new communicators for each site to separate the communication. Therefore, the MPIOperator has a Vector{MPI.Comm}, which we parse to the MPO derivatives here.
5+
# In that case, it is crucial that MPI.ThreadLevel(3) is used, otherwise the communication will deadlock or fail!
6+
7+
function MPSKit.C_hamiltonian(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
8+
return MPIOperator(MPSKit.C_hamiltonian(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
79
end
810

9-
function MPSKit.AC_hamiltonian(site::Int, below, operator::MPIOperator, above, envs)
10-
return MPIOperator(MPSKit.AC_hamiltonian(site, below, parent(operator), above, envs))
11+
function MPSKit.AC_hamiltonian(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
12+
return MPIOperator(MPSKit.AC_hamiltonian(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
1113
end
1214

13-
function MPSKit.AC2_hamiltonian(site::Int, below, operator::MPIOperator, above, envs)
14-
return MPIOperator(MPSKit.AC2_hamiltonian(site, below, parent(operator), above, envs))
15+
function MPSKit.AC2_hamiltonian(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
16+
return MPIOperator(MPSKit.AC2_hamiltonian(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
1517
end
1618

17-
function MPSKit.C_projection(site::Int, below, operator::MPIOperator, above, envs)
18-
return MPIOperator(MPSKit.C_projection(site, below, parent(operator), above, envs))
19+
function MPSKit.C_projection(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
20+
return MPIOperator(MPSKit.C_projection(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
1921
end
2022

21-
function MPSKit.AC_projection(site::Int, below, operator::MPIOperator, above, envs)
22-
return MPIOperator(MPSKit.AC_projection(site, below, parent(operator), above, envs))
23+
function MPSKit.AC_projection(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
24+
return MPIOperator(MPSKit.AC_projection(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
2325
end
2426

25-
function MPSKit.AC2_projection(site::Int, below, operator::MPIOperator, above, envs)
26-
return MPIOperator(MPSKit.AC2_projection(site, below, parent(operator), above, envs))
27+
function MPSKit.AC2_projection(site::Int, below, operator::MPIOperator{O, F}, above, envs) where {O, F}
28+
return MPIOperator(MPSKit.AC2_projection(site, below, parent(operator), above, envs), operator.reduction, operator.comm[site])
2729
end

src/MPIOperator/mpioperator.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
## This shallow struct is used to indicate that each LazyMIPOperator should be evaluated on each rank and the result is to be reduced across all ranks using MPI.Allreduce
22
## This is the MPI-parallelized version of a linear operator
3-
## If one added the flexibilty of choosing the reduction, one could also parallelize over products of functions etc...
4-
struct MPIOperator{O}
3+
struct MPIOperator{O, F, C}
54
parent::O
6-
function MPIOperator(parent::O) where {O}
7-
if !MPI.Initialized()
8-
@warn "MPI is currently not initialized. Please initialize MPI by running \n `using MPI; MPI.Init()` \n before creating an MPIOperator." maxlog=1
9-
end
10-
return new{O}(parent)
11-
end
5+
reduction::F
6+
comm::C
7+
MPIOperator{O, F, C}(parent::O, reduction::F=Base.:+, comm::C=MPI.COMM_WORLD) where {O, F, C} = new{O, F, C}(parent, reduction, comm)
8+
MPIOperator{O, F}(parent::O, reduction::F=Base.:+, comm::C=MPI.COMM_WORLD) where {O, F, C} = new{O, F, C}(parent, reduction, comm)
9+
MPIOperator{O}(parent::O, reduction::F=Base.:+, comm::C=MPI.COMM_WORLD) where {O, F, C} = new{O, F, C}(parent, reduction, comm)
10+
MPIOperator(parent::O, reduction::F=Base.:+, comm::C=MPI.COMM_WORLD) where {O, F, C} = new{O, F, C}(parent, reduction, comm)
1211
end
1312

14-
function Base.parent(op::MPIOperator{O})::O where {O}
13+
function Base.parent(op::MPIOperator{O, F})::O where {O, F}
1514
return op.parent
1615
end
1716

18-
function (Op::MPIOperator{O})(x::S) where {O,S}
17+
function (Op::MPIOperator{O, F})(x::S) where {O, F, S}
1918
y_per_rank = parent(Op)(x)
20-
y = MPIHelper.allreduce(y_per_rank, Base.:+, MPI.COMM_WORLD)
19+
y = MPIHelper.allreduce(y_per_rank, Op.reduction, Op.comm)
2120
return y
2221
end
2322

2423
Base.:*(Op::MPIOperator, v) = Op(v)
2524
(Op::MPIOperator)(x, ::Number) = Op(x)
2625

2726
function Base.show(io::IO, ::MIME"text/plain", op::MPIOperator)
28-
print(io, "MPIOperator wrapping:\n")
27+
print(io, "MPIOperator with communicator $(op.comm) and reduction $(op.reduction) wrapping:\n")
2928
show(io, MIME"text/plain"(), parent(op))
3029
end
3130
Base.show(io::IO, op::MPIOperator) = show(convert(IOContext, io), op)
3231
function Base.show(io::IOContext, op::MPIOperator)
33-
print(io, "MPIOperator wrapping:\n")
32+
print(io, "MPIOperator with communicator $(op.comm) and reduction $(op.reduction) wrapping:\n")
3433
show(io, parent(op))
3534
end
3635

src/MPIOperator/mpskit.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
MPIOperator(parent::FiniteMPO, reduction::F=Base.:+, comm::C=[MPI.Comm_dup(MPI.COMM_WORLD) for _ in eachsite(parent)]) where {F, C} = MPIOperator{FiniteMPO, F, C}(parent, reduction, comm)
3+
4+
MPIOperator(parent::O, reduction::F=Base.:+, comm::C=MPSKit.PeriodicVector([MPI.Comm_dup(MPI.COMM_WORLD) for _ in eachsite(parent)])) where {O <: InfiniteMPO, F, C} = MPIOperator{O, F, C}(parent, reduction, comm)
5+
6+
MPIOperator(parent::O, reduction::F=Base.:+, comm::C=MPSKit.PeriodicVector([MPI.Comm_dup(MPI.COMM_WORLD) for _ in eachsite(parent)])) where {O <: InfiniteMPOHamiltonian, F, C} = MPIOperator{O, F, C}(parent, reduction, comm)

src/algorithms/groundstate/vumps.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function MPSKit.localupdate_step!(
1111
ACs = similar(mps.AC)
1212
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs
1313

14+
# TODO: If we have parallel = true, then one would even need 2 communicators per site!
1415
tforeach(eachsite(mps), src_ACs, src_Cs; scheduler) do site, AC₀, C₀
1516
dst_ACs[site] = MPSKit._localupdate_vumps_step!(
1617
site, mps, state.operator, state.envs, AC₀, C₀;
@@ -27,12 +28,13 @@ function MPSKit._localupdate_vumps_step!(
2728
parallel::Bool = false, alg_orth = MPSKit.Defaults.alg_qr(),
2829
alg_eigsolve = MPSKit.Defaults.eigsolver, which
2930
)
31+
comm = operator.comm[site]
3032
if !parallel
3133
Hac = AC_hamiltonian(site, mps, operator, mps, envs)
3234
_, AC = fixedpoint(Hac, AC₀, which, alg_eigsolve)
3335
Hc = C_hamiltonian(site, mps, operator, mps, envs)
3436
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
35-
return mpi_execute_on_root_and_bcast(regauge!, AC, C; alg = alg_orth)
37+
return mpi_execute_on_root_and_bcast(regauge!, AC, C; comm = comm, alg = alg_orth)
3638
end
3739

3840
local AC, C
@@ -46,7 +48,7 @@ function MPSKit._localupdate_vumps_step!(
4648
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
4749
end
4850
end
49-
return mpi_execute_on_root_and_bcast(regauge!, AC, C; alg = alg_orth)
51+
return mpi_execute_on_root_and_bcast(regauge!, AC, C; comm = comm, alg = alg_orth)
5052
end
5153

5254
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, MPIOperator{O}, E}, ACs::AbstractVector) where {S, O, E}
@@ -67,6 +69,6 @@ function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, M
6769
else
6870
psi = nothing
6971
end
70-
psi = MPIHelper.bcast(psi, 0, MPI.COMM_WORLD)
72+
psi = MPIHelper.bcast(psi, MPI.COMM_WORLD)
7173
return psi
7274
end

src/includes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include("utility/forward.jl")
22

33
include("MPIOperator/mpioperator.jl")
4+
include("MPIOperator/mpskit.jl")
45
include("MPIOperator/derivatives.jl")
56
include("MPIOperator/environments.jl")
67
include("MPIOperator/ortho.jl")

src/utility/forward.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ macro forward_1_1_astype(ex, fs)
6969
T = esc(T)
7070
fs = isexpr(fs, :tuple) ? map(esc, fs.args) : [esc(fs)]
7171
:($([:($f(a,x::$T, args...; kwargs...) =
72-
(Base.@inline; $T($f(a, x.$field, args...; kwargs...))))
72+
(Base.@inline; $T($f(a, x.$field, args...; kwargs...), x.reduction, x.comm)))
7373
for f in fs]...);
7474
nothing)
7575
end
@@ -106,7 +106,7 @@ macro forward_3_1_astype(ex, fs)
106106
T = esc(T)
107107
fs = isexpr(fs, :tuple) ? map(esc, fs.args) : [esc(fs)]
108108
:($([:($f(a,b,c,x::$T, y::$T, args...; kwargs...) =
109-
(Base.@inline; $T($f(a, b, c, x.$field, y.$field, args...; kwargs...))))
109+
(Base.@inline; $T($f(a, b, c, x.$field, y.$field, args...; kwargs...), x.reduction, x.comm)))
110110
for f in fs]...);
111111
nothing)
112112
end
@@ -115,7 +115,7 @@ macro forward_astype(ex, fs)
115115
T = esc(T)
116116
fs = isexpr(fs, :tuple) ? map(esc, fs.args) : [esc(fs)]
117117
:($([:($f(x::$T, args...; kwargs...) =
118-
(Base.@inline; $T($f(x.$field, args...; kwargs...))))
118+
(Base.@inline; $T($f(x.$field, args...; kwargs...), x.reduction, x.comm)))
119119
for f in fs]...);
120120
nothing)
121121
end

0 commit comments

Comments
 (0)