Skip to content

Commit

Permalink
feat: implement norm method
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Jan 22, 2024
1 parent f001cbe commit 07b33c0
Showing 1 changed file with 126 additions and 5 deletions.
131 changes: 126 additions & 5 deletions lib/numo/tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,132 @@ def eigh(a, b = nil, vals_only: false, vals_range: nil, uplo: 'U', turbo: false)
[vals, vecs]
end

# @!visibility private
# rubocop: disable Metrics/BlockNesting
def norm(a, ord = nil, axis: nil, keepdims: false) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
a = Numo::NArray.asarray(a) unless a.is_a?(Numo::NArray)

if axis.nil?
norm = case a.ndim
when 1
Numo::TinyLinalg::Blas.send(:"#{blas_char(a)}nrm2", a) if ord.nil? || ord == 2
when 2
if ord.nil? || ord == 'fro'
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'F')
elsif ord.is_a?(Numeric)
if ord == 1
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: '1')
elsif !ord.infinite?.nil? && ord.infinite?.positive?
Numo::TinyLinalg::Lapack.send(:"#{blas_char(a)}lange", a, norm: 'I')
end
end
else
if ord.nil?
b = a.flatten.dup
Numo::TinyLinalg::Blas.send(:"#{blas_char(b)}nrm2", b)
end
end
unless norm.nil?
norm = Numo::NArray.asarray(norm).reshape(*([1] * a.ndim)) if keepdims
return norm
end
end

if axis.nil?
axis = Array.new(a.ndim) { |d| d }
else
case axis
when Integer
axis = [axis]
when Array, Numo::NArray
axis = axis.flatten.to_a
else
raise ArgumentError, "invalid axis: #{axis}"
end
end

raise ArgumentError, "the number of dimensions of axis is inappropriate for the norm: #{axis.size}" unless axis.size == 1 || axis.size == 2
raise ArgumentError, "axis is out of range: #{axis}" unless axis.all? { |ax| (-a.ndim...a.ndim).cover?(ax) }

if axis.size == 1
ord ||= 2
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(Numeric)

ord_inf = ord.infinite?
if ord_inf.nil?
case ord
when 0
a.class.cast(a.ne(0)).sum(axis: axis, keepdims: keepdims)
when 1
a.abs.sum(axis: axis, keepdims: keepdims)
else
(a.abs**ord).sum(axis: axis, keepdims: keepdims)**1.fdiv(ord)
end
elsif ord_inf.positive?
a.abs.max(axis: axis, keepdims: keepdims)
else
a.abs.min(axis: axis, keepdims: keepdims)
end
else
ord ||= 'fro'
raise ArgumentError, "invalid ord: #{ord}" unless ord.is_a?(String) || ord.is_a?(Numeric)
raise ArgumentError, "invalid axis: #{axis}" if axis.uniq.size == 1

r_axis, c_axis = axis.map { |ax| ax.negative? ? ax + a.ndim : ax }

norm = if ord.is_a?(String)
raise ArgumentError, "invalid ord: #{ord}" unless %w[fro nuc].include?(ord)

if ord == 'fro'
Math.sqrt((a.abs**2).sum(axis: axis))
else
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.sum(axis: -1)
end
else
ord_inf = ord.infinite?
if ord_inf.nil?
case ord
when -2
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.min(axis: -1)
when -1
c_axis -= 1 if c_axis > r_axis
a.abs.sum(axis: r_axis).min(axis: c_axis)
when 1
c_axis -= 1 if c_axis > r_axis
a.abs.sum(axis: r_axis).max(axis: c_axis)
when 2
b = a.transpose(c_axis, r_axis).dup
gesvd = :"#{blas_char(b)}gesvd"
s, = Numo::TinyLinalg::Lapack.send(gesvd, b, jobu: 'N', jobvt: 'N')
s.max(axis: -1)
else
raise ArgumentError, "invalid ord: #{ord}"
end
else
r_axis -= 1 if r_axis > c_axis
if ord_inf.positive?
a.abs.sum(axis: r_axis).max(axis: c_axis)
else
a.abs.sum(axis: r_axis).min(axis: c_axis)
end
end
end
if keepdims
norm = Numo::NArray.asarray(norm) unless norm.is_a?(Numo::NArray)
norm = norm.reshape(*([1] * a.ndim))
end

norm
end
end
# rubocop: enable Metrics/BlockNesting

# Computes the Cholesky decomposition of a symmetric / Hermitian positive-definite matrix.
#
# @example
Expand Down Expand Up @@ -575,11 +701,6 @@ def eigvalsh(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
end

# @!visibility private
def norm(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
end

# @!visibility private
def cond(*args)
raise NotImplementedError, "#{__method__} is not yet implemented in Numo::TinyLinalg"
Expand Down

0 comments on commit 07b33c0

Please sign in to comment.