Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proposal: Move token decoding and stopping evaluation to router #138

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 20 additions & 20 deletions launcher/tests/bloom_560m.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,121 +19,121 @@
"tokens": [
{
"id": 17,
"text": ".",
"text": "",
"logprob": -1.8267672,
"special": false
},
{
"id": 1587,
"text": "get",
"text": ".ge",
"logprob": -2.4674969,
"special": false
},
{
"id": 11,
"text": "(",
"text": "t",
"logprob": -1.906001,
"special": false
},
{
"id": 5,
"text": "\"",
"text": "(",
"logprob": -1.2279545,
"special": false
},
{
"id": 4899,
"text": "action",
"text": "\"actio",
"logprob": -4.170299,
"special": false
},
{
"id": 5,
"text": "\"",
"text": "n",
"logprob": -0.32478866,
"special": false
},
{
"id": 12,
"text": ")",
"text": "\"",
"logprob": -1.0773665,
"special": false
},
{
"id": 30,
"text": ";",
"text": ")",
"logprob": -0.27640742,
"special": false
},
{
"id": 837,
"text": "\n ",
"text": ";\n ",
"logprob": -1.6970354,
"special": false
},
{
"id": 1320,
"text": " if",
"text": " i",
"logprob": -1.4495516,
"special": false
},
{
"id": 375,
"text": " (",
"text": "f ",
"logprob": -0.23609057,
"special": false
},
{
"id": 4899,
"text": "action",
"text": "(actio",
"logprob": -1.1916996,
"special": false
},
{
"id": 3535,
"text": " ==",
"text": "n =",
"logprob": -0.8918753,
"special": false
},
{
"id": 5109,
"text": " null",
"text": "= nul",
"logprob": -0.3933342,
"special": false
},
{
"id": 12,
"text": ")",
"text": "l",
"logprob": -0.43212673,
"special": false
},
{
"id": 731,
"text": " {",
"text": ") ",
"logprob": -0.17702064,
"special": false
},
{
"id": 1260,
"text": "\n ",
"text": "{\n ",
"logprob": -0.07027565,
"special": false
},
{
"id": 10519,
"text": " throw",
"text": " thro",
"logprob": -1.3915029,
"special": false
},
{
"id": 2084,
"text": " new",
"text": "w ne",
"logprob": -0.04201372,
"special": false
},
{
"id": 150858,
"text": " RuntimeException",
"text": "w RuntimeException",
"logprob": -1.7329919,
"special": false
}
Expand Down
2 changes: 1 addition & 1 deletion launcher/tests/mt0_base.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"tokens": [
{
"id": 259,
"text": " ",
"text": "",
"logprob": -1.3656927,
"special": false
},
Expand Down
67 changes: 32 additions & 35 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ service TextGenerationService {
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Returns some model metadata
rpc ModelInfo (ModelInfoRequest) returns (ModelInfoResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
Expand All @@ -27,6 +29,20 @@ message ClearCacheRequest {}
/// Empty response
message ClearCacheResponse {}

/// Empty request
message ModelInfoRequest {}

message ModelInfoResponse {
enum ModelType {
CAUSAL_LM = 0;
SEQ2SEQ_LM = 1;
}

ModelType model_type = 1;
uint32 eos_token = 2;
bool skip_special_tokens = 3;
}

message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
Expand All @@ -46,22 +62,15 @@ message NextTokenChooserParameters {
bool watermark = 8;
}

message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
}

message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 3;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 4;
/// Maximum number of generated tokens - for preallocation sizing
uint32 max_new_tokens = 4;
}

message Batch {
Expand All @@ -73,23 +82,6 @@ message Batch {
uint32 size = 3;
}

enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}

message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}

message PrefillTokens {
/// Prefill Token IDs
repeated uint32 ids = 1;
Expand All @@ -108,12 +100,8 @@ message Generation {
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text
GeneratedText generated_text = 7;
}

message PrefillRequest {
Expand All @@ -124,18 +112,27 @@ message PrefillRequest {
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
}

message RequestsStatus {
/// Ids of finished requests, if any
repeated uint64 completed_ids = 3;
}

message CachedBatch {
uint64 batch_id = 1;
/// If absent, batch is finished
optional RequestsStatus status = 2;
}

message DecodeRequest {
/// Cached batches
repeated Batch batches = 1;
repeated CachedBatch batches = 1;
}

message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional Batch batch = 2;
/// Next batch (cached) - unset if batch is completed
optional uint64 batch_id = 2;
}
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
unicode-segmentation = "1.10.1"
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

26 changes: 19 additions & 7 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use crate::{ClientError, Result};
use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
use crate::pb::generate::v1::model_info_response::ModelType;

/// Text Generation Inference gRPC client
#[derive(Clone)]
Expand Down Expand Up @@ -67,23 +68,34 @@ impl Client {
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
pub async fn prefill(&mut self, batch: Batch) -> Result<Vec<Generation>> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok(response.generations)
}

/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
#[instrument(skip_all, fields(size = _size))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
batches: Vec<CachedBatch>,
_size: u32,
) -> Result<(Vec<Generation>, Option<u64>)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok((response.generations, response.batch_id))
}

/// Get shard model info
#[instrument(skip(self))]
pub async fn model_info(&mut self) -> Result<(ModelType, u32, bool)> {
let request = tonic::Request::new(ModelInfoRequest {}).inject_context();
let response = self.stub.model_info(request).await?.into_inner();
ModelType::from_i32(response.model_type)
.map(|mt| (mt, response.eos_token, response.skip_special_tokens))
.ok_or(ClientError::Generation("Unrecognized model type".to_string()))
}
}
3 changes: 1 addition & 2 deletions router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ mod sharded_client;

pub use client::Client;
pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters,
Batch, CachedBatch, Generation, NextTokenChooserParameters, PrefillTokens, Request, RequestsStatus
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
Expand Down
Loading