diff --git a/src/configs.js b/src/configs.js index 2c277aeb1..479fa0e29 100644 --- a/src/configs.js +++ b/src/configs.js @@ -117,6 +117,7 @@ function getNormalizedConfig(config) { break; case 'gemma': case 'gemma2': + case 'glm': mapping['num_heads'] = 'num_key_value_heads'; mapping['num_layers'] = 'num_hidden_layers'; mapping['dim_kv'] = 'head_dim'; diff --git a/src/models.js b/src/models.js index 7133e64d5..2138ea004 100644 --- a/src/models.js +++ b/src/models.js @@ -4037,6 +4037,23 @@ export class Gemma2Model extends Gemma2PreTrainedModel { } export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { } ////////////////////////////////////////////////// + +////////////////////////////////////////////////// +// Glm models + +/** + * The bare Glm Model outputting raw hidden-states without any specific head on top. + */ +export class GlmPreTrainedModel extends PreTrainedModel { } +/** + * The bare Glm Model outputting raw hidden-states without any specific head on top. + */ +export class GlmModel extends GlmPreTrainedModel { } + +export class GlmForCausalLM extends GlmPreTrainedModel { } +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// export class OpenELMPreTrainedModel extends PreTrainedModel { } export class OpenELMModel extends OpenELMPreTrainedModel { } @@ -6765,6 +6782,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ ['cohere', ['CohereModel', CohereModel]], ['gemma', ['GemmaModel', GemmaModel]], ['gemma2', ['Gemma2Model', Gemma2Model]], + ['glm', ['GlmModel', GlmModel]], ['openelm', ['OpenELMModel', OpenELMModel]], ['qwen2', ['Qwen2Model', Qwen2Model]], ['phi', ['PhiModel', PhiModel]], @@ -6856,6 +6874,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ ['cohere', ['CohereForCausalLM', CohereForCausalLM]], ['gemma', ['GemmaForCausalLM', GemmaForCausalLM]], ['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]], + ['glm', ['GlmForCausalLM', GlmForCausalLM]], ['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]], ['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]], ['phi', ['PhiForCausalLM', PhiForCausalLM]], diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js index bd2fe1c60..8d6e6489f 100644 --- a/tests/tiny_random.test.js +++ b/tests/tiny_random.test.js @@ -10,7 +10,6 @@ import { BertTokenizer, T5Tokenizer, WhisperTokenizer, - BartTokenizer, MarianTokenizer, PreTrainedTokenizer, AutoTokenizer, @@ -29,6 +28,7 @@ import { CohereForCausalLM, GemmaForCausalLM, Gemma2ForCausalLM, + GlmForCausalLM, OPTForCausalLM, GPTNeoXForCausalLM, GPTJForCausalLM, @@ -1366,7 +1366,7 @@ describe("Tiny random models", () => { }); }); - describe("gemma", () => { + describe("gemma2", () => { describe("Gemma2ForCausalLM", () => { const model_id = "hf-internal-testing/tiny-random-Gemma2ForCausalLM"; /** @type {Gemma2ForCausalLM} */ @@ -1417,6 +1417,57 @@ describe("Tiny random models", () => { }); }); + describe("glm", () => { + describe("GlmForCausalLM", () => { + const model_id = "hf-internal-testing/tiny-random-GlmForCausalLM"; + /** @type {GlmForCausalLM} */ + let model; + /** @type {PreTrainedTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GlmForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await PreTrainedTokenizer.from_pretrained(model_id); + // tokenizer.padding_side = "left"; + }, MAX_MODEL_LOAD_TIME); + + it( + "batch_size=1", + async () => { + const inputs = tokenizer("hello"); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([[23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n, 34489n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size>1", + async () => { + const inputs = tokenizer(["hello", "hello world"], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [59246n, 23582n, 5797n, 38238n, 24486n, 36539n, 34489n, 6948n, 34489n, 6948n], + [23582n, 2901n, 39936n, 25036n, 55411n, 10337n, 3424n, 39183n, 30430n, 37285n] + ]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + describe("gpt_neo", () => { describe("GPTNeoForCausalLM", () => { const model_id = "hf-internal-testing/tiny-random-GPTNeoForCausalLM";