Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpp backend logging and load model #1814

Closed
wants to merge 12 commits into from
6 changes: 4 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ set(CMAKE_CXX_STANDARD_REQUIRED True)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -W -Wall -Wextra")

# set(CMAKE_BUILD_TYPE Debug)
find_package(Boost REQUIRED)
find_package(folly REQUIRED)
find_package(fmt REQUIRED)
find_package(gflags REQUIRED)
find_package(glog REQUIRED)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

Expand All @@ -27,4 +27,6 @@ set(FOLLY_LIBRARIES Folly::folly)
add_subdirectory(src/utils)
add_subdirectory(src/backends)
add_subdirectory(src/examples)
add_subdirectory(test)
add_subdirectory(test)

FILE(COPY src/resources/logging.config DESTINATION "${CMAKE_INSTALL_PREFIX}/resources")
4 changes: 1 addition & 3 deletions cpp/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function install_dependencies_linux() {
flex \
bison \
libgflags-dev \
libgoogle-glog-dev \
google-mock \
libkrb5-dev \
libsasl2-dev \
libnuma-dev \
Expand Down Expand Up @@ -61,7 +61,6 @@ function install_dependencies_mac() {
boost \
double-conversion \
gflags \
glog \
gperf \
libevent \
lz4 \
Expand All @@ -75,7 +74,6 @@ function install_dependencies_mac() {
boost \
double-conversion \
gflags \
glog \
gperf \
libevent \
lz4 \
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(TS_BACKENDS_TORCH_DEPLOY_SRC_DIR "${TS_BACKENDS_SRC_DIR}/torch_deploy")
# build library TS_BACKENDS_protocol
set(TS_BACKENDS_PROTOCOL_SOURCE_FILES "")
list(APPEND TS_BACKENDS_PROTOCOL_SOURCE_FILES ${TS_BACKENDS_PROTOCOL_SRC_DIR}/otf_message.cc)
list(APPEND TS_BACKENDS_PROTOCOL_SOURCE_FILES ${TS_BACKENDS_PROTOCOL_SRC_DIR}/socket.cc)
add_library(ts_backends_protocol STATIC ${TS_BACKENDS_PROTOCOL_SOURCE_FILES})
target_include_directories(ts_backends_protocol PUBLIC ${TS_BACKENDS_PROTOCOL_SRC_DIR})
target_link_libraries(ts_backends_protocol PRIVATE ts_utils ${FOLLY_LIBRARIES})
Expand Down Expand Up @@ -48,4 +49,4 @@ target_include_directories(model_worker_socket PRIVATE
)
target_link_libraries(model_worker_socket
PRIVATE ts_backends_core ts_backends_protocol ts_backends_torch_scripted ${FOLLY_LIBRARIES})
install(TARGETS model_worker_socket DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/bin)
install(TARGETS model_worker_socket DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/bin)
64 changes: 35 additions & 29 deletions cpp/src/backends/process/model_worker.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "src/backends/process/model_worker.hh"
#include "src/backends/torch_scripted/torch_scripted_backend.hh"

namespace fs = std::experimental::filesystem;

namespace torchserve {
void SocketServer::Initialize(
const std::string& socket_type,
Expand All @@ -15,56 +17,57 @@ namespace torchserve {
if (socket_type == "unix") {
socket_family = AF_UNIX;
if (socket_name.empty()) {
LOG(FATAL) << "Wrong arguments passed. No socket name given.";
TS_LOG(FATAL, "Wrong arguments passed. No socket name given.");
}
std::filesystem::path s_name_path(socket_name);
if (std::remove(socket_name.c_str()) != 0 && std::filesystem::exists(s_name_path)) {
LOG(FATAL) << "socket already in use: " << socket_name;

fs::path s_name_path(socket_name);
if (std::remove(socket_name.c_str()) != 0 && fs::exists(s_name_path)) {
TS_LOGF(FATAL, "socket already in use: {}", socket_name);
}
socket_name_ = socket_name;
} else if (socket_type == "tcp") {
if (host_addr.empty()) {
socket_name_ = "127.0.0.1";
} else {
socket_name_ = host_addr;
if (port_num.empty())
LOG(FATAL) << "Wrong arguments passed. No socket port given.";
if (port_num.empty()) {
TS_LOG(FATAL, "Wrong arguments passed. No socket port given.");
}
port_ = htons(stoi(port_num));
}
} else {
LOG(FATAL) << "Incomplete data provided";
TS_LOG(FATAL, "Incomplete data provided");
}

LOG(INFO) << "Listening on port: " << socket_name;
TS_LOGF(INFO, "Listening on port: {}", socket_name);
server_socket_ = socket(socket_family, SOCK_STREAM, 0);
if (server_socket_ == -1) {
LOG(FATAL) << "Failed to create socket descriptor. errno: " << errno;
TS_LOGF(FATAL, "Failed to create socket descriptor. errno: {}", errno);
}

if (!CreateBackend(runtime_type, model_dir)) {
LOG(FATAL) << "Failed to create backend, model_dir: " << model_dir;
TS_LOGF(FATAL, "Failed to create backend, model_dir: {}", model_dir);
}
}

void SocketServer::Run() {
// TODO: Add sock accept timeout
int on = 1;
if (setsockopt(server_socket_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
LOG(FATAL) << "Failed to setsockopt. errno: " << errno;
TS_LOGF(FATAL, "Failed to setsockopt. errno: {}", errno);
}

sockaddr* srv_sock_address, client_sock_address{};
if (socket_type_ == "unix") {
LOG(INFO) << "Binding to unix socket";
TS_LOG(INFO, "Binding to unix socket");
sockaddr_un sock_addr{};
std::memset(&sock_addr, 0, sizeof(sock_addr));
sock_addr.sun_family = AF_UNIX;
std::strcpy(sock_addr.sun_path, socket_name_.c_str());
// TODO: Fix truncation of socket name to 14 chars when casting
srv_sock_address = reinterpret_cast<sockaddr*>(&sock_addr);
} else {
LOG(INFO) << "Binding to udp socket";
TS_LOG(INFO, "Binding to tcp socket");
sockaddr_in sock_addr{};
std::memset(&sock_addr, 0, sizeof(sock_addr));
sock_addr.sin_family = AF_INET;
Expand All @@ -74,22 +77,22 @@ namespace torchserve {
}

if (bind(server_socket_, srv_sock_address, sizeof(*srv_sock_address)) < 0) {
LOG(FATAL) << "Could not bind socket. errno: " << errno;
TS_LOGF(FATAL, "Could not bind socket. errno: {}", errno);
}
if (listen(server_socket_, 1) == -1) {
LOG(FATAL) << "Failed to listen on socket. errno: " << errno;
TS_LOGF(FATAL, "Failed to listen on socket. errno: {}", errno);
}
LOG(INFO) << "Socket bind successful";
LOG(INFO) << "[PID]" << getpid();
LOG(INFO) << "Torchserve worker started.";
TS_LOG(INFO, "Socket bind successful");
TS_LOGF(INFO, "[PID] {}", getpid());
TS_LOG(INFO, "Torchserve worker started.");

while (true) {
socklen_t len = sizeof(client_sock_address);
auto client_sock = accept(server_socket_, (sockaddr *)&client_sock_address, &len);
if (client_sock < 0) {
LOG(FATAL) << "Failed to accept client. errno: " << errno;
TS_LOGF(FATAL, "Failed to accept client. errno: {}", errno);
}
LOG(INFO) << "Connection accepted: " << socket_name_;
TS_LOGF(INFO, "Connection accepted: {}", socket_name_);
auto model_worker = std::make_unique<torchserve::SocketModelWorker>(client_sock, backend_);
model_worker->Run();
}
Expand All @@ -98,32 +101,35 @@ namespace torchserve {
bool SocketServer::CreateBackend(
const torchserve::Manifest::RuntimeType& runtime_type,
const std::string& model_dir) {
if (runtime_type == "LDP") {
if (runtime_type == "LSP") {
backend_ = std::make_shared<torchserve::torchscripted::Backend>();
return backend_->Initialize(model_dir);
}
return false;
}

[[noreturn]] void SocketModelWorker::Run() {
LOG(INFO) << "Handle connection";
TS_LOG(INFO, "Handle connection");
while (true) {
char cmd = torchserve::OTFMessage::RetrieveCmd(client_socket_);

if (cmd == 'I') {
LOG(INFO) << "INFER request received";
TS_LOG(INFO, "INFER request received");
auto model_instance = backend_->GetModelInstance();
if (!model_instance) {
LOG(ERROR) << "Model is not loaded yet, not able to process this inference request.";
TS_LOG(ERROR, "Model is not loaded yet, not able to process this inference request.");
} else {
//auto response = model_instance->Predict(torchserve::OTFMessage::RetrieveInferenceMsg(client_socket_));
}
} else if (cmd == 'L') {
LOG(INFO) << "LOAD request received";
TS_LOG(INFO, "LOAD request received");
// TODO: error handling
auto response = backend_->LoadModel(torchserve::OTFMessage::RetrieveLoadMsg(client_socket_));
auto backend_response = backend_->LoadModel(torchserve::OTFMessage::RetrieveLoadMsg(client_socket_));
if (!torchserve::OTFMessage::SendLoadModelResponse(client_socket_, std::move(backend_response))) {
TS_LOG(ERROR, "Error writing response to socket");
}
} else {
LOG(ERROR) << "Received unknown command: " << cmd;
TS_LOGF(ERROR, "Received unknown command: {}", cmd);
}
}
}
Expand Down
14 changes: 5 additions & 9 deletions cpp/src/backends/process/model_worker.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

#include <arpa/inet.h>
#include <cstdio>
#include <filesystem>
#include <glog/logging.h>
#include <experimental/filesystem>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
Expand All @@ -16,6 +15,7 @@
#include "src/backends/core/backend.hh"
#include "src/backends/protocol/otf_message.hh"
#include "src/utils/config.hh"
#include "src/utils/logging.hh"
#include "src/utils/model_archive.hh"

namespace torchserve {
Expand Down Expand Up @@ -51,7 +51,7 @@ namespace torchserve {
// TODO; impl.
//short MAX_FAILURE_THRESHOLD = 5;
//float SOCKET_ACCEPT_TIMEOUT = 30.0f;
Socket server_socket_;
int server_socket_;
std::string socket_type_;
std::string socket_name_;
ushort port_;
Expand All @@ -60,13 +60,9 @@ namespace torchserve {

class SocketModelWorker {
public:
SocketModelWorker(Socket client_socket, std::shared_ptr<torchserve::Backend> backend) :
SocketModelWorker(int client_socket, std::shared_ptr<torchserve::Backend> backend) :
client_socket_(client_socket), backend_(backend) {};
~SocketModelWorker() {
if (client_socket_ >= 0) {
close(client_socket_);
}
};
~SocketModelWorker() = default;

[[noreturn]] void Run();

Expand Down
9 changes: 3 additions & 6 deletions cpp/src/backends/process/model_worker_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ DEFINE_string(runtime_type, "LSP", "model runtime type");
DEFINE_string(device_type, "cpu", "cpu, or gpu");
// TODO: discuss multiple backends support
DEFINE_string(model_dir, "", "model path");
// TODO: change to file based config
DEFINE_string(logger_config_path, "./_build/resources/logging.config", "Logging config file path");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

// Init logging
google::InitGoogleLogging("ts_cpp_backend");
FLAGS_logtostderr = 1;
// TODO: Set logging format same as python worker
LOG(INFO) << "Initializing Libtorch backend worker...";
torchserve::Logger::InitLogger(FLAGS_logger_config_path);

torchserve::SocketServer server = torchserve::SocketServer::GetInstance();

Expand Down
16 changes: 16 additions & 0 deletions cpp/src/backends/protocol/isocket.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TS_CPP_BACKENDS_PROTOCOL_ISOCKET_HH_
#define TS_CPP_BACKENDS_PROTOCOL_ISOCKET_HH_

#include <cstddef>

namespace torchserve {
class ISocket {
public:
virtual ~ISocket() {}
virtual bool SendAll(size_t length, char *data) const = 0;
virtual void RetrieveBuffer(size_t length, char *data) const = 0;
virtual int RetrieveInt() const = 0;
virtual bool RetrieveBool() const = 0;
};
} // namespace torchserve
#endif // TS_CPP_BACKENDS_PROTOCOL_ISOCKET_HH_
Loading