Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for Glm #1060

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ function getNormalizedConfig(config) {
break;
case 'gemma':
case 'gemma2':
case 'glm':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['dim_kv'] = 'head_dim';
Expand Down
19 changes: 19 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4037,6 +4037,23 @@ export class Gemma2Model extends Gemma2PreTrainedModel { }
export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Glm models

/**
* The bare Glm Model outputting raw hidden-states without any specific head on top.
*/
export class GlmPreTrainedModel extends PreTrainedModel { }
/**
* The bare Glm Model outputting raw hidden-states without any specific head on top.
*/
export class GlmModel extends GlmPreTrainedModel { }

export class GlmForCausalLM extends GlmPreTrainedModel { }
//////////////////////////////////////////////////


//////////////////////////////////////////////////
export class OpenELMPreTrainedModel extends PreTrainedModel { }
export class OpenELMModel extends OpenELMPreTrainedModel { }
Expand Down Expand Up @@ -6765,6 +6782,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['cohere', ['CohereModel', CohereModel]],
['gemma', ['GemmaModel', GemmaModel]],
['gemma2', ['Gemma2Model', Gemma2Model]],
['glm', ['GlmModel', GlmModel]],
['openelm', ['OpenELMModel', OpenELMModel]],
['qwen2', ['Qwen2Model', Qwen2Model]],
['phi', ['PhiModel', PhiModel]],
Expand Down Expand Up @@ -6856,6 +6874,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
['glm', ['GlmForCausalLM', GlmForCausalLM]],
['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]],
['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]],
['phi', ['PhiForCausalLM', PhiForCausalLM]],
Expand Down
55 changes: 53 additions & 2 deletions tests/tiny_random.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
BertTokenizer,
T5Tokenizer,
WhisperTokenizer,
BartTokenizer,
MarianTokenizer,
PreTrainedTokenizer,
AutoTokenizer,
Expand All @@ -29,6 +28,7 @@ import {
CohereForCausalLM,
GemmaForCausalLM,
Gemma2ForCausalLM,
GlmForCausalLM,
OPTForCausalLM,
GPTNeoXForCausalLM,
GPTJForCausalLM,
Expand Down Expand Up @@ -1366,7 +1366,7 @@ describe("Tiny random models", () => {
});
});

describe("gemma", () => {
describe("gemma2", () => {
describe("Gemma2ForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-Gemma2ForCausalLM";
/** @type {Gemma2ForCausalLM} */
Expand Down Expand Up @@ -1417,6 +1417,57 @@ describe("Tiny random models", () => {
});
});

describe("glm", () => {
describe("GlmForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-GlmForCausalLM";
/** @type {GlmForCausalLM} */
let model;
/** @type {PreTrainedTokenizer} */
let tokenizer;
beforeAll(async () => {
model = await GlmForCausalLM.from_pretrained(model_id, {
// TODO move to config
...DEFAULT_MODEL_OPTIONS,
});
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
// tokenizer.padding_side = "left";
}, MAX_MODEL_LOAD_TIME);

it(
"batch_size=1",
async () => {
const inputs = tokenizer("hello");
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([[23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n, 34489n]]);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size>1",
async () => {
const inputs = tokenizer(["hello", "hello world"], { padding: true });
const outputs = await model.generate({
...inputs,
max_length: 10,
});
expect(outputs.tolist()).toEqual([
[59246n, 23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n],
[23582n, 2901n, 39936n, 25036n, 55411n, 10337n, 3424n, 39183n, 30430n, 37285n]
]);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});

describe("gpt_neo", () => {
describe("GPTNeoForCausalLM", () => {
const model_id = "hf-internal-testing/tiny-random-GPTNeoForCausalLM";
Expand Down