Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorNetworkSolvers"
uuid = "62d5c68e-057c-4e85-9e57-968d0954630a"
version = "0.1.0"
version = "0.1.1"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
9 changes: 5 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using TensorNetworkSolvers: TensorNetworkSolvers
using Documenter: Documenter, DocMeta, deploydocs, makedocs
using TensorNetworkSolvers: TensorNetworkSolvers

DocMeta.setdocmeta!(
TensorNetworkSolvers, :DocTestSetup, :(using TensorNetworkSolvers); recursive = true
Expand All @@ -14,11 +14,12 @@ makedocs(;
format = Documenter.HTML(;
canonical = "https://itensor.github.io/TensorNetworkSolvers.jl",
edit_link = "main",
assets = ["assets/favicon.ico", "assets/extras.css"],
assets = ["assets/favicon.ico", "assets/extras.css"]
),
pages = ["Home" => "index.md", "Reference" => "reference.md"],
pages = ["Home" => "index.md", "Reference" => "reference.md"]
)

deploydocs(;
repo = "github.com/ITensor/TensorNetworkSolvers.jl", devbranch = "main", push_preview = true
repo = "github.com/ITensor/TensorNetworkSolvers.jl", devbranch = "main",
push_preview = true
)
2 changes: 1 addition & 1 deletion docs/make_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ Literate.markdown(
joinpath(pkgdir(TensorNetworkSolvers), "docs", "src");
flavor = Literate.DocumenterFlavor(),
name = "index",
postprocess = ccq_logo,
postprocess = ccq_logo
)
2 changes: 1 addition & 1 deletion docs/make_readme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ Literate.markdown(
joinpath(pkgdir(TensorNetworkSolvers));
flavor = Literate.CommonMarkFlavor(),
name = "README",
postprocess = ccq_logo,
postprocess = ccq_logo
)
38 changes: 11 additions & 27 deletions src/AlgorithmsInterfaceExtensions.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module AlgorithmsInterfaceExtensions

import AlgorithmsInterface as AI

#========================== Patches for AlgorithmsInterface.jl ============================#
import AlgorithmsInterface as AI #========================== Patches for AlgorithmsInterface.jl ============================#

abstract type Problem <: AI.Problem end
abstract type Algorithm <: AI.Algorithm end
Expand All @@ -25,26 +23,20 @@ function AI.initialize_state(
problem, algorithm, algorithm.stopping_criterion
)
return DefaultState(; stopping_criterion_state, kwargs...)
end

#============================ DefaultState ================================================#
end #============================ DefaultState ================================================#

@kwdef mutable struct DefaultState{
Iterate, StoppingCriterionState <: AI.StoppingCriterionState,
} <: State
iterate::Iterate
iteration::Int = 0
stopping_criterion_state::StoppingCriterionState
end

#============================ increment! ==================================================#
end #============================ increment! ==================================================#

# Custom version of `increment!` that also takes the problem and algorithm as arguments.
function AI.increment!(problem::Problem, algorithm::Algorithm, state::State)
return AI.increment!(state)
end

#============================ solve! ======================================================#
end #============================ solve! ======================================================#

# Custom version of `solve!` that allows specifying the logger and also overloads
# `increment!` on the problem and algorithm.
Expand All @@ -55,13 +47,13 @@ default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_)
function default_logging_context_prefix(problem::Problem, algorithm::Algorithm)
return Symbol(
default_logging_context_prefix(problem),
default_logging_context_prefix(algorithm),
default_logging_context_prefix(algorithm)
)
end
function AI.solve!(
problem::Problem, algorithm::Algorithm, state::State;
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
kwargs...,
kwargs...
)
logger = AI.algorithm_logger()

Expand Down Expand Up @@ -94,13 +86,11 @@ end
function AI.solve(
problem::Problem, algorithm::Algorithm;
logging_context_prefix = default_logging_context_prefix(problem, algorithm),
kwargs...,
kwargs...
)
state = AI.initialize_state(problem, algorithm; kwargs...)
return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...)
end

#============================ AlgorithmIterator ===========================================#
end #============================ AlgorithmIterator ===========================================#

abstract type AlgorithmIterator end

Expand Down Expand Up @@ -133,19 +123,15 @@ struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator
problem::Problem
algorithm::Algorithm
state::State
end

#============================ with_algorithmlogger ========================================#
end #============================ with_algorithmlogger ========================================#

# Allow passing functions, not just CallbackActions.
@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...)
return AI.with_algorithmlogger(f, args...)
end
@inline function with_algorithmlogger(f, args::Pair{Symbol}...)
return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...)
end

#============================ NestedAlgorithm =============================================#
end #============================ NestedAlgorithm =============================================#

abstract type NestedAlgorithm <: Algorithm end

Expand Down Expand Up @@ -205,9 +191,7 @@ from a list of stored algorithms.
end
function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...)
return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...)
end

#============================ FlattenedAlgorithm ==========================================#
end #============================ FlattenedAlgorithm ==========================================#

# Flatten a nested algorithm.
abstract type FlattenedAlgorithm <: Algorithm end
Expand Down
10 changes: 6 additions & 4 deletions src/eigenproblem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import AlgorithmsInterface as AI
import .AlgorithmsInterfaceExtensions as AIE
import AlgorithmsInterface as AI

maybe_fill(value, len::Int) = fill(value, len)
function maybe_fill(v::AbstractVector, len::Int)
Expand Down Expand Up @@ -43,7 +43,7 @@ function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, reg
return Sweeping(nsweeps) do i
return select_algorithm(
dmrg_sweep, operator, state;
regions, region_kwargs = region_kwargs′[i],
regions, region_kwargs = region_kwargs′[i]
)
end
end
Expand All @@ -60,7 +60,8 @@ end

function AI.step!(problem::EigenProblem, algorithm::Sweep, state::AI.State; kwargs...)
iterate = solve_region!!(
problem, algorithm.region_algorithms[state.iteration](state.iterate), state.iterate
problem, algorithm.region_algorithms[state.iteration](state.iterate),
state.iterate
)
state.iterate = iterate
return state
Expand All @@ -84,7 +85,8 @@ function solve_region!!(problem::EigenProblem, algorithm::RegionAlgorithm, state
=#

# Dummy update for demonstration purposes.
state′ = "region = $region" *
state′ =
"region = $region" *
", update_kwargs = $(region_kwargs.update)" *
", insert_kwargs = $(region_kwargs.insert)"
state = [state; [state′]]
Expand Down
2 changes: 1 addition & 1 deletion src/sweep.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import AlgorithmsInterface as AI
import .AlgorithmsInterfaceExtensions as AIE
import AlgorithmsInterface as AI

@kwdef struct Sweeping{
Algorithms <: AbstractVector{<:AI.Algorithm},
Expand Down
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ const GROUP = uppercase(
arg == "" ? "ALL" : arg
else
only(match(pat, ARGS[arg_id]).captures)
end,
end
)

"match files of the form `test_*.jl`, but exclude `*setup*.jl`"
"""
match files of the form `test_*.jl`, but exclude `*setup*.jl`
"""
function istestfile(path)
fn = basename(path)
return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup")
return endswith(fn, ".jl") && startswith(basename(fn), "test_") &&
!contains(fn, "setup")
end
"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`"
"""
match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`
"""
function isexamplefile(path)
fn = basename(path)
return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup")
Expand Down Expand Up @@ -60,7 +65,7 @@ end
:macrocall,
GlobalRef(Suppressor, Symbol("@suppress")),
LineNumberNode(@__LINE__, @__FILE__),
:(include($filename)),
:(include($filename))
)
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_aqua.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using TensorNetworkSolvers: TensorNetworkSolvers
using Aqua: Aqua
using TensorNetworkSolvers: TensorNetworkSolvers
using Test: @testset

@testset "Code quality (Aqua.jl)" begin
Expand Down
4 changes: 2 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import AlgorithmsInterface as AI
import TensorNetworkSolvers.AlgorithmsInterfaceExtensions as AIE
using Graphs: path_graph
using TensorNetworkSolvers: EigenProblem, Region, Sweep, Sweeping, dmrg, dmrg_sweep
import TensorNetworkSolvers.AlgorithmsInterfaceExtensions as AIE
using Test: @test, @testset

@testset "TensorNetworkSolvers" begin
Expand Down Expand Up @@ -160,7 +160,7 @@ using Test: @test, @testset
:EigenProblem_Sweeping_PreStep => print_dmrg_prestep,
:EigenProblem_Sweeping_PostStep => print_dmrg_poststep,
:EigenProblem_Sweeping_Sweep_Start => print_sweep_start,
:EigenProblem_Sweeping_Sweep_PostStep => print_sweep_poststep,
:EigenProblem_Sweeping_Sweep_PostStep => print_sweep_poststep
) do
x = dmrg(operator, x0; nsweeps, regions, region_kwargs)
return x
Expand Down