diff --git a/README.md b/README.md index 2a3fd61..8b9091c 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ This `Sceptr` object will then have the methods: `calc_pdist_vector`, `calc_cdis |`sceptr.variant.average_pooling`|variant using the average-pooling method to generate the TCR representation vector| |`sceptr.variant.unpaired`|variant trained on the Tanno et al. dataset with randomised alpha/beta pairing| |`sceptr.variant.olga`|variant trained using synthetic TCR sequences generated by OLGA| +|`sceptr.variant.dropout_noise_only`|variant trained without residue/chain dropping during autocontrastive learning| |`sceptr.variant.finetuned`|variant fine-tuned using supervised contrastive learning for six pMHCs with peptides GILGFVFTL, NLVPMVATV, SPRWYFYYL, TFEYVSQPFLMDLE, TTDPSFLGRY and YLQPRTFLL (from [VDJdb](https://vdjdb.cdr3.net/))| #### Single-chain variants diff --git a/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/config.json b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/config.json new file mode 100644 index 0000000..7c77dea --- /dev/null +++ b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/config.json @@ -0,0 +1,103 @@ +{ + "training_delegate": { + "class": "ClTrainingDelegate", + "initargs": {} + }, + "model": { + "name": "SCEPTR (dropout noise only)", + "path_to_pretrained_state_dict": null, + "token_embedder": { + "class": "CdrSimpleEmbedder", + "initargs": {} + }, + "self_attention_stack": { + "class": "SelfAttentionStackWithInitialProjection", + "initargs": { + "num_layers": 3, + "embedding_dim": 29, + "d_model": 64, + "nhead": 8 + } + }, + "mlm_token_prediction_projector": { + "class": "AminoAcidTokenProjector", + "initargs": { + "d_model": 64 + } + }, + "vector_representation_delegate": { + "class": "ClsVectorRepresentationDelegate", + "initargs": {} + }, + "trainable_model": { + "class": "ClTrainableModel", + "initargs": {} + } + }, + "data": { + "training_data": { + "dataset": { + "class": "TcrDataset", + "initargs": {} + }, + "dataloader": { + "class": "SingleDatasetDataLoader", + "initargs": { + "batch_size": 1024, + "num_workers": 4 + } + }, + "csv_paths": [ + "tcr_data/preprocessed/tanno/train.csv" + ] + }, + "validation_data": { + "dataset": { + "class": "TcrDataset", + "initargs": {} + }, + "dataloader": { + "class": "SingleDatasetDataLoader", + "initargs": { + "batch_size": 1024, + "num_workers": 4 + } + }, + "csv_paths": [ + "tcr_data/preprocessed/tanno/test.csv" + ] + }, + "tokeniser": { + "class": "CdrTokeniser", + "initargs": {} + }, + "batch_collator": { + "class": "ClBatchCollator", + "initargs": { + "frac_dropped_tokens": 0, + "prob_drop_chain": 0 + } + } + }, + "loss": { + "cross_entropy_loss": { + "class": "AdjustedCrossEntropyLoss", + "initargs": { + "label_smoothing": 0.1 + } + }, + "contrastive_loss": { + "class": "DotProductLoss", + "initargs": { + "temp": 0.05 + } + } + }, + "optimiser": { + "initargs": { + "n_warmup_steps": 10000, + "d_model": 64 + } + }, + "num_epochs": 200 +} \ No newline at end of file diff --git a/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/log.csv b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/log.csv new file mode 100644 index 0000000..1e4ea51 --- /dev/null +++ b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/log.csv @@ -0,0 +1,202 @@ +epoch,loss,lr,valid_cont_loss,valid_positive_distance,valid_negative_distance,valid_mlm_loss,valid_mlm_acc +0,,,7.498027629133681,0.00015237881834970687,0.10975030459804733,3.04809786580713,0.06912250065779857 +1,5.920741786808741,4.493055062954425e-05,0.035865706390709365,0.0001660916724894503,1.386665392373939,2.564889996758053,0.28018854137777416 +2,2.548739247412042,0.00013443055062954424,0.009694411915730897,0.0001645709809453627,1.397023825940967,2.2302052390883587,0.3935517453075181 +3,2.196317719991706,0.00022393055062954424,0.005090244160973123,0.000161823038892819,1.3995820239090522,1.9330423470869509,0.5066318881795733 +4,1.8300549171630718,0.0003134305506295442,0.0034799991864900323,0.0001632344199444303,1.4001201507846033,1.4895693847108453,0.6795650767227318 +5,1.5132187177801653,0.0004029305506295442,0.002539796827865367,0.00016294473192396232,1.4026292590363734,1.2873766294657503,0.7597034292867573 +6,1.3519530623399536,0.0004924305506295441,0.002103828529536188,0.0001614527620979248,1.401912845204582,1.1793127595774324,0.8009738605333503 +7,1.25479015516009,0.0005819305506295443,0.001785655289001306,0.00016351977969205627,1.4024429413442112,1.1129819174030562,0.8267581637524095 +8,1.1922219232876268,0.0006714305506295441,0.0015079263170800117,0.0001609705061259504,1.4035697604947317,1.0737590697256865,0.8399588944684011 +9,1.1472624949992507,0.0007609305506295441,0.001225322316369514,0.00016033410238880968,1.4057343626073064,1.044727747393977,0.8495674055206269 +10,1.1142254288606,0.0008504305506295442,0.001186105246629399,0.00016261171876712045,1.4008711383573487,1.0222655506893417,0.8558513031457946 +11,1.0895754895348881,0.0009399305506295442,0.0009914499336769814,0.00016072635164934168,1.4046067336979182,1.0065952455139087,0.8600586538435251 +12,1.0704047033337853,0.0010294305506295441,0.00099170191935035,0.00015949813587198624,1.401550751729768,0.9943637230554149,0.8626603778546567 +13,1.0554646643016596,0.0011189305506295443,0.0009327477763914865,0.00015892750562861968,1.402628368959762,0.9846632774943181,0.8648484503999012 +14,1.0447235770250047,0.0012083461937734932,0.0009000918423255129,0.00016042374446282784,1.4028836858293292,0.9761257460913432,0.8672133559305052 +15,1.0346703708788973,0.001226884932938532,0.0009003339275127778,0.00016009002775981216,1.3995296291960722,0.9688726946200338,0.8689243624210613 +16,1.0267084341831574,0.0011866314992028263,0.0008321055924001475,0.00015908587218668114,1.4024909173436313,0.9677657134821459,0.8700809793863281 +17,1.019625038383313,0.0011500979307215047,0.0007948264390597074,0.00015916730220005064,1.401834858441051,0.9605678781059857,0.8714311099663955 +18,1.0141310281945557,0.0011167443205193882,0.0007536613771535193,0.00015973220985384526,1.4037068425139203,0.9564912307117488,0.8732361011806404 +19,1.0094367362399579,0.0010861344888781837,0.0007046259527304415,0.0001595024139295528,1.404741957044628,0.9553548818481805,0.8735756948303715 +20,1.005795917356891,0.0010579116864878114,0.0007339583069559079,0.00015996828925281823,1.4018888469518656,0.9526368404547936,0.8746256770320996 +21,1.0024126872290828,0.0010317809114094956,0.0006802121978360684,0.0001600159983779228,1.404977890330475,0.9484000474991382,0.8756068277617928 +22,0.9997961211819545,0.0010074958119580654,0.0006588868321435337,0.00015996415421119518,1.4049558926376131,0.9486158177256406,0.8757613797162765 +23,0.9970575364904718,0.000984848832905516,0.0006645727059913642,0.00015943095777001232,1.4039062504825914,0.9472452921074658,0.8767357903089608 +24,0.9947260548356288,0.0009636636964983663,0.0006421841079280862,0.0001597148172697442,1.404769234292557,0.9467836295654576,0.8762329863271803 +25,0.9929426263094108,0.0009437895913492682,0.0006533913021071087,0.0001596590424712781,1.4036721070695497,0.9445289415770293,0.8769205856938459 +26,0.9902852152007833,0.0009250966288375309,0.0006406157603416679,0.00015863965303857344,1.4041152124188467,0.9438422452250971,0.8774224249447224 +27,0.9894310529841479,0.0009074722526913594,0.0006084443628530309,0.00016121161432747393,1.405645957371014,0.9427177184526265,0.8779914805392556 +28,0.9878986067256548,0.0008908183740754731,0.0006736192840911935,0.0001614488041405594,1.401401751794166,0.9450262602408502,0.876688380209966 +29,0.9861230346583366,0.0008750490650422323,0.0006381294100635538,0.00015970218715205076,1.4024826165842899,0.9397985968016129,0.8788702473203913 +30,0.984904093780553,0.0008600886861199685,0.0006372685065990919,0.0001613774281835083,1.4030536978537802,0.9410312402148875,0.878090406010259 +31,0.9840356792328665,0.0008458703546512784,0.0006021346357318911,0.0001605706860698069,1.4053780508650686,0.93935085472747,0.8791055225524248 +32,0.9826730397055682,0.0008323346829334197,0.0005823304500587645,0.0001606794189978958,1.4064917031702433,0.938491849330124,0.8794179758408484 +33,0.98122826053346,0.0008194287317307121,0.0006064169964513815,0.00016173489653437746,1.4047890704427903,0.9380129518409431,0.879240197747934 +34,0.9802193669765984,0.0008071051370188623,0.0005808260713740117,0.0001612021335770036,1.4060689621373341,0.9379100499113504,0.879441696739919 +35,0.9787922980850189,0.0007953213770578917,0.0005708499037352394,0.00016190691951809444,1.4067688519261177,0.9363963151943642,0.8798011448580104 +36,0.9782811491333006,0.0007840391538973042,0.0005769014314266837,0.00016085882788564467,1.4056476814971461,0.936760644701411,0.8795431334644276 +37,0.9772736436168701,0.0007732238687796371,0.0005832973052742414,0.00016045829807295758,1.4046969093787196,0.9376555224954546,0.8796276308461369 +38,0.9770215416888984,0.0007628441750461128,0.0005725108943584981,0.00016099308862733738,1.4057384599092764,0.9360900715020625,0.879673561838046 +39,0.9756368397423991,0.000752871595365571,0.0005559707684505406,0.0001603277607140926,1.406648222123834,0.9351901621459349,0.8802561870379751 +40,0.9751413950042691,0.0007432801926278408,0.0005547695420265696,0.00016000324316563854,1.4061540657012115,0.9340151400344601,0.88064655782419 +41,0.97440304141653,0.000734046285830236,0.0005643252025123361,0.00016169014372271616,1.4049013018702816,0.9342123068918176,0.8805653325500321 +42,0.9729020965537261,0.0007251482038634115,0.0005479739601498313,0.00016005783036396806,1.4063897297218668,0.931759201259109,0.8813146457038278 +43,0.9728501401977324,0.0007165660713628845,0.000556986277434747,0.00016188689391465576,1.4050576444558718,0.9324328991353797,0.8812305481450136 +44,0.9723529832859141,0.0007082816218046972,0.0005377808274945819,0.00016097323726913017,1.4072484159938206,0.9333999570023115,0.8809000429491162 +45,0.9719886419994488,0.0007002780338413631,0.0005383633963898524,0.0001607128213829118,1.406921492927382,0.9327568371917477,0.8815874234240247 +46,0.9714041228881873,0.0006925397875382193,0.0005625571440446523,0.00015980314363256965,1.405380014529551,0.9322668443652623,0.8814411515880348 +47,0.970610348001491,0.0006850525377121659,0.0005353043098314523,0.0001608413526149063,1.40642426170651,0.9328312055661379,0.8807152208646093 +48,0.9706983154843254,0.0006778030020191236,0.0005498443127791896,0.00015907193310048216,1.4042835558390903,0.9317374556359211,0.881439056282384 +49,0.9698482067851976,0.0006707788618025577,0.0005272698702980088,0.0001613644073448376,1.4065773929811722,0.9313596790572674,0.8817533461389689 +50,0.9694334085008423,0.0006639686740182302,0.0005441815864932084,0.0001605530094056493,1.4058948871701928,0.9331414569995484,0.8811511659481944 +51,0.9680341096141764,0.0006573617928019408,0.0005327422058059375,0.0001611970888049308,1.4055969615198605,0.931983706883828,0.8813677350411745 +52,0.9684330063989111,0.0006509482994567963,0.0005118473010676549,0.0001620338557496963,1.407656766488711,0.9303352337313147,0.881850661201475 +53,0.9674536738273793,0.00064471893981229,0.0005155931469441904,0.00016122026750181443,1.4070789203752996,0.9302679546210479,0.8814698635185422 +54,0.9670179014755121,0.0006386650680550355,0.000527732219503655,0.00016069333093590702,1.4060577862323278,0.9301213279926669,0.8818343084426494 +55,0.9665915346107252,0.0006327785962555395,0.0005411819122186439,0.0001618977847739391,1.4047343972358477,0.9298134822300878,0.8818645129216117 +56,0.9666591414705111,0.0006270519489206681,0.0005414311199074891,0.00016193343903902883,1.404715604887145,0.9304777703521254,0.8818602656206352 +57,0.9656796214831193,0.0006214780219909308,0.0005216608262620026,0.00016204013674022894,1.4055733990824628,0.9312712929584922,0.8816835705416821 +58,0.965396811238094,0.0006160501457777644,0.0005124046054663675,0.00016154121816875858,1.406454270010917,0.930074129010756,0.8822118673382615 +59,0.9647270455071215,0.0006107620514010679,0.0005156795736419375,0.0001629159564418541,1.4058346141076135,0.9272424344865784,0.8828849918774548 +60,0.9642003416180853,0.0006056078403428885,0.0005105201950090907,0.00015998582077176267,1.4066643899968077,0.9313646973603894,0.8812136183463107 +61,0.9642803896024262,0.0006005819567810015,0.0005086998668883783,0.00016053642345986005,1.4061261584382632,0.9289101153361238,0.8824897152781422 +62,0.9641022525830909,0.0005956791624072976,0.0005113302202976864,0.0001617686816556096,1.4063416488056804,0.9274695628611676,0.8831529028232932 +63,0.9639581878782693,0.0005908945134714726,0.0005063432820852787,0.0001606705261032827,1.406265670008377,0.9279763614485925,0.8827143268531888 +64,0.9633166362405163,0.0005862233398212821,0.0005103574893629557,0.00016167610722243154,1.4065365139521677,0.9281330930128997,0.8824612619332365 +65,0.9631607768965149,0.0005816612257373163,0.00048567149529904924,0.0001618550718025841,1.4085134918354654,0.9287472707778189,0.8826711446717195 +66,0.9624534650818913,0.0005772039923835236,0.0005266781845426047,0.00016065272079575157,1.4048806386562436,0.9279180096078679,0.8827938340899475 +67,0.9625179550652191,0.0005728476817148925,0.0005076845136342186,0.00016065300345629506,1.4066427832034403,0.9264803933950229,0.8831861637048568 +68,0.9626061551631744,0.000568588541701453,0.0005345516451454644,0.00016316358570906114,1.4038998737378527,0.9295619810379449,0.8821495450995587 +69,0.9619359595279592,0.0005644230127431929,0.0004890685033358201,0.00016004730502930468,1.4075415874332786,0.9281228732426234,0.882484765713522 +70,0.9619113687562704,0.000560347715164093,0.0004984748885404098,0.00016140608779786078,1.4065150771565662,0.9255331505936072,0.8835150098214116 +71,0.9609221718297878,0.0005563594376854221,0.0005017482284455109,0.00016133459473098002,1.4061128352707863,0.9265686588337662,0.8827577439640741 +72,0.9614163512744984,0.0005524551267889306,0.0005034327160068696,0.0001621562440388059,1.40553183802065,0.9257575495937299,0.883506084200634 +73,0.9615185281976273,0.0005486318768898732,0.0005066209782167629,0.00016057121909697115,1.405091339015289,0.9266013417541885,0.8831764292540393 +74,0.9611055893760567,0.0005448869212480005,0.0004923117734933015,0.00016036055004689404,1.4064319445576339,0.9260851400473347,0.8834177135035759 +75,0.9603052829616983,0.0005412176235518742,0.0004868813875215316,0.00016140201539225274,1.4081303513662586,0.923556973447025,0.884197972231468 +76,0.9605594956388787,0.0005376214701183853,0.0004985408705117827,0.0001608517885202452,1.406549091966662,0.9258895358496031,0.8835346236021012 +77,0.959772528797847,0.0005340960626549686,0.0004878692174383533,0.00016336637362247806,1.4069850278134644,0.926153135590115,0.8832060138763511 +78,0.9594906002400587,0.0005306391115372118,0.00048478779382161915,0.00016143601042398263,1.407210510332656,0.9272868177095538,0.8826093057168936 +79,0.9596738638291409,0.0005272484295589966,0.0004872439449234872,0.00016182809271428773,1.406374662174527,0.924371275076887,0.8838990080826314 +80,0.9594241714491198,0.0005239219261164537,0.000490582854068192,0.00016091357974749228,1.40617825722658,0.9255202852909855,0.8834132037914101 +81,0.9590088436465225,0.0005206576017905405,0.0004930388193300735,0.00016136517672093087,1.405927735573928,0.9262465255789861,0.8831508287464604 +82,0.9584710146436775,0.0005174535432964054,0.00048583404732192403,0.0001619984121048609,1.4073941562382832,0.9266610748343788,0.8828725477670782 +83,0.9583638625110028,0.0005143079187704905,0.0004876392208131134,0.00016014523404328683,1.4070821552026664,0.9258803154477695,0.8834713107398662 +84,0.9583963374740496,0.0005112189733690399,0.00047796502834217947,0.0001621950510424623,1.407824487199762,0.9266112016409898,0.8833126225314952 +85,0.9582789597977416,0.0005081850251539766,0.00048113546215863124,0.00015993235055279366,1.4075269542121722,0.9258751610012765,0.8832282858024698 +86,0.9579770721601767,0.0005052044612442199,0.0004815939937038694,0.00016198504880865792,1.407182078255493,0.9258178450786175,0.8832101122491121 +87,0.957293281308389,0.0005022757342124591,0.00046270400784121366,0.00016135346467332658,1.4082628742507206,0.9259440580193283,0.8833803513740275 +88,0.9572458179569743,0.0004993973587090738,0.0004769205974863142,0.00016243387524742854,1.4070656558592995,0.9253262337318272,0.883446551799848 +89,0.9578017566811509,0.0004965679082964922,0.0004673236424774068,0.00015935104379661994,1.4074245598446669,0.9241833540221301,0.8841845281529045 +90,0.9573275789139383,0.0004937860124786321,0.00048223035199441137,0.00016088709305279804,1.407315804442772,0.9260367002534197,0.8832554516014518 +91,0.9571121385635862,0.0004910503539114016,0.00046646010678797866,0.00016202235946808128,1.4075294101045905,0.9266329715492075,0.8832615000175407 +92,0.9567009212859927,0.0004883596657813364,0.000490694882841341,0.00016091351941818608,1.4057621759954326,0.9252251286112539,0.8836466002453506 +93,0.9568651943808989,0.0004857127293405453,0.0004720452460475542,0.0001605657218551018,1.4069797177770569,0.9248462126767452,0.883634721514261 +94,0.9567360336825723,0.00048310837158704845,0.00046939476021177996,0.0001624462885344979,1.4075247446342554,0.9255721140950327,0.8835512918441617 +95,0.9568219310337382,0.0004805454630805032,0.00045867466453639093,0.00016138251075011707,1.4083601579072047,0.9247205429758856,0.8836259406128427 +96,0.9567681763688523,0.0004780229158840507,0.00047166739121531383,0.0001618957678232813,1.4069841039370894,0.9240791889416559,0.8840402764756721 +97,0.9557935614839363,0.0004755396816237931,0.0004698630015360784,0.0001623824125748431,1.406891304330997,0.9256488771539327,0.8834573542738587 +98,0.9563671889442298,0.00047309474965801337,0.00046608671463209774,0.00016200740170380307,1.4076017961647926,0.9236526268367381,0.8840343457677377 +99,0.955638850525251,0.0004706871453488936,0.0004652841311331833,0.00016147983150274135,1.407085143973017,0.9231184553369839,0.8843178030541143 +100,0.9554437336872321,0.00046831592843001697,0.0004780153159682089,0.00016143548218767504,1.4063708057930175,0.9239073693308691,0.8841880832056118 +101,0.9554084360279976,0.00046598019146343084,0.00046149481717540786,0.00016288683850221357,1.4075245797835256,0.9242795119550579,0.8839252343250645 +102,0.9556602424986714,0.0004636790583805375,0.00045942149437606326,0.00016160468974462267,1.407743022342513,0.9242428847613762,0.8838529317886531 +103,0.9551032297805433,0.0004614116831014618,0.00045374950196888154,0.00016162178600536675,1.4075171733386838,0.9237546495504442,0.8839616625285626 +104,0.9553299100492001,0.0004591772482279871,0.0004587397295527341,0.00016205851162738632,1.4079648953990593,0.9240647583738919,0.8837451706592236 +105,0.9546899495488492,0.0004569749638054307,0.00045659488201553987,0.00016174980213154973,1.407376975580852,0.9235706569286001,0.8840627483011725 +106,0.9544802642890473,0.00045480406614924434,0.0004719503196233673,0.00016347574266986918,1.4059249994903562,0.9228893452028859,0.884366354885229 +107,0.9542432691326412,0.00045266381673234274,0.00045775318570868295,0.00016059631271700942,1.4073869926065246,0.9259136017968175,0.8836229540460269 +108,0.9537676089171658,0.0004505535011294995,0.00045376412435088305,0.00016068114477093157,1.4076858313807081,0.9250120085624536,0.8836238786437469 +109,0.9540436094659689,0.0004484724280153877,0.0004440063538444349,0.0001606197277628221,1.4087578404276995,0.9253760961901485,0.8836424365494364 +110,0.9541320233550449,0.00044641992821305214,0.0004616478865292083,0.0001612714075900508,1.4074965610078458,0.9222464604289112,0.8846883558953308 +111,0.9538292297917634,0.00044439535378986095,0.00045453056174149834,0.0001624034538398957,1.4069810818563153,0.9232811427374461,0.8841603960499924 +112,0.9536715056318043,0.00044239807719815183,0.0004573100685988113,0.00016166786641664288,1.4073570388718966,0.9230975994844722,0.8842747065459378 +113,0.9536340310694987,0.00044042749045798365,0.0004608032627503704,0.0001616518035614263,1.4071665085378027,0.9223201099845862,0.8844048689141598 +114,0.9533665384098139,0.00043848300437958907,0.0004546837482492609,0.0001614439190637089,1.4073116660370935,0.9227052892345605,0.8842502391689394 +115,0.9530202323025702,0.0004365640478232457,0.0004618193174917753,0.00016268604376261023,1.4070669935350784,0.9216014091415409,0.884780905367081 +116,0.9527986094283829,0.00043467006699447796,0.0004657242869943475,0.0001610531821169322,1.4065738824779719,0.923623617736844,0.8842568151711989 +117,0.9528068657157636,0.0004328005247726144,0.00045466219751326984,0.0001612622982196919,1.4075437911888597,0.9232674649949626,0.8843576953259186 +118,0.953157811020485,0.00043095490007082104,0.0004559852286456458,0.0001614546674393661,1.4071901883364628,0.9223717551653888,0.8843691065148216 +119,0.9523616642541651,0.0004291326872259206,0.0004578765316190044,0.00016235922340905045,1.4074416335975621,0.9239030172407268,0.884073860148124 +120,0.9522634249307752,0.00042733339541635044,0.00044376647891144555,0.00016120905654232992,1.4081832685985396,0.9229466442631489,0.8841774847284074 +121,0.9521266121681993,0.0004255565481067345,0.0004456636570484235,0.00016067676805720543,1.4081745891621855,0.9229012878901376,0.8842723270651229 +122,0.952887874096012,0.00042380168251766556,0.0004559332008234344,0.00016246140013850694,1.4075357855353678,0.9224544877005452,0.8845167496849148 +123,0.9521647338970807,0.00042206834911933045,0.0004437232522690949,0.00016195564607920427,1.4086228459212007,0.9235432399084504,0.8839649650335673 +124,0.9522419705283863,0.0004203561111477381,0.00044121460918971895,0.00016136937528576456,1.4081177906134705,0.9222030565819562,0.8846190112529904 +125,0.9525387810991112,0.00041866454414236616,0.0004392126620851222,0.0001604854059820666,1.4083203495731882,0.9208943570173624,0.8849948951118035 +126,0.9522093192612285,0.0004169932355041022,0.00045200141801617934,0.0001611474881786375,1.4075338702141014,0.9223405136043601,0.884262357645143 +127,0.9520150629239539,0.00041534178407245064,0.00043844089688624036,0.0001613290902141061,1.4085932994027495,0.924075324022727,0.8841470683459508 +128,0.9516388699420274,0.00041370979972101863,0.0004336480917987155,0.0001624520597423931,1.4090047192874962,0.9216666861882933,0.8846141279462727 +129,0.9520597924521966,0.00041209690297033433,0.00044541449137969777,0.00016213071498312812,1.4077301143044754,0.9209543418361329,0.8851785551459898 +130,0.95163916830203,0.00041050272461715684,0.000445462421553104,0.00016145863781747102,1.406889257786852,0.9220574689271427,0.8843452474297007 +131,0.9513483195466381,0.00040892690537943423,0.0004345294786853514,0.00016192529458269097,1.408352215760444,0.9215401842105863,0.8845204234841504 +132,0.9511493314606471,0.00040736909555614046,0.00044791207626259895,0.00016222211317227427,1.4076014266791426,0.9227861948677297,0.8842583427255565 +133,0.9515330251129602,0.0004058289547012542,0.0004442852225246813,0.00016003391192332685,1.4076118163791076,0.9216443021582579,0.8847823379034909 +134,0.9514896800191274,0.00040430615131120586,0.00044569561028190847,0.0001609678951090656,1.4070411985184201,0.9209758569900149,0.8848444815035693 +135,0.9515107162409485,0.0004028003625251184,0.0004448162410323374,0.00016244545545775487,1.40786553122418,0.921757862102273,0.8844379481299873 +136,0.9509154285709659,0.00040131127383724607,0.0004563567847377651,0.0001616843375593112,1.406860915788486,0.9219839296356954,0.8845701570450052 +137,0.9510137477579288,0.00039983857882101734,0.00044028814652150207,0.000161006061025124,1.4075929411951087,0.9231362695051157,0.8841410672110038 +138,0.9508700191814463,0.0003983819788641418,0.0004432381775915718,0.00016278079147032478,1.407582213476635,0.9231362043061699,0.8842660672728889 +139,0.9501712817535745,0.0003969411829142449,0.00044016345543917816,0.00016012294946220898,1.4073467665729618,0.9199148954216527,0.8850252930689801 +140,0.9510353979519042,0.0003955159072345627,0.00043089783428670127,0.00016033986507962635,1.4083270627408528,0.9199420701976355,0.8855117210026935 +141,0.9501059199380584,0.00039410587516920865,0.0004462111305191232,0.00016031636965388514,1.407080387003908,0.9217490378439567,0.8847285053630333 +142,0.9504181525545832,0.0003927108169175847,0.00043305693640657096,0.00016101536185231051,1.4078362762371088,0.9228202503905903,0.8844685692999189 +143,0.9502992693316371,0.00039133046931751157,0.0004366892727366043,0.00016122413212619436,1.4078733443413476,0.9227400302048031,0.8845641813718644 +144,0.9500974087359539,0.0003899645756366926,0.00042704942003660227,0.0001606718808510853,1.4085335862709152,0.9220280227016787,0.8844568505586007 +145,0.9493387907372282,0.00038861288537211675,0.000430457503136876,0.0001615469776656708,1.408459189469198,0.9217117118183041,0.8846401362714444 +146,0.9495934436922842,0.00038727515405707186,0.0004293898440528869,0.0001606869986654641,1.408379723403191,0.9219652991318481,0.8845702335737948 +147,0.9497423278203543,0.0003859511430754054,0.0004359801628621543,0.0001611694700485495,1.4075776755544278,0.9219960816294421,0.8846723337943349 +148,0.9489851430123729,0.00038464061948272635,0.0004411429925507107,0.00016193486618199887,1.4070092947189994,0.9188239233458232,0.8856537693379729 +149,0.9502747952849454,0.0003833433558342441,0.00044304324127926735,0.00016204694614450828,1.406634821295982,0.9203018863301728,0.8852403578570978 +150,0.9493786627829651,0.0003820591300189433,0.0004277212152202603,0.00016192297616294128,1.4087760017443385,0.9199064982536584,0.8848862195760482 +151,0.9494340366196902,0.0003807877250998383,0.000422105789380145,0.00016028999545560305,1.4086404923063618,0.9203878428591943,0.8851115653374586 +152,0.9495026330874894,0.0003795289291600299,0.00042589800713465373,0.00016222613074918935,1.408269987810401,0.9209813520160244,0.8848484039291026 +153,0.9487772502402799,0.0003782825351543267,0.0004338765630482906,0.0001602500079493755,1.4082486405549342,0.9205631821541003,0.8844670814074668 +154,0.9491505002545181,0.00037704834076619174,0.0004258030415331531,0.0001618308549093163,1.408642297967042,0.9215466784928968,0.8847807673329185 +155,0.9492397768438893,0.00037582614826978713,0.0004327517160104929,0.00016173089066844537,1.4074347031177739,0.9196274628811436,0.8851225717280731 +156,0.9488183015559496,0.0003746157643969125,0.0004207720205989703,0.00016184289628395684,1.4089057083055985,0.919783033728053,0.8852185947724793 +157,0.9488364781019367,0.00037341700020861624,0.00042877197374269216,0.00016034629973245068,1.408737668429561,0.9222874349311054,0.8845017836349387 +158,0.948765985746229,0.0003722296709713117,0.00042736026226639984,0.00016151669820944667,1.4078423732803356,0.9206917558834775,0.8850017211861347 +159,0.9488702257353385,0.00037105359603718525,0.0004363256610234033,0.00016073459174537378,1.4072753910489257,0.920128534327328,0.8849556710005507 +160,0.9487191037841476,0.00036988859872874423,0.00043233499130028804,0.00016061883577628592,1.4075491719524806,0.9210704502518666,0.8844702257836818 +161,0.9487022539131352,0.00036873450622732445,0.000429547165877106,0.00016157477545814021,1.4073260986504867,0.9210248951583792,0.8846819157468615 +162,0.9484681208714857,0.0003675911494653992,0.00042876908452619465,0.00016243050763006827,1.4078080728079085,0.920798666114177,0.8848861701542805 +163,0.9478366150235084,0.00036645836302254015,0.00044245197075443334,0.00016083906968298377,1.4065555013936075,0.9195279854076347,0.8853983714837497 +164,0.9485861345086279,0.00036533598502488425,0.00044215450497286484,0.00016202423251559985,1.4063392834057558,0.9190273576863366,0.8851719521644645 +165,0.9485360630050609,0.00036422385704796363,0.00043116811396109236,0.000161164774299257,1.4078143921290411,0.9223458298852816,0.8842073548836474 +166,0.9482355562973575,0.000363121824022779,0.00043102376579626033,0.00016213684639246917,1.4076709895709014,0.9216788738884721,0.884820295330694 +167,0.9480750055418291,0.0003620297341449645,0.00042540401918150355,0.0001602737824868775,1.4082063174847048,0.9210259349176473,0.8847742850866595 +168,0.9476546432218161,0.00036094743878695547,0.00042954080683317255,0.00016122876311020225,1.4079801256261542,0.9198383264336137,0.884880487984634 +169,0.9483333728139024,0.00035987479241301855,0.00042245407796601006,0.00016114180409335804,1.4080759462289092,0.9204119238276962,0.8849585214736787 +170,0.9484461045586203,0.00035881165249704555,0.00043557994397348533,0.00016102325540971022,1.4062636667216721,0.9205123466223968,0.8848800624607446 +171,0.9476891941357634,0.0003577578794430048,0.00042707847582822877,0.0001606926227763156,1.4080553648801895,0.9192463098671143,0.8852535461692127 +172,0.9479468331183856,0.0003567133365079405,0.000421894062343935,0.00016076703223283408,1.4078418355023912,0.9199114028374579,0.8850811272151591 +173,0.9481340206274174,0.00035567788972744275,0.0004306586556367725,0.0001605153184941577,1.4072308598712338,0.9203150730155257,0.8847576568741583 +174,0.9471227717383011,0.00035465140784347155,0.0004184443911153265,0.0001600829864425822,1.4082549543282135,0.9205152986858178,0.884966219608841 +175,0.9473668059938863,0.00035363376223446517,0.0004272187453974001,0.00016052597052036412,1.407504396987057,0.9207990737145845,0.8848677728771432 +176,0.9481894257463565,0.00035262482684764136,0.0004202225054265429,0.00016054869994135566,1.4074300906451174,0.9200694994069691,0.8849332420876408 +177,0.9473076463464432,0.0003516244781334122,0.00040890670467816793,0.00016180456144584106,1.408769031229465,0.919518162762469,0.8856281027563185 +178,0.9474003710536599,0.0003506325949818298,0.00040929805057024027,0.0001615975135736494,1.4090048340710513,0.9193939855816743,0.8855484495545778 +179,0.9470681672429161,0.0003496490586609978,0.0004207969018815671,0.0001609098116500023,1.4081515483588676,0.9196909071611906,0.8850347641901835 +180,0.9470686982764189,0.0003486737527573712,0.00041050545221438413,0.00016081934075778034,1.4081024006932135,0.9182524521211549,0.8855510993376553 +181,0.9467371571969282,0.00034770656311787436,0.00041749873562984445,0.00016116841676983877,1.4075702164167636,0.9178552068899906,0.8857209452127308 +182,0.9468788618331264,0.00034674737779378097,0.0004189781237541094,0.00015989472316451273,1.408039539401974,0.919424387296302,0.8856779322535459 +183,0.9463434660436512,0.00034579608698628736,0.0004124174387551863,0.00015987332968277538,1.4084173182172701,0.9181760862264021,0.8859685677003434 +184,0.9469233092583023,0.0003448525829937104,0.0004216960239227502,0.00016050993658772668,1.4079379484335812,0.9205931830670786,0.8847166950141477 +185,0.9470092530620144,0.0003439167601602754,0.0004140815371197182,0.0001606511605734296,1.4083964097758273,0.9218805851565733,0.8848695810365893 +186,0.9472513089078259,0.0003429885148264131,0.0004149341133317968,0.00016042780923855308,1.4081722704091337,0.9190765690523316,0.885574250440624 +187,0.9466261012492342,0.00034206774528053,0.0004161609367902696,0.00016112018703830948,1.4076779615826576,0.9178216331626474,0.8858536069183552 +188,0.946829375286211,0.00034115435171219423,0.00041048457046588007,0.00016136282316823227,1.4086018488747527,0.9186974200493885,0.8855210181732143 +189,0.9469002616518181,0.00034024823616668985,0.00041098129503026816,0.00015987225280465962,1.4081128343010882,0.9188492052820164,0.885608663405397 +190,0.9460525275016639,0.00033934930250088965,0.0004157813064663829,0.0001612916968680454,1.4078050159648166,0.9181218980361673,0.8854800467371684 +191,0.9460670108660649,0.0003384574563404049,0.0004136175331117413,0.00016277053122973044,1.408114074746304,0.9197219282888711,0.8855077011552672 +192,0.9464843952452684,0.0003375726050379705,0.00041227649921263824,0.00016092051619819294,1.4086313879697085,0.9192645804683233,0.8852140129320145 +193,0.9469640569598297,0.00033669465763301376,0.0004148582932990523,0.00016015017448114773,1.4083161627823246,0.9201574893239002,0.8848585957150978 +194,0.9466143487700124,0.0003358235248123809,0.00041335585651397516,0.00016081731883882673,1.4080380011828468,0.9177287485850232,0.8856796621034232 +195,0.9462016477508669,0.00033495911887217194,0.0004060755159160657,0.0001608767062980998,1.4085357260786826,0.9200365213357076,0.8850416669366751 +196,0.945958733273588,0.00033410135368064944,0.0004195622555394295,0.0001596503664072891,1.4079291068814441,0.9194489121529341,0.885345352222086 +197,0.9468209378531378,0.00033325014464219253,0.0004211609420473047,0.00016098481535018686,1.4068208260218749,0.9170885639761943,0.8861566808847771 +198,0.9462318950419663,0.0003324054086622485,0.00041093361073041924,0.00016105081099775914,1.4086777403258088,0.9196333745079832,0.8854984454437553 +199,0.9458133719653257,0.0003315670641132623,0.00041785837915752245,0.00016009119477700896,1.4079056559651788,0.9207885903620207,0.8847241807106594 +200,0.9460254597296691,0.0003307350308015427,0.0004096829547122694,0.0001600259104829325,1.4085450778111057,0.9191823085963929,0.8855784943553521 diff --git a/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/state_dict.pt b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/state_dict.pt new file mode 100644 index 0000000..90331b2 --- /dev/null +++ b/src/sceptr/_model_saves/SCEPTR_dropout_noise_only/state_dict.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01adf95aabd30be0571df47badb47222fe0e3a55e8fdc006b1d50898f72ee94d +size 633058 diff --git a/src/sceptr/variant.py b/src/sceptr/variant.py index 0fad2e0..ef37a19 100644 --- a/src/sceptr/variant.py +++ b/src/sceptr/variant.py @@ -53,6 +53,10 @@ def olga(): return load_variant("SCEPTR_OLGA") +def dropout_noise_only(): + return load_variant("SCEPTR_dropout_noise_only") + + def finetuned(): return load_variant("SCEPTR_finetuned") diff --git a/tests/test_variants.py b/tests/test_variants.py index 325bc42..8de890e 100644 --- a/tests/test_variants.py +++ b/tests/test_variants.py @@ -27,6 +27,7 @@ def dummy_data(): variant.average_pooling(), variant.unpaired(), variant.olga(), + variant.dropout_noise_only(), variant.finetuned(), variant.a_sceptr(), variant.b_sceptr(),