diff --git a/Seq2SeqSharp/Models/Model.cs b/Seq2SeqSharp/Models/Model.cs index 38a5dba..6e29afa 100644 --- a/Seq2SeqSharp/Models/Model.cs +++ b/Seq2SeqSharp/Models/Model.cs @@ -17,9 +17,17 @@ using Seq2SeqSharp.Utils; using Seq2SeqSharp.Enums; using TensorSharp; +using System.Runtime.InteropServices; namespace Seq2SeqSharp.Models { + [StructLayout(LayoutKind.Explicit)] + public class Name2WeightsHalf + { + [FieldOffset(0)] public Dictionary usDict; + [FieldOffset(0)] public Dictionary halfDict; + } + [Serializable] public abstract class Model : IModel { @@ -56,7 +64,7 @@ public abstract class Model : IModel public Dictionary Name2Weights { get; set; } - public Dictionary Name2WeightsHalf { get; set; } + public Name2WeightsHalf Name2WeightsHalf { get; set; } public VQTypeEnums VQType { get; set; } public Dictionary Name2WeightsVQ { get; set; } @@ -92,7 +100,7 @@ public Model(Options opts,Vocab srcVocab, Vocab tgtVocab) KVGroupNum = opts.KVGroupNum; Name2Weights = new Dictionary(); - Name2WeightsHalf= new Dictionary(); + Name2WeightsHalf = new Name2WeightsHalf(); Name2WeightsVQ = new Dictionary(); Name2CodeBook = new Dictionary(); } @@ -117,7 +125,8 @@ public Model(Model_4_ProtoBufSerializer m) VQType = m.VQType; Name2Weights = m.Name2Weights; - Name2WeightsHalf = m.Name2WeightsHalf; + Name2WeightsHalf = new Name2WeightsHalf(); + Name2WeightsHalf.usDict = m.Name2WeightsHalf; Name2WeightsVQ = m.Name2WeightsVQ; Name2CodeBook = m.Name2CodeBook; PEType = m.PEType; @@ -132,9 +141,9 @@ public Model(Model_4_ProtoBufSerializer m) Name2Weights = new Dictionary(); } - if (Name2WeightsHalf == null) + if (Name2WeightsHalf.usDict == null) { - Name2WeightsHalf = new Dictionary(); + Name2WeightsHalf.usDict = new Dictionary(); } if (Name2WeightsVQ == null) @@ -159,7 +168,7 @@ public void AddWeights(string name, float[] weights) { weightsHalf[i] = (new half(weights[i])).x; } - Name2WeightsHalf.Add(name, weightsHalf); + Name2WeightsHalf.usDict.Add(name, weightsHalf); } else if (VQType == VQTypeEnums.INT8) { @@ -237,7 +246,7 @@ public float[] GetWeights(string name) { weight = Name2Weights[name]; } - else if (Name2WeightsHalf.ContainsKey(name)) + else if (Name2WeightsHalf.halfDict.ContainsKey(name)) { throw new InvalidCastException($"The model is saved as Float16 type, so please enable AMP for model loading."); } @@ -298,14 +307,16 @@ public half[] GetWeightsHalfType(string name) weights[i] = new half(values[i]); } } - else if (Name2WeightsHalf.ContainsKey(name)) + else if (Name2WeightsHalf.halfDict.ContainsKey(name)) { - var values = Name2WeightsHalf[name]; - weights = new half[values.Length]; - for (int i = 0; i < values.Length; i++) - { - weights[i] = new half(values[i]); - } + weights = Name2WeightsHalf.halfDict[name]; + + //var values = Name2WeightsHalf[name]; + //weights = new half[values.Length]; + //for (int i = 0; i < values.Length; i++) + //{ + // weights[i] = new half(values[i]); + //} } else if (VQType == VQTypeEnums.INT8) { @@ -368,9 +379,9 @@ public void DeleteWeights(string name) Name2Weights.Remove(name); } - if (Name2WeightsHalf != null && Name2WeightsHalf.ContainsKey(name)) + if (Name2WeightsHalf != null && Name2WeightsHalf.halfDict.ContainsKey(name)) { - Name2WeightsHalf.Remove(name); + Name2WeightsHalf.halfDict.Remove(name); } } @@ -379,7 +390,7 @@ public void ClearWeights() Name2WeightsVQ.Clear(); Name2CodeBook.Clear(); Name2Weights.Clear(); - Name2WeightsHalf.Clear(); + Name2WeightsHalf.halfDict.Clear(); } public void ShowModelInfo() diff --git a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs index 1bc5af0..00cbb0f 100644 --- a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs +++ b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs @@ -241,7 +241,7 @@ public Model_4_ProtoBufSerializer(Model m) Name2Weights = new Dictionary(); } - Name2WeightsHalf = m.Name2WeightsHalf; + Name2WeightsHalf = m.Name2WeightsHalf.usDict; if (Name2WeightsHalf == null) { Name2WeightsHalf = new Dictionary();