Skip to content

Commit

Permalink
Add mbart support in triton fastertransformer (#21)
Browse files Browse the repository at this point in the history
* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit
  • Loading branch information
sfc-gh-zhwang authored Oct 4, 2023
1 parent bf3fa27 commit 3336e68
Show file tree
Hide file tree
Showing 14 changed files with 422 additions and 92 deletions.
8 changes: 5 additions & 3 deletions examples/cpp/bart/bart_triton_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
ft::FT_CHECK(false);
}

const size_t request_batch_size = reader.GetInteger("request", "request_batch_size");
const size_t request_batch_size = 1; //reader.GetInteger("request", "request_batch_size");

const int start_id = reader.GetInteger("decoder", "start_id");
const int end_id = reader.GetInteger("decoder", "end_id");
Expand All @@ -251,6 +251,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std

RequestParam param;
param.beam_width = reader.GetInteger("request", "beam_width");
// param.beam_width = 5;
param.request_output_len = reader.GetInteger("request", "request_output_len");
param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate");
param.runtime_top_k = reader.GetInteger("request", "top_k");
Expand All @@ -261,7 +262,7 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f);
param.min_length = reader.GetInteger("request", "min_length", 0);
param.random_seed = (unsigned long long int)0;
param.start_id = start_id;
param.start_id = 250025;
param.end_id = end_id;

auto request_list =
Expand Down Expand Up @@ -381,10 +382,11 @@ int main(int argc, char* argv[])
}

const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data;
const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0];
const int batch_size = 1; // output_tensors_lists[0].get()->at("output_ids").shape[0];
const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1];
const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2];
const int* d_input_lengths = (const int*)output_tensors_lists[0].get()->at("input_sequence_lengths").data;
printf("batch_size: %d beam_width: %d seq_len: %d\n", batch_size, beam_width, seq_len);
// step 6: check results
if (node_id == 0) {

Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/bart/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repetition_penalty=1.0 ; Use for sampling
presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed.
len_penalty=0.0
beam_search_diversity_rate=0.0
request_batch_size=8 # determine by the request
request_batch_size=1 # determine by the request
request_output_len=32 # determine by the request

[encoder]
Expand Down
3 changes: 1 addition & 2 deletions examples/cpp/bart/start_ids.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
0, 4154, 1231, 15674, 345, 1534, 440, 50264, 11, 1854, 2
0, 4154, 1231, 15674, 345, 1534, 440, 50264, 11, 1854, 2
250004, 35378, 4, 765, 398, 49782, 111, 76935, 13034, 350, 32, 2
2 changes: 2 additions & 0 deletions examples/cpp/multi_gpu_gpt/gpt_example_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ int read_start_ids(size_t batch_size,
int i1 = 0;
std::vector<int> tmp_vec;
while (std::getline(lineStream, vals, ',')) {
printf("vals: %s\n", vals.c_str());
tmp_vec.push_back(std::stoi(vals));
i1++;
}
Expand Down Expand Up @@ -88,6 +89,7 @@ int read_start_ids(size_t batch_size,
for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) {
v_start_ids->push_back(tmp_start_ids[i][j]);
}
printf("tmp_start_lengths[i]: %d\n", tmp_start_lengths[i]);
v_start_lengths->push_back(tmp_start_lengths[i]);
}
}
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/bart/utils/ft_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn as nn
import torch.distributed as dist
import numpy as np
from transformers import MBartForConditionalGeneration, BartModel

class FTBartEncoderWeight(object):
def __init__(
Expand Down Expand Up @@ -246,4 +247,4 @@ def __init__(self, encoder_weight_list, lib_path, head_num, head_size, inter_siz

def forward(self, input, seq_len, inputs_embeds=None):
output = self.encoder.forward(input, seq_len, inputs_embeds)
return output
return output
Loading

0 comments on commit 3336e68

Please sign in to comment.