Skip to content

Commit f50737c

Browse files
committed
Rename Sweeping to NestedAlgorithm
1 parent 5ddf922 commit f50737c

4 files changed

Lines changed: 121 additions & 117 deletions

File tree

src/AlgorithmsInterfaceExtensions.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,100 @@ end
141141
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
142142
end
143143

144+
#============================ NestedAlgorithm =============================================#
145+
146+
#=
147+
NestedAlgorithm(sweeps::AbstractVector{<:Algorithm})
148+
149+
An algorithm that consists of running an algorithm at each iteration
150+
from a list of stored algorithms.
151+
=#
152+
@kwdef struct NestedAlgorithm{
153+
Algorithms <: AbstractVector{<:Algorithm},
154+
StoppingCriterion <: AI.StoppingCriterion,
155+
} <: Algorithm
156+
algorithms::Algorithms
157+
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms))
158+
end
159+
function NestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
160+
return NestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
161+
end
162+
163+
max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms)
164+
165+
function AI.step!(
166+
problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State;
167+
logging_context_prefix = Symbol()
168+
)
169+
# Perform the current sweep.
170+
sub_algorithm = algorithm.algorithms[state.iteration]
171+
sub_state = AI.initialize_state(problem, sub_algorithm; state.iterate)
172+
logging_context_prefix = Symbol(logging_context_prefix, :Sweep_)
173+
AI.solve!(problem, sub_algorithm, sub_state; logging_context_prefix)
174+
state.iterate = sub_state.iterate
175+
return state
176+
end
177+
178+
#============================ FlattenedAlgorithm ==========================================#
179+
180+
# Flatten a nested algorithm.
181+
function default_flattened_stopping_criterion(algorithm::NestedAlgorithm)
182+
return AI.StopAfterIteration(sum(max_iterations, algorithm.algorithms))
183+
end
184+
@kwdef struct FlattenedAlgorithm{
185+
ParentAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion,
186+
} <: Algorithm
187+
parent_algorithm::ParentAlgorithm
188+
stopping_criterion::StoppingCriterion =
189+
default_flattened_stopping_criterion(parent_algorithm)
190+
end
191+
192+
@kwdef mutable struct FlattenedAlgorithmState{
193+
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
194+
} <: State
195+
iterate::Iterate
196+
iteration::Int = 0
197+
parent_iteration::Int = 1
198+
child_iteration::Int = 0
199+
stopping_criterion_state::StoppingCriterionState
200+
end
201+
202+
function AI.initialize_state(
203+
problem::Problem, algorithm::FlattenedAlgorithm; kwargs...
204+
)
205+
stopping_criterion_state = AI.initialize_state(
206+
problem, algorithm, algorithm.stopping_criterion
207+
)
208+
return FlattenedAlgorithmState(; stopping_criterion_state, kwargs...)
209+
end
210+
function AI.increment!(
211+
problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState
212+
)
213+
# Increment the total iteration count.
214+
state.iteration += 1
215+
if state.child_iteration max_iterations(algorithm.parent_algorithm.algorithms[state.parent_iteration])
216+
# We're on the last iteration of the child algorithm, so move to the next
217+
# child algorithm.
218+
state.parent_iteration += 1
219+
state.child_iteration = 1
220+
else
221+
# Iterate the child algorithm.
222+
state.child_iteration += 1
223+
end
224+
return state
225+
end
226+
function AI.step!(
227+
problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState;
228+
logging_context_prefix = Symbol()
229+
)
230+
algorithm_sweep = algorithm.parent_algorithm.algorithms[state.parent_iteration]
231+
state_sweep = AI.initialize_state(
232+
problem, algorithm_sweep;
233+
state.iterate, iteration = state.child_iteration
234+
)
235+
AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix)
236+
state.iterate = state_sweep.iterate
237+
return state
238+
end
239+
144240
end

src/eigenproblem.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
import AlgorithmsInterface as AI
22
import .AlgorithmsInterfaceExtensions as AIE
33

4+
function dmrg_sweep(operator, state; regions, region_kwargs)
5+
problem = EigenProblem(operator)
6+
algorithm = Sweep(; regions, region_kwargs)
7+
return AI.solve(problem, algorithm; iterate = state).iterate
8+
end
9+
10+
function dmrg(operator, state; nsweeps, regions, region_kwargs, kwargs...)
11+
problem = EigenProblem(operator)
12+
algorithm = AIE.NestedAlgorithm(nsweeps) do i
13+
return Sweep(; regions, region_kwargs = region_kwargs[i])
14+
end
15+
return AI.solve(problem, algorithm; iterate = state, kwargs...).iterate
16+
end
17+
418
#=
519
EigenProblem(operator)
620
@@ -36,17 +50,3 @@ function update!(problem::EigenProblem, algorithm::Sweep, state::AI.State)
3650

3751
return state
3852
end
39-
40-
function dmrg_sweep(operator, state; regions, region_kwargs)
41-
problem = EigenProblem(operator)
42-
algorithm = Sweep(; regions, region_kwargs)
43-
return AI.solve(problem, algorithm; iterate = state).iterate
44-
end
45-
46-
function dmrg(operator, state; nsweeps, regions, region_kwargs, kwargs...)
47-
problem = EigenProblem(operator)
48-
algorithm = Sweeping(nsweeps) do i
49-
return Sweep(; regions, region_kwargs = region_kwargs[i])
50-
end
51-
return AI.solve(problem, algorithm; iterate = state, kwargs...).iterate
52-
end

src/sweeping.jl

Lines changed: 1 addition & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function Sweep(
2727
return Sweep(regions, Returns(region_kwargs), stopping_criterion)
2828
end
2929

30-
maxiter(algorithm::Sweep) = length(algorithm.regions)
30+
AIE.max_iterations(algorithm::Sweep) = length(algorithm.regions)
3131

3232
function AI.step!(
3333
problem::AI.Problem, algorithm::Sweep, state::AI.State; kwargs...
@@ -56,95 +56,3 @@ function insert!(
5656
# Insert step goes here.
5757
return state
5858
end
59-
60-
#=
61-
Sweeping(sweeps::Vector{<:Sweep})
62-
63-
The sweeping algorithm, which just stores a list of sweeps defined above.
64-
=#
65-
@kwdef struct Sweeping{
66-
Sweeps <: Vector{<:Sweep}, StoppingCriterion <: AI.StoppingCriterion,
67-
} <: AIE.Algorithm
68-
sweeps::Sweeps
69-
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(sweeps))
70-
end
71-
function Sweeping(f::Function, nsweeps::Int; kwargs...)
72-
return Sweeping(; sweeps = f.(1:nsweeps), kwargs...)
73-
end
74-
75-
maxiter(algorithm::Sweeping) = length(algorithm.sweeps)
76-
nregions(algorithm::Sweeping) = sum(maxiter, algorithm.sweeps)
77-
78-
function AI.step!(
79-
problem::AI.Problem, algorithm::Sweeping, state::AI.State;
80-
logging_context_prefix = Symbol()
81-
)
82-
# Perform the current sweep.
83-
algorithm_sweep = algorithm.sweeps[state.iteration]
84-
state_sweep = AI.initialize_state(problem, algorithm_sweep; state.iterate)
85-
logging_context_prefix = Symbol(logging_context_prefix, :Sweep_)
86-
AI.solve!(problem, algorithm_sweep, state_sweep; logging_context_prefix)
87-
state.iterate = state_sweep.iterate
88-
return state
89-
end
90-
91-
# TODO: Use a proper stopping criterion.
92-
function AI.is_finished(
93-
problem::AI.Problem, algorithm::Sweeping, state::AI.State
94-
)
95-
state.iteration == 0 && return false
96-
return state.iteration >= length(algorithm.sweeps)
97-
end
98-
99-
# Sweeping by region.
100-
@kwdef struct ByRegion{
101-
ParentAlgorithm <: Sweeping, StoppingCriterion <: AI.StoppingCriterion,
102-
} <: AIE.Algorithm
103-
sweeping::ParentAlgorithm
104-
stopping_criterion::StoppingCriterion = AI.StopAfterIteration(nregions(sweeping))
105-
end
106-
107-
@kwdef mutable struct ByRegionState{
108-
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
109-
} <: AIE.State
110-
iterate::Iterate
111-
iteration::Int = 0
112-
sweeping_iteration::Int = 1
113-
sweep_iteration::Int = 0
114-
stopping_criterion_state::StoppingCriterionState
115-
end
116-
117-
function AI.initialize_state(
118-
problem::AIE.Problem, algorithm::ByRegion; kwargs...
119-
)
120-
stopping_criterion_state = AI.initialize_state(
121-
problem, algorithm, algorithm.stopping_criterion
122-
)
123-
return ByRegionState(; stopping_criterion_state, kwargs...)
124-
end
125-
function AI.increment!(problem::AIE.Problem, algorithm::AIE.Algorithm, state::ByRegionState)
126-
# Increment the total iteration count.
127-
state.iteration += 1
128-
if state.sweep_iteration maxiter(algorithm.sweeping.sweeps[state.sweeping_iteration])
129-
# We're on the last region of the sweep, so move to the next sweep.
130-
state.sweeping_iteration += 1
131-
state.sweep_iteration = 1
132-
else
133-
# Move to the next region in the current sweep.
134-
state.sweep_iteration += 1
135-
end
136-
return state
137-
end
138-
function AI.step!(
139-
problem::AI.Problem, algorithm::ByRegion, state::ByRegionState;
140-
logging_context_prefix = Symbol()
141-
)
142-
algorithm_sweep = algorithm.sweeping.sweeps[state.sweeping_iteration]
143-
state_sweep = AI.initialize_state(
144-
problem, algorithm_sweep;
145-
state.iterate, iteration = state.sweep_iteration
146-
)
147-
AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix)
148-
state.iterate = state_sweep.iterate
149-
return state
150-
end

test/test_basics.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import AlgorithmsInterface as AI
22
using Graphs: path_graph
3-
using TensorNetworkSolvers: ByRegion, EigenProblem, Sweeping, Sweep, dmrg, dmrg_sweep
3+
using TensorNetworkSolvers: EigenProblem, Sweep, dmrg, dmrg_sweep
44
import TensorNetworkSolvers.AlgorithmsInterfaceExtensions as AIE
55
using Test: @test, @testset
66

@@ -66,7 +66,7 @@ using Test: @test, @testset
6666
end
6767
x = []
6868
problem = EigenProblem(operator)
69-
algorithm = Sweeping(nsweeps) do i
69+
algorithm = AIE.NestedAlgorithm(nsweeps) do i
7070
Sweep(; regions, region_kwargs = region_kwargs[i])
7171
end
7272
state = AI.initialize_state(problem, algorithm; iterate = x)
@@ -78,7 +78,7 @@ using Test: @test, @testset
7878
@test iterations == 1:nsweeps
7979
@test length(state.iterate) == nsweeps * length(regions)
8080
end
81-
@testset "ByRegion" begin
81+
@testset "FlattenedAlgorithm" begin
8282
operator = path_graph(4)
8383
regions = [(1, 2), (2, 3), (3, 4)]
8484
nsweeps = 3
@@ -92,10 +92,10 @@ using Test: @test, @testset
9292
end
9393
x = []
9494
problem = EigenProblem(operator)
95-
sweeping = Sweeping(nsweeps) do i
95+
sweeping = AIE.NestedAlgorithm(nsweeps) do i
9696
Sweep(; regions, region_kwargs = region_kwargs[i])
9797
end
98-
algorithm = ByRegion(; sweeping)
98+
algorithm = AIE.FlattenedAlgorithm(; parent_algorithm = sweeping)
9999
state = AI.initialize_state(problem, algorithm; iterate = x)
100100
iterator = AIE.algorithm_iterator(problem, algorithm, state)
101101
iterations = Int[]
@@ -147,11 +147,11 @@ using Test: @test, @testset
147147
return nothing
148148
end
149149
x = AIE.with_algorithmlogger(
150-
:EigenProblem_Sweeping_Start => print_dmrg_start,
151-
:EigenProblem_Sweeping_PreStep => print_dmrg_prestep,
152-
:EigenProblem_Sweeping_PostStep => print_dmrg_poststep,
153-
:EigenProblem_Sweeping_Sweep_Start => print_sweep_start,
154-
:EigenProblem_Sweeping_Sweep_PostStep => print_sweep_poststep,
150+
:EigenProblem_NestedAlgorithm_Start => print_dmrg_start,
151+
:EigenProblem_NestedAlgorithm_PreStep => print_dmrg_prestep,
152+
:EigenProblem_NestedAlgorithm_PostStep => print_dmrg_poststep,
153+
:EigenProblem_NestedAlgorithm_Sweep_Start => print_sweep_start,
154+
:EigenProblem_NestedAlgorithm_Sweep_PostStep => print_sweep_poststep,
155155
) do
156156
x = dmrg(operator, x0; nsweeps, regions, region_kwargs)
157157
return x

0 commit comments

Comments
 (0)