Skip to content

Commit

Permalink
Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET (
Browse files Browse the repository at this point in the history
#9)

* [add] Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET
  • Loading branch information
filipecosta90 authored Jun 7, 2020
1 parent e4b0a6a commit ab65392
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ AI._SCRIPTSCAN | N/A
AI.DAGRUN | N/A
AI.DAGRUN_RO | N/A
AI.INFO | info and infoResetStat (for resetting stats)
AI.CONFIG * | N/A
AI.CONFIG * | configLoadBackend and configBackendsPath


### Running tests
Expand Down
18 changes: 2 additions & 16 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,8 @@ export class Client {
});
}

public modelset(keName: string, m: Model): Promise<any> {
const args: any[] = [keName, m.backend.toString(), m.device];
if (m.tag !== undefined) {
args.push('TAG');
args.push(m.tag.toString());
}
if (m.inputs.length > 0) {
args.push('INPUTS');
m.inputs.forEach((value) => args.push(value));
}
if (m.outputs.length > 0) {
args.push('OUTPUTS');
m.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(m.blob);
public modelset(keyName: string, m: Model): Promise<any> {
const args: any[] = m.modelSetFlatArgs(keyName);
return this._sendCommand('ai.modelset', args);
}

Expand Down
87 changes: 84 additions & 3 deletions src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,32 @@ export class Model {
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow models)
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow models)
* @param blob - the Protobuf-serialized model
* @param batchsize - when provided with an batchsize that is greater than 0, the engine will batch incoming requests from multiple clients that use the model with input tensors of the same shape.
* @param minbatchsize - when provided with an minbatchsize that is greater than 0, the engine will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
*/
constructor(backend: Backend, device: string, inputs: string[], outputs: string[], blob: Buffer | undefined) {
constructor(
backend: Backend,
device: string,
inputs: string[],
outputs: string[],
blob: Buffer | undefined,
batchsize?: number,
minbatchsize?: number,
) {
this._backend = backend;
this._device = device;
this._inputs = inputs;
this._outputs = outputs;
this._blob = blob;
this._tag = undefined;
this._batchsize = batchsize || 0;
if (this._batchsize < 0) {
this._batchsize = 0;
}
this._minbatchsize = minbatchsize || 0;
if (this._minbatchsize < 0) {
this._minbatchsize = 0;
}
}

// tag is an optional string for tagging the model such as a version number or any arbitrary identifier
Expand Down Expand Up @@ -86,14 +104,39 @@ export class Model {
this._blob = value;
}

private _batchsize: number;

get batchsize(): number {
return this._batchsize;
}

set batchsize(value: number) {
this._batchsize = value;
}

private _minbatchsize: number;

get minbatchsize(): number {
return this._minbatchsize;
}

set minbatchsize(value: number) {
this._minbatchsize = value;
}

static NewModelFromModelGetReply(reply: any[]) {
let backend = null;
let device = null;
let tag = null;
let blob = null;
let batchsize: number = 0;
let minbatchsize: number = 0;
const inputs: string[] = [];
const outputs: string[] = [];
for (let i = 0; i < reply.length; i += 2) {
const key = reply[i];
const obj = reply[i + 1];

switch (key.toString()) {
case 'backend':
backend = BackendMap[obj.toString()];
Expand All @@ -106,9 +149,20 @@ export class Model {
tag = obj.toString();
break;
case 'blob':
// blob = obj;
blob = Buffer.from(obj);
break;
case 'batchsize':
batchsize = parseInt(obj.toString(), 10);
break;
case 'minbatchsize':
minbatchsize = parseInt(obj.toString(), 10);
break;
case 'inputs':
obj.forEach((input) => inputs.push(input));
break;
case 'outputs':
obj.forEach((output) => outputs.push(output));
break;
}
}
if (backend == null || device == null || blob == null) {
Expand All @@ -126,10 +180,37 @@ export class Model {
'AI.MODELGET reply did not had the full elements to build the Model. Missing ' + missingArr.join(',') + '.',
);
}
const model = new Model(backend, device, [], [], blob);
const model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
if (tag !== null) {
model.tag = tag;
}
return model;
}

modelSetFlatArgs(keyName: string) {
const args: any[] = [keyName, this.backend.toString(), this.device];
if (this.tag !== undefined) {
args.push('TAG');
args.push(this.tag.toString());
}
if (this.batchsize > 0) {
args.push('BATCHSIZE');
args.push(this.batchsize);
if (this.minbatchsize > 0) {
args.push('MINBATCHSIZE');
args.push(this.minbatchsize);
}
}
if (this.inputs.length > 0) {
args.push('INPUTS');
this.inputs.forEach((value) => args.push(value));
}
if (this.outputs.length > 0) {
args.push('OUTPUTS');
this.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(this.blob);
return args;
}
}
114 changes: 89 additions & 25 deletions tests/test_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,77 @@ it(
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const model = new Model(Backend.TF, 'CPU', ['a', 'b'], ['c'], modelBlob);
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
model.tag = 'test_tag';
const resultModelSet = await aiclient.modelset('mymodel', model);
expect(resultModelSet).to.equal('OK');

const modelOut = await aiclient.modelget('mymodel');
const modelOut: Model = await aiclient.modelget('mymodel');
expect(modelOut.blob.toString()).to.equal(modelBlob.toString());
for (let index = 0; index < modelOut.outputs.length; index++) {
expect(modelOut.outputs[index]).to.equal(outputs[index]);
expect(modelOut.outputs[index]).to.equal(model.outputs[index]);
}
for (let index = 0; index < modelOut.inputs.length; index++) {
expect(modelOut.inputs[index]).to.equal(inputs[index]);
expect(modelOut.inputs[index]).to.equal(model.inputs[index]);
}
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
aiclient.end(true);
}),
);

it(
'ai.modelget batching positive testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
model.tag = 'test_tag';
model.batchsize = 100;
model.minbatchsize = 5;
const resultModelSet = await aiclient.modelset('mymodel-batching', model);
expect(resultModelSet).to.equal('OK');
const modelOut: Model = await aiclient.modelget('mymodel-batching');
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop', modelOut);
expect(resultModelSet2).to.equal('OK');
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
aiclient.end(true);
}),
);

it(
'ai.modelget batching via constructor positive testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 100, 5);
model.tag = 'test_tag';
const resultModelSet = await aiclient.modelset('mymodel-batching-t2', model);
expect(resultModelSet).to.equal('OK');
const modelOut: Model = await aiclient.modelget('mymodel-batching-t2');
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop-t2', modelOut);
expect(resultModelSet2).to.equal('OK');
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);

const model2 = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 1000);
expect(model2.batchsize).to.equal(1000);
expect(model2.minbatchsize).to.equal(0);
aiclient.end(true);
}),
);
Expand Down Expand Up @@ -624,26 +688,26 @@ it(
);

it(
'ai.config positive and negative testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
expect(result).to.equal('OK');
// negative test
try {
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}

try {
// may throw error if backend already loaded
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
expect(loadResult).to.equal('OK');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}
aiclient.end(true);
}),
);
'ai.config positive and negative testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
expect(result).to.equal('OK');
// negative test
try {
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}

try {
// may throw error if backend already loaded
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
expect(loadResult).to.equal('OK');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}
aiclient.end(true);
}),
);

0 comments on commit ab65392

Please sign in to comment.