-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors]
NamedDimsArrays
submodule (#1267)
- Loading branch information
Showing
20 changed files
with
483 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
style = "blue" | ||
indent = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# NamedDimsArrays.jl | ||
|
||
````julia | ||
using NDTensors.NamedDimsArrays: align, dimnames, named, unname | ||
using NDTensors.TensorAlgebra: TensorAlgebra | ||
|
||
# Named dimensions | ||
i = named(2, "i") | ||
j = named(2, "j") | ||
k = named(2, "k") | ||
|
||
# Arrays with named dimensions | ||
na1 = randn(i, j) | ||
na2 = randn(j, k) | ||
|
||
@show dimnames(na1) == ("i", "j") | ||
|
||
# Indexing | ||
@show na1[j => 2, i => 1] == na1[1, 2] | ||
|
||
# Tensor contraction | ||
na_dest = TensorAlgebra.contract(na1, na2) | ||
|
||
@show issetequal(dimnames(na_dest), ("i", "k")) | ||
# `unname` removes the names and returns an `Array` | ||
@show unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2) | ||
|
||
# Permute dimensions (like `ITensors.permute`) | ||
na1 = align(na1, (j, i)) | ||
@show na1[i => 1, j => 2] == na1[2, 1] | ||
```` | ||
|
||
--- | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# # NamedDimsArrays.jl | ||
|
||
using NDTensors.NamedDimsArrays: align, dimnames, named, unname | ||
using NDTensors.TensorAlgebra: TensorAlgebra | ||
|
||
## Named dimensions | ||
i = named(2, "i") | ||
j = named(2, "j") | ||
k = named(2, "k") | ||
|
||
## Arrays with named dimensions | ||
na1 = randn(i, j) | ||
na2 = randn(j, k) | ||
|
||
@show dimnames(na1) == ("i", "j") | ||
|
||
## Indexing | ||
@show na1[j => 2, i => 1] == na1[1, 2] | ||
|
||
## Tensor contraction | ||
na_dest = TensorAlgebra.contract(na1, na2) | ||
|
||
@show issetequal(dimnames(na_dest), ("i", "k")) | ||
## `unname` removes the names and returns an `Array` | ||
@show unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2) | ||
|
||
## Permute dimensions (like `ITensors.permute`) | ||
na1 = align(na1, (j, i)) | ||
@show na1[i => 1, j => 2] == na1[2, 1] |
6 changes: 6 additions & 0 deletions
6
...amedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
module NamedDimsArraysTensorAlgebraExt | ||
using ..NamedDimsArrays: NamedDimsArrays | ||
using ...NDTensors.TensorAlgebra: TensorAlgebra | ||
|
||
include("contract.jl") | ||
end |
18 changes: 18 additions & 0 deletions
18
NDTensors/src/NamedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/src/contract.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
using NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, named, unname | ||
using NDTensors.TensorAlgebra: contract | ||
|
||
function TensorAlgebra.contract( | ||
na1::AbstractNamedDimsArray, na2::AbstractNamedDimsArray, α, β; kwargs... | ||
) | ||
a_dest, names_dest = contract( | ||
unname(na1), dimnames(na1), unname(na2), dimnames(na2), α, β; kwargs... | ||
) | ||
# TODO: Automate `Tuple` conversion of names? | ||
return named(a_dest, Tuple(names_dest)) | ||
end | ||
|
||
function TensorAlgebra.contract( | ||
na1::AbstractNamedDimsArray, na2::AbstractNamedDimsArray; kwargs... | ||
) | ||
return contract(na1, na2, true, false; kwargs...) | ||
end |
13 changes: 13 additions & 0 deletions
13
NDTensors/src/NamedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/test/runtests.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
using Test: @test, @testset | ||
using NDTensors.NamedDimsArrays: named, unname | ||
using NDTensors.TensorAlgebra: TensorAlgebra | ||
|
||
@testset "NamedDimsArraysTensorAlgebraExt" begin | ||
i = named(2, "i") | ||
j = named(2, "j") | ||
k = named(2, "k") | ||
na1 = randn(i, j) | ||
na2 = randn(j, k) | ||
na_dest = TensorAlgebra.contract(na1, na2) | ||
@test unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Literate | ||
using NDTensors.NamedDimsArrays: NamedDimsArrays | ||
Literate.markdown( | ||
joinpath( | ||
pkgdir(NamedDimsArrays), "src", "NamedDimsArrays", "examples", "example_readme.jl" | ||
), | ||
joinpath(pkgdir(NamedDimsArrays), "src", "NamedDimsArrays"); | ||
flavor=Literate.CommonMarkFlavor(), | ||
name="README", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
module NamedDimsArrays | ||
include("traits.jl") | ||
include("abstractnamedint.jl") | ||
include("abstractnamedunitrange.jl") | ||
include("abstractnameddimsarray.jl") | ||
include("namedint.jl") | ||
include("namedunitrange.jl") | ||
include("nameddimsarray.jl") | ||
|
||
# Extensions | ||
include("../ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl") | ||
end |
126 changes: 126 additions & 0 deletions
126
NDTensors/src/NamedDimsArrays/src/abstractnameddimsarray.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# https://github.com/invenia/NamedDims.jl | ||
# https://github.com/mcabbott/NamedPlus.jl | ||
|
||
abstract type AbstractNamedDimsArray{T,N,Names} <: AbstractArray{T,N} end | ||
|
||
# Required interface | ||
|
||
# Output the names. | ||
dimnames(a::AbstractNamedDimsArray) = error("Not implemented") | ||
|
||
# Unwrapping the names | ||
Base.parent(::AbstractNamedDimsArray) = error("Not implemented") | ||
|
||
# Set the names of an unnamed AbstractArray | ||
# `ndims(a) == length(names)` | ||
# This is a constructor | ||
## named(a::AbstractArray, names) = error("Not implemented") | ||
|
||
# Traits | ||
isnamed(::AbstractNamedDimsArray) = true | ||
|
||
# AbstractArray interface | ||
# TODO: Use `unname` instead of `parent`? | ||
|
||
# Helper function, move to `utils.jl`. | ||
named_tuple(t::Tuple, names) = ntuple(i -> named(t[i], names[i]), length(t)) | ||
|
||
# TODO: Use the proper type, `namedaxistype(a)`. | ||
Base.axes(a::AbstractNamedDimsArray) = named_tuple(axes(unname(a)), dimnames(a)) | ||
# TODO: Use the proper type, `namedlengthtype(a)`. | ||
Base.size(a::AbstractNamedDimsArray) = length.(axes(a)) | ||
Base.getindex(a::AbstractNamedDimsArray, I...) = unname(a)[I...] | ||
function Base.setindex!(a::AbstractNamedDimsArray, x, I...) | ||
unname(a)[I...] = x | ||
return a | ||
end | ||
|
||
# Derived interface | ||
|
||
# Output the names. | ||
dimname(a::AbstractNamedDimsArray, i) = dimnames(a)[i] | ||
|
||
# Renaming | ||
# Unname and set new naems | ||
rename(a::AbstractNamedDimsArray, names) = named(unname(a), names) | ||
|
||
# replacenames(a, :i => :a, :j => :b) | ||
# `rename` in `NamedPlus.jl`. | ||
replacenames(a::AbstractNamedDimsArray, names::Pair) = error("Not implemented yet") | ||
|
||
# Either define new names or replace names | ||
setnames(a::AbstractArray, names) = named(a, names) | ||
setnames(a::AbstractNamedDimsArray, names) = rename(a, names) | ||
|
||
function getperm(x, y) | ||
return map(xᵢ -> findfirst(isequal(xᵢ), y), x) | ||
end | ||
|
||
function get_name_perm(a::AbstractNamedDimsArray, names::Tuple) | ||
return getperm(dimnames(a), names) | ||
end | ||
|
||
function get_name_perm( | ||
a::AbstractNamedDimsArray, namedints::Tuple{Vararg{AbstractNamedInt}} | ||
) | ||
return getperm(size(a), namedints) | ||
end | ||
|
||
function get_name_perm( | ||
a::AbstractNamedDimsArray, namedaxes::Tuple{Vararg{AbstractNamedUnitRange}} | ||
) | ||
return getperm(axes(a), namedaxes) | ||
end | ||
|
||
# Indexing | ||
# a[:i => 2, :j => 3] | ||
# TODO: Write a generic version using `dim`. | ||
# TODO: Define a `NamedIndex` type for indexing? | ||
function Base.getindex(a::AbstractNamedDimsArray, I::Pair...) | ||
perm = get_name_perm(a, first.(I)) | ||
i = last.(I) | ||
return unname(a)[map(p -> i[p], perm)...] | ||
end | ||
|
||
# a[:i => 2, :j => 3] = 12 | ||
# TODO: Write a generic version using `dim`. | ||
function Base.setindex!(a::AbstractNamedDimsArray, value, I::Pair...) | ||
perm = get_name_perm(a, first.(I)) | ||
i = last.(I) | ||
unname(a)[map(p -> i[p], perm)...] = value | ||
return a | ||
end | ||
|
||
# Output the dimension of the specified name. | ||
dim(a::AbstractNamedDimsArray, name) = findfirst(==(name), dimnames(a)) | ||
|
||
# Output the dimensions of the specified names. | ||
dims(a::AbstractNamedDimsArray, names) = map(name -> dim(a, name), names) | ||
|
||
# Unwrapping the names | ||
unname(a::AbstractNamedDimsArray) = parent(a) | ||
unname(a::AbstractArray) = a | ||
|
||
# Permute into a certain order. | ||
# align(a, (:j, :k, :i)) | ||
# Like `named(nameless(a, names), names)` | ||
function align(a::AbstractNamedDimsArray, names) | ||
perm = get_name_perm(a, names) | ||
# TODO: Avoid permutation if it is a trivial permutation? | ||
return typeof(a)(permutedims(unname(a), perm), names) | ||
end | ||
|
||
# Unwrapping names and permuting | ||
# nameless(a, (:j, :i)) | ||
# Could just call `unname`? | ||
## nameless(a::AbstractNamedDimsArray, names) = unname(align(a, names)) | ||
unname(a::AbstractNamedDimsArray, names) = unname(align(a, names)) | ||
|
||
# In `TensorAlgebra` this this `fuse` and `unfuse`, | ||
# in `NDTensors`/`ITensors` this is `combine` and `uncombine`. | ||
# t = split(g, :n => (j=4, k=5)) | ||
# join(t, (:i, :k) => :χ) | ||
|
||
# TensorAlgebra | ||
# contract, fusedims, unfusedims, qr, eigen, svd, add, etc. | ||
# Some of these can simply wrap `TensorAlgebra.jl` functions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
abstract type AbstractNamedInt{Value,Name} <: Integer end | ||
|
||
# Interface | ||
unname(i::AbstractNamedInt) = error("Not implemented") | ||
name(i::AbstractNamedInt) = error("Not implemented") | ||
|
||
# Derived | ||
unname(::Type{<:AbstractNamedInt{Value}}) where {Value} = Value | ||
|
||
# Integer interface | ||
# TODO: Should this make a random name, or require defining a way | ||
# to combine names? | ||
Base.:*(i1::AbstractNamedInt, i2::AbstractNamedInt) = unname(i1) * unname(i2) | ||
Base.:-(i::AbstractNamedInt) = typeof(i)(-unname(i), name(i)) | ||
|
||
# TODO: Define for `NamedInt`, `NamedUnitRange` fallback? | ||
# Base.OneTo(stop::AbstractNamedInt) = namedoneto(stop) | ||
## nameduniterange_type(::Type{<:AbstractNamedInt}) = error("Not implemented") | ||
|
||
# TODO: Use conversion from `AbstractNamedInt` to `AbstractNamedUnitRange` | ||
# instead of general `named`. | ||
# Base.OneTo(stop::AbstractNamedInt) = namedoneto(stop) | ||
Base.OneTo(stop::AbstractNamedInt) = named(Base.OneTo(unname(stop)), name(stop)) | ||
|
||
# TODO: Is this needed? | ||
# Include the name as well? | ||
Base.:<(i1::AbstractNamedInt, i2::AbstractNamedInt) = unname(i1) < unname(i2) | ||
## Base.zero(type::Type{<:AbstractNamedInt}) = zero(unname(type)) | ||
|
||
function Base.promote_rule(type1::Type{<:AbstractNamedInt}, type2::Type{<:Integer}) | ||
return promote_type(unname(type1), type2) | ||
end | ||
(type::Type{<:Integer})(i::AbstractNamedInt) = type(unname(i)) | ||
# TODO: Use conversion from `AbstractNamedInt` to `AbstractNamedUnitRange` | ||
# instead of general `named`. | ||
function Base.oftype(i1::AbstractNamedInt, i2::Integer) | ||
return named(convert(typeof(unname(i1)), i2), name(i1)) | ||
end | ||
|
||
# Traits | ||
isnamed(::AbstractNamedInt) = true |
14 changes: 14 additions & 0 deletions
14
NDTensors/src/NamedDimsArrays/src/abstractnamedunitrange.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
abstract type AbstractNamedUnitRange{T,Value<:AbstractUnitRange{T},Name} <: | ||
AbstractUnitRange{T} end | ||
|
||
# Required interface | ||
unname(::AbstractNamedUnitRange) = error("Not implemented") | ||
name(::AbstractNamedUnitRange) = error("Not implemented") | ||
|
||
# Traits | ||
isnamed(::AbstractNamedUnitRange) = true | ||
|
||
# Unit range | ||
Base.first(i::AbstractNamedUnitRange) = first(unname(i)) | ||
Base.last(i::AbstractNamedUnitRange) = last(unname(i)) | ||
Base.length(i::AbstractNamedUnitRange) = named(length(unname(i)), name(i)) |
Oops, something went wrong.