-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from yutanagano/add_new_variants
Add new variants
- Loading branch information
Showing
12 changed files
with
634 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
src/sceptr/_model_saves/AB_SCEPTR_CDR3_only/config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
{ | ||
"training_delegate": { | ||
"class": "ClTrainingDelegate", | ||
"initargs": {} | ||
}, | ||
"model": { | ||
"name": "AB SCEPTR (CDR3 only)", | ||
"path_to_pretrained_state_dict": "model_saves/AB_SCEPTR_CDR3_only_MLM_only/state_dict.pt", | ||
"token_embedder": { | ||
"class": "Cdr3SimpleEmbedder", | ||
"initargs": {} | ||
}, | ||
"self_attention_stack": { | ||
"class": "SelfAttentionStackWithInitialProjection", | ||
"initargs": { | ||
"num_layers": 3, | ||
"embedding_dim": 25, | ||
"d_model": 64, | ||
"nhead": 8 | ||
} | ||
}, | ||
"mlm_token_prediction_projector": { | ||
"class": "AminoAcidTokenProjector", | ||
"initargs": { | ||
"d_model": 64 | ||
} | ||
}, | ||
"vector_representation_delegate": { | ||
"class": "ClsVectorRepresentationDelegate", | ||
"initargs": {} | ||
}, | ||
"trainable_model": { | ||
"class": "ClTrainableModel", | ||
"initargs": {} | ||
} | ||
}, | ||
"data": { | ||
"training_data": { | ||
"dataset": { | ||
"class": "TcrDataset", | ||
"initargs": {} | ||
}, | ||
"dataloader": { | ||
"class": "SingleDatasetDataLoader", | ||
"initargs": { | ||
"batch_size": 1024, | ||
"num_workers": 4 | ||
} | ||
}, | ||
"csv_paths": [ | ||
"tcr_data/preprocessed/tanno/train.csv" | ||
] | ||
}, | ||
"validation_data": { | ||
"dataset": { | ||
"class": "TcrDataset", | ||
"initargs": {} | ||
}, | ||
"dataloader": { | ||
"class": "SingleDatasetDataLoader", | ||
"initargs": { | ||
"batch_size": 1024, | ||
"num_workers": 4 | ||
} | ||
}, | ||
"csv_paths": [ | ||
"tcr_data/preprocessed/tanno/test.csv" | ||
] | ||
}, | ||
"tokeniser": { | ||
"class": "Cdr3Tokeniser", | ||
"initargs": {} | ||
}, | ||
"batch_collator": { | ||
"class": "ClBatchCollator", | ||
"initargs": { | ||
"drop_chains": true | ||
} | ||
} | ||
}, | ||
"loss": { | ||
"cross_entropy_loss": { | ||
"class": "AdjustedCrossEntropyLoss", | ||
"initargs": { | ||
"label_smoothing": 0.1 | ||
} | ||
}, | ||
"contrastive_loss": { | ||
"class": "DotProductLoss", | ||
"initargs": { | ||
"temp": 0.05 | ||
} | ||
} | ||
}, | ||
"optimiser": { | ||
"initargs": { | ||
"n_warmup_steps": 10000, | ||
"d_model": 64 | ||
} | ||
}, | ||
"num_epochs": 100 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
epoch,loss,lr,valid_cont_loss,valid_positive_distance,valid_negative_distance,valid_mlm_loss,valid_mlm_acc | ||
0,,,7.762381200838441,0.05279639546512531,0.05718007618408098,1.2419513030994118,0.7635881860224767 | ||
1,9.112978323366486,4.493055062954425e-05,7.621762994500321,0.02003439685079406,0.024134573539402862,1.243119986134588,0.7628986654676354 | ||
2,8.90173319200559,0.00013443055062954424,7.2041078361658455,0.21147233492835893,0.3569256290828469,1.2500688523835919,0.761477219117162 | ||
3,7.156940219263341,0.00022393055062954424,3.337632958097402,0.700572789874395,1.1925218426387145,1.3000093400375448,0.7460907500230439 | ||
4,5.122295899748735,0.0003134305506295442,2.8898283667379516,0.736344215124063,1.2543521461101397,1.294579391642509,0.747388816354365 | ||
5,4.791155044319623,0.0004029305506295442,2.715480305173625,0.7485661253354917,1.2729882126773826,1.29394688080361,0.747282472020365 | ||
6,4.62729550767433,0.0004924305506295441,2.6416309727344194,0.7562005902652843,1.2865301209457605,1.2951804744928865,0.7462057159862548 | ||
7,4.53575318273899,0.0005819305506295443,2.5344046885511107,0.7566226585544534,1.2925042948797807,1.2937717441250338,0.7472594567094377 | ||
8,4.451724638058108,0.0006714305506295441,2.4831349676731485,0.7650843031384059,1.30031503809169,1.2954695905588829,0.7460533787717044 | ||
9,4.396852007827698,0.0007609305506295441,2.45976722854139,0.767622909615852,1.30590520040279,1.2961659586837422,0.746401779294975 | ||
10,4.352536110266097,0.0008504305506295442,2.432826508495789,0.7689183104948004,1.313472760384712,1.3002393342876888,0.7447102197270231 | ||
11,4.306390221640894,0.0009399305506295442,2.3573539382697764,0.7706730601988542,1.3173947905314713,1.2947371563767045,0.7468065448403167 | ||
12,4.258934861979767,0.0010294305506295441,2.34357655327888,0.7690763722095874,1.3186334652451432,1.2934570876075087,0.7464628864851264 | ||
13,4.220289997479037,0.0011189305506295443,2.32019594293367,0.7759701830083187,1.3266125088482144,1.2957579856513433,0.7458785558615914 | ||
14,4.199515611401721,0.0012083461937734932,2.292261097817209,0.775192920054914,1.3300602156874155,1.2946272743584495,0.746502878195242 | ||
15,4.169210983684444,0.001226884932938532,2.299946471483616,0.7723563952669642,1.3311815372873792,1.2952108289462456,0.7460212630177361 | ||
16,4.13768655340723,0.0011866314992028263,2.274164061317273,0.7772363164572837,1.337867815872176,1.293664887612968,0.7463298423598372 | ||
17,4.111515886329867,0.0011500979307215047,2.2511911254448167,0.7715759422474768,1.328113073834814,1.291946166566322,0.7472491428226048 | ||
18,4.087757030417955,0.0011167443205193882,2.2412600506342697,0.7712943388324164,1.3348377158079807,1.2929385880950741,0.7477090486582577 | ||
19,4.066611263959385,0.0010861344888781837,2.259292568367124,0.7717083772199539,1.3337252850452124,1.295763463522537,0.7462432307638528 | ||
20,4.051268791330541,0.0010579116864878114,2.222501818315182,0.7713129445380797,1.3313135950084578,1.2900873361413276,0.748110922243782 | ||
21,4.024202073704881,0.0010317809114094956,2.2441669039305,0.7715686062379105,1.3310894183323758,1.288971638744518,0.7479170140547358 | ||
22,4.003224204231743,0.0010074958119580654,2.23775050119152,0.7738775886898404,1.3311825098191419,1.289019656433441,0.7485954179156981 | ||
23,3.9831246494929786,0.000984848832905516,2.2249009120301295,0.7707060171420829,1.3292037129800087,1.2908269648714958,0.747236661994393 | ||
24,3.9636486920362435,0.0009636636964983663,2.22256381268114,0.7741636406601843,1.3348050692879827,1.2904465948928243,0.7473379093530166 | ||
25,3.9519243949670173,0.0009437895913492682,2.199859188723674,0.7729352396690198,1.3359631829468468,1.288483151235008,0.7476837338631755 | ||
26,3.927755118746585,0.0009250966288375309,2.1930931243067318,0.7755514399966928,1.3369727923071586,1.2859577685937016,0.7489536850826521 | ||
27,3.926465194860524,0.0009074722526913594,2.1646620368127176,0.7709330636668378,1.3345097423465007,1.286610760881438,0.7488208563542789 | ||
28,3.9112838406783768,0.0008908183740754731,2.1707241156365193,0.770531909768061,1.337011253630735,1.2852896711067991,0.74932011149869 | ||
29,3.891847966349471,0.0008750490650422323,2.1675154697516184,0.7721986772106176,1.337264511071075,1.2871684978038336,0.7485575223735876 | ||
30,3.883887992012604,0.0008600886861199685,2.19489818539154,0.7749513125680523,1.3391671771473823,1.2875629777212991,0.7489312301273551 | ||
31,3.8784893897170787,0.0008458703546512784,2.1629789283377554,0.7752403102040718,1.3387697808315604,1.2882882281523045,0.7479766152529334 | ||
32,3.860915168573728,0.0008323346829334197,2.153606718536965,0.7700055455256986,1.3359170970198553,1.2880122103395444,0.7484088907163294 | ||
33,3.8519079181158524,0.0008194287317307121,2.1397062347887736,0.7682272004479871,1.3330188637342926,1.2823840005517122,0.7498121081027518 | ||
34,3.842736325365515,0.0008071051370188623,2.1405153903770797,0.7703682377173774,1.3365645853935013,1.2821266454167732,0.7511674640375396 | ||
35,3.8331822492920637,0.0007953213770578917,2.178359575468395,0.7746631318788487,1.3367616669970368,1.2875275824052062,0.7479080102230455 | ||
36,3.820642862479701,0.0007840391538973042,2.1417715968537983,0.7732939821734165,1.336992701482007,1.283531817603312,0.749598027919431 | ||
37,3.8172812534997935,0.0007732238687796371,2.1355069160836178,0.7705719641574866,1.3386313196113562,1.281844297340762,0.7504793525148107 | ||
38,3.8079702705141223,0.0007628441750461128,2.1433389237675926,0.7711487405106717,1.3389207403838181,1.283274957309216,0.7499048128671044 | ||
39,3.8054058937039867,0.000752871595365571,2.1068212915258093,0.766890015474526,1.334197239968474,1.283080711118489,0.7499391075395934 | ||
40,3.804591294610655,0.0007432801926278408,2.129849029313564,0.7701278266642823,1.3353779563532628,1.2846699109817474,0.7494325245648502 | ||
41,3.7884655847813136,0.000734046285830236,2.1285221047319642,0.7705890382028509,1.3383333157129647,1.280544791026975,0.7513947305179767 | ||
42,3.791377397505723,0.0007251482038634115,2.1062713827666846,0.7680589544195218,1.3350307060570625,1.2803563729623595,0.7508250988062498 | ||
43,3.7848194654564815,0.0007165660713628845,2.127872078393237,0.7672711427350392,1.3340649511744702,1.283205334994501,0.7503666460700734 | ||
44,3.7694308174053734,0.0007082816218046972,2.1204620947051875,0.7669331006947344,1.3368249476999534,1.2825407864705949,0.7499421935149108 | ||
45,3.7737389379690733,0.0007002780338413631,2.130829363060976,0.7671799429824664,1.3377758401892883,1.282753877562038,0.7506571872455874 | ||
46,3.7660545417524105,0.0006925397875382193,2.1464810834314787,0.7719004261662463,1.3356136029177625,1.2805999808384199,0.7509058764234703 | ||
47,3.7605208362762155,0.0006850525377121659,2.110271296473167,0.771535982863223,1.3402624808242818,1.2796646512743277,0.7518908305603318 | ||
48,3.7567992293042907,0.0006778030020191236,2.1459811338376436,0.7749366190291012,1.3407848323748655,1.2790841081494777,0.7513563181240536 | ||
49,3.7531569709437247,0.0006707788618025577,2.110209801634663,0.7715622685321519,1.340665858690836,1.2770644134664275,0.7522690970938641 | ||
50,3.7572914033796856,0.0006639686740182302,2.104587899580546,0.7686209039925793,1.337795538139737,1.2814823500300585,0.7510009904471834 | ||
51,3.7508138354266576,0.0006573617928019408,2.1333978848263486,0.7694736387206724,1.333348288212023,1.280015651901976,0.7515902945811049 | ||
52,3.746899056971591,0.0006509482994567963,2.089636543020974,0.7658247139762803,1.3336317169083245,1.2793679966896083,0.7515962448683834 | ||
53,3.7434352002689018,0.00064471893981229,2.0916940408273157,0.763576082134609,1.3321884178647259,1.276611538224293,0.7525974309038873 | ||
54,3.737763185992981,0.0006386650680550355,2.078343730104343,0.7674884063219581,1.338259790454553,1.2790974633743293,0.7511096631517198 | ||
55,3.7340273343144434,0.0006327785962555395,2.110350710735367,0.7671728735461018,1.3372308666932167,1.278304951159634,0.7517096826576193 | ||
56,3.7360605685769155,0.0006270519489206681,2.0844620974646992,0.7660577733029601,1.335017218949528,1.2784418795000032,0.7512342914189956 | ||
57,3.732170315994262,0.0006214780219909308,2.0951151139012594,0.7669938714127305,1.337316729503759,1.2780945121459113,0.7526324986070059 | ||
58,3.7344442342159514,0.0006160501457777644,2.1009307516497167,0.7685156612128009,1.3355836534983927,1.277049074074573,0.7527485151092673 | ||
59,3.727959960087887,0.0006107620514010679,2.103755388968409,0.7674557139021542,1.3378475934700382,1.2779728179972818,0.7514565464624954 | ||
60,3.7260847608274954,0.0006056078403428885,2.0743731619848154,0.7664868982074519,1.3356119054268762,1.2791487680749813,0.7519802999193536 | ||
61,3.717876913568103,0.0006005819567810015,2.076300912789498,0.7660934031268736,1.3351145226212109,1.274983458165864,0.7530415396767931 | ||
62,3.7234010957313073,0.0005956791624072976,2.0953971295478886,0.7649160693177532,1.3344627185379916,1.2758442911919878,0.7531323733255845 | ||
63,3.7170657583081597,0.0005908945134714726,2.077151996027498,0.7672471723828043,1.336235655367009,1.2786986731635746,0.7512502608127608 | ||
64,3.7144685032614655,0.0005862233398212821,2.0902870287814745,0.7692040037201189,1.3366618404152146,1.2803997760626507,0.750965719346293 | ||
65,3.7182075023690153,0.0005816612257373163,2.0727226285261655,0.766370471918733,1.3381836398804998,1.2749658276969207,0.7528075374257693 | ||
66,3.7067992840754846,0.0005772039923835236,2.068878932121725,0.7676702518876725,1.338306531426511,1.2764666392966464,0.7524991848203644 | ||
67,3.7117608274932103,0.0005728476817148925,2.0861694891699587,0.766012964816576,1.3374556662592825,1.2781786190472508,0.7515926202595348 | ||
68,3.7081695108105244,0.000568588541701453,2.080389078822681,0.7663479120458041,1.337436549615449,1.2788272076235416,0.7512154039087551 | ||
69,3.7037022618451134,0.0005644230127431929,2.078158610021454,0.7635414428128596,1.3375895476591992,1.2776458820570298,0.751488356371013 | ||
70,3.699084381419278,0.000560347715164093,2.0710010827985452,0.7701691225825662,1.3427296814144138,1.2773414814846524,0.7521338678574286 | ||
71,3.707335336657783,0.0005563594376854221,2.0668262955043124,0.7690962548410875,1.3419071740544748,1.2767616018742878,0.7528286917522004 | ||
72,3.6963874278165525,0.0005524551267889306,2.066828113417061,0.7673438935756859,1.3378775975866781,1.2770937033346543,0.7522350374792571 | ||
73,3.702084476760743,0.0005486318768898732,2.1096120933234577,0.7731786887640242,1.3425909613615201,1.2769620291615724,0.7531764254897176 | ||
74,3.690986733693714,0.0005448869212480005,2.066500256628905,0.7645823565538863,1.3383528468756392,1.27619258900284,0.7525033039733774 | ||
75,3.691247302048633,0.0005412176235518742,2.0634797885368994,0.7660988624038654,1.3367605923185355,1.2750035435975207,0.7534378321847148 | ||
76,3.6954951726551353,0.0005376214701183853,2.0563709501295824,0.7648936888769566,1.3393141682333456,1.277334020451471,0.7527339150111846 | ||
77,3.690793166634838,0.0005340960626549686,2.0966378375656314,0.7656858025491575,1.337969245086248,1.2725407623322291,0.7540012773723757 | ||
78,3.6837293342751285,0.0005306391115372118,2.075730923075861,0.765152458322467,1.3370160959864206,1.2788842372531233,0.7516897849899269 | ||
79,3.6944885791477424,0.0005272484295589966,2.082516486303579,0.7678278329439719,1.3383380379355279,1.2754139388728565,0.7528577899449943 | ||
80,3.6803325010968595,0.0005239219261164537,2.0535234679553045,0.7686489648364385,1.3434281202727,1.2779384956116389,0.7520321872875425 | ||
81,3.6818773785121657,0.0005206576017905405,2.047808034389224,0.7628892557297268,1.3347870561630626,1.2738468031473908,0.7531669717447741 | ||
82,3.677086910085993,0.0005174535432964054,2.0551705017571797,0.76616168861088,1.3401155046294464,1.2746346503353154,0.7527900240318778 | ||
83,3.680452075655629,0.0005143079187704905,2.085066824044578,0.7657991541441392,1.3358125651829602,1.2744321378443448,0.7528847644333585 | ||
84,3.6795739013809463,0.0005112189733690399,2.04495018581916,0.7632438464592802,1.3369268351061683,1.2758554264871662,0.7524665824431146 | ||
85,3.6818065166675034,0.0005081850251539766,2.0628927771565113,0.7669954214755973,1.3391431512224123,1.281490843158557,0.7507655288082401 | ||
86,3.6803051068700947,0.0005052044612442199,2.0446594784676977,0.7641745722514304,1.336466554723952,1.2762467643138327,0.7526367500657363 | ||
87,3.663054018041315,0.0005022757342124591,2.040313660314199,0.7644576188968027,1.3371143136243349,1.272762205370238,0.7541795910802265 | ||
88,3.669665906942736,0.0004993973587090738,2.064276907274198,0.7636698755192898,1.3362635723195235,1.2736869076488173,0.75343412369045 | ||
89,3.6722765982051024,0.0004965679082964922,2.0604519461822886,0.7662919184322347,1.3405033760080352,1.2741566049117041,0.7534863166592237 | ||
90,3.6653336802803453,0.0004937860124786321,2.05101285224263,0.7653170765978623,1.3373038387500185,1.2745056950110105,0.7535270191706446 | ||
91,3.668098697401907,0.0004910503539114016,2.08317401507053,0.7663760342466969,1.3403219451728705,1.2769238625556305,0.7522618106984095 | ||
92,3.6699733951906897,0.0004883596657813364,2.0264513960214425,0.7611791260291094,1.3373937090769905,1.274494820513035,0.7534017255711969 | ||
93,3.6677283933819758,0.0004857127293405453,2.0401880402870707,0.7649393848088545,1.337613813050143,1.2733763550940957,0.7535563865804833 | ||
94,3.665543499435497,0.00048310837158704845,2.05391259995566,0.7682618185224647,1.34243581222608,1.2712898571726683,0.7540985717383809 | ||
95,3.663732008304565,0.0004805454630805032,2.0341231750040043,0.764708075514689,1.3391495328355825,1.2744927962973092,0.7533479525076279 | ||
96,3.66345698646294,0.0004780229158840507,2.071742994847143,0.768103684594208,1.338785864192772,1.2733639867104838,0.7533030768378186 | ||
97,3.6569189551809806,0.0004755396816237931,2.05041911147769,0.7643333404578809,1.339491959128433,1.2727547076071553,0.7540891961681523 | ||
98,3.661707510742569,0.00047309474965801337,2.073534300643597,0.7642273457109177,1.3333704930613164,1.2714290554762884,0.7542138446521336 | ||
99,3.6613128041252527,0.0004706871453488936,2.037688012972204,0.7621832658186745,1.334718887871471,1.273594506474309,0.753979969649102 | ||
100,3.6568468123533324,0.00046831592843001697,2.0355441350520063,0.7641465336939267,1.339444780993077,1.2735859129550062,0.7539072484471787 |
Git LFS file not shown
Oops, something went wrong.