Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Add math.concat[1-4]D and math.slice[1-4]D #151

Merged
merged 7 commits into from
Sep 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ index.ts
npm-debug.log
.DS_Store
dist/
.idea/
.idea/
6 changes: 2 additions & 4 deletions src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,10 @@ export class Concat3DNode extends Node {
public axis: number) {
super(
graph, 'Concat3D', {x1, x2},
new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis)));
new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis)));
}
validate() {
concat_util.assertConcatShapesMatch(
this.x1.shape, this.x2.shape, 3, this.axis);
concat_util.assertParams(this.x1.shape, this.x2.shape, this.axis);
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/graph/ops/concat3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ export class Concat3D extends Operation {
private x1Tensor: Tensor, private x2Tensor: Tensor, private axis: number,
private yTensor: Tensor) {
super();
concat_util.assertConcatShapesMatch(
x1Tensor.shape, x2Tensor.shape, 3, axis);
concat_util.assertParams(x1Tensor.shape, x2Tensor.shape, axis);
}

feedForward(math: NDArrayMath, inferenceArrays: TensorArrayMap) {
Expand Down
9 changes: 3 additions & 6 deletions src/graph/ops/concat3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand All @@ -75,8 +74,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand All @@ -99,8 +97,7 @@ describe('concat3d operation', () => {

x1Tensor = new Tensor(x1.shape);
x2Tensor = new Tensor(x2.shape);
yTensor = new Tensor(
concat_util.computeConcatOutputShape(x1.shape, x2.shape, axis));
yTensor = new Tensor(concat_util.computeOutShape(x1.shape, x2.shape, axis));

tensorArrayMap.set(x1Tensor, x1);
tensorArrayMap.set(x2Tensor, x2);
Expand Down
34 changes: 16 additions & 18 deletions src/math/concat_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,33 @@

import * as util from '../util';

export function assertConcatShapesMatch(
x1Shape: number[], x2Shape: number[], rank: number, axis: number,
errorMessagePrefix = '') {
export function assertParams(aShape: number[], bShape: number[], axis: number) {
const aRank = aShape.length;
const bRank = bShape.length;
util.assert(
x1Shape.length === rank,
errorMessagePrefix + `x1 shape should be of rank ${rank}.`);
util.assert(
x2Shape.length === rank,
errorMessagePrefix + `x2 shape should be of rank ${rank}.`);
aShape.length === bShape.length,
`Error in concat${aRank}D: rank of x1 (${aRank}) and x2 (${bRank}) ` +
`must be the same.`);

util.assert(
axis >= 0 && axis < rank, `axis must be between 0 and ${rank - 1}.`);
axis >= 0 && axis < aRank,
`Error in concat${aRank}D: axis must be ` +
`between 0 and ${aRank - 1}.`);

for (let i = 0; i < rank; i++) {
for (let i = 0; i < aRank; i++) {
util.assert(
(i === axis) || (x1Shape[i] === x2Shape[i]),
errorMessagePrefix +
`Shape (${x1Shape}) does not match (${x2Shape}) along ` +
`the non-concatenated axis ${i}.`);
(i === axis) || (aShape[i] === bShape[i]),
`Error in concat${aRank}D: Shape (${aShape}) does not match ` +
`(${bShape}) along the non-concatenated axis ${i}.`);
}
}

export function computeConcatOutputShape(
x1Shape: number[], x2Shape: number[],
axis: number): [number, number, number] {
export function computeOutShape(
x1Shape: number[], x2Shape: number[], axis: number): number[] {
util.assert(
x1Shape.length === x2Shape.length,
'x1 and x2 should have the same rank.');
const outputShape = x1Shape.slice();
outputShape[axis] += x2Shape[axis];
return outputShape as [number, number, number];
return outputShape;
}
15 changes: 8 additions & 7 deletions src/math/concat_util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,39 @@ import * as concat_util from './concat_util';
describe('concat_util.assertConcatShapesMatch rank=3D', () => {
it('Non-3D tensor x1', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1], [1, 2, 3], 3, 1);
concat_util.assertParams([1], [1, 2, 3], 1);
};

expect(assertFn).toThrow();
});

it('Non-3D tensor x2', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1, 2, 3], [2, 3], 3, 1);
concat_util.assertParams([1, 2, 3], [2, 3], 1);
};

expect(assertFn).toThrow();
});

it('axis out of bound', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([1, 2, 3], [1, 2, 3], 3, 4);
concat_util.assertParams([1, 2, 3], [1, 2, 3], 4);
};

expect(assertFn).toThrow();
});

it('non-axis shape mismatch', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([2, 3, 3], [2, 2, 4], 3, 2);
concat_util.assertParams([2, 3, 3], [2, 2, 4], 2);
};

expect(assertFn).toThrow();
});

it('shapes line up', () => {
const assertFn = () => {
concat_util.assertConcatShapesMatch([2, 3, 3], [2, 3, 4], 3, 2);
concat_util.assertParams([2, 3, 3], [2, 3, 4], 2);
};

expect(assertFn).not.toThrow();
Expand All @@ -61,7 +61,8 @@ describe('concat_util.assertConcatShapesMatch rank=3D', () => {

describe('concat_util.computeConcatOutputShape', () => {
it('compute output shape, axis=0', () => {
expect(concat_util.computeConcatOutputShape([2, 2, 3], [1, 2, 3], 0))
.toEqual([3, 2, 3]);
expect(concat_util.computeOutShape([2, 2, 3], [1, 2, 3], 0)).toEqual([
3, 2, 3
]);
});
});
Loading