Skip to content

Commit

Permalink
modify code based on the code review feedback and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jenrryyou committed Oct 1, 2024
1 parent 69f5380 commit c881856
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 26 deletions.
29 changes: 28 additions & 1 deletion example/streaming_batch_echo_c++/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ DEFINE_string(server, "0.0.0.0:8001", "IP Address of server");
DEFINE_int32(timeout_ms, 100, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");

class StreamClientReceiver : public brpc::StreamInputHandler {
public:
virtual int on_received_messages(brpc::StreamId id,
butil::IOBuf *const messages[],
size_t size) {
std::ostringstream os;
for (size_t i = 0; i < size; ++i) {
os << "msg[" << i << "]=" << *messages[i];
}
LOG(INFO) << "Received from Stream=" << id << ": " << os.str();
return 0;
}
virtual void on_idle_timeout(brpc::StreamId id) {
LOG(INFO) << "Stream=" << id << " has no data transmission for a while";
}
virtual void on_closed(brpc::StreamId id) {
LOG(INFO) << "Stream=" << id << " is closed";
}

virtual void on_finished(brpc::StreamId id, int32_t finish_code) {
LOG(INFO) << "Stream=" << id << " is finished, code " << finish_code;
}
};

int main(int argc, char* argv[]) {
// Parse gflags. We recommend you to use gflags as well.
GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -51,9 +75,12 @@ int main(int argc, char* argv[]) {
// Normally, you should not call a Channel directly, but instead construct
// a stub Service wrapping it. stub can be shared by all threads as well.
example::EchoService_Stub stub(&channel);
StreamClientReceiver receiver;
brpc::Controller cntl;
brpc::StreamIds streams;
if (brpc::StreamCreate(streams, 3, cntl, NULL) != 0) {
brpc::StreamOptions stream_options;
stream_options.handler = &receiver;
if (brpc::StreamCreate(streams, 3, cntl, &stream_options) != 0) {
LOG(ERROR) << "Fail to create stream";
return -1;
}
Expand Down
15 changes: 10 additions & 5 deletions example/streaming_batch_echo_c++/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class StreamReceiver : public brpc::StreamInputHandler {
for (size_t i = 0; i < size; ++i) {
os << "msg[" << i << "]=" << *messages[i];
}
LOG(INFO) << "Received from Stream=" << id << ": " << os.str();
auto res = brpc::StreamWrite(id, *messages[0]);
LOG(INFO) << "Received from Stream=" << id << ": " << os.str() << " and write back result: " << res;
return 0;
}
virtual void on_idle_timeout(brpc::StreamId id) {
Expand All @@ -56,9 +57,7 @@ class StreamReceiver : public brpc::StreamInputHandler {
class StreamingBatchEchoService : public example::EchoService {
public:
virtual ~StreamingBatchEchoService() {
brpc::StreamClose(_sds[0]);
brpc::StreamClose(_sds[1]);
brpc::StreamClose(_sds[2]);
closeStreams();
};
virtual void Echo(google::protobuf::RpcController* controller,
const example::EchoRequest* /*request*/,
Expand All @@ -67,7 +66,7 @@ class StreamingBatchEchoService : public example::EchoService {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);

closeStreams();
brpc::Controller* cntl =
static_cast<brpc::Controller*>(controller);
brpc::StreamOptions stream_options;
Expand All @@ -80,6 +79,12 @@ class StreamingBatchEchoService : public example::EchoService {
}

private:
void closeStreams() {
for(auto i = 0; i < _sds.size(); ++i) {
brpc::StreamClose(_sds[i]);
}
_sds.clear();
}
StreamReceiver _receiver;
brpc::StreamIds _sds;
};
Expand Down
30 changes: 21 additions & 9 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,10 @@ void SendRpcResponse(int64_t correlation_id,
if (attached_size > 0) {
meta.set_attachment_size(attached_size);
}
StreamId response_stream_id = INVALID_STREAM_ID;
SocketUniquePtr stream_ptr;
if (!response_stream_ids.empty()) {
StreamId response_stream_id = response_stream_ids[0];
response_stream_id = response_stream_ids[0];
if (Socket::Address(response_stream_id, &stream_ptr) == 0) {
Stream* s = (Stream *) stream_ptr->conn();
StreamSettings *stream_settings = meta.mutable_stream_settings();
Expand Down Expand Up @@ -309,24 +310,35 @@ void SendRpcResponse(int64_t correlation_id,
// Response_stream can be INVALID_STREAM_ID when error occurs.
if (SendStreamData(sock, &res_buf,
accessor.remote_stream_settings()->stream_id(),
accessor.response_streams()[0]) != 0) {
response_stream_id) != 0) {
const int errcode = errno;
std::string error_text = butil::string_printf(64, "Fail to write into %s",
sock->description().c_str());
PLOG_IF(WARNING, errcode != EPIPE) << error_text;
cntl->SetFailed(errcode, "%s", error_text.c_str());
if(stream_ptr) {
((Stream*)stream_ptr->conn())->Close(errcode, "%s",
error_text.c_str());
}
Stream::SetFailed(response_stream_ids, errcode, "%s",
error_text.c_str());
return;
}

// Now it's ok the mark these server-side streams as connected as all the
// written user data would follower the RPC response.
// Reuse stream_ptr to avoid address first stream id again
if(stream_ptr) {
// Now it's ok the mark this server-side stream as connected as all the
// written user data would follower the RPC response.
((Stream*)stream_ptr->conn())->SetConnected();
}
for (size_t i = 1; i < response_stream_ids.size(); ++i) {
StreamId extra_stream_id = response_stream_ids[i];
SocketUniquePtr extra_stream_ptr;
if (Socket::Address(extra_stream_id, &extra_stream_ptr) == 0) {
Stream* extra_stream = (Stream *) extra_stream_ptr->conn();
extra_stream->SetHostSocket(sock);
extra_stream->SetConnected();
} else {
LOG(WARNING) << "Stream=" << extra_stream_id
<< " was closed before sending response";
}
}
} else{
// Have the risk of unlimited pending responses, in which case, tell
// users to set max_concurrency.
Expand Down Expand Up @@ -722,7 +734,7 @@ void ProcessRpcResponse(InputMessageBase* msg_base) {
LOG_IF(ERROR, rc != EINVAL && rc != EPERM)
<< "Fail to lock correlation_id=" << cid << ": " << berror(rc);
if (remote_stream_id != INVALID_STREAM_ID) {
SendStreamRst(msg->socket(), meta.stream_settings().stream_id());
SendStreamRst(msg->socket(), remote_stream_id);
const auto & extra_stream_ids = meta.stream_settings().extra_stream_ids();
for (int i = 0; i < extra_stream_ids.size(); ++i) {
policy::SendStreamRst(msg->socket(), extra_stream_ids[i]);
Expand Down
43 changes: 32 additions & 11 deletions src/brpc/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ int StreamCreate(StreamIds& request_streams, int request_stream_size, Controller
}
for (auto i = 0; i < request_stream_size; ++i) {
StreamId stream_id;
if (Stream::Create(opt, NULL, &stream_id) != 0) {
bool parse_rpc_response = i == 0; // Only the first stream need parse rpc response
if (Stream::Create(opt, NULL, &stream_id, parse_rpc_response) != 0) {
// Close already created streams
Stream::SetFailed(request_streams, 0 , "Fail to create stream at %d index", i);
LOG(ERROR) << "Fail to create stream";
Expand All @@ -821,8 +822,10 @@ int StreamAccept(StreamId* response_stream, Controller &cntl,
return res;
}
if(response_streams.size() != 1) {
Stream::SetFailed(response_streams, 0, "Logic error");
LOG(ERROR) << "accept more than one response_stream";
Stream::SetFailed(response_streams, EINVAL,
"misusing StreamAccept for single stream to accept multiple streams");
cntl._response_streams.clear();
LOG(ERROR) << "misusing StreamAccept for single stream to accept multiple streams";
return -1;
}
*response_stream = response_streams[0];
Expand All @@ -848,15 +851,33 @@ int StreamAccept(StreamIds& response_streams, Controller& cntl,
if (options != NULL) {
opt = *options;
}
for (auto i = 0; i <= cntl._remote_stream_settings->extra_stream_ids_size(); ++i) {
StreamId stream_id;
if (Stream::Create(opt, cntl._remote_stream_settings, &stream_id, false) != 0) {
Stream::SetFailed(response_streams, 0, "Fail to accept stream at %d index", i);
LOG(ERROR) << "Fail to accept stream";
return -1;
StreamId stream_id;
if (Stream::Create(opt, cntl._remote_stream_settings, &stream_id, false) != 0) {
Stream::SetFailed(response_streams, 0, "Fail to accept stream");
LOG(ERROR) << "Fail to accept stream";
return -1;
}

cntl._response_streams.push_back(stream_id);
response_streams.push_back(stream_id);
if(!cntl._remote_stream_settings->extra_stream_ids().empty()) {
StreamSettings stream_remote_settings;
stream_remote_settings.MergeFrom(*cntl._remote_stream_settings);
//Only the first stream needs extra_stream_ids settings
stream_remote_settings.clear_extra_stream_ids();
for (auto i = 0; i < cntl._remote_stream_settings->extra_stream_ids_size(); ++i) {
stream_remote_settings.set_stream_id(cntl._remote_stream_settings->extra_stream_ids()[i]);
StreamId extra_stream_id;
if (Stream::Create(opt, &stream_remote_settings, &extra_stream_id, false) != 0) {
Stream::SetFailed(response_streams, 0, "Fail to accept stream at %d index", i);
cntl._response_streams.clear();
response_streams.clear();
LOG(ERROR) << "Fail to accept stream";
return -1;
}
cntl._response_streams.push_back(extra_stream_id);
response_streams.push_back(extra_stream_id);
}
cntl._response_streams.push_back(stream_id);
response_streams.push_back(stream_id);
}

return 0;
Expand Down

0 comments on commit c881856

Please sign in to comment.