Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for DAGRUN and DAGRUN_RO #8

Merged
merged 11 commits into from
Jun 8, 2020
71 changes: 69 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,73 @@ Example of AI.SCRIPTSET and AI.SCRIPTRUN
})();
```

Example of AI.DAGRUN enqueuing multiple SCRIPTRUN and MODELRUN commands

A common pattern is enqueuing multiple SCRIPTRUN and MODELRUN commands within a DAG. The following example uses ResNet-50,to classify images into 1000 object categories.

Given that our input tensor contains each color represented as a 8-bit integer and that neural networks usually work with floating-point tensors as their input we need to cast a tensor to floating-point and normalize the values of the pixels - for that we will use `pre_process_4ch` function.

To optimize the classification process we can use a post process script to return only the category position with the maximum classification - for that we will use `post_process` script.

Using the DAG capabilities we've removed the necessity of storing the intermediate tensors in the keyspace. You can even run the entire process without storing the output tensor, as follows:


```javascript
var redis = require('redis');
var redisai = require('redisai-js');
var fs = require("fs");

(async () => {
const nativeClient = redis.createClient();
const aiclient = new redisai.Client(nativeClient);
const scriptFileStr = fs.readFileSync('./tests/test_data/imagenet/data_processing_script.txt').toString();
const jsonLabels = fs.readFileSync('./tests/test_data/imagenet/imagenet_class_index.json');
const labels = JSON.parse(jsonLabels);

const dataProcessingScript = new redisai.Script('CPU', scriptFileStr);
const resultScriptSet = await aiclient.scriptset('data_processing_script', dataProcessingScript);
// AI.SCRIPTSET result: OK
console.log(`AI.SCRIPTSET result: ${resultScriptSet}`)

const modelBlob = fs.readFileSync('./tests/test_data/imagenet/resnet50.pb');
const imagenetModel = new redisai.Model(Backend.TF, 'CPU', ['images'], ['output'], modelBlob);
const resultModelSet = await aiclient.modelset('imagenet_model', imagenetModel);

// AI.MODELSET result: OK
console.log(`AI.MODELSET result: ${resultModelSet}`)

const inputImage = await Jimp.read('./tests/test_data/imagenet/cat.jpg');
const imageWidth = 224;
const imageHeight = 224;
const image = inputImage.cover(imageWidth, imageHeight);
const tensor = new redisai.Tensor(Dtype.uint8, [imageWidth, imageHeight, 4], Buffer.from(image.bitmap.data));

///
// Prepare the DAG enqueuing multiple SCRIPTRUN and MODELRUN commands
const dag = new redisai.Dag();

dag.tensorset('tensor-image', tensor);
dag.scriptrun('data_processing_script', 'pre_process_4ch', ['tensor-image'], ['temp_key1']);
dag.modelrun('imagenet_model', ['temp_key1'], ['temp_key2']);
dag.scriptrun('data_processing_script', 'post_process', ['temp_key2'], ['classification']);
dag.tensorget('classification');

// Send the AI.DAGRUN command to RedisAI server
const resultDagRun = await aiclient.dagrun_ro(null, dag);

// The 5th element of the reply will be the `classification` tensor
const classTensor = resultDagRun[4];

// Print the category in the position with the max classification
const idx = classTensor.data[0];

// 281 [ 'n02123045', 'tabby' ]
console.log(idx, labels[idx.toString()]);

await aiclient.end();
})();
```

### Further examples

The [RedisAI examples repo](https://github.com/RedisAI/redisai-examples) shows more advanced examples
Expand All @@ -147,8 +214,8 @@ AI.SCRIPTGET | scriptget
AI.SCRIPTDEL | scriptdel
AI.SCRIPTRUN | scriptrun
AI._SCRIPTSCAN | N/A
AI.DAGRUN | N/A
AI.DAGRUN_RO | N/A
AI.DAGRUN | dagrun
AI.DAGRUN_RO | dagrun_ro
AI.INFO | info and infoResetStat (for resetting stats)
AI.CONFIG * | configLoadBackend and configBackendsPath

Expand Down
81 changes: 47 additions & 34 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import { Model } from './model';
import * as util from 'util';
import { Script } from './script';
import { Stats } from './stats';
import { Dag } from './dag';

export class Client {
private _sendCommand: any;
private readonly _sendCommand: any;

constructor(client: RedisClient) {
this._client = client;
this._sendCommand = util.promisify(this._client.send_command).bind(this._client);
}

private _client: RedisClient;
private readonly _client: RedisClient;

get client(): RedisClient {
return this._client;
Expand All @@ -23,23 +24,13 @@ export class Client {
this._client.end(flush);
}

public tensorset(keName: string, t: Tensor): Promise<any> {
const args: any[] = [keName, t.dtype];
t.shape.forEach((value) => args.push(value.toString()));
if (t.data != null) {
if (t.data instanceof Buffer) {
args.push('BLOB');
args.push(t.data);
} else {
args.push('VALUES');
t.data.forEach((value) => args.push(value.toString()));
}
}
public tensorset(keyName: string, t: Tensor): Promise<any> {
const args: any[] = t.tensorSetFlatArgs(keyName);
return this._sendCommand('ai.tensorset', args);
}

public tensorget(keName: string): Promise<any> {
const args: any[] = [keName, 'META', 'VALUES'];
public tensorget(keyName: string): Promise<any> {
const args: any[] = Tensor.tensorGetFlatArgs(keyName);
return this._sendCommand('ai.tensorget', args)
.then((reply: any[]) => {
return Tensor.NewTensorFromTensorGetReply(reply);
Expand All @@ -55,10 +46,7 @@ export class Client {
}

public modelrun(modelName: string, inputs: string[], outputs: string[]): Promise<any> {
const args: any[] = [modelName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
const args: any[] = Model.modelRunFlatArgs(modelName, inputs, outputs);
return this._sendCommand('ai.modelrun', args);
}

Expand All @@ -68,7 +56,7 @@ export class Client {
}

public modelget(modelName: string): Promise<any> {
const args: any[] = [modelName, 'META', 'BLOB'];
const args: any[] = Model.modelGetFlatArgs(modelName);
return this._sendCommand('ai.modelget', args)
.then((reply: any[]) => {
return Model.NewModelFromModelGetReply(reply);
Expand All @@ -78,22 +66,13 @@ export class Client {
});
}

public scriptset(keName: string, s: Script): Promise<any> {
const args: any[] = [keName, s.device];
if (s.tag !== undefined) {
args.push('TAG');
args.push(s.tag);
}
args.push('SOURCE');
args.push(s.script);
public scriptset(keyName: string, s: Script): Promise<any> {
const args: any[] = s.scriptSetFlatArgs(keyName);
return this._sendCommand('ai.scriptset', args);
}

public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): Promise<any> {
const args: any[] = [scriptName, functionName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
const args: any[] = Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs);
return this._sendCommand('ai.scriptrun', args);
}

Expand All @@ -103,7 +82,7 @@ export class Client {
}

public scriptget(scriptName: string): Promise<any> {
const args: any[] = [scriptName, 'META', 'SOURCE'];
const args: any[] = Script.scriptGetFlatArgs(scriptName);
return this._sendCommand('ai.scriptget', args)
.then((reply: any[]) => {
return Script.NewScriptFromScriptGetReply(reply);
Expand Down Expand Up @@ -137,6 +116,40 @@ export class Client {
});
}

/**
* specifies a direct acyclic graph of operations to run within RedisAI
*
* @param loadKeys
* @param persistKeys
* @param dag
*/
public dagrun(loadKeys: string[] | null, persistKeys: string[] | null, dag: Dag): Promise<any> {
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
const args: any[] = dag.dagRunFlatArgs(loadKeys, persistKeys);
return this._sendCommand('ai.dagrun', args)
.then((reply: any[]) => {
return dag.ProcessDagReply(reply);
})
.catch((error: any) => {
throw error;
});
}

/**
* specifies a Read Only direct acyclic graph of operations to run within RedisAI
*
* @param loadKeys
* @param dag
*/
public dagrun_ro(loadKeys: string[] | null, dag: Dag): Promise<any> {
const args: any[] = dag.dagRunFlatArgs(loadKeys, null);
return this._sendCommand('ai.dagrun_ro', args)
.then((reply: any[]) => {
return dag.ProcessDagReply(reply);
})
.catch((error: any) => {
throw error;
});
}
/**
* Loads the DL/ML backend specified by the backend identifier from path.
*
Expand Down
88 changes: 88 additions & 0 deletions src/dag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import { Model } from './model';
import { Script } from './script';
import { Tensor } from './tensor';

export interface DagCommandInterface {
tensorset(keyName: string, t: Tensor): DagCommandInterface;

tensorget(keyName: string): DagCommandInterface;

tensorget(keyName: string): DagCommandInterface;

modelrun(modelName: string, inputs: string[], outputs: string[]): DagCommandInterface;

scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): DagCommandInterface;
}

/**
* Direct mapping to RedisAI DAGs
*/
export class Dag implements DagCommandInterface {
private _commands: any[][];
private readonly _tensorgetflag: boolean[];

constructor() {
this._commands = [];
this._tensorgetflag = [];
}

public tensorset(keyName: string, t: Tensor): Dag {
const args: any[] = ['AI.TENSORSET'];
t.tensorSetFlatArgs(keyName).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
return this;
}

public tensorget(keyName: string): Dag {
const args: any[] = ['AI.TENSORGET'];
Tensor.tensorGetFlatArgs(keyName).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(true);
return this;
}

public modelrun(modelName: string, inputs: string[], outputs: string[]): Dag {
const args: any[] = ['AI.MODELRUN'];
Model.modelRunFlatArgs(modelName, inputs, outputs).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
return this;
}

public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): Dag {
const args: any[] = ['AI.SCRIPTRUN'];
Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs).forEach((arg) => args.push(arg));
this._commands.push(args);
this._tensorgetflag.push(false);
return this;
}

public dagRunFlatArgs(loadKeys: string[] | null, persistKeys: string[] | null): string[] {
const args: any[] = [];
if (loadKeys != null && loadKeys.length > 0) {
args.push('LOAD');
args.push(loadKeys.length);
loadKeys.forEach((value) => args.push(value));
}
if (persistKeys != null && persistKeys.length > 0) {
args.push('PERSIST');
args.push(persistKeys.length);
persistKeys.forEach((value) => args.push(value));
}
this._commands.forEach((value) => {
args.push('|>');
value.forEach((arg) => args.push(arg));
});
return args;
}

public ProcessDagReply(reply: any[]): any[] {
for (let i = 0; i < reply.length; i++) {
if (this._tensorgetflag[i] === true) {
reply[i] = Tensor.NewTensorFromTensorGetReply(reply[i]);
}
}
return reply;
}
}
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import { Backend, BackendMap } from './backend';
import { Tensor } from './tensor';
import { Model } from './model';
import { Script } from './script';
import { Dag } from './dag';
import { Client } from './client';
import { Stats } from './stats';
import { Helpers } from './helpers';

export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Client, Stats, Helpers };
export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Dag, Client, Stats, Helpers };
14 changes: 13 additions & 1 deletion src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,19 @@ export class Model {
return model;
}

modelSetFlatArgs(keyName: string) {
static modelGetFlatArgs(keyName: string): string[] {
return [keyName, 'META', 'BLOB'];
}

static modelRunFlatArgs(modelName: string, inputs: string[], outputs: string[]): string[] {
const args: string[] = [modelName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
return args;
}

modelSetFlatArgs(keyName: string): any[] {
const args: any[] = [keyName, this.backend.toString(), this.device];
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
if (this.tag !== undefined) {
args.push('TAG');
Expand Down
24 changes: 24 additions & 0 deletions src/script.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,28 @@ export class Script {
}
return script;
}

scriptSetFlatArgs(keyName: string): string[] {
const args: string[] = [keyName, this.device];
if (this.tag !== undefined) {
args.push('TAG');
args.push(this.tag);
}
args.push('SOURCE');
args.push(this.script);
return args;
}

static scriptRunFlatArgs(scriptName: string, functionName: string, inputs: string[], outputs: string[]): string[] {
const args: string[] = [scriptName, functionName, 'INPUTS'];
inputs.forEach((value) => args.push(value));
args.push('OUTPUTS');
outputs.forEach((value) => args.push(value));
return args;
}

static scriptGetFlatArgs(scriptName: string): string[] {
const args: string[] = [scriptName, 'META', 'SOURCE'];
return args;
}
}
Loading