Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FunctionImplementations"
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
version = "0.4.1"
version = "0.4.2"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
12 changes: 7 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using FunctionImplementations: FunctionImplementations
using Documenter: Documenter, DocMeta, deploydocs, makedocs
using FunctionImplementations: FunctionImplementations

DocMeta.setdocmeta!(
FunctionImplementations, :DocTestSetup, :(using FunctionImplementations); recursive = true
FunctionImplementations, :DocTestSetup, :(using FunctionImplementations);
recursive = true
)

include("make_index.jl")
Expand All @@ -14,11 +15,12 @@ makedocs(;
format = Documenter.HTML(;
canonical = "https://itensor.github.io/FunctionImplementations.jl",
edit_link = "main",
assets = ["assets/favicon.ico", "assets/extras.css"],
assets = ["assets/favicon.ico", "assets/extras.css"]
),
pages = ["Home" => "index.md", "Reference" => "reference.md"],
pages = ["Home" => "index.md", "Reference" => "reference.md"]
)

deploydocs(;
repo = "github.com/ITensor/FunctionImplementations.jl", devbranch = "main", push_preview = true
repo = "github.com/ITensor/FunctionImplementations.jl", devbranch = "main",
push_preview = true
)
4 changes: 2 additions & 2 deletions docs/make_index.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Literate: Literate
using FunctionImplementations: FunctionImplementations
using Literate: Literate

function ccq_logo(content)
include_ccq_logo = """
Expand All @@ -17,5 +17,5 @@ Literate.markdown(
joinpath(pkgdir(FunctionImplementations), "docs", "src");
flavor = Literate.DocumenterFlavor(),
name = "index",
postprocess = ccq_logo,
postprocess = ccq_logo
)
4 changes: 2 additions & 2 deletions docs/make_readme.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Literate: Literate
using FunctionImplementations: FunctionImplementations
using Literate: Literate

function ccq_logo(content)
include_ccq_logo = """
Expand All @@ -17,5 +17,5 @@ Literate.markdown(
joinpath(pkgdir(FunctionImplementations));
flavor = Literate.CommonMarkFlavor(),
name = "README",
postprocess = ccq_logo,
postprocess = ccq_logo
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module FunctionImplementationsFillArraysExt

using FillArrays: FillArrays as FA, AbstractFill, RectDiagonal
import FunctionImplementations as FI
using FillArrays: FillArrays as FA, AbstractFill, RectDiagonal

function check_perm(a::AbstractArray, perm)
(ndims(a) == length(perm) && isperm(perm)) ||
Expand Down
27 changes: 20 additions & 7 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ reminiscent of how Broadcast works.

The various entry points for specializing behavior are:

* Destination selection can be achieved through:
- Destination selection can be achieved through:

```julia
Base.similar(concat::Concatenated{Style}, ::Type{T}, axes) where {Style}
```

* Custom implementations:
- Custom implementations:

```julia
Base.copy(concat::Concatenated{Style}) # custom implementation of cat
Expand All @@ -28,11 +28,12 @@ Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat
module Concatenate

export concatenate
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public Concatenated, cat, cat!, concatenated"))
VERSION >= v"1.11.0-DEV.469" &&
eval(Meta.parse("public Concatenated, cat, cat!, concatenated"))

using Base: promote_eltypeof
import Base.Broadcast as BC
using ..FunctionImplementations: zero!
using Base: promote_eltypeof

unval(::Val{x}) where {x} = x

Expand Down Expand Up @@ -116,7 +117,11 @@ end

function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
return ntuple(cat_ndims(dims, a, as...)) do dim
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
return if dim in dims
cat_axis(map(Base.Fix2(axes, dim), (a, as...))...)
else
axes(a, dim)
end
end
end
function cat_axes(dims::Val, as::AbstractArray...)
Expand Down Expand Up @@ -191,11 +196,19 @@ end
__cat_offset!(A, shape, catdims, offsets) = A
function __cat_offset1!(A, shape, catdims, offsets, x)
inds = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
return if (i <= length(catdims) && catdims[i])
offsets[i] .+ cat_indices(x, i)
else
1:shape[i]
end
end
_copy_or_fill!(A, inds, x)
newoffsets = ntuple(length(offsets)) do i
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
return if (i <= length(catdims) && catdims[i])
offsets[i] + cat_size(x, i)
else
offsets[i]
end
end
return newoffsets
end
Expand Down
68 changes: 54 additions & 14 deletions src/style.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ by defining a type/method pair

struct MyContainerImplementationStyle <: ImplementationStyle end
FunctionImplementations.ImplementationStyle(::Type{<:MyContainer}) = MyContainerImplementationStyle()

"""
abstract type ImplementationStyle end
ImplementationStyle(::Type{T}) where {T} = throw(MethodError(ImplementationStyle, (T,)))
Expand Down Expand Up @@ -69,20 +68,37 @@ The result does not have to be one of the input arguments, it could be a third t
"""
ImplementationStyle(::S, ::S) where {S <: ImplementationStyle} = S() # homogeneous types preserved
# Fall back to UnknownImplementationStyle. This is necessary to implement argument-swapping
ImplementationStyle(::ImplementationStyle, ::ImplementationStyle) = UnknownImplementationStyle()
function ImplementationStyle(::ImplementationStyle, ::ImplementationStyle)
return UnknownImplementationStyle()
end
# UnknownImplementationStyle loses to everything
ImplementationStyle(::UnknownImplementationStyle, ::UnknownImplementationStyle) = UnknownImplementationStyle()
ImplementationStyle(::S, ::UnknownImplementationStyle) where {S <: ImplementationStyle} = S()
function ImplementationStyle(::UnknownImplementationStyle, ::UnknownImplementationStyle)
return UnknownImplementationStyle()
end
function ImplementationStyle(
::S,
::UnknownImplementationStyle
) where {S <: ImplementationStyle}
return S()
end
# Precedence rules
ImplementationStyle(::A, ::A) where {A <: AbstractArrayImplementationStyle} = A()
function ImplementationStyle(a::A, b::B) where {A <: AbstractArrayImplementationStyle, B <: AbstractArrayImplementationStyle}
function ImplementationStyle(
a::A,
b::B
) where {A <: AbstractArrayImplementationStyle, B <: AbstractArrayImplementationStyle}
if Base.typename(A) ≡ Base.typename(B)
return A()
end
return UnknownImplementationStyle()
end
# Any specific array type beats DefaultArrayImplementationStyle
ImplementationStyle(a::AbstractArrayImplementationStyle, ::DefaultArrayImplementationStyle) = a
function ImplementationStyle(
a::AbstractArrayImplementationStyle,
::DefaultArrayImplementationStyle
)
return a
end

## logic for deciding the ImplementationStyle

Expand All @@ -94,6 +110,7 @@ Uses [`ImplementationStyle`](@ref) to get the style for each argument, and uses
[`result_style`](@ref) to combine styles.

# Examples

```jldoctest
julia> FunctionImplementations.style([1], [1 2; 3 4])
FunctionImplementations.DefaultArrayImplementationStyle()
Expand All @@ -115,10 +132,16 @@ determine a common `ImplementationStyle`.
# Examples

```jldoctest
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayImplementationStyle(), FunctionImplementations.DefaultArrayImplementationStyle())
julia> FunctionImplementations.result_style(
FunctionImplementations.DefaultArrayImplementationStyle(),
FunctionImplementations.DefaultArrayImplementationStyle()
)
FunctionImplementations.DefaultArrayImplementationStyle()

julia> FunctionImplementations.result_style(FunctionImplementations.UnknownImplementationStyle(), FunctionImplementations.DefaultArrayImplementationStyle())
julia> FunctionImplementations.result_style(
FunctionImplementations.UnknownImplementationStyle(),
FunctionImplementations.DefaultArrayImplementationStyle()
)
FunctionImplementations.DefaultArrayImplementationStyle()
```
"""
Expand All @@ -129,24 +152,41 @@ function result_style(s1::S, s2::S) where {S <: ImplementationStyle}
return s1 ≡ s2 ? s1 : error("inconsistent styles, custom rule needed")
end
# Test both orders so users typically only have to declare one order
result_style(s1, s2) = result_join(s1, s2, ImplementationStyle(s1, s2), ImplementationStyle(s2, s1))
function result_style(s1, s2)
return result_join(s1, s2, ImplementationStyle(s1, s2), ImplementationStyle(s2, s1))
end

# result_join is the final arbiter. Because `ImplementationStyle` for undeclared pairs results in UnknownImplementationStyle,
# we defer to any case where the result of `ImplementationStyle` is known.
result_join(::Any, ::Any, ::UnknownImplementationStyle, ::UnknownImplementationStyle) = UnknownImplementationStyle()
function result_join(
::Any,
::Any,
::UnknownImplementationStyle,
::UnknownImplementationStyle
)
return UnknownImplementationStyle()
end
result_join(::Any, ::Any, ::UnknownImplementationStyle, s::ImplementationStyle) = s
result_join(::Any, ::Any, s::ImplementationStyle, ::UnknownImplementationStyle) = s
# For AbstractArray types with undefined precedence rules,
# we have to signal conflict. Because ArrayImplementationConflict is a subtype of AbstractArray,
# this will "poison" any future operations (if we instead returned `DefaultArrayImplementationStyle`, then for
# 3-array functions returned type would depend on argument order).
result_join(::AbstractArrayImplementationStyle, ::AbstractArrayImplementationStyle, ::UnknownImplementationStyle, ::UnknownImplementationStyle) =
ArrayImplementationConflict()
function result_join(
::AbstractArrayImplementationStyle,
::AbstractArrayImplementationStyle,
::UnknownImplementationStyle,
::UnknownImplementationStyle
)
return ArrayImplementationConflict()
end
# Fallbacks in case users define `rule` for both argument-orders (not recommended)
result_join(::Any, ::Any, s1::S, s2::S) where {S <: ImplementationStyle} = result_style(s1, s2)
function result_join(::Any, ::Any, s1::S, s2::S) where {S <: ImplementationStyle}
return result_style(s1, s2)
end

@noinline function result_join(::S, ::T, ::U, ::V) where {S, T, U, V}
error(
return error(
"""
conflicting rules defined
FunctionImplementations.ImplementationStyle(::$S, ::$T) = $U()
Expand Down
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ const GROUP = uppercase(
arg == "" ? "ALL" : arg
else
only(match(pat, ARGS[arg_id]).captures)
end,
end
)

"match files of the form `test_*.jl`, but exclude `*setup*.jl`"
"""
match files of the form `test_*.jl`, but exclude `*setup*.jl`
"""
function istestfile(path)
fn = basename(path)
return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup")
return endswith(fn, ".jl") && startswith(basename(fn), "test_") &&
!contains(fn, "setup")
end
"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`"
"""
match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`
"""
function isexamplefile(path)
fn = basename(path)
return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup")
Expand Down Expand Up @@ -60,7 +65,7 @@ end
:macrocall,
GlobalRef(Suppressor, Symbol("@suppress")),
LineNumberNode(@__LINE__, @__FILE__),
:(include($filename)),
:(include($filename))
)
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_aqua.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using FunctionImplementations: FunctionImplementations
using Aqua: Aqua
using FunctionImplementations: FunctionImplementations
using Test: @testset

@testset "Code quality (Aqua.jl)" begin
Expand Down
12 changes: 8 additions & 4 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ using Test: @test, @testset
end
@testset "ImplementationStyle" begin
# Test basic ImplementationStyle trait for different array types
@test FI.ImplementationStyle(typeof([1, 2, 3])) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof([1, 2, 3])) ≡
FI.DefaultArrayImplementationStyle()
@test FI.style([1, 2, 3]) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof([1 2; 3 4])) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof(rand(2, 3, 4))) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof([1 2; 3 4])) ≡
FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof(rand(2, 3, 4))) ≡
FI.DefaultArrayImplementationStyle()

# Test custom ImplementationStyle definition
struct CustomImplementationStyle <: FI.ImplementationStyle end
Expand Down Expand Up @@ -75,7 +78,8 @@ using Test: @test, @testset
@test result ≡ FI.DefaultArrayImplementationStyle()

# Test result_style with single argument
@test FI.result_style(FI.DefaultArrayImplementationStyle()) isa FI.DefaultArrayImplementationStyle
@test FI.result_style(FI.DefaultArrayImplementationStyle()) isa
FI.DefaultArrayImplementationStyle

# Test result_style with two identical styles
s = FI.DefaultArrayImplementationStyle()
Expand Down