Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [Formater] Replace black with ruff #335

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,10 @@
import re
from typing import Callable, List, Optional, Union

from sglang.backend.anthropic import Anthropic
from sglang.backend.base_backend import BaseBackend
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI
from sglang.global_config import global_config
from sglang.lang.ir import (
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
)
from sglang.lang.ir import (SglExpr, SglExprList, SglFunction, SglGen,
SglImage, SglRoleBegin, SglRoleEnd, SglSelect)


def function(
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/backend/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List, Optional, Union

import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union

from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
import time
from typing import Callable, List, Optional, Union
from typing import List, Optional

import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.chat_template import (ChatTemplate,
get_chat_template_by_model_path)
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
import tiktoken

import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e

Expand Down
7 changes: 3 additions & 4 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
from typing import Callable, List, Optional, Union
from typing import List, Optional

import numpy as np
import requests
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglArgument, SglSamplingParams
from sglang.utils import encode_image_base64, find_printable_text, http_request
from sglang.lang.ir import SglSamplingParams
from sglang.utils import find_printable_text, http_request


class RuntimeEndpoint(BaseBackend):
Expand Down
9 changes: 2 additions & 7 deletions python/sglang/backend/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import os
import warnings
from typing import List, Optional, Union

import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
from vertexai.preview.generative_models import (GenerationConfig,
GenerativeModel, Image)
except ImportError as e:
GenerativeModel = e

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple


class ChatTemplateStyle(Enum):
Expand Down
8 changes: 1 addition & 7 deletions python/sglang/lang/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable


def compile_func(function, backend):
Expand Down
22 changes: 5 additions & 17 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,14 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.lang.ir import (SglCommitLazy, SglConcateAndAppend,
SglConstantText, SglExpr, SglExprList, SglGen,
SglImage, SglRoleBegin, SglRoleEnd, SglSelect,
SglVariable, SglVarScopeBegin, SglVarScopeEnd)
from sglang.utils import encode_image_base64


Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,4 @@ def __init__(self):
super().__init__()

def __repr__(self):
return f"CommitLazy()"
return "CommitLazy()"
25 changes: 5 additions & 20 deletions python/sglang/lang/tracer.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
"""Tracing a program."""

import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.lang.ir import (SglArgument, SglConstantText, SglExpr, SglExprList,
SglFork, SglGen, SglGetForkItem, SglRoleBegin,
SglRoleEnd, SglSelect, SglVariable,
SglVarScopeBegin, SglVarScopeEnd)


class StopTracing(Exception):
Expand Down
11 changes: 3 additions & 8 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Union

from huggingface_hub import snapshot_download
from sglang.srt.utils import is_multimodal_model
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
PreTrainedTokenizer, PreTrainedTokenizerFast)


def download_from_hf(model_path: str):
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.context_flashattention_nopad import \
context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher

CUDA_CAPABILITY = torch.cuda.get_device_capability()
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from sglang.srt.managers.router.model_runner import ForwardMode
from torch import nn
from vllm.model_executor.parallel_utils.communication_op import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather)


class LogitsProcessor(nn.Module):
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.context_flashattention_nopad import \
context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
Expand All @@ -15,7 +16,8 @@ def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
self.head_dim = head_dim
self.layer_id = layer_id

from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.managers.router.model_runner import \
global_server_args_dict

if global_server_args_dict.get("enable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def start_detokenizer_process(
):
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
Expand Down
16 changes: 5 additions & 11 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,17 @@
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.io_struct import (BatchTokenIDOut, FlushCacheReq,
TokenizedGenerateReqInput)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
)
from sglang.srt.utils import (get_exception_traceback,
get_int_token_logit_bias, is_multimodal_model,
set_random_seed)
from vllm.logger import _default_handler as vllm_default_handler

logger = logging.getLogger("model_rpc")
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
from vllm.model_executor.parallel_utils.parallel_state import \
initialize_model_parallel

QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}

Expand Down Expand Up @@ -92,10 +93,8 @@ class InputMetadata:
decode_wrapper = None

def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper)

self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/router/radix_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import heapq
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple

import torch

Expand Down
21 changes: 7 additions & 14 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,16 @@
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
DetokenizeReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.hf_transformers_utils import (get_config, get_context_length,
get_processor, get_tokenizer)
from sglang.srt.managers.io_struct import (BatchStrOut, DetokenizeReqInput,
FlushCacheReq, GenerateReqInput,
TokenizedGenerateReqInput)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
from sglang.srt.utils import (get_exception_traceback, is_multimodal_model,
load_image)

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Expand Down
24 changes: 10 additions & 14 deletions python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import \
get_tensor_model_parallel_world_size
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)


class GemmaMLP(nn.Module):
Expand Down
Loading