Skip to content

Commit

Permalink
Merge pull request #2 from yutanagano/add_new_variants
Browse files Browse the repository at this point in the history
Add new variants
  • Loading branch information
yutanagano authored Mar 26, 2024
2 parents f39288d + 58553e3 commit b61dfb9
Show file tree
Hide file tree
Showing 12 changed files with 634 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Currently available variants:
- `sceptr.variant.ab_sceptr` (default model used by the functional API)
- `sceptr.variant.ab_sceptr_large` (larger variant of the paired-chain model, with model dimensionality 128)
- `sceptr.variant.ab_sceptr_blosum` (variant using BLOSUM62 embeddings instead of one-hot)
- `sceptr.variant.ab_sceptr_cdr3_only` (only uses the CDR3 loops as input)
- `sceptr.variant.ab_sceptr_cdr3_only_mlm_only` (only uses CDR3 loops as input, and did not receive contrastive learning)
- `sceptr.variant.ab_sceptr_xlarge_cdr3_only_mlm_only` (extra larger variant using only the CDR3 sequences as input, only trained on MLM, with model dimensionality 768)
- `sceptr.variant.a_sceptr` (alpha-chain only variant)
- `sceptr.variant.b_sceptr` (beta-chain only variant)
102 changes: 102 additions & 0 deletions src/sceptr/_model_saves/AB_SCEPTR_CDR3_only/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
"training_delegate": {
"class": "ClTrainingDelegate",
"initargs": {}
},
"model": {
"name": "AB SCEPTR (CDR3 only)",
"path_to_pretrained_state_dict": "model_saves/AB_SCEPTR_CDR3_only_MLM_only/state_dict.pt",
"token_embedder": {
"class": "Cdr3SimpleEmbedder",
"initargs": {}
},
"self_attention_stack": {
"class": "SelfAttentionStackWithInitialProjection",
"initargs": {
"num_layers": 3,
"embedding_dim": 25,
"d_model": 64,
"nhead": 8
}
},
"mlm_token_prediction_projector": {
"class": "AminoAcidTokenProjector",
"initargs": {
"d_model": 64
}
},
"vector_representation_delegate": {
"class": "ClsVectorRepresentationDelegate",
"initargs": {}
},
"trainable_model": {
"class": "ClTrainableModel",
"initargs": {}
}
},
"data": {
"training_data": {
"dataset": {
"class": "TcrDataset",
"initargs": {}
},
"dataloader": {
"class": "SingleDatasetDataLoader",
"initargs": {
"batch_size": 1024,
"num_workers": 4
}
},
"csv_paths": [
"tcr_data/preprocessed/tanno/train.csv"
]
},
"validation_data": {
"dataset": {
"class": "TcrDataset",
"initargs": {}
},
"dataloader": {
"class": "SingleDatasetDataLoader",
"initargs": {
"batch_size": 1024,
"num_workers": 4
}
},
"csv_paths": [
"tcr_data/preprocessed/tanno/test.csv"
]
},
"tokeniser": {
"class": "Cdr3Tokeniser",
"initargs": {}
},
"batch_collator": {
"class": "ClBatchCollator",
"initargs": {
"drop_chains": true
}
}
},
"loss": {
"cross_entropy_loss": {
"class": "AdjustedCrossEntropyLoss",
"initargs": {
"label_smoothing": 0.1
}
},
"contrastive_loss": {
"class": "DotProductLoss",
"initargs": {
"temp": 0.05
}
}
},
"optimiser": {
"initargs": {
"n_warmup_steps": 10000,
"d_model": 64
}
},
"num_epochs": 100
}
102 changes: 102 additions & 0 deletions src/sceptr/_model_saves/AB_SCEPTR_CDR3_only/log.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
epoch,loss,lr,valid_cont_loss,valid_positive_distance,valid_negative_distance,valid_mlm_loss,valid_mlm_acc
0,,,7.762381200838441,0.05279639546512531,0.05718007618408098,1.2419513030994118,0.7635881860224767
1,9.112978323366486,4.493055062954425e-05,7.621762994500321,0.02003439685079406,0.024134573539402862,1.243119986134588,0.7628986654676354
2,8.90173319200559,0.00013443055062954424,7.2041078361658455,0.21147233492835893,0.3569256290828469,1.2500688523835919,0.761477219117162
3,7.156940219263341,0.00022393055062954424,3.337632958097402,0.700572789874395,1.1925218426387145,1.3000093400375448,0.7460907500230439
4,5.122295899748735,0.0003134305506295442,2.8898283667379516,0.736344215124063,1.2543521461101397,1.294579391642509,0.747388816354365
5,4.791155044319623,0.0004029305506295442,2.715480305173625,0.7485661253354917,1.2729882126773826,1.29394688080361,0.747282472020365
6,4.62729550767433,0.0004924305506295441,2.6416309727344194,0.7562005902652843,1.2865301209457605,1.2951804744928865,0.7462057159862548
7,4.53575318273899,0.0005819305506295443,2.5344046885511107,0.7566226585544534,1.2925042948797807,1.2937717441250338,0.7472594567094377
8,4.451724638058108,0.0006714305506295441,2.4831349676731485,0.7650843031384059,1.30031503809169,1.2954695905588829,0.7460533787717044
9,4.396852007827698,0.0007609305506295441,2.45976722854139,0.767622909615852,1.30590520040279,1.2961659586837422,0.746401779294975
10,4.352536110266097,0.0008504305506295442,2.432826508495789,0.7689183104948004,1.313472760384712,1.3002393342876888,0.7447102197270231
11,4.306390221640894,0.0009399305506295442,2.3573539382697764,0.7706730601988542,1.3173947905314713,1.2947371563767045,0.7468065448403167
12,4.258934861979767,0.0010294305506295441,2.34357655327888,0.7690763722095874,1.3186334652451432,1.2934570876075087,0.7464628864851264
13,4.220289997479037,0.0011189305506295443,2.32019594293367,0.7759701830083187,1.3266125088482144,1.2957579856513433,0.7458785558615914
14,4.199515611401721,0.0012083461937734932,2.292261097817209,0.775192920054914,1.3300602156874155,1.2946272743584495,0.746502878195242
15,4.169210983684444,0.001226884932938532,2.299946471483616,0.7723563952669642,1.3311815372873792,1.2952108289462456,0.7460212630177361
16,4.13768655340723,0.0011866314992028263,2.274164061317273,0.7772363164572837,1.337867815872176,1.293664887612968,0.7463298423598372
17,4.111515886329867,0.0011500979307215047,2.2511911254448167,0.7715759422474768,1.328113073834814,1.291946166566322,0.7472491428226048
18,4.087757030417955,0.0011167443205193882,2.2412600506342697,0.7712943388324164,1.3348377158079807,1.2929385880950741,0.7477090486582577
19,4.066611263959385,0.0010861344888781837,2.259292568367124,0.7717083772199539,1.3337252850452124,1.295763463522537,0.7462432307638528
20,4.051268791330541,0.0010579116864878114,2.222501818315182,0.7713129445380797,1.3313135950084578,1.2900873361413276,0.748110922243782
21,4.024202073704881,0.0010317809114094956,2.2441669039305,0.7715686062379105,1.3310894183323758,1.288971638744518,0.7479170140547358
22,4.003224204231743,0.0010074958119580654,2.23775050119152,0.7738775886898404,1.3311825098191419,1.289019656433441,0.7485954179156981
23,3.9831246494929786,0.000984848832905516,2.2249009120301295,0.7707060171420829,1.3292037129800087,1.2908269648714958,0.747236661994393
24,3.9636486920362435,0.0009636636964983663,2.22256381268114,0.7741636406601843,1.3348050692879827,1.2904465948928243,0.7473379093530166
25,3.9519243949670173,0.0009437895913492682,2.199859188723674,0.7729352396690198,1.3359631829468468,1.288483151235008,0.7476837338631755
26,3.927755118746585,0.0009250966288375309,2.1930931243067318,0.7755514399966928,1.3369727923071586,1.2859577685937016,0.7489536850826521
27,3.926465194860524,0.0009074722526913594,2.1646620368127176,0.7709330636668378,1.3345097423465007,1.286610760881438,0.7488208563542789
28,3.9112838406783768,0.0008908183740754731,2.1707241156365193,0.770531909768061,1.337011253630735,1.2852896711067991,0.74932011149869
29,3.891847966349471,0.0008750490650422323,2.1675154697516184,0.7721986772106176,1.337264511071075,1.2871684978038336,0.7485575223735876
30,3.883887992012604,0.0008600886861199685,2.19489818539154,0.7749513125680523,1.3391671771473823,1.2875629777212991,0.7489312301273551
31,3.8784893897170787,0.0008458703546512784,2.1629789283377554,0.7752403102040718,1.3387697808315604,1.2882882281523045,0.7479766152529334
32,3.860915168573728,0.0008323346829334197,2.153606718536965,0.7700055455256986,1.3359170970198553,1.2880122103395444,0.7484088907163294
33,3.8519079181158524,0.0008194287317307121,2.1397062347887736,0.7682272004479871,1.3330188637342926,1.2823840005517122,0.7498121081027518
34,3.842736325365515,0.0008071051370188623,2.1405153903770797,0.7703682377173774,1.3365645853935013,1.2821266454167732,0.7511674640375396
35,3.8331822492920637,0.0007953213770578917,2.178359575468395,0.7746631318788487,1.3367616669970368,1.2875275824052062,0.7479080102230455
36,3.820642862479701,0.0007840391538973042,2.1417715968537983,0.7732939821734165,1.336992701482007,1.283531817603312,0.749598027919431
37,3.8172812534997935,0.0007732238687796371,2.1355069160836178,0.7705719641574866,1.3386313196113562,1.281844297340762,0.7504793525148107
38,3.8079702705141223,0.0007628441750461128,2.1433389237675926,0.7711487405106717,1.3389207403838181,1.283274957309216,0.7499048128671044
39,3.8054058937039867,0.000752871595365571,2.1068212915258093,0.766890015474526,1.334197239968474,1.283080711118489,0.7499391075395934
40,3.804591294610655,0.0007432801926278408,2.129849029313564,0.7701278266642823,1.3353779563532628,1.2846699109817474,0.7494325245648502
41,3.7884655847813136,0.000734046285830236,2.1285221047319642,0.7705890382028509,1.3383333157129647,1.280544791026975,0.7513947305179767
42,3.791377397505723,0.0007251482038634115,2.1062713827666846,0.7680589544195218,1.3350307060570625,1.2803563729623595,0.7508250988062498
43,3.7848194654564815,0.0007165660713628845,2.127872078393237,0.7672711427350392,1.3340649511744702,1.283205334994501,0.7503666460700734
44,3.7694308174053734,0.0007082816218046972,2.1204620947051875,0.7669331006947344,1.3368249476999534,1.2825407864705949,0.7499421935149108
45,3.7737389379690733,0.0007002780338413631,2.130829363060976,0.7671799429824664,1.3377758401892883,1.282753877562038,0.7506571872455874
46,3.7660545417524105,0.0006925397875382193,2.1464810834314787,0.7719004261662463,1.3356136029177625,1.2805999808384199,0.7509058764234703
47,3.7605208362762155,0.0006850525377121659,2.110271296473167,0.771535982863223,1.3402624808242818,1.2796646512743277,0.7518908305603318
48,3.7567992293042907,0.0006778030020191236,2.1459811338376436,0.7749366190291012,1.3407848323748655,1.2790841081494777,0.7513563181240536
49,3.7531569709437247,0.0006707788618025577,2.110209801634663,0.7715622685321519,1.340665858690836,1.2770644134664275,0.7522690970938641
50,3.7572914033796856,0.0006639686740182302,2.104587899580546,0.7686209039925793,1.337795538139737,1.2814823500300585,0.7510009904471834
51,3.7508138354266576,0.0006573617928019408,2.1333978848263486,0.7694736387206724,1.333348288212023,1.280015651901976,0.7515902945811049
52,3.746899056971591,0.0006509482994567963,2.089636543020974,0.7658247139762803,1.3336317169083245,1.2793679966896083,0.7515962448683834
53,3.7434352002689018,0.00064471893981229,2.0916940408273157,0.763576082134609,1.3321884178647259,1.276611538224293,0.7525974309038873
54,3.737763185992981,0.0006386650680550355,2.078343730104343,0.7674884063219581,1.338259790454553,1.2790974633743293,0.7511096631517198
55,3.7340273343144434,0.0006327785962555395,2.110350710735367,0.7671728735461018,1.3372308666932167,1.278304951159634,0.7517096826576193
56,3.7360605685769155,0.0006270519489206681,2.0844620974646992,0.7660577733029601,1.335017218949528,1.2784418795000032,0.7512342914189956
57,3.732170315994262,0.0006214780219909308,2.0951151139012594,0.7669938714127305,1.337316729503759,1.2780945121459113,0.7526324986070059
58,3.7344442342159514,0.0006160501457777644,2.1009307516497167,0.7685156612128009,1.3355836534983927,1.277049074074573,0.7527485151092673
59,3.727959960087887,0.0006107620514010679,2.103755388968409,0.7674557139021542,1.3378475934700382,1.2779728179972818,0.7514565464624954
60,3.7260847608274954,0.0006056078403428885,2.0743731619848154,0.7664868982074519,1.3356119054268762,1.2791487680749813,0.7519802999193536
61,3.717876913568103,0.0006005819567810015,2.076300912789498,0.7660934031268736,1.3351145226212109,1.274983458165864,0.7530415396767931
62,3.7234010957313073,0.0005956791624072976,2.0953971295478886,0.7649160693177532,1.3344627185379916,1.2758442911919878,0.7531323733255845
63,3.7170657583081597,0.0005908945134714726,2.077151996027498,0.7672471723828043,1.336235655367009,1.2786986731635746,0.7512502608127608
64,3.7144685032614655,0.0005862233398212821,2.0902870287814745,0.7692040037201189,1.3366618404152146,1.2803997760626507,0.750965719346293
65,3.7182075023690153,0.0005816612257373163,2.0727226285261655,0.766370471918733,1.3381836398804998,1.2749658276969207,0.7528075374257693
66,3.7067992840754846,0.0005772039923835236,2.068878932121725,0.7676702518876725,1.338306531426511,1.2764666392966464,0.7524991848203644
67,3.7117608274932103,0.0005728476817148925,2.0861694891699587,0.766012964816576,1.3374556662592825,1.2781786190472508,0.7515926202595348
68,3.7081695108105244,0.000568588541701453,2.080389078822681,0.7663479120458041,1.337436549615449,1.2788272076235416,0.7512154039087551
69,3.7037022618451134,0.0005644230127431929,2.078158610021454,0.7635414428128596,1.3375895476591992,1.2776458820570298,0.751488356371013
70,3.699084381419278,0.000560347715164093,2.0710010827985452,0.7701691225825662,1.3427296814144138,1.2773414814846524,0.7521338678574286
71,3.707335336657783,0.0005563594376854221,2.0668262955043124,0.7690962548410875,1.3419071740544748,1.2767616018742878,0.7528286917522004
72,3.6963874278165525,0.0005524551267889306,2.066828113417061,0.7673438935756859,1.3378775975866781,1.2770937033346543,0.7522350374792571
73,3.702084476760743,0.0005486318768898732,2.1096120933234577,0.7731786887640242,1.3425909613615201,1.2769620291615724,0.7531764254897176
74,3.690986733693714,0.0005448869212480005,2.066500256628905,0.7645823565538863,1.3383528468756392,1.27619258900284,0.7525033039733774
75,3.691247302048633,0.0005412176235518742,2.0634797885368994,0.7660988624038654,1.3367605923185355,1.2750035435975207,0.7534378321847148
76,3.6954951726551353,0.0005376214701183853,2.0563709501295824,0.7648936888769566,1.3393141682333456,1.277334020451471,0.7527339150111846
77,3.690793166634838,0.0005340960626549686,2.0966378375656314,0.7656858025491575,1.337969245086248,1.2725407623322291,0.7540012773723757
78,3.6837293342751285,0.0005306391115372118,2.075730923075861,0.765152458322467,1.3370160959864206,1.2788842372531233,0.7516897849899269
79,3.6944885791477424,0.0005272484295589966,2.082516486303579,0.7678278329439719,1.3383380379355279,1.2754139388728565,0.7528577899449943
80,3.6803325010968595,0.0005239219261164537,2.0535234679553045,0.7686489648364385,1.3434281202727,1.2779384956116389,0.7520321872875425
81,3.6818773785121657,0.0005206576017905405,2.047808034389224,0.7628892557297268,1.3347870561630626,1.2738468031473908,0.7531669717447741
82,3.677086910085993,0.0005174535432964054,2.0551705017571797,0.76616168861088,1.3401155046294464,1.2746346503353154,0.7527900240318778
83,3.680452075655629,0.0005143079187704905,2.085066824044578,0.7657991541441392,1.3358125651829602,1.2744321378443448,0.7528847644333585
84,3.6795739013809463,0.0005112189733690399,2.04495018581916,0.7632438464592802,1.3369268351061683,1.2758554264871662,0.7524665824431146
85,3.6818065166675034,0.0005081850251539766,2.0628927771565113,0.7669954214755973,1.3391431512224123,1.281490843158557,0.7507655288082401
86,3.6803051068700947,0.0005052044612442199,2.0446594784676977,0.7641745722514304,1.336466554723952,1.2762467643138327,0.7526367500657363
87,3.663054018041315,0.0005022757342124591,2.040313660314199,0.7644576188968027,1.3371143136243349,1.272762205370238,0.7541795910802265
88,3.669665906942736,0.0004993973587090738,2.064276907274198,0.7636698755192898,1.3362635723195235,1.2736869076488173,0.75343412369045
89,3.6722765982051024,0.0004965679082964922,2.0604519461822886,0.7662919184322347,1.3405033760080352,1.2741566049117041,0.7534863166592237
90,3.6653336802803453,0.0004937860124786321,2.05101285224263,0.7653170765978623,1.3373038387500185,1.2745056950110105,0.7535270191706446
91,3.668098697401907,0.0004910503539114016,2.08317401507053,0.7663760342466969,1.3403219451728705,1.2769238625556305,0.7522618106984095
92,3.6699733951906897,0.0004883596657813364,2.0264513960214425,0.7611791260291094,1.3373937090769905,1.274494820513035,0.7534017255711969
93,3.6677283933819758,0.0004857127293405453,2.0401880402870707,0.7649393848088545,1.337613813050143,1.2733763550940957,0.7535563865804833
94,3.665543499435497,0.00048310837158704845,2.05391259995566,0.7682618185224647,1.34243581222608,1.2712898571726683,0.7540985717383809
95,3.663732008304565,0.0004805454630805032,2.0341231750040043,0.764708075514689,1.3391495328355825,1.2744927962973092,0.7533479525076279
96,3.66345698646294,0.0004780229158840507,2.071742994847143,0.768103684594208,1.338785864192772,1.2733639867104838,0.7533030768378186
97,3.6569189551809806,0.0004755396816237931,2.05041911147769,0.7643333404578809,1.339491959128433,1.2727547076071553,0.7540891961681523
98,3.661707510742569,0.00047309474965801337,2.073534300643597,0.7642273457109177,1.3333704930613164,1.2714290554762884,0.7542138446521336
99,3.6613128041252527,0.0004706871453488936,2.037688012972204,0.7621832658186745,1.334718887871471,1.273594506474309,0.753979969649102
100,3.6568468123533324,0.00046831592843001697,2.0355441350520063,0.7641465336939267,1.339444780993077,1.2735859129550062,0.7539072484471787
3 changes: 3 additions & 0 deletions src/sceptr/_model_saves/AB_SCEPTR_CDR3_only/state_dict.pt
Git LFS file not shown
Loading

0 comments on commit b61dfb9

Please sign in to comment.