Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Llava model #594

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json
from typing import Callable, List, Optional, Union
from typing import List, Optional

import numpy as np
import requests

from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglArgument, SglSamplingParams
from sglang.utils import encode_image_base64, find_printable_text, http_request
from sglang.lang.ir import SglSamplingParams
from sglang.utils import find_printable_text, http_request


class RuntimeEndpoint(BaseBackend):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ def _execute_gen(self, expr: SglGen):
self, sampling_params=sampling_params
)

self.variables[name] = ""
self.stream_var_event[name].set()

self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
from typing import List, Union

import numpy as np
import torch
Expand Down Expand Up @@ -31,7 +31,7 @@ def __str__(self):


class FINISH_MATCHED_TOKEN(BaseFinishReason):
def __init__(self, matched: int | List[int]):
def __init__(self, matched: Union[int, List[int]]):
super().__init__()
self.matched = matched

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
return config

if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union

import torch
from torch import nn
from transformers import Gemma2Config
from transformers import PretrainedConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size

Expand Down Expand Up @@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
Expand Down Expand Up @@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
Expand Down Expand Up @@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):

def __init__(
self,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
Expand Down Expand Up @@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):

def __init__(
self,
config: Gemma2Config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
Expand Down