Skip to content

Commit

Permalink
Broadcast refactor (#3220)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvd101x authored Jun 19, 2024
1 parent 0602198 commit 5785cb9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 79 deletions.
108 changes: 38 additions & 70 deletions src/type/matrix/utils/broadcast.js
Original file line number Diff line number Diff line change
@@ -1,72 +1,40 @@
import { checkBroadcastingRules } from '../../../utils/array.js'
import { factory } from '../../../utils/factory.js'

const name = 'broadcast'

const dependancies = ['concat']

export const createBroadcast = /* #__PURE__ */ factory(
name, dependancies,
({ concat }) => {
/**
* Broadcasts two matrices, and return both in an array
* It checks if it's possible with broadcasting rules
*
* @param {Matrix} A First Matrix
* @param {Matrix} B Second Matrix
*
* @return {Matrix[]} [ broadcastedA, broadcastedB ]
*/
return function (A, B) {
const N = Math.max(A._size.length, B._size.length) // max number of dims
if (A._size.length === B._size.length) {
if (A._size.every((dim, i) => dim === B._size[i])) {
// If matrices have the same size return them
return [A, B]
}
}

const sizeA = _padLeft(A._size, N, 0) // pad to the left to align dimensions to the right
const sizeB = _padLeft(B._size, N, 0) // pad to the left to align dimensions to the right

// calculate the max dimensions
const sizeMax = []

for (let dim = 0; dim < N; dim++) {
sizeMax[dim] = Math.max(sizeA[dim], sizeB[dim])
}

// check if the broadcasting rules applyes for both matrices
checkBroadcastingRules(sizeA, sizeMax)
checkBroadcastingRules(sizeB, sizeMax)

// reshape A or B if needed to make them ready for concat
let AA = A.clone()
let BB = B.clone()
if (AA._size.length < N) {
AA.reshape(_padLeft(AA._size, N, 1))
} else if (BB._size.length < N) {
BB.reshape(_padLeft(BB._size, N, 1))
}

// stretches the matrices on each dimension to make them the same size
for (let dim = 0; dim < N; dim++) {
if (AA._size[dim] < sizeMax[dim]) { AA = _stretch(AA, sizeMax[dim], dim) }
if (BB._size[dim] < sizeMax[dim]) { BB = _stretch(BB, sizeMax[dim], dim) }
}

// return the array with the two broadcasted matrices
return [AA, BB]
}

function _padLeft (shape, N, filler) {
// pads an array of dimensions with numbers to the left, unitl the number of dimensions is N
return [...Array(N - shape.length).fill(filler), ...shape]
}
import { broadcastSizes, broadcastTo } from '../../../utils/array.js'
import { deepStrictEqual } from '../../../utils/object.js'

/**
* Broadcasts two matrices, and return both in an array
* It checks if it's possible with broadcasting rules
*
* @param {Matrix} A First Matrix
* @param {Matrix} B Second Matrix
*
* @return {Matrix[]} [ broadcastedA, broadcastedB ]
*/

export function broadcast (A, B) {
if (deepStrictEqual(A.size(), B.size())) {
// If matrices have the same size return them
return [A, B]
}

function _stretch (arrayToStretch, sizeToStretch, dimToStretch) {
// stretches a matrix up to a certain size in a certain dimension
return concat(...Array(sizeToStretch).fill(arrayToStretch), dimToStretch)
}
// calculate the broadcasted sizes
const newSize = broadcastSizes(A.size(), B.size())

// return the array with the two broadcasted matrices
return [A, B].map(M => _broadcastTo(M, newSize))
}

/**
* Broadcasts a matrix to the given size.
*
* @param {Matrix} M - The matrix to be broadcasted.
* @param {number[]} size - The desired size of the broadcasted matrix.
* @returns {Matrix} The broadcasted matrix.
* @throws {Error} If the size parameter is not an array of numbers.
*/
function _broadcastTo (M, size) {
if (deepStrictEqual(M.size(), size)) {
return M
}
)
return M.create(broadcastTo(M.valueOf(), size), M.datatype())
}
7 changes: 3 additions & 4 deletions src/type/matrix/utils/matrixAlgorithmSuite.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@ import { factory } from '../../../utils/factory.js'
import { extend } from '../../../utils/object.js'
import { createMatAlgo13xDD } from './matAlgo13xDD.js'
import { createMatAlgo14xDs } from './matAlgo14xDs.js'
import { createBroadcast } from './broadcast.js'
import { broadcast } from './broadcast.js'

const name = 'matrixAlgorithmSuite'
const dependencies = ['typed', 'matrix', 'concat']
const dependencies = ['typed', 'matrix']

export const createMatrixAlgorithmSuite = /* #__PURE__ */ factory(
name, dependencies, ({ typed, matrix, concat }) => {
name, dependencies, ({ typed, matrix }) => {
const matAlgo13xDD = createMatAlgo13xDD({ typed })
const matAlgo14xDs = createMatAlgo14xDs({ typed })
const broadcast = createBroadcast({ concat })

/**
* Return a signatures object with the usual boilerplate of
Expand Down
4 changes: 2 additions & 2 deletions test/unit-tests/function/arithmetic/dotDivide.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ describe('dotDivide', function () {
assert.deepStrictEqual(dotDivide(a, b), math.matrix([[1 / 5, Infinity], [0, 4 / 8]]))
})

it('should throw an error when dividing element-wise with differing size', function () {
assert.throws(function () { dotDivide(math.sparse([[1, 2], [3, 4]]), math.sparse([[1]])) })
it('should throw an error when dividing element-wise with differing size is not broadcastable', function () {
assert.throws(function () { dotDivide(math.sparse([[1, 2], [3, 4]]), math.sparse([1, 2, 3])) })
})
})

Expand Down
4 changes: 1 addition & 3 deletions test/unit-tests/type/matrix/utils/broadcast.test.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import assert from 'assert'
import math from '../../../../../src/defaultInstance.js'
import { createBroadcast } from '../../../../../src/type/matrix/utils/broadcast.js'
const concat = math.concat
import { broadcast } from '../../../../../src/type/matrix/utils/broadcast.js'
const matrix = math.matrix
const broadcast = createBroadcast({ concat })

describe('broadcast', function () {
it('should return matrices as such if they are the same size', function () {
Expand Down

0 comments on commit 5785cb9

Please sign in to comment.