Skip to content

Commit

Permalink
feat: add norm method
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Jan 25, 2024
1 parent ef8d532 commit c94237f
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 5 deletions.
171 changes: 166 additions & 5 deletions lib/numo/tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,172 @@ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false)
[vals, vecs]
end

# Computes the matrix or vector norm.
#
# | ord | matrix norm | vector norm |
# | ----- | ---------------------- | --------------------------- |
# | nil | Frobenius norm | 2-norm |
# | 'fro' | Frobenius norm | - |
# | 'nuc' | nuclear norm | - |
# | 'inf' | x.abs.sum(axis:-1).max | x.abs.max |
# | 0 | - | (x.ne 0).sum |
# | 1 | x.abs.sum(axis:-2).max | same as below |
# | 2 | 2-norm (max sing_vals) | same as below |
# | other | - | (x.abs**ord).sum**(1.0/ord) |
#
# @example
# require 'numo/tiny_linalg'
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
#
# # matrix norm
# x = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]
# pp Numo::Linalg.norm(x)
# # => 10
#
# # vector norm
# x = Numo::DFloat[3, -4]
# pp Numo::Linalg.norm(x)
# # => 5
#
# @param a [Numo::NArray] The matrix or vector (>= 1-dimensinal NArray)
# @param ord [String/Numeric] The order of the norm.
# @param axis [Integer/Array] The applied axes.
# @param keepdims [Bool] The flag indicating whether to leave the normed axes in the result as dimensions with size one.
# @return [Numo::NArray/Numeric] The norm of the matrix or vectors.
def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)

return 0.0 if a.empty?

# for compatibility with Numo::Linalg.norm
if ord.is_a?(String)
if ord == 'inf'
ord = Float::INFINITY
elsif ord == '-inf'
ord = -Float::INFINITY
end
end

if axis.nil?
norm = case a.ndim
when 1
Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
when 2
if ord.nil? || ord == 'fro'
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
elsif ord.is_a?(Numeric)
if ord == 1
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
elsif !ord.infinite?.nil? && ord.infinite?.positive?
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
end
end
else
if ord.nil?
b = a.flatten.dup
Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
end
end
unless norm.nil?
norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
return norm
end
end

if axis.nil?
axis = Array.new(a.ndim) { |d| d }
else
case axis
when Integer
axis = [axis]
when Array, Numo::NArray
axis = axis.flatten.to_a
else
raise ArgumentError, "invalid axis: #{axis}"
end
end

raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }

if axis.size == 1
ord ||= 2
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)

ord_inf = ord.infinite?
if ord_inf.nil?
case ord
when 0
a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
when 1
a.abs.sum(axis: axis, keepdims: keepdims)
else
(a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
end
elsif ord_inf.positive?
a.abs.max(axis: axis, keepdims: keepdims)
else
a.abs.min(axis: axis, keepdims: keepdims)
end
else
ord ||= 'fro'
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1

r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }

norm = if ord.is_a?(String)
raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)

if ord == 'fro'
Numo::NMath.sqrt((a.abs**2).sum(axis: axis))
else
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.sum(axis: -1)
end
else
ord_inf = ord.infinite?
if ord_inf.nil?
case ord
when -2
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.min(axis: -1)
when -1
c_axis -= 1 if c_axis > r_axis
a.abs.sum(axis: r_axis).min(axis: c_axis)
when 1
c_axis -= 1 if c_axis > r_axis
a.abs.sum(axis: r_axis).max(axis: c_axis)
when 2
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.max(axis: -1)
else
raise ArgumentError, "invalid ord: #{ord}"
end
else
r_axis -= 1 if r_axis > c_axis
if ord_inf.positive?
a.abs.sum(axis: c_axis).max(axis: r_axis)
else
a.abs.sum(axis: c_axis).min(axis: r_axis)
end
end
end
if keepdims
norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
norm = norm.reshape(*([1] * a.ndim))
end

norm
end
end

# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
#
# @example
Expand Down Expand Up @@ -575,11 +741,6 @@ def eigvalsh(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
end

# @!visibility private
def norm(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
end

# @!visibility private
def cond(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
Expand Down
55 changes: 55 additions & 0 deletions test/test_tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,61 @@ def test_eigh
assert((e - 1).abs.max < 1e-7)
end

def test_norm
# empty array
assert_equal(0, Numo::TinyLinalg.norm([]))
assert_equal(0, Numo::TinyLinalg.norm(Numo::DFloat[]))

# vector
a = Numo::DFloat[3, -4]
b = Numo::DFloat[1, 0, 2, 0, 3]

assert_equal(5, Numo::TinyLinalg.norm(a))
assert_equal(5, Numo::TinyLinalg.norm(a, 2))
assert_equal(7, Numo::TinyLinalg.norm(a, 1))
assert_equal(3, Numo::TinyLinalg.norm(b, 0))
assert_in_delta(2.4, Numo::TinyLinalg.norm(a, -2))
assert_equal(4, Numo::TinyLinalg.norm(a, Float::INFINITY))
assert_equal(3, Numo::TinyLinalg.norm(a, -Float::INFINITY))
assert_equal(4, Numo::TinyLinalg.norm(a, 'inf'))
assert_equal(3, Numo::TinyLinalg.norm(a, '-inf'))
assert_equal(Numo::DFloat[5], Numo::TinyLinalg.norm(a, keepdims: true))
assert_equal(Numo::DFloat[5], Numo::TinyLinalg.norm(a, 2, keepdims: true))
assert_equal(5, Numo::TinyLinalg.norm(a, axis: 0))
assert_match(/axis is out of range/, assert_raises(ArgumentError) { Numo::TinyLinalg.norm(a, axis: 1) }.message)
assert_match(/invalid axis/, assert_raises(ArgumentError) { Numo::TinyLinalg.norm(a, axis: '1') }.message)

# matrix
a = Numo::DFloat[[1, 2, -3, 1], [-4, 1, 8, 2]]

assert_equal(10, Numo::TinyLinalg.norm(a))
assert_equal(10, Numo::TinyLinalg.norm(a, 'fro'))
assert((Numo::TinyLinalg.norm(a, 'nuc') - 12.3643).abs < 1e-4)
assert((Numo::TinyLinalg.norm(a, 2) - 9.6144).abs < 1e-4)
assert_equal(11, Numo::TinyLinalg.norm(a, 1))
assert_match(/invalid ord/, assert_raises(ArgumentError) { Numo::TinyLinalg.norm(a, 0) }.message)
assert_equal(3, Numo::TinyLinalg.norm(a, -1))
assert((Numo::TinyLinalg.norm(a, -2) - 2.7498).abs < 1e-4)
assert_equal(15, Numo::TinyLinalg.norm(a, Float::INFINITY))
assert_equal(7, Numo::TinyLinalg.norm(a, -Float::INFINITY))
assert_equal(15, Numo::TinyLinalg.norm(a, 'inf'))
assert_equal(7, Numo::TinyLinalg.norm(a, '-inf'))
assert_equal(Numo::DFloat[[10]], Numo::TinyLinalg.norm(a, keepdims: true))
assert_equal(Numo::DFloat[5, 3, 11, 3], Numo::TinyLinalg.norm(a, 1, axis: 0))
assert_equal(Numo::DFloat[7, 15], Numo::TinyLinalg.norm(a, 1, axis: 1))
assert_equal(Numo::DFloat[[5, 3, 11, 3]], Numo::TinyLinalg.norm(a, 1, axis: 0, keepdims: true))
assert_equal(Numo::DFloat[[7], [15]], Numo::TinyLinalg.norm(a, 1, axis: 1, keepdims: true))
assert_equal(10, Numo::TinyLinalg.norm(a, 'fro', axis: [0, 1]))
assert_equal(11, Numo::TinyLinalg.norm(a, 1, axis: [0, 1]))
assert_equal(Numo::DFloat[[15]], Numo::TinyLinalg.norm(a, Float::INFINITY, axis: [0, 1], keepdims: true))

# tensor
a = Numo::DFloat[[[2, 3, 1], [1, 2, 4]], [[2, 2, 3], [3, 2, 4]]]

assert_equal(9, Numo::TinyLinalg.norm(a))
assert_equal(Numo::DFloat[[[9]]], Numo::TinyLinalg.norm(a, keepdims: true))
end

def test_cholesky
a = Numo::DFloat.new(3, 3).rand - 0.5
b = a.transpose.dot(a)
Expand Down

0 comments on commit c94237f

Please sign in to comment.