From d31174a4e1ff7ac1efbdb5d89a24f0e477f95cc8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Sep 2024 00:21:51 +0200 Subject: [PATCH] [Hotfix][Pixtral] Fix multiple images bugs (#8415) --- tests/conftest.py | 2 +- tests/models/fixtures/pixtral_chat.pickle | Bin 0 -> 20865 bytes .../fixtures/pixtral_chat_engine.pickle | Bin 0 -> 20858 bytes tests/models/test_pixtral.py | 188 ++++++++++++++---- vllm/model_executor/models/pixtral.py | 83 ++++---- 5 files changed, 196 insertions(+), 77 deletions(-) create mode 100644 tests/models/fixtures/pixtral_chat.pickle create mode 100644 tests/models/fixtures/pixtral_chat_engine.pickle diff --git a/tests/conftest.py b/tests/conftest.py index c850e60a9ca6c..620f8b4983517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -658,8 +658,8 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + @staticmethod def _final_steps_generate_w_logprobs( - self, req_outputs: List[RequestOutput], ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] diff --git a/tests/models/fixtures/pixtral_chat.pickle b/tests/models/fixtures/pixtral_chat.pickle new file mode 100644 index 0000000000000000000000000000000000000000..43d4c883c3a49c314e1b5f529f6505f36b511181 GIT binary patch literal 20865 zcmai6X>?r0mF``O7fD8zyn)MBykpsBaY!K0W;qza8z2Tu2pQCtT2fnWwP>|0oET`# zl1bu8>|innNP@^A3ke~FZ4NQM{E+(R8(ZeW|GWbUFID5c)#$w&^Uj&(ojcE)(&824_$L04L2)}%VqqOU7A&?i!2RW{YuE!uOPqAyb@WO_P9t_S|}KrYvw z?h$?c`Hoaux_Ju>a-|w){_Ip%^U?Eh_%Kv4jNmjI$=&tJkJ7t?4lF~C%umd+Pa znI5$DX`(zX4!Fo!24%IPj_sB9Es=Srf7Dk?>Lx;N8SPz!P5)bc>oQ6)OksX~7~t&mry9i&V= zRNqP_QB0L+&-C?XQ|lD+gBj(KsHaM_<$7@CtfF^aSMW2n?@t!wmTiX7(1VdbtgUi$ zwHR!D{l_+LZY3ZshKfhDw{ns@YNYE4K`?Tuwrs9X17LH4!oA6G z^@a8#nIo^trt+N%cJ4gaT>bq9Mk-*N0BTAF1Nsk#@6h&8MGW^70JdEEKRg_xL*qoq z#&tsU_V@PUzT#3maDJ2GUA1)+>ek?O=yR6E*TgD!Bw*Z%4VsSv3`A>E^J zZ1!Zm#i0&6drAY3?3b5+u{F@3BYV@;&Y=Fg7J?Eg^s5V1@a=Zuorxhbg(~cJd*Twu zm~5D5Kos5h9z;{C?s8F+EsFT(`_*hH{+!&+j~(Igcgb;!bvHjIh#ZQD3vJQk>?Te& zD3k#6rrDi6&SlqNQKX@ezCx`*`ZO0c^+jU~HB^cL*tNUT7-34riU|(b-*fVNNusCY zg7N~Knb#M%z_%@s_Nfl=Wfip>+*gK-J-xL^_E5+eu?muq&U~()yzhj|9h0(t&n`(= zk^>n;Hx7I*dxJMv6v~N;-mFp5;BL(RuAv@1GAao-qosO=7`bBVAdQC-!bc;52Tk75bR3F+N&ud%A7IUf#r+l$@_PkVO_dMj#Jl6pzPn__-cbCNgjGVjhh*mqmLz6*Zj zb+>$uKVDabS?#1#yj09r$?=NW<%bZ*mw_^dm2d@{APUZV>BZq8PZAY24>Fm0) zS#I?)Xk~4xuOQZdFJi(WAaR9kS=r0?cXp=RMQ=LQ4F(JObfK*)-6z0l(|M5+@JjA) zUnlyy*V%NLC9M8Gx>gO*oWlBb9RYJb(Va~Hz~3iHdafLAp>jMIA){-?wh}_E+cHEP zU|~?Z^aeo)TsNIvx13SOU$1o!8Wb*DYz1U%xNw;8vS``%k-x7<=8*CqVYr!$(41mq zke(~bUq(U+=16ms27#0#&5H9pW@NrF5a8Uk8%)%(=L_s0$TmFBP(8zSe1H4CqlBdj zuBZmO$%V-1^ch0|+lZKnbWdaRH{L>!ZdA;lj~Kt%i9>nB4@_=Hv2hj_R0!jdSYKmc zYLay%s!8T4;1fSjw+76e6#wKOxx-f8a*>ANTvv*={(L6aj~Q{xJ5_vr$66?=hy%2s z2=XMbW{NRLI&2ciQ_nhYfgpOQBY}RdQctCXr-6?*+&Dx~wL^`T28OQ5BngH}4wD8} zSs?1D!sKq0G;pfCe~?U~7Mlh>%!-O5lLka9l2y5)(!sW5|CN*erU{}Bo1eh07yKYZ#Pm=AmppoA!FhDePkS07sp3ed^b7B@@qhb@FWsV~=$hFU(ofqPuM;}bm` zDi|ZuSM)*!*av$p4tw-uS8dg9gGC)MZ@CVWj;r^BgZ0}#o$knp;f9XGZ{5v@pM2_V zg4CjtSbs-{s>%MsR$}@F141nks&JIakxdN^+rAFe57U@9blm6L>T zFFF?e>ZhO6%*d7SZG=;D>ZuM&&Mqco$a!l63PEiRE@;l60nu8d*O8Yrvg~he$e>U{ zLaZ|*uDruFHeVPRN~jREUP%jvzR_&`#uds z4d5uM#^srKgT)|!R$jBrUp6i?#y~=s)EN=qGoSqY^dVw~(_gkA{p_+|4H5+B-mUp~ z1=g_gEkQWs?p9zW<>n*{R0}XlaX59SF^3Cai$k>nyeNHO&-e7esKF)3i_-Y-PHH8B zC?qIKxr_-rf2yq=B7;}2jTT5U%2%%eie%{Ol_G%^rEmQ8)Ip+$GT?nUT2UGRQ3ssI zmM=;xj@_i~p^(EDrL{Ld-b!PW7G0FKT%6%M#UzVEHC3w6ZO^M$k5Y7r8@0at&a@%2 zs+_C2S_L*shNG36B|$(7`Y6xXW$ks#wFwZRGWIXCfT_37k$AMZT8&W>FcyQP-GEV? zr?`at*?ZRO0g-d~7^KXL#F?Q6jlk)rBGKao=`3j=$oe@jF^;jGlQhp2YoQ8FZ>mc~ z&b1h19X=L_D7w6&d1>STOrL%csQKab$~+m?&lOp!SJ*1a`NBZZPa?ERe3lnlh&#)? zc3KE86=DK>yY^|eyr_A3*ZE1pD2Je;<`}mc@%O*y?lv*urB>q-%-QU2&f)zRXgf$b zE>`s91_0IarLEjuYb*|V$HZ6r)o~>+YyuEYv^io>b6s9*>&T{JiSo4q6kx=i0h)fmOh>k^LvXdQXx@CihZD zWjOL5TiuakKVJZ-AiFE1^Lc8^wp=#M2d}VD)Dl8lJZBqg0iZ&~(Yf!tW7R!-Q<26( z_RtZ7LqSzGsdc~;MF;KgUHM3o&?ulfIOW5ESsml=vrt+JeF(gcdDK5GsKld6yyZ8q z2pSfb!u6IC*Q;J&fT%{ho>~Y#?mK^2Xk+MsXcd+O2~^qAyNnVy1bpU!M@Ov&1*rwJO80;7-{NafpcOeKRS z+iPXFHi^qowy!(!`g^oJlwkV0ycv6x#SndJ#(Ebs`ZIs_0`NkoB>EVzIAfDeap}+e zlR;51s6$^voT~9huj)+8p+Bsj@gyk}-p3oz~_vbqxIr-Y}N&=yZIP?GnA}QD;7-#)@6yBis^9@Vy2Mh>=;F*VnwN5!-K zImxgf5VFQ>pBD^^MQD(oD`%l(e-87fC9MX6q$6T3uIO&Jn3Bgpjzw0Gb3}X{AfxoA z4(&$0zQ8NjjiMIRtOmI4;eEj`?M1;armFtck3>tM0eV$Crg{0exBen;d>ntVe6{yx zyrtnvcmmLArqy-)GhXZsZ`U}l@i_0oYV2;+#JzjRc{>+)%NFSd*Bd1buEX}TRDIh# zdN<3aKzhT++bNezI)sasn@iB4BYri=%oTTR%0Z{vubj{}5=PClI3gum~g+*fm zxON=`6P)IvR`cAY$;(SQB`YjV2$6CV9H?C<8A3==)Vl3GbP0+k4s&5mYq$@ZLTQ^yh-rN3yTM7Sd`#2 z2}&CP{zutKNkXEG(Hj7X2geG+kzd&c06P+w6-x~awU{IEk|x2_gf@jTDseX*y-6^s z;=hMzT(ubXV85hEkQ@sj6C2U8O@dj{U{FEDs7-?T))q<_tx2$;Xsju38of#I>jN8; zfdDfjyh$*=`xTABMPN+==5(^|_`QORp@zWecpg;$fI-nBlxrnSv=~~1;HI8(Ctkhm zyvp2bz(pwAD){iF;vg|2KcZ{b=))yv>;prNg2S{58ulBYT82?t1yig!3JHA@uQc#h z!Tl9a0zY&S3WX|Z`_+C78__k!+%x zf7L43bLI7#fpV@gtpfjT3jsB%IN)JX{ZcAP1M;xo&495Mh4kHK06!3swYhr+$sF=G zja}tatYffkkH23K5U8(yI~)puzw@(vs1S+-0qUziAqX!Ljt&_Jve#V&cHv4cv{rz+ z${Kkc50ZBcTasYPFUoZchM#{)TS6}%rDG6!D_0~t zD-c`cW*XDrRA-#o!#gczeZ$V8tZyUUO%cFW|ln}HH_(6Z9%}^UX~8cU^rRFf^|M#}p@5 zarx((9Ii_Vxsu+&|6T3qaC-;5D=>4K!I22;>8Qh*Z?mh0O+Pj;TpIA3Od6Y9xH5-y z1$M{Iam17y^zyURe@13&C<=1UQh8S(a?KGMhKpiyR=(rLEe3@IiO5y#a&&drSVA3@ zaF+^h(FaGQ{0Mhf;|2la+81e0xZoHH*t`*D=|JKl&1;)UcZ%~1*{}WVHpM1m`0&9%th$3{{lG*=y|8G2TK^NfZu1%_6Xg@t?h7yQYUa z98c$mj?r_S;-^er5wp)xK;Y&4tvKGIkiYJt^RlSMLWsBefJJSzJkkE;yj_B(6?tpF zV2AOtwt0ml2)VyKjOSm81u)lrc%xuak1<8wZyE*cSYJ-IDC8a->nr+k&44K3V!{y@tPi} zSp8r`Wp%s>Fh0|s3>4@lK51xMU(kf@;2X%~3BLlFQ`MND@n z)JpJ9!SN5&x6;_S5;i`0r(lL$ph4(Lqjd_JLIwpgRFLpa!7Tfnyr_JqV746#ne$Eo z3nyRP_J^T>x(+8Ew@sniL4p8rMTw=la+ZCLl<~4Z-w^LwJKH$*txb~Zyf-=dJR zZYQX@W-0!*Cr3UPmas|jr@6f+Ij?B4Vw~ylak!zfQbIP{ycsB)Qi2x&OG4&c3RssZ z0+xJWo&(=eiU5CP7^v!(Y?&fpVba(@qWU9;)%EVQEV2-E>qztID8`?{AVp&Z=Qq-P z0YZ44uvqdU0P6%bw?7h&CI6KJgRPb?_%R&bqb*)399(~1l0b?yRSJ)=d+D*0=L&)# z=}NI84NcLMx=H>FTlDxg160(-N*r$MmmN;B%0kdDD#Ur|T(9~x75S@niG%i6KkDAQ ziDByjP5vA6kH%{+mVttB2O_qy9?kK95-hNjrSy67?k(NZJOW z-Gd7FkndX%DlkLx7kK>jd@C7J)33ik<>#7p`yNUX2C2K}nkDB84S54pE1*ABRV%Rm zB4nm;tRJp!%1=aQf&tX zR6$FuEw@I0zRRwR$foNI@(=}lZlAyAAwkqoLHW52Z~CIf(EQWay>Ge;DI*p`^H2Bc z9%sqnIEz90Fk4+2@~~T(sqq$u+~J;uy|Te()z9zgl>zngiM*Wj$U{Q}LhfY~`KcD7 zsjFLj6MpK>{E?>{GvqpbU2yqKU2ODj;%)is43JzO@9UfBUG5%e`1kx|7>Jp2DnX4+ z^qtj5xg(Fc?gI@&0qPGKz^B;o=P&>nWBdNZQ?JlS)RDli5$X0kXE?=e&d~-%$ zuyZG>ocre_jaTACue*0$WR-={IzT~ac2U(9cZZe}Iz%7=E^evLGEVRg*0Ze$&u%3g z%7C1$UwO(t_NEwPL8v1l)YZDTpLfQSwV@l6WQSIuJvE>V@%^B_Xi%ttJXdW&JR93F z?dBmeh5=kQ8;djz_BlC@jmx`PT#mIcqHU5gKJm4i&5nM^#KS!VXBxIW zKJt0dGGF~btjyu;eCecRCUfIS*U@fXrjo8B-BIL1VN4+F@bS2;{A_lyG|3P!9wjZ? zIx-w^Et8fl2pJpb8agfGIoIM%CZSS}K{*#so0pttAiz~E6w9OJlA?J|E`d#S%DN0a3~|L?yjI%Cc-~xQC!51FFCek)d}+@*+LO2!|58JdtttN zUtZQto7hS~6au%O3ptgFDsSTaZhC5hhN6yuA0N49!7YLeq6#+hx2|<5PQU3X4MZUc zQGolnD$TE)<-(e(cNfVXDuEO-mr^_F<8B5ln_3J?ITc>}F^kdY-uCt%{mL1wWDZ5Z zEeX6EgtNTl$L^?!=bqOvS_}A^S)n@Jj9S)M6iSGSHR@&j&e5)=+RR&m2GsoFor2?C zSnZkJG9M1*pr4VSo6Jk>&wI$BE6txKW;0(WFZEPNrQuuL(lO1kM7?r=sE4BCSBxvRpOH6MM9wj*a7V&eeb(0X&!5VC;)bF zI6F~O!Ws)QfYDNzjtedYzdyWvWLVP3xDxhn-0<9RzCB9=5#{a=+g3T!*H;7h2ZeAs dSkO2~JgPq^{|7y`q4EF# literal 0 HcmV?d00001 diff --git a/tests/models/fixtures/pixtral_chat_engine.pickle b/tests/models/fixtures/pixtral_chat_engine.pickle new file mode 100644 index 0000000000000000000000000000000000000000..19dbeaecc8dfffcddda1d66f83f24a1ed1ec16fc GIT binary patch literal 20858 zcmb_kdw5mVmA~)ghDQk%jC4e= zV(sS)ZAH-;AMM;u?WnC)tE1nHC}%_cs?}OS>(J_m5d~3H6jAX-X05%>-e<3HL%*4C z{Q8IW{Ib^XxA$86ti9Jhw-~uK`|W`G=dumc;eXD-n3Jk@-feK+k2vQ{bk3dOjBj;{ zQT)F*UE?es<D_T@!NEbB_OtrYFq4w26k3YqQ>k?n^6 z+?UO^r@KYIH@7_1mTuVqPSz+t>qcDB)wQZ6pT4{|-QAYnu&&~qY)4NnyKKWP8)jU) z0YK|2y1f5#ru*c(RsFzoT_l(4Ub*4y;JOj*={B(3zO;~CneN`OE)0w|^e%HI)^=9V zpLbTBDMS0tOTwRZ4m&&4IecB@{LV}sLE|DeXw+x-3Ic@z6r2l)*0=u<4A__kgNXp9 z;*4~zkjiwUtL)*NlaCKwnBs32xIXiCMq z4Hnk@;GJ!f{E`Y43TX1ZXKmPLg*_fUR8c9~(>0RUqv(A*6x>YRj1t*JMIn(wHS40b z!Exb6&o!zEi@F!eVV6Z z`1SwtV5Fj_x2FdOueE6#2i9cwCdeLz)H$hTxm0^LhXeD}-`W%>&DKa%5)ptz(1|QN zChgN8($SW}bW}RVw_4WRcW=HN+~T;Pj!L~IqQwX7aB@FCYkrY{s3a=-GKEe=^<-^x z0|tdMVxns`%m>2qEXRX)-<2R^sG(AHWwR?i;1^76WDyEv?J(}k(LZuZ9piaQlvUC71fG%f5`@p+J3wAg;a=C zp^)xYI5v8+Zi&XB4l{a61JCR?zEw()MbzMB_D$DVPTltn1SM2>uP#(^%3KR4$K3ZP zu!M7qD!?hc;t<6bWgkWH|=G(k8Dsa!7&@gNoBHHxb^dxULkjIo6(qN0$#QcVMwG2s<$tE`O0qb!r|+llcd zUk0QGl?FJ1yJp(Jy#|On;Oz(DtP=M+`>i1(TVtL*6f#UKgF2)mm+hqhoPW7xQr7JW zC&@R-fg)lhIzE@Z#OnoIi113RN?W3waGeIBfS^D#H?yp7y7saJnL+k3G?P>Bh&{~EF+mMO`XPaq zdecAHa2Zrj<04RaIt!y-P5MC8X^6MBPD!r#ds1 zI``GKevN-IwH{uv&J@>K9mleE?CDM_;xsMxD%&5=Z4nEsVit!QBM;s$BAv?RaV3Xx zSFJ~xE7D!7%a*xS^U%r_seD1Kf=I-QLqO$<<#HjHF0^%~!A4syo6qA5U36x1neRcl z3^AL|LD38^<=*zyBENFADVSNw>b}~!te=Jyrmu4}4EY#)EZu{zjFt3U#oa6wcQ!(X zR}D`RLQUHOL>y#UP;=r9f)F@uUVYuVu$I4IYai1noVCbesMm1bu;gXQvT6Uo0ZHbN z@{{1MkrA3ELrHqBD4&dkVl0lP7c~MYhguZp8O%tt&=BCvgb3631l0bp{Xvk z}`FYl00j&8LfGTd-=zt1)#8jC_hat&P&1HO%Z~iYKa=E z6bvr=2B2`VP)keeQuU&m6SjD} zJJ)XPCxWQM)F`lQSySR3Sg^zMi87*UnZZEFv#_fa`#coI6m4g}T4Gs}!%G^2?4zP9 zLpvx|pvn;jh4k$LRY_>Dw<_*|3-lU~62h<^A~jxMeQGc$z$>se*wv1E$UrET^4VTA z)cW}iJV>I=uO(>2P{A;fzOn~)fq7VN(e1J0Y_)NNMIA71+2y($bi3tR%ZC zqHs+|_ZxR}_v5ZxQz95DiS#aCuJ&brU^8*F0il)cPDvVN|JLe|u zuIXA0Qz-$gtU`Qq(X!}Pz3`GCW)w>JLc$^W#y2dKJf;+rgi`$RwS}OzdK)yYs6jLr z>D3e^4Xpf|9?&S15EHAl$0c{$#+nTbB~%D_zol22wB~{S=1&Lu17>mc1O12ghS42( zp$H&9tFe|*QhD_h?vNe~Rha6DCr#9VP()Zz((rVsG?=pKz;-9@_UU>Th|^&2PT0Ol zhIUO#`Xao2Q|mD$KL6pgMxh26yj0gJIte&1+aBbIu|{)|>ViWkoY?QMgs6M?M>r<1 z9+HxtR3SrqCzaTX?UN(7e%h82fmJ}Mf(A|2u{T*$V3;5j22|L~!vW@Yn$#Gj0-R!T zgtq?9_9%}vwopYBm=S8kMqBHwqPB)2BI$MoQUcfzyT9J@Zyq%iptq)c*BS4tD*5!; zu5%>I>Cqd%pCBM=fIw0ETwY{1oT4$vpKY;O=`S0MF_6$Hy_$%xna{rarRNaYm)(?p zdD+q=al)~;H+_5yR=<5oKjDzOy_0nK?r`i94O9~#%I^**7;`uQW{<08fbU8l-Fdbk z$CM0o;ENN_{YD=gl_VSri5kCL#;EN#de#n;!MCps21qjWCY*0yeH6*?wy%^4Y*)Jb zt7c19A-n+>?gB?sWZ=!c(|oX z`QQmtAJI0Dw4E@D^LU$(JLQ(!Jcl?7!FfoT?-CmmbtR&N!%w?Jhwn(IN)18QFM){> zjPZ2iY)C|*euEB zq=sO?d$Ytxc~((F>`~^sr&;|P1AII8du;i(X6J9GO2Zb0PHk)E+4~Xqv((KN;fQcj zD{u;?t+yBF;6uMJk{zTR6^pyGeSm5`X(JEU>Ny&Ryd&Z}z3Rx4Z*8iLF&t=p#-PT! z`c%uS`&MsLiM(QHMAhaRx&7r93nc@S7WET~vU6DuPSg-~#|JhnY+xt>HVvt6I$GLo zS0OTPV5k6Ul{RQKfw9%>v_Jk6?UN+r{R6xiOD(VN?lq4AH@&Zno1+07yx)J!2G1Qh z?NtwpLWUbGZPWNMlMMXJsv;T2h{rs>kGH|hbJP)|qX-nTc9q3Xh7AmbREaJ)#Lc&* zdd!ZVS#=+-{kb5oI3q(^jo$P_4~9y(Rp%5od*be<1X;vssen1|>Po93;?JxUYJ5xM zXkL2KDRqojeY`EFsa>N`PE5_7mZFY=EU^30e?M0w*C?jSKc~TA;v2gC28x3G4c%(C zXZqTpD8$s+4jSXdX{0+X$uCOL#p!V;THIvC2O5JyV7s2Z0#4BI%vt5hRJ*WuUyhGX z`a+;MRmV)|cz%U#H1M-V4~HruBHOF-go%(L0~#daPL=^x21fId=`N^-S7VdyadeH{ z%BqAmWnU`SZng=m@8$mdU|T+w|mzUP!x=e|)S5NasfuA6kvhms(uK)36-%Er+L zlZ2u4FjK|mT{h-545z`&yPAV3uBFu77X7%N3}P@<>hnFEgr>z7PzJ@OA^~X<;KUC> ztXV884TvT|_1zK1vv;_Gp^8Cy_SP`lZEcNxULS^oMlh#zuz_*kXp#w{M;UGghLyVR z!S(&Zs8Guf8#wnX^Be`3fpg|8vq_Cy^dT6*8Atw=Vp+NyyCEo))%>zy%9S1v=fSJ_ z6?!=$vhK64q&LXQ%>}I$;+Q2A?+Pb`b z_eCsKM0k=Qj}%0h%6jmbUw7B-x-&?FU*50#=^yS&_!44pcI^56@Jw4pZ1=Z3zo;T2 zI#b_64OgtNK@|ps8YMdPi9d<1drvB;J#){ryLi0B|#um5rqyyuR5(_nIez6PNP7E z%9KYtY%7x&tP*4e6+}fxE~R`$zHg%%H)|BdP)AO^BvlHq0@b~D?QN3m0d-ZNaz628 zBa`$TSRbCbvVLxbpyxR3e&+H5b>eR{3R&xI0N-0TUbMDEmXNczw?@XSFSTj!z{iqc zYOHl%iWm6D6oVcJS!03E_Xu+Wx)@f$LiPPL7Ef~%8iJHVVmi*~ZhQYE=M^;s8HdCr z02!hmb$ADA*G%Vs-+_X^n4o%Aw~5w5J+!K}O>}Zm=c)^xkx}efEn4Bc6K$~IIoRx8jg1^sKTCBb<}xygtL97vtYK@*4i_ut#!mMmTI)N{Qvc^s*Y!+n!(-qbfKl{cTE{-#r*n zZ^DW(xX;z_$cdf>j49pca^9F_C+0W-N9BVcuCPaSc;0T>dHNTU_~^Xl z+r8<5prGf-7q)wYT4l{rKUWz*8&BQ}n^n{h;M(~GnBX*LC-rliCf`wRnfY9Qz!Yg^ z8)c!!CnQ0DDP)mql<_U#sM_x)38IG9>uwF|l=&bW+L$n%GTw9>vtHYxWkxmK)By8F z*>9_>i{um68NX3>SpRF-oR0wKMFy;Zy8p#*RaHNvQ78dl!rrYa-gm1Vd!^?UMc92e zhxgNJ|5*c32E1UsYgI}KZ5YIZ>so-fAoOH&|G-8R`ET+`)MPLF;rmsX-Wd`3%nV8ut;uEg4GuoeCF>SC~lX8 zL>WV$`N!nSe!@{)nKOT05&JOIqEEqjUto0qiC_r>8D&)BYC3dZV0`x85`ooZn6v!B zeSr}PLZOQCeSxV`V^BfGkbQyXfX1POq51+diyEb(Y3RPd{qm`#FF=n7?+c8`hlNj{ z^#xea$=Z{yEfNql_(8`joVx$iD4K+F-Gea(LzBR-=42gCZGrcYt3-d+X+7`^hoNunztC z)wU1MsxJ}@Rj?)kzq1yd@{b7uqL3qZ2ZkFgX&JIRFlu3mY@(XK)*X2J?@#m4D!9sY z2i&~|0&2=^4yty%;F;iQ!O#DP8x+#F&;JMAgsaJR^asoWrF*$5rdX3;)2`D$2OO?s zKz+sDU?2c%Hb1_fRtzQy0@PQ0N)WzDI5wam$lhxtuv=DgRzgF7x~dxaKE!_xBaZtJ zMV9TrHa$J?wKQbQwFvqHK2v?E=~BBjC0nG1D5fnAzN2^-!H3`=dtHRR?96Ajv4Y_d zYFVsb!@Q*1_2eO+FAPhWCc&W{0~X~WngoG;*LZd)IJ_o-a}_&qh!_;|jtj@TEC<{U zYbV4E76s^bf*Nw(DR_V7?>&AP+J2{ipY4Y(-fyY!&-U3Fw+y|aQ7FSd<8Hm=HGZNH zHs(-C1uUE$9qR01Q0F0L45A9Z^PqI_cEQeDS`z^s=P)?>CLs>a;C8{+XZ=DD7FV9b z+6DZiFnsrmMFOL!K_`V=zYj%OzWzZCh^VoMnV_59%#?fI!U#C^($+S`$yKK7JQyk& zL~c;K;4ce3DC)4=1-uc^I8Nh81h#S1;mi-%{lcc(H4LW);wF>Ez822RVeNulm;EN` ziy71|;D@MRJ@98qP!trSOP*kNP$M9;q|392Vsy>Q3z6t{jX@!PA#ydl5M5;~p^i$p zJ_VQO{d0ESE_ki>k3de(irp^YDn9$B?okg?!Mh6Ke?3W(IaFb~3cOuVvu$EOfk=jK z7x44oQR_8|WccU7Dpz>B;Q9PU&m`u2`F4T(`2&4L@{1bE9|*7fiD!pqfxhlxLm?;b z)ffsVb1Kd!Ysv+G^h{AWd6%{odEPE~={|ewGjzMaJ@oFS!WUppB6+(Y-rM4VP=fVs z8PBs328JpIC_T@f+QQGg%i{FOJ- zppd`!%H!K2A40scynOS zI{;%J)G&&EuLCfsoiNUzkh_1Zujn(s?z5xf9yc0xARQsBEWTa&e9%5ULoe zQ_v_i3S_7x;hlmh<~c=C`A)$x#uu{3euCdAVCm$F2Oqbnm%T026sRc~;))V0b>&p^ z94X@sf4(42mHO|Avi^#NZ=;)w8ikatMgcGJH$6M>wIDVaFJ6g1(e6FT8AV+<<46aO z$CobKO30RgMgV2KmEfCzIUfcR*HHlLndl~io95Y z?lzumaguQZLGl&iymYoleUpm(Rl~$WyDLw7)AOuy5t|Qa@>?uI8rF972h@_SB4jxt zHnA?yQ3inm!e9n5@*~TYoPO;io)s0?%eyjEUY^N1zs*67Qe(fo;5UA^*_A8t7A5BjUb!XDsPUh+xA96>yjP3rKs5v2H>xW7ue#8tzvlJf39^VY5Tjld3SdR2tp7ou z=a{AdI#~!}#f5CB&>&Rc6&F+Pu-r*uz7RnHacGLQWmkFM?ehEXO$K?G0zS4c-8;*p zh6>7$ZLl|L(bs3B@7NP5!v;g+PnYSQVJ8bC3j|pGt zAyG#RKj)w;^PJ(BGc*jPK%SnMynKM= z%dVOI zfh73@@~Z5bWP6ct9`8Lz*1;2S#`x*%oQHHd1dRWWh1tx2wxO08D$ID*HK#|HR4Op2 z>f#&Tx#wvJa8(n<>L@uksh?9w;F+rm2?ld7zcud`Sd)RF5-1_FDYcS5 zV^_e^N2#E~i$8iY8tg0H?lTL2S0t|}0`5uR?I0ZGxp&!KW06af1f#itOB@Q-VOP|$ z+MrNESgcYnNLQlk%Ir?MABy+NUx zFx=bLLfD(lUH|-JKo3Izg?KgDc2<6jUjMxGCBWz3cySAPRWpxuwHk&3LP7=7br;)1 zJlS!rN_eS}zHBB;>h`TzsQ`x{gx*sIqu)yN3JTmU?C!(%AcSg(|Ffd94Lu>*2 zgdu!WRgxUi(BtJb9Eej{*;!>!NFNp^xG=8E>o6hERw;`M-1-U-@U7yOAmnE?B$>4K=fGjAu)IIa1LFXSA-4w9`}0ougmUNIwcI zzE7SH?)L9V;I3oG3JmrpbwAmms}SI?-lTlIhV8q^GlSHv@q(L>cJ$Qr-S6S}=(DMm z4F-ku{TNe&&kVY6JnSTV{tB^fDQfUf=lNO{yYOX6D3oA- z>*4$ZE1HKJAPT^% List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "text": PROMPT, + }] + [{ + "type": "image_url", + "image_url": { + "url": url + } + } for url in urls], + }] + + +def _create_engine_inputs(urls: List[str]) -> TokensPrompt: + msg = _create_msg_format(urls) + + tokenizer = MistralTokenizer.from_model("pixtral") + + request = ChatCompletionRequest(messages=msg) # type: ignore[type-var] + tokenized = tokenizer.encode_chat_completion(request) + + engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens) + + images = [] + for chunk in request.messages[0].content: + if isinstance(chunk, ImageURLChunk): + images.append(image_from_chunk(chunk)) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs["multi_modal_data"] = mm_data + + return engine_inputs + + +MSGS = [ + _create_msg_format(IMG_URLS[:1]), + _create_msg_format(IMG_URLS[:2]), + _create_msg_format(IMG_URLS), +] +ENGINE_INPUTS = [ + _create_engine_inputs(IMG_URLS[:1]), + _create_engine_inputs(IMG_URLS[:2]), + _create_engine_inputs(IMG_URLS), +] + +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +LIMIT_MM_PER_PROMPT = dict(image=4) + +MAX_MODEL_LEN = [8192, 65536] +FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" +FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" + + +def load_logprobs(filename: str) -> Any: + with open(filename, 'rb') as f: + return pickle.load(f) @pytest.mark.skip( @@ -16,49 +95,74 @@ "Model is too big, test passed on A100 locally but will OOM on CI machine." ) @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( +def test_chat( vllm_runner, - example_prompts, + max_model_len: int, model: str, dtype: str, - max_tokens: int, - num_logprobs: int, ) -> None: - image_urls = [ - "https://picsum.photos/id/237/200/300", - "https://picsum.photos/seed/picsum/200/300" - ] - expected = [ - "The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa - "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa - ] - prompt = "Describe the image in one short sentence." - - sampling_params = SamplingParams(max_tokens=512, temperature=0.0) - - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: - - for i, image_url in enumerate(image_urls): - messages = [ - { - "role": - "user", - "content": [{ - "type": "text", - "text": prompt - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }, - ] - - outputs = vllm_model.model.chat(messages, - sampling_params=sampling_params) - assert outputs[0].outputs[0].text == expected[i] + EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = [] + for msg in MSGS: + output = vllm_model.model.chat(msg, + sampling_params=SAMPLING_PARAMS) + + outputs.extend(output) + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_CHAT_LOGPROBS, + name_0="output", + name_1="h100_ref") + + +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_model_engine(vllm_runner, model: str, dtype: str) -> None: + EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) + args = EngineArgs( + model=model, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + dtype=dtype, + ) + engine = LLMEngine.from_engine_args(args) + + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS) + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS) + + outputs = [] + count = 0 + while True: + out = engine.step() + count += 1 + for request_output in out: + if request_output.finished: + outputs.append(request_output) + + if count == 2: + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2], + SAMPLING_PARAMS) + if not engine.has_unfinished_requests(): + break + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, + name_0="output", + name_1="h100_ref") diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 010cf85f45e07..b26fd558fa1ea 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -import math from array import array from dataclasses import dataclass, fields from itertools import tee @@ -15,11 +14,12 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) - mm_encoder = tokenizer.instruct.mm_encoder - mm_config = ctx.model_config.multimodal_config - max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + patch_size = mm_encoder.mm_config.image_patch_size + image_token_id = mm_encoder.special_ids.img - # approximate image size - size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) + mm_config = ctx.model_config.multimodal_config + num_images = mm_config.limit_per_prompt.get("image", 1) + # dummy size + size = 256 image = Image.new("RGB", (size, size), color=0) - img_chunk = ImageChunk(image=image) - tokens = mm_encoder(img_chunk).tokens - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, - tokens) + image_feature_size = (size**2) // (patch_size**2) + + num_image_tokens = image_feature_size * num_images + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * num_image_tokens + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_image_tokens) seq_data = SequenceData(token_ids) - mm_data = {"image": max_num_images_per_request * [image]} + mm_data = {"image": num_images * [image]} return seq_data, mm_data @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def merge_multimodal_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: Optional[List[torch.Tensor]], - image_id: int) -> torch.Tensor: - text_locations = input_ids != image_id - image_locations = input_ids == image_id - - seq_len = input_ids.shape[0] +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is not None and "image" in multi_modal_data: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - N_txt = text_locations.sum().item() - _, D_txt = inputs_embeds.shape - N_img, D_img = image_features.shape + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img - assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " - "to image features dim {D_img}") - assert (seq_len == N_txt + - N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " - f"{(N_txt, N_img, image_locations.sum().item())}") + if image_token_id not in llm_inputs['prompt_token_ids']: + raise ValueError( + (f"You've passed {llm_inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.")) - inputs_embeds[image_locations, :] = image_features - return inputs_embeds + return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, @@ -201,11 +206,21 @@ def _parse_and_validate_image_input( return None if isinstance(images, torch.Tensor): - # always take last images - images = [images[-1][i] for i in range(images.size(1))] + # if passed as batch take all images + N, B, C, W, H = images.shape + images = images.reshape(N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): - # always take last images - images = [images[-1][i] for i in range(len(images[0]))] + # if passed as list flatten lists of tensors + flatten_images = [] + for imgs_per_req in images: + imgs_per_req = [ + imgs_per_req[i] for i in range(imgs_per_req.size(0)) + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req + + flatten_images.extend(imgs_per_req) + + images = flatten_images return images