diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml index 2c04d646..48b47146 100644 --- a/.github/workflows/check.yaml +++ b/.github/workflows/check.yaml @@ -25,17 +25,23 @@ jobs: run: docker build . macos: - runs-on: macos-13 + runs-on: macos-latest steps: - name: depends - run: brew install spdlog poco + run: brew install openssl fmt spdlog - name: checkout uses: actions/checkout@v4 - name: build run: | - export CPATH=/usr/local/include - export LIBRARY_PATH=/usr/local/lib - cmake -B build -DCANDY_STATIC_POCO=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build + if [ "$RUNNER_ARCH" == "ARM64" ]; then + export CPATH=/opt/homebrew/include + export LIBRARY_PATH=/opt/homebrew/lib + else + export CPATH=/usr/local/include + export LIBRARY_PATH=/usr/local/lib + fi + cmake -B build -DCANDY_STATIC_POCO=1 -DCMAKE_BUILD_TYPE=Release + cmake --build build windows: runs-on: windows-latest diff --git a/CMakeLists.txt b/CMakeLists.txt index 16581429..ce324874 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,9 @@ set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) if (${CANDY_STATIC}) set(CMAKE_SKIP_BUILD_RPATH TRUE) +if (${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -static") +endif() set(CANDY_STATIC_OPENSSL 1) set(CANDY_STATIC_FMT 1) set(CANDY_STATIC_SPDLOG 1) diff --git a/src/cffi/CMakeLists.txt b/src/cffi/CMakeLists.txt index 7f9f7159..70277217 100644 --- a/src/cffi/CMakeLists.txt +++ b/src/cffi/CMakeLists.txt @@ -14,14 +14,16 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(${CANDY_LIBRARY_NAME} PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(${CANDY_LIBRARY_NAME} PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(${CANDY_LIBRARY_NAME} PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(${CANDY_LIBRARY_NAME} PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/cffi/candy.cc b/src/cffi/candy.cc index bbd4b202..49e778bb 100644 --- a/src/cffi/candy.cc +++ b/src/cffi/candy.cc @@ -18,82 +18,82 @@ void candy_client_release(void *candy) { delete c; } -int candy_client_set_name(void *candy, const char *name) { +void candy_client_set_name(void *candy, const char *name) { Candy::Client *c = static_cast(candy); - return c->setName(name); + c->setName(name); } -int candy_client_set_password(void *candy, const char *password) { +void candy_client_set_password(void *candy, const char *password) { Candy::Client *c = static_cast(candy); - return c->setPassword(password); + c->setPassword(password); } -int candy_client_set_websocket_server(void *candy, const char *server) { +void candy_client_set_websocket(void *candy, const char *server) { Candy::Client *c = static_cast(candy); - return c->setWebSocketServer(server); + c->setWebSocket(server); } -int candy_client_set_tun_address(void *candy, const char *cidr) { +void candy_client_set_tun_address(void *candy, const char *cidr) { Candy::Client *c = static_cast(candy); - return c->setTunAddress(cidr); + c->setTunAddress(cidr); } -int candy_client_set_expected_address(void *candy, const char *cidr) { +void candy_client_set_expt_tun_address(void *candy, const char *cidr) { Candy::Client *c = static_cast(candy); - return c->setExpectedAddress(cidr); + c->setExptTunAddress(cidr); } -int candy_client_set_virtual_mac(void *candy, const char *vmac) { +void candy_client_set_virtual_mac(void *candy, const char *vmac) { Candy::Client *c = static_cast(candy); - return c->setVirtualMac(vmac); + c->setVirtualMac(vmac); } -int candy_client_set_stun(void *candy, const char *stun) { +void candy_client_set_stun(void *candy, const char *stun) { Candy::Client *c = static_cast(candy); - return c->setStun(stun); + c->setStun(stun); } -int candy_client_set_discovery_interval(void *candy, int interval) { +void candy_client_set_discovery_interval(void *candy, int interval) { Candy::Client *c = static_cast(candy); - return c->setDiscoveryInterval(interval); + c->setDiscoveryInterval(interval); } -int candy_client_set_route_cost(void *candy, int cost) { +void candy_client_set_route_cost(void *candy, int cost) { Candy::Client *c = static_cast(candy); - return c->setRouteCost(cost); + c->setRouteCost(cost); } -int candy_client_set_mtu(void *candy, int mtu) { +void candy_client_set_mtu(void *candy, int mtu) { Candy::Client *c = static_cast(candy); - return c->setMtu(mtu); + c->setMtu(mtu); } -int candy_client_set_address_update_callback(void *candy, void (*callback)(const char *, const char *)) { +void candy_client_set_tun_update_callback(void *candy, void (*callback)(const char *, const char *)) { Candy::Client *c = static_cast(candy); - return c->setAddressUpdateCallback([=](const std::string &address) { + return c->setTunUpdateCallback([=](const std::string &address) { callback(c->getName().c_str(), address.c_str()); return 0; }); } -int candy_client_set_udp_bind_port(void *candy, int port) { +void candy_client_set_port(void *candy, int port) { Candy::Client *c = static_cast(candy); - return c->setUdpBindPort(port); + c->setPort(port); } -int candy_client_set_localhost(void *candy, const char *ip) { +void candy_client_set_localhost(void *candy, const char *ip) { Candy::Client *c = static_cast(candy); - return c->setLocalhost(ip); + c->setLocalhost(ip); } -int candy_client_run(void *candy) { +void candy_client_run(void *candy) { Candy::Client *c = static_cast(candy); - return c->run(); + c->run(); } -int candy_client_shutdown(void *candy) { +void candy_client_shutdown(void *candy) { Candy::Client *c = static_cast(candy); - return c->shutdown(); + c->shutdown(); } namespace { @@ -110,13 +110,12 @@ void shutdown(Client *c) { } } // namespace Candy -int candy_client_set_error_cb(void (*callback)(void *)) { +void candy_client_set_error_cb(void (*callback)(void *)) { client_error_cb = callback; - return 0; } void candy_use_system_time() { - Candy::Time::useSystemTime = true; + Candy::useSystemTime = true; } void candy_set_log_path(const char *path) { diff --git a/src/cffi/candy.h b/src/cffi/candy.h index 162abcd6..ec965f89 100644 --- a/src/cffi/candy.h +++ b/src/cffi/candy.h @@ -8,22 +8,22 @@ extern "C" { void candy_init(); void *candy_client_create(); -int candy_client_set_name(void *candy, const char *name); -int candy_client_set_password(void *candy, const char *password); -int candy_client_set_websocket_server(void *candy, const char *server); -int candy_client_set_tun_address(void *candy, const char *cidr); -int candy_client_set_expected_address(void *candy, const char *cidr); -int candy_client_set_virtual_mac(void *candy, const char *vmac); -int candy_client_set_stun(void *candy, const char *stun); -int candy_client_set_discovery_interval(void *candy, int interval); -int candy_client_set_route_cost(void *candy, int cost); -int candy_client_set_mtu(void *candy, int mtu); -int candy_client_set_udp_bind_port(void *candy, int port); -int candy_client_set_localhost(void *candy, const char *ip); -int candy_client_set_address_update_callback(void *candy, void (*callback)(const char *, const char *)); -int candy_client_set_error_cb(void (*callback)(void *)); -int candy_client_run(void *candy); -int candy_client_shutdown(void *candy); +void candy_client_set_name(void *candy, const char *name); +void candy_client_set_password(void *candy, const char *password); +void candy_client_set_websocket(void *candy, const char *server); +void candy_client_set_tun_address(void *candy, const char *cidr); +void candy_client_set_expt_tun_address(void *candy, const char *cidr); +void candy_client_set_virtual_mac(void *candy, const char *vmac); +void candy_client_set_stun(void *candy, const char *stun); +void candy_client_set_discovery_interval(void *candy, int interval); +void candy_client_set_route_cost(void *candy, int cost); +void candy_client_set_mtu(void *candy, int mtu); +void candy_client_set_port(void *candy, int port); +void candy_client_set_localhost(void *candy, const char *ip); +void candy_client_set_address_update_callback(void *candy, void (*callback)(const char *, const char *)); +void candy_client_set_error_cb(void (*callback)(void *)); +void candy_client_run(void *candy); +void candy_client_shutdown(void *candy); void candy_client_release(void *candy); void candy_use_system_time(); void candy_set_log_path(const char *path); diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 02885fca..616b90bf 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -7,24 +7,26 @@ if (${CANDY_STATIC_FMT}) target_link_libraries(core PRIVATE fmt::fmt) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED fmt) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(core PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(FMT REQUIRED fmt) + add_definitions(${FMT_CFLAGS}) + include_directories(${FMT_INCLUDEDIR}) + target_link_libraries(core PRIVATE ${FMT_LIBRARIES}) endif() if (${CANDY_STATIC_SPDLOG}) target_link_libraries(core PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(core PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(core PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(core PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/core/client.cc b/src/core/client.cc index 332cc2c1..fdf0827c 100644 --- a/src/core/client.cc +++ b/src/core/client.cc @@ -1,1598 +1,99 @@ // SPDX-License-Identifier: MIT #include "core/client.h" -#include "core/common.h" #include "core/message.h" -#include "utility/address.h" -#include "utility/time.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace { - -constexpr std::size_t AES_256_GCM_IV_LEN = 12; -constexpr std::size_t AES_256_GCM_TAG_LEN = 16; -constexpr std::size_t AES_256_GCM_KEY_LEN = 32; - -} // namespace +#include namespace Candy { -// Public -int Client::setName(const std::string &name) { - this->tunName = name; - return 0; -} - -std::string Client::getName() const { - return this->tunName; -} - -int Client::setWorkers(int number) { - number = std::min(number, int(Poco::Environment::processorCount())); - number = std::max(number, 0); - this->workers = number; - if (this->workers) { - spdlog::debug("workers: {}", this->workers); - } - return 0; -} - -int Client::setPassword(const std::string &password) { - this->password = password; - return 0; -} - -int Client::setWebSocketServer(const std::string &uri) { - try { - Poco::URI parser(uri); - if (parser.getScheme() != "ws" && parser.getScheme() != "wss") { - spdlog::critical("invalid websocket scheme {}", parser.getScheme()); - return -1; - } - this->wsUri = uri; - return 0; - } catch (std::exception &e) { - spdlog::critical("client websocket server parser failed: {}", e.what()); - return -1; - } -} - -int Client::setTunAddress(const std::string &cidr) { - this->tunAddress = cidr; - this->realAddress = cidr; - return 0; -} - -int Client::setExpectedAddress(const std::string &cidr) { - this->expectedAddress = cidr; - return 0; -} - -int Client::setVirtualMac(const std::string &vmac) { - if (vmac.length() != 16) { - Candy::shutdown(this); - return -1; - } - this->virtualMac = vmac; - return 0; -} -int Client::setStun(const std::string &stun) { - this->stun.uri = stun; - return 0; -} - -int Client::setDiscoveryInterval(int interval) { - this->discoveryInterval = interval; - return 0; -} - -int Client::setRouteCost(int cost) { - if (cost < 0) { - this->routeCost = 0; - } else if (cost > 1000) { - this->routeCost = 1000; - } else { - this->routeCost = cost; - } - return 0; -} - -int Client::setAddressUpdateCallback(std::function callback) { - this->addressUpdateCallback = callback; - return 0; -} - -int Client::setUdpBindPort(int port) { - if (port > 0 && port < UINT16_MAX) { - this->udpHolder.setPort(port); - } - return 0; -} - -int Client::setLocalhost(std::string ip) { - if (ip.empty()) { - return 0; +Msg MsgQueue::read() { + std::unique_lock lock(msgMutex); + if (!msgCondition.wait_for(lock, std::chrono::seconds(1), [this] { return !msgQueue.empty(); })) { + return Msg(MsgKind::TIMEOUT); } - Address addr; - if (addr.ipStrUpdate(ip)) { - return 0; - } - this->udpHolder.setIP(addr.getIp()); - return 0; -} - -int Client::setMtu(int mtu) { - this->mtu = mtu; - return 0; -} - -int Client::run() { - std::lock_guard lock(this->runningMutex); - this->running = true; - this->localP2PDisabled = false; - this->sysRtTable.clear(); - if (startWsThread()) { - spdlog::critical("start websocket client thread failed"); - Candy::shutdown(this); - return -1; - } - if (startTickThread()) { - spdlog::critical("start tick thread failed"); - Candy::shutdown(this); - return -1; - } - if (startWorkerThreads()) { - spdlog::critical("start worker threads failed"); - Candy::shutdown(this); - return -1; - } - return 0; + Msg msg = std::move(msgQueue.front()); + msgQueue.pop(); + return msg; } -int Client::shutdown() { - std::lock_guard lock(this->runningMutex); - - if (!this->running) { - return 0; - } - - this->running = false; - - if (this->wsThread.joinable()) { - this->wsThread.join(); - } - if (this->tunThread.joinable()) { - this->tunThread.join(); - } - if (this->udpThread.joinable()) { - this->udpThread.join(); - } - if (this->tickThread.joinable()) { - this->tickThread.join(); - } - stopWorkerThreads(); - - this->tun.down(); - this->ws.disconnect(); - return 0; -} - -// Common -int Client::startWorkerThreads() { - for (int i = 0; i < this->workers; ++i) { - this->udpMsgWorkerThreads.emplace_back([&] { - while (this->running) { - UdpMessage message; - { - static const auto timeout = std::chrono::seconds(1); - std::unique_lock lock(this->udpMsgQueueMutex); - if (!this->udpMsgQueueCondition.wait_for(lock, timeout, [this] { return !this->udpMsgQueue.empty(); })) { - continue; - } - message = std::move(this->udpMsgQueue.front()); - this->udpMsgQueue.pop(); - } - handleUdpMessage(std::move(message)); - } - }); - - this->tunMsgWorkerThreads.emplace_back([&] { - while (this->running) { - std::string message; - { - static const auto timeout = std::chrono::seconds(1); - std::unique_lock lock(this->tunMsgQueueMutex); - if (!this->tunMsgQueueCondition.wait_for(lock, timeout, [this] { return !this->tunMsgQueue.empty(); })) { - continue; - } - message = std::move(this->tunMsgQueue.front()); - this->tunMsgQueue.pop(); - } - handleTunMessage(std::move(message)); - } - }); - } - return 0; -} - -int Client::stopWorkerThreads() { - for (std::thread &t : this->udpMsgWorkerThreads) { - if (t.joinable()) { - t.join(); - } - } - { - std::unique_lock lock(this->udpMsgQueueMutex); - while (!this->udpMsgQueue.empty()) { - this->udpMsgQueue.pop(); - } - } - this->udpMsgWorkerThreads.clear(); - - for (std::thread &t : this->tunMsgWorkerThreads) { - if (t.joinable()) { - t.join(); - } - } +void MsgQueue::write(Msg msg) { { - std::unique_lock lock(this->tunMsgQueueMutex); - while (!this->tunMsgQueue.empty()) { - this->tunMsgQueue.pop(); - } - } - this->tunMsgWorkerThreads.clear(); - return 0; -} - -// WebSocket -std::string Client::hostName() { - char hostname[64] = {0}; - if (!gethostname(hostname, sizeof(hostname))) { - return std::string(hostname, strnlen(hostname, sizeof(hostname))); - } - return ""; -} - -int Client::startWsThread() { - if (this->ws.setTimeout(1)) { - spdlog::critical("websocket clinet set read write timeout failed"); - return -1; - } - - if (this->ws.connect(this->wsUri)) { - spdlog::critical("websocket client connect failed"); - return -1; - } - - ws.setPingMessage(fmt::format("candy::{}::{}::{}", CANDY_SYSTEM, CANDY_VERSION, hostName())); - - // 只需要开 wsThread, 执行过程中会设置 tun 并开 tunThread. - this->wsThread = std::thread([&] { - this->handleWebSocketMessage(); - spdlog::debug("websocket client thread exit"); - }); - - sendVirtualMacMessage(); - - if (!this->tunAddress.empty()) { - if (startTunThread()) { - spdlog::critical("start tun thread with static address failed"); - return -1; - } - if (startUdpThread()) { - spdlog::critical("start udp thread failed"); - return -1; - } - } else { - Address address; - if (this->expectedAddress.empty() || address.cidrUpdate(this->expectedAddress)) { - this->expectedAddress = "0.0.0.0/0"; - spdlog::debug("set default expected address"); - } - sendDynamicAddressMessage(); - } - return 0; -} - -void Client::handleWebSocketMessage() { - int error; - WebSocketMessage message; - - while (this->running) { - error = this->ws.read(message); - - if (error == 0) { - continue; - } - if (error < 0) { - spdlog::critical("webSocket client read failed: error {}", error); - Candy::shutdown(this); - break; - } - if (message.type == WebSocketMessageType::Message) { - uint8_t msgType = message.buffer.front(); - switch (msgType) { - // FORWARD, 拆包后转发给 TUN 设备 - case MessageType::FORWARD: - handleForwardMessage(message); - break; - - // 动态地址响应包,启动 TUN 设备并发送 Auth 包 - case MessageType::EXPECTED: - handleExpectedAddressMessage(message); - break; - - // 对端连接请求包 - case MessageType::PEER: - handlePeerConnMessage(message); - break; - - // 主动发现报文 - case MessageType::DISCOVERY: - handleDiscoveryMessage(message); - break; - - // 路由表 - case MessageType::ROUTE: - handleSysRtMessage(message); - break; - - // 通用报文 - case MessageType::GENERAL: - handleGeneralMessage(message); - break; - - default: - spdlog::debug("unknown websocket message: type {}", msgType); - break; - } - } - // 连接断开,可能是地址冲突,触发正常退出进程的流程 - if (message.type == WebSocketMessageType::Close) { - spdlog::info("client websocket close: {}", message.buffer); - Candy::shutdown(this); - break; - } - // 通信出现错误,触发正常退出进程的流程 - if (message.type == WebSocketMessageType::Error) { - spdlog::warn("client websocket error: {}", message.buffer); - Candy::shutdown(this); - break; - } - } - return; -} - -void Client::recvUdpMessage() { - while (this->running) { - UdpMessage message; - int error = this->udpHolder.read(message); - if (error == 0) { - continue; - } - if (error < 0) { - spdlog::critical("udp read failed: error {}", error); - Candy::shutdown(this); - break; - } - if (!this->workers) { - handleUdpMessage(std::move(message)); - continue; - } - { - std::unique_lock lock(this->udpMsgQueueMutex); - this->udpMsgQueue.emplace(std::move(message)); - } - this->udpMsgQueueCondition.notify_one(); - } - - this->udpHolder.reset(); -} - -void Client::handleUdpMessage(UdpMessage message) { - if (isStunResponse(message)) { - handleStunResponse(message.buffer); - return; - } - - message.buffer = decrypt(selfInfo.getKey(), message.buffer); - if (message.buffer.empty()) { - spdlog::debug("invalid peer message: ip {} port {}", Address::ipToStr(message.ip), message.port); - return; - } - - if (isHeartbeatMessage(message)) { - handleHeartbeatMessage(message); - return; - } - if (isPeerForwardMessage(message)) { - handlePeerForwardMessage(message); - return; - } - if (isDelayMessage(message)) { - if (routeCost) { - handleDelayMessage(message); - } - return; - } - if (isCandyRtMessage(message)) { - if (routeCost) { - handleCandyRtMessage(message); - } - return; - } - spdlog::debug("unknown peer message: type {}", int(message.buffer.front())); -} - -void Client::sendForwardMessage(const std::string &buffer) { - WebSocketMessage message; - message.buffer.push_back(MessageType::FORWARD); - message.buffer.append(buffer); - if (this->ws.write(message)) { - spdlog::critical("send forward message failed"); - Candy::shutdown(this); - } -} - -void Client::sendVirtualMacMessage() { - VMacMessage buffer(this->virtualMac); - buffer.updateHash(this->password); - - WebSocketMessage message; - message.buffer.assign((char *)(&buffer), sizeof(buffer)); - if (this->ws.write(message)) { - spdlog::critical("send virtual mac message failed"); - Candy::shutdown(this); - } - return; -} - -void Client::sendDynamicAddressMessage() { - Address address; - if (address.cidrUpdate(this->expectedAddress)) { - spdlog::critical("cannot send invalid expected address"); - Candy::shutdown(this); - return; - } - - ExpectedAddressMessage header(address.getCidr()); - header.updateHash(this->password); - - WebSocketMessage message; - message.buffer.assign((char *)(&header), sizeof(header)); - if (this->ws.write(message)) { - spdlog::critical("send expected address message failed"); - Candy::shutdown(this); - } - return; -} - -void Client::sendAuthMessage() { - Address address; - if (address.cidrUpdate(this->realAddress)) { - spdlog::critical("cannot send invalid auth address"); - Candy::shutdown(this); - return; - } - - AuthHeader header(address.getIp()); - header.updateHash(this->password); - - WebSocketMessage message; - message.buffer.assign((char *)(&header), sizeof(AuthHeader)); - if (this->ws.write(message)) { - spdlog::critical("send auth message failed"); - Candy::shutdown(this); - } - - this->ws.sendPingMessage(); - return; -} - -void Client::sendPeerConnMessage(const PeerInfo &peer) { - PeerConnMessage header; - header.src = Address::hostToNet(this->tun.getIP()); - header.dst = Address::hostToNet(peer.getTun()); - header.ip = Address::hostToNet(this->selfInfo.wide.ip); - header.port = Address::hostToNet(this->selfInfo.wide.port); - - WebSocketMessage message; - message.buffer.assign((char *)(&header), sizeof(PeerConnMessage)); - if (this->ws.write(message)) { - spdlog::critical("send peer conn message failed"); - Candy::shutdown(this); - } - return; -} - -void Client::sendDiscoveryMessage(uint32_t dst) { - DiscoveryMessage header; - - header.src = Address::hostToNet(this->tun.getIP()); - header.dst = Address::hostToNet(dst); - - WebSocketMessage message; - message.buffer.assign((char *)(&header), sizeof(DiscoveryMessage)); - if (this->ws.write(message)) { - spdlog::critical("send discovery conn message failed"); - Candy::shutdown(this); - } - return; -} - -void Client::sendLocalPeerConnMessage(const PeerInfo &peer) { - if (this->localP2PDisabled) { - return; - } - - LocalPeerConnMessage header; - header.ge.subtype = GeSubType::LOCAL_PEER_CONN; - header.ge.extra = 0; - header.ge.src = Address::hostToNet(this->tun.getIP()); - header.ge.dst = Address::hostToNet(peer.getTun()); - header.ip = Address::hostToNet(this->selfInfo.local.ip); - header.port = Address::hostToNet(this->selfInfo.local.port); - - WebSocketMessage message; - message.buffer.assign((char *)(&header), sizeof(LocalPeerConnMessage)); - if (this->ws.write(message)) { - spdlog::critical("send peer conn message failed"); - Candy::shutdown(this); + std::unique_lock lock(this->msgMutex); + msgQueue.push(std::move(msg)); } - return; + msgCondition.notify_one(); } -void Client::handleForwardMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(ForwardHeader)) { - spdlog::warn("invalid forward message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - - const char *src = message.buffer.c_str() + sizeof(ForwardHeader::type); - const size_t len = message.buffer.length() - sizeof(ForwardHeader::type); - - const IPv4Header *header = (const IPv4Header *)src; - if (header->protocol == 0x04) { - this->tun.write(std::string(src + sizeof(IPv4Header), len - sizeof(IPv4Header))); - } else { - this->tun.write(std::string(src, len)); - } - - tryDirectConnection(Address::netToHost(header->saddr)); -} - -void Client::handleExpectedAddressMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(ExpectedAddressMessage)) { - spdlog::warn("invalid expected address message: len {}", message.buffer.length()); - spdlog::debug("expected address buffer: {:n}", spdlog::to_hex(message.buffer)); - return; - } - - ExpectedAddressMessage *header = (ExpectedAddressMessage *)message.buffer.c_str(); - - Address address; - if (address.cidrUpdate(header->cidr)) { - spdlog::warn("invalid expected address ip: cidr {}", header->cidr); - return; - } - - this->realAddress = address.getCidr(); - if (startTunThread()) { - spdlog::critical("start tun thread with expected address failed"); - Candy::shutdown(this); - return; - } - if (startUdpThread()) { - spdlog::critical("start udp thread failed"); - Candy::shutdown(this); - return; - } -} - -void Client::handlePeerConnMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(PeerConnMessage)) { - spdlog::warn("invalid peer conn message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - PeerConnMessage *header = (PeerConnMessage *)message.buffer.c_str(); - - uint32_t src = Address::netToHost(header->src); - uint32_t dst = Address::netToHost(header->dst); - uint32_t ip = Address::netToHost(header->ip); - uint16_t port = Address::netToHost(header->port); - - if (dst != this->tun.getIP()) { - spdlog::warn("peer conn message dest not match: {:n}", spdlog::to_hex(message.buffer)); - return; - } - - if (src == this->tun.getIP()) { - spdlog::warn("peer conn message connect to self"); - return; - } - - std::unique_lock lock(this->ipPeerMutex); - PeerInfo &peer = this->ipPeerMap[src]; - - peer.wide.ip = ip; - peer.wide.port = port; - peer.count = 0; - peer.setTun(src, this->password); - - if (this->stun.uri.empty()) { - peer.updateState(PeerState::FAILED); - return; - } - - if (peer.getState() == PeerState::CONNECTED) { - return; - } - - if (peer.getState() == PeerState::SYNCHRONIZING) { - peer.updateState(PeerState::CONNECTING); - return; - } - - if (peer.getState() != PeerState::CONNECTING) { - peer.updateState(PeerState::PREPARING); - sendLocalPeerConnMessage(peer); - return; - } -} - -void Client::handleDiscoveryMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(DiscoveryMessage)) { - spdlog::warn("invalid discovery message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - - DiscoveryMessage *header = (DiscoveryMessage *)message.buffer.c_str(); - - uint32_t src = Address::netToHost(header->src); - uint32_t dst = Address::netToHost(header->dst); - - // 收到广播后向发送方回包 - if (dst == BROADCAST_IP) { - sendDiscoveryMessage(src); - } - - // 接收方收到广播或发送方收到回包,同时尝试开始直连 - tryDirectConnection(src); -} - -void Client::handleSysRtMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(SysRouteMessage)) { - spdlog::warn("invalid system route message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - this->localP2PDisabled = true; - SysRouteMessage *header = (SysRouteMessage *)message.buffer.c_str(); - SysRouteItem *rt = header->rtTable; - std::unique_lock lock(this->sysRtTableMutex); - for (uint8_t idx = 0; idx < header->size; ++idx) { - SysRouteEntry entry; - entry.dst = Address::netToHost(rt[idx].dest); - entry.mask = Address::netToHost(rt[idx].mask); - entry.next = Address::netToHost(rt[idx].nexthop); - - if (entry.next == this->tun.getIP()) { - continue; - } - - std::string dstStr = Address::ipToStr(entry.dst); - std::string maskStr = Address::ipToStr(entry.mask); - std::string nextStr = Address::ipToStr(entry.next); - spdlog::info("system route: dst={} mask={} next={}", dstStr, maskStr, nextStr); - - this->tun.setSysRtTable(entry.dst, entry.mask, entry.next); - sysRtTable.push_back(entry); - } -} - -void Client::handleGeneralMessage(WebSocketMessage &message) { - if (message.buffer.size() < sizeof(GeneralHeader)) { - spdlog::warn("invalid general message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - GeneralHeader *header = (GeneralHeader *)message.buffer.c_str(); - switch (header->subtype) { - case GeSubType::LOCAL_PEER_CONN: - handleLocalPeerConnMessage(message); - break; - } -} - -void Client::handleLocalPeerConnMessage(WebSocketMessage &message) { - if (this->localP2PDisabled) { - return; - } - if (message.buffer.size() < sizeof(LocalPeerConnMessage)) { - spdlog::warn("invalid local peer conn message: {:n}", spdlog::to_hex(message.buffer)); - return; - } - LocalPeerConnMessage *header = (LocalPeerConnMessage *)message.buffer.c_str(); - - uint32_t src = Address::netToHost(header->ge.src); - uint32_t dst = Address::netToHost(header->ge.dst); - uint32_t ip = Address::netToHost(header->ip); - uint16_t port = Address::netToHost(header->port); - - if (dst != this->tun.getIP()) { - spdlog::warn("local peer conn message dest not match: {:n}", spdlog::to_hex(message.buffer)); - return; - } - - if (src == this->tun.getIP()) { - spdlog::warn("local peer conn message connect to self"); - return; - } - - std::unique_lock lock(this->ipPeerMutex); - PeerInfo &peer = this->ipPeerMap[src]; - - peer.local.ip = ip; - peer.local.port = port; - peer.setTun(src, this->password); - - if (this->stun.uri.empty()) { - peer.updateState(PeerState::FAILED); - return; - } - - if (peer.getState() == PeerState::INIT) { - peer.updateState(PeerState::PREPARING); - sendLocalPeerConnMessage(peer); - return; - } -} - -// 需要确保双方同时调用. -// 1. 收到对方报文时,一般会回包,此时调用 -// 2. 收到主动发现报文时,这时一定会回包 -void Client::tryDirectConnection(uint32_t ip) { - std::unique_lock lock(this->ipPeerMutex); - PeerInfo &peer = this->ipPeerMap[ip]; - peer.setTun(ip, this->password); - if (this->stun.uri.empty()) { - peer.updateState(PeerState::FAILED); - return; - } - if (peer.getState() == PeerState::INIT) { - peer.updateState(PeerState::PREPARING); - sendLocalPeerConnMessage(peer); - } -} - -// TUN -int Client::startTunThread() { - if (this->tun.setName(this->tunName)) { - return -1; - } - if (this->tun.setAddress(this->realAddress)) { - return -1; - } - if (this->tun.setMTU(this->mtu)) { - return -1; - } - if (this->tun.setTimeout(1)) { - return -1; - } - if (this->tun.up()) { - return -1; - } - - this->tunThread = std::thread([&] { - this->recvTunMessage(); - spdlog::debug("tun thread exit"); - }); - - sendAuthMessage(); - - if (addressUpdateCallback) { - int error = addressUpdateCallback(this->realAddress); - if (error) { - spdlog::critical("address update callback failed: {}", error); - Candy::shutdown(this); - } - } - - return 0; -} - -void Client::recvTunMessage() { - while (this->running) { - std::string buffer; - int error = this->tun.read(buffer); - if (error == 0) { - continue; - } - if (error < 0) { - spdlog::critical("tun read failed. error {}", error); - Candy::shutdown(this); - break; - } - if (!this->workers) { - handleTunMessage(std::move(buffer)); - continue; - } - { - std::unique_lock lock(this->tunMsgQueueMutex); - this->tunMsgQueue.emplace(std::move(buffer)); - } - this->tunMsgQueueCondition.notify_one(); - } - return; -} - -void Client::handleTunMessage(std::string buffer) { - if (buffer.length() < sizeof(IPv4Header)) { - return; - } - - // 仅处理 IPv4 - IPv4Header *header = (IPv4Header *)buffer.data(); - if ((header->version_ihl >> 4) != 4) { - return; - } - // 存在路由表项时封装成简单的 IPIP 协议 - uint32_t nextHop = [&]() { - uint32_t daddr = Address::netToHost(header->daddr); - std::shared_lock lock(this->sysRtTableMutex); - for (auto const &rt : sysRtTable) { - if ((daddr & rt.mask) == rt.dst) { - return rt.next; - } - } - if (Address::netToHost(header->saddr) != this->tun.getIP()) { - return daddr; - } - return uint32_t(0); - }(); - if (nextHop) { - buffer = std::string(sizeof(IPv4Header), 0) + buffer; - header = (IPv4Header *)buffer.data(); - header->protocol = 0x04; - header->saddr = Address::hostToNet(this->tun.getIP()); - header->daddr = Address::hostToNet(nextHop); - } - // 目的地址是本机,直接回写,在 macos 中遇到了这种情况 - if (Address::netToHost(header->daddr) == this->tun.getIP()) { - this->tun.write(buffer); - return; - } - - // 尝试通过路由或直连发送 - if (!sendPeerForwardMessage(buffer)) { - return; - } - - // 通过 WebSocket 转发 - sendForwardMessage(buffer); -} - -// P2P -int Client::startUdpThread() { - if (this->stun.uri.empty()) { - return 0; - } - if (this->udpHolder.init()) { - spdlog::critical("udpHolder init failed"); - Candy::shutdown(this); - return -1; - } - if (this->selfInfo.setTun(this->tun.getIP(), this->password)) { - return -1; - } - sendStunRequest(); - this->selfInfo.local.ip = udpHolder.IP(); - this->selfInfo.local.port = udpHolder.Port(); - spdlog::debug("localhost: {}", Address::ipToStr(this->selfInfo.local.ip)); - this->udpThread = std::thread([&] { - recvUdpMessage(); - spdlog::debug("udp thread exit"); - }); - return 0; -} - -int Client::startTickThread() { - if (this->stun.uri.empty()) { - return 0; - } - this->tickThread = std::thread([&] { - while (this->running) { - std::this_thread::sleep_for(std::chrono::seconds(1)); - this->tick(); - } - spdlog::debug("tick thread exit"); - }); - return 0; -} - -void Client::tick() { - if (discoveryInterval) { - if (tickTick % discoveryInterval == 0) { - sendDiscoveryMessage(BROADCAST_IP); - } - } - - std::unique_lock lock(this->ipPeerMutex); - bool needSendStunRequest = false; - for (auto &[ip, peer] : this->ipPeerMap) { - switch (peer.getState()) { - case PeerState::INIT: - // 收到对方通过服务器转发的数据的时候,会切换为 PREPARING, 这里不做处理 - break; - - case PeerState::PREPARING: - // 长时间处于 PREPARING 状态,无法获取本机的公网信息,进入失败状态 - if (peer.count > 10) { - peer.updateState(PeerState::FAILED); - } else { - sendHeartbeatMessage(peer); - needSendStunRequest = true; - } - break; - - case PeerState::SYNCHRONIZING: - // 1.对方版本不支持 2.没有启用对等连接 3.对方无法获取到自己在公网中的信息 - if (peer.count > 10) { - peer.updateState(PeerState::FAILED); - } else { - sendHeartbeatMessage(peer); - } - break; - - case PeerState::CONNECTING: - // 进行超时检测,超时后进入 WAITING 状态,否则发送心跳 - if (peer.count > 10) { - peer.updateState(PeerState::WAITING); - } else { - sendHeartbeatMessage(peer); - } - break; - - case PeerState::CONNECTED: - // 进行超时检测,超时后清空对端信息,否则发送心跳 - if (peer.count > 3) { - peer.updateState(PeerState::INIT); - if (routeCost) { - updateCandyRtTable(CandyRouteEntry(peer.getTun(), peer.getTun(), DELAY_LIMIT)); - } - } else { - sendHeartbeatMessage(peer); - if (routeCost && peer.tick % 60 == 0) { - sendDelayMessage(peer); - } - } - break; - - case PeerState::WAITING: - // 达到等待时长,重新进入初始状态 - if (peer.count > peer.retry) { - peer.updateState(PeerState::INIT); - } - break; - - case PeerState::FAILED: - // 两端任意一方不支持或者未启用对等连接功能,进入失败状态,不再主动重连 - break; - } - ++peer.count; - ++peer.tick; - } - if (needSendStunRequest) { - sendStunRequest(); - } - ++tickTick; -} - -std::string Client::encrypt(const std::string &key, const std::string &plaintext) { - using lock = std::unique_lock; - auto guard = this->workers ? lock() : lock(cryptMutex); - return encryptHelper(key, plaintext); -} -std::string Client::decrypt(const std::string &key, const std::string &ciphertext) { - using lock = std::unique_lock; - auto guard = this->workers ? lock() : lock(cryptMutex); - return decryptHelper(key, ciphertext); -} - -std::string Client::encryptHelper(const std::string &key, const std::string &plaintext) { - int len = 0; - int ciphertextLen = 0; - EVP_CIPHER_CTX *ctx = NULL; - unsigned char ciphertext[1500] = {0}; - unsigned char iv[AES_256_GCM_IV_LEN] = {0}; - unsigned char tag[AES_256_GCM_TAG_LEN] = {0}; - - if (key.size() != AES_256_GCM_KEY_LEN) { - spdlog::debug("invalid key size: {}", key.size()); - return ""; - } - ctx = EVP_CIPHER_CTX_new(); - if (!ctx) { - spdlog::debug("create cipher context failed"); - return ""; - } - if (!RAND_bytes(iv, AES_256_GCM_IV_LEN)) { - spdlog::debug("generate random iv failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, (unsigned char *)key.data(), iv)) { - spdlog::debug("initialize cipher context failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, AES_256_GCM_IV_LEN, NULL)) { - spdlog::debug("set iv length failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_EncryptUpdate(ctx, ciphertext, &len, (unsigned char *)plaintext.data(), plaintext.size())) { - spdlog::debug("encrypt update failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - ciphertextLen = len; - if (!EVP_EncryptFinal_ex(ctx, ciphertext + len, &len)) { - spdlog::debug("encrypt final failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - ciphertextLen += len; - if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, AES_256_GCM_TAG_LEN, tag)) { - spdlog::debug("get tag failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - EVP_CIPHER_CTX_free(ctx); - - std::string result; - result.append((char *)iv, AES_256_GCM_IV_LEN); - result.append((char *)tag, AES_256_GCM_TAG_LEN); - result.append((char *)ciphertext, ciphertextLen); - return result; -} - -std::string Client::decryptHelper(const std::string &key, const std::string &ciphertext) { - int len = 0; - int plaintextLen = 0; - unsigned char *enc = NULL; - EVP_CIPHER_CTX *ctx = NULL; - unsigned char plaintext[1500] = {0}; - unsigned char iv[AES_256_GCM_IV_LEN] = {0}; - unsigned char tag[AES_256_GCM_TAG_LEN] = {0}; - - if (key.size() != AES_256_GCM_KEY_LEN) { - spdlog::debug("invalid key length: {}", key.size()); - return ""; - } - if (ciphertext.size() < AES_256_GCM_IV_LEN + AES_256_GCM_TAG_LEN) { - spdlog::debug("invalid ciphertext length: {}", ciphertext.size()); - return ""; - } - ctx = EVP_CIPHER_CTX_new(); - if (!ctx) { - spdlog::debug("create cipher context failed"); - return ""; - } - - enc = (unsigned char *)ciphertext.data(); - memcpy(iv, enc, AES_256_GCM_IV_LEN); - memcpy(tag, enc + AES_256_GCM_IV_LEN, AES_256_GCM_TAG_LEN); - enc += AES_256_GCM_IV_LEN + AES_256_GCM_TAG_LEN; - - if (!EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, (unsigned char *)key.data(), iv)) { - spdlog::debug("initialize cipher context failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, AES_256_GCM_IV_LEN, NULL)) { - spdlog::debug("set iv length failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_DecryptUpdate(ctx, plaintext, &len, enc, ciphertext.size() - AES_256_GCM_IV_LEN - AES_256_GCM_TAG_LEN)) { - spdlog::debug("decrypt update failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, AES_256_GCM_TAG_LEN, tag)) { - spdlog::debug("set tag failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - plaintextLen = len; - if (!EVP_DecryptFinal_ex(ctx, plaintext + len, &len)) { - spdlog::debug("decrypt final failed"); - EVP_CIPHER_CTX_free(ctx); - return ""; - } - - plaintextLen += len; - EVP_CIPHER_CTX_free(ctx); - - std::string result; - result.append((char *)plaintext, plaintextLen); - return result; -} - -int Client::sendStunRequest() { - struct addrinfo hints = {}, *info = NULL; - - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET; - hints.ai_socktype = SOCK_DGRAM; - - try { - Poco::URI uri(this->stun.uri); - std::string strPort = uri.getPort() == 0 ? "3478" : std::to_string(uri.getPort()); - if (getaddrinfo(uri.getHost().c_str(), strPort.c_str(), &hints, &info)) { - spdlog::warn("resolve stun server domain name failed: {}:{}", uri.getHost(), strPort); - return -1; - } - } catch (std::exception &e) { - spdlog::error("invalid stun uri: {}: {}", this->stun.uri, e.what()); - } - - this->stun.ip = Address::netToHost((uint32_t)((struct sockaddr_in *)info->ai_addr)->sin_addr.s_addr); - this->stun.port = Address::netToHost((uint16_t)((struct sockaddr_in *)info->ai_addr)->sin_port); - - freeaddrinfo(info); - - UdpMessage message; - StunRequest request; - message.ip = this->stun.ip; - message.port = this->stun.port; - message.buffer.assign((char *)&request, sizeof(request)); - if (this->udpHolder.write(message) != message.buffer.size()) { - spdlog::warn("send stun request failed"); - } - return 0; -} - -int Client::sendHeartbeatMessage(const PeerInfo &peer) { - UdpMessage message; - PeerHeartbeatMessage heartbeat; - heartbeat.type = PeerMessageType::HEARTBEAT; - heartbeat.tun = Address::hostToNet(this->tun.getIP()); - heartbeat.ack = peer.ack; - - if ((peer.getState() == PeerState::CONNECTED) && (peer.real.ip && peer.real.port)) { - heartbeat.ip = Address::hostToNet(this->selfInfo.real.ip); - heartbeat.port = Address::hostToNet(this->selfInfo.real.port); - message.ip = peer.real.ip; - message.port = peer.real.port; - message.buffer = encrypt(peer.getKey(), std::string((char *)&heartbeat, sizeof(heartbeat))); - this->udpHolder.write(message); - } - - if ((peer.getState() == PeerState::CONNECTING) && (peer.wide.ip && peer.wide.port)) { - heartbeat.ip = Address::hostToNet(this->selfInfo.wide.ip); - heartbeat.port = Address::hostToNet(this->selfInfo.wide.port); - message.ip = peer.wide.ip; - message.port = peer.wide.port; - message.buffer = encrypt(peer.getKey(), std::string((char *)&heartbeat, sizeof(heartbeat))); - this->udpHolder.write(message); - } - - if ((peer.getState() == PeerState::PREPARING || peer.getState() == PeerState::SYNCHRONIZING || - peer.getState() == PeerState::CONNECTING) && - (!this->localP2PDisabled && peer.local.ip && peer.local.port)) { - heartbeat.ip = Address::hostToNet(this->selfInfo.local.ip); - heartbeat.port = Address::hostToNet(this->selfInfo.local.port); - message.ip = peer.local.ip; - message.port = peer.local.port; - message.buffer = encrypt(peer.getKey(), std::string((char *)&heartbeat, sizeof(heartbeat))); - this->udpHolder.write(message); - } - - return 0; -} - -int Client::sendPeerForwardMessage(const std::string &buffer) { - std::shared_lock ipPeerLock(this->ipPeerMutex); - std::shared_lock rtTableLock(this->candyRtTableMutex); - - IPv4Header *header = (IPv4Header *)buffer.data(); - uint32_t dst = Address::netToHost(header->daddr); - - // 优先尝试最快的路由转发 - if (routeCost) { - auto route = this->candyRtTable.find(dst); - if (route != this->candyRtTable.end()) { - if (!sendPeerForwardMessage(buffer, route->second.next)) { - return 0; - } - } - } - - // 尝试直连 - return sendPeerForwardMessage(buffer, dst); -} - -int Client::sendPeerForwardMessage(const std::string &buffer, uint32_t nextHop) { - auto it = this->ipPeerMap.find(nextHop); - if (it == this->ipPeerMap.end()) { - return -1; - } - - const auto &peer = it->second; - if (peer.getState() != PeerState::CONNECTED) { - return -1; - } - - UdpMessage message; - message.ip = peer.real.ip; - message.port = peer.real.port; - message.buffer.push_back(PeerMessageType::FORWARD); - message.buffer.append(buffer); - message.buffer = encrypt(peer.getKey(), message.buffer); - this->udpHolder.write(message); - return 0; +void Client::setName(const std::string &name) { + this->tunName = name; + tun.setName(name); } -bool Client::isStunResponse(const UdpMessage &message) { - return message.ip == this->stun.ip && message.port == this->stun.port; +std::string Client::getName() const { + return this->tunName; } -bool Client::isHeartbeatMessage(const UdpMessage &message) { - return message.buffer.front() == PeerMessageType::HEARTBEAT; +void Client::setPassword(const std::string &password) { + ws.setPassword(password); + peer.setPassword(password); } -bool Client::isPeerForwardMessage(const UdpMessage &message) { - return message.buffer.front() == PeerMessageType::FORWARD; +void Client::setWebSocket(const std::string &uri) { + ws.setWsServerUri(uri); } -int Client::handleStunResponse(const std::string &buffer) { - if (buffer.length() < sizeof(StunResponse)) { - spdlog::debug("invalid stun response length: {}", buffer.length()); - return -1; - } - StunResponse *response = (StunResponse *)buffer.c_str(); - if (Address::netToHost(response->type) != 0x0101) { - spdlog::debug("stun not success response"); - return -1; - } - - int pos = 0; - uint32_t ip = 0; - uint16_t port = 0; - uint8_t *attr = response->attr; - while (pos < Address::netToHost(response->length)) { - // mapped address - if (Address::netToHost(*(uint16_t *)(attr + pos)) == 0x0001) { - pos += 6; // 跳过 2 字节类型, 2 字节长度, 1 字节保留, 1 字节IP版本号,指向端口号 - port = Address::netToHost(*(uint16_t *)(attr + pos)); - pos += 2; // 跳过2字节端口号,指向地址 - ip = Address::netToHost(*(uint32_t *)(attr + pos)); - break; - } - // xor mapped address - if (Address::netToHost(*(uint16_t *)(attr + pos)) == 0x0020) { - pos += 6; // 跳过 2 字节类型, 2 字节长度, 1 字节保留, 1 字节IP版本号,指向端口号 - port = Address::netToHost(*(uint16_t *)(attr + pos)) ^ 0x2112; - pos += 2; // 跳过2字节端口号,指向地址 - ip = Address::netToHost(*(uint32_t *)(attr + pos)) ^ 0x2112a442; - break; - } - // 跳过 2 字节类型,指向属性长度 - pos += 2; - // 跳过 2 字节长度和用该属性其他内容 - pos += 2 + Address::netToHost(*(uint16_t *)(attr + pos)); - } - if (!ip || !port) { - spdlog::warn("stun response parse failed: {:n}", spdlog::to_hex(buffer)); - return -1; - } - - this->selfInfo.wide.ip = ip; - this->selfInfo.wide.port = port; - - // 收到 STUN 响应后,向所有 PREPARING 状态的对端发送自己的公网信息,如果当前持有对端公网信息,就将状态调整为 CONNECTING, - // 否则调整为 SYNCHRONIZING - std::unique_lock lock(this->ipPeerMutex); - for (auto &[tun, peer] : this->ipPeerMap) { - if (peer.getState() == PeerState::PREPARING) { - if (peer.wide.ip && peer.wide.port) { - peer.updateState(PeerState::CONNECTING); - } else { - peer.updateState(PeerState::SYNCHRONIZING); - } - sendPeerConnMessage(peer); - } - } - - return 0; -} - -int Client::handleHeartbeatMessage(const UdpMessage &message) { - if (message.buffer.length() < sizeof(PeerHeartbeatMessage)) { - spdlog::debug("invalid heartbeat length: {}", message.buffer.length()); - return -1; - } - - // 收到对端的心跳,更新地址和端口 - PeerHeartbeatMessage *heartbeat = (PeerHeartbeatMessage *)message.buffer.c_str(); - std::unique_lock lock(this->ipPeerMutex); - uint32_t tun = Address::netToHost(heartbeat->tun); - PeerInfo &peer = this->ipPeerMap[tun]; - if (peer.getState() == PeerState::INIT || peer.getState() == PeerState::WAITING || peer.getState() == PeerState::FAILED) { - spdlog::debug("heartbeat peer state invalid: {} {}", Address::ipToStr(tun), peer.getStateStr()); - return -1; - } - - if (!isLocalIp(message.ip)) { - peer.wide.ip = message.ip; - peer.wide.port = message.port; - } else if (!this->localP2PDisabled) { - peer.local.ip = message.ip; - peer.local.port = message.port; - } else { - return 0; - } - - if (isLocalIp(message.ip) || !isLocalIp(peer.real.ip)) { - peer.real.ip = message.ip; - peer.real.port = message.port; - } - - // 设置确认标识,下次向对方发送的心跳将携带确认标识 - if (!peer.ack) { - peer.ack = 1; - } - - // 对方发来的心跳中包含确认标识,状态更新为 CONNECTED - if (heartbeat->ack) { - if (peer.getState() == PeerState::CONNECTED) { - peer.count = 0; - return 0; - } - peer.updateState(PeerState::CONNECTED); - if (routeCost) { - sendDelayMessage(peer); - } - } - return 0; +void Client::setTunAddress(const std::string &cidr) { + ws.setAddress(cidr); } -int Client::handlePeerForwardMessage(const UdpMessage &message) { - if (message.buffer.length() < sizeof(PeerForwardMessage)) { - spdlog::debug("invalid raw ipv4 length: {}", message.buffer.length()); - return -1; - } - - PeerForwardMessage *ipv4Message = (PeerForwardMessage *)message.buffer.c_str(); - if (Address::netToHost(ipv4Message->iph.daddr) == this->tun.getIP()) { - const char *src = message.buffer.c_str() + sizeof(ForwardHeader::type); - const size_t len = message.buffer.length() - sizeof(ForwardHeader::type); - if (ipv4Message->iph.protocol == 0x04) { - this->tun.write(std::string(src + sizeof(IPv4Header), len - sizeof(IPv4Header))); - } else { - this->tun.write(std::string(src, len)); - } - - // 可能是转发来的,尝试跟源地址建立直连 - tryDirectConnection(Address::netToHost(ipv4Message->iph.saddr)); - return 0; - } - - std::shared_lock ipPeerLock(this->ipPeerMutex); - std::shared_lock rtTableLock(this->candyRtTableMutex); - auto route = this->candyRtTable.find(Address::netToHost(ipv4Message->iph.daddr)); - if (route == this->candyRtTable.end()) { - return 0; - } - - auto peer = this->ipPeerMap.find(route->second.next); - if (peer == this->ipPeerMap.end() || peer->second.getState() != PeerState::CONNECTED) { - return 0; - } - - UdpMessage forward; - forward.ip = peer->second.real.ip; - forward.port = peer->second.real.port; - forward.buffer = encrypt(peer->second.getKey(), message.buffer); - this->udpHolder.write(forward); - return 0; +void Client::setExptTunAddress(const std::string &cidr) { + ws.setExptTunAddress(cidr); } -bool Client::isLocalIp(uint32_t ip) { - // 10.0.0.0/8 - if ((ip & 0xFF000000) == 0x0A000000) { - return true; - } - // 172.16.0.0/12 - if ((ip & 0xFFF00000) == 0xAC000000) { - return true; - } - // 192.168.0.0/16 - if ((ip & 0xFFFF0000) == 0xC0A80000) { - return true; - } - return false; +void Client::setVirtualMac(const std::string &vmac) { + ws.setVirtualMac(vmac); } -// Route -void Client::showCandyRtChange(const CandyRouteEntry &entry) { - std::string dstStr = Address::ipToStr(entry.dst); - std::string nextStr = Address::ipToStr(entry.next); - std::string delayStr = (entry.delay == DELAY_LIMIT) ? "[deleted]" : std::to_string(entry.delay); - spdlog::debug("candy route: dst={} next={} delay={}", dstStr, nextStr, delayStr); +void Client::setStun(const std::string &stun) { + peer.setStun(stun); } -int Client::updateCandyRtTable(CandyRouteEntry entry) { - bool isDirect = (entry.dst == entry.next); - bool isDelete = (entry.delay < 0 || entry.delay > 1000); - - std::unique_lock lock(this->candyRtTableMutex); - - // 到达此目的地址的历史路由,下一跳可能不同 - auto oldEntry = this->candyRtTable.find(entry.dst); - - // 本机检测到连接断开,删除所有以断联设备作为下一跳的路由并广播 - if (isDirect && isDelete) { - for (auto it = this->candyRtTable.begin(); it != this->candyRtTable.end();) { - if (it->second.next == entry.next) { - it->second.delay = DELAY_LIMIT; - sendCandyRtMessage(it->second.dst, it->second.delay); - showCandyRtChange(it->second); - it = this->candyRtTable.erase(it); - continue; - } - ++it; - } - return 0; - } - - // 本机检测到直连设备时延有更新,下一跳相同或者延迟更低时更新并广播 - if (isDirect && !isDelete) { - if (oldEntry == this->candyRtTable.end() || oldEntry->second.next == entry.next || oldEntry->second.delay > entry.delay) { - this->candyRtTable[entry.dst] = entry; - sendCandyRtMessage(entry.dst, entry.delay); - showCandyRtChange(entry); - } - return 0; - } - - // 收到设备断连广播,删除本机相同的路由并广播 - if (!isDirect && isDelete) { - if (oldEntry != this->candyRtTable.end() && oldEntry->second.next == entry.next) { - oldEntry->second.delay = DELAY_LIMIT; - sendCandyRtMessage(oldEntry->second.dst, oldEntry->second.delay); - showCandyRtChange(oldEntry->second); - this->candyRtTable.erase(oldEntry); - } - return 0; - } - - // 收到设备时延更新广播,更新本机相同路由并广播 - if (!isDirect && !isDelete) { - auto directEntry = this->candyRtTable.find(entry.next); - if (directEntry == this->candyRtTable.end()) { - return 0; - } - int32_t nowDelay = directEntry->second.delay + entry.delay; - if (oldEntry == this->candyRtTable.end() || oldEntry->second.next == entry.next || oldEntry->second.delay > nowDelay) { - entry.delay = nowDelay; - this->candyRtTable[entry.dst] = entry; - sendCandyRtMessage(entry.dst, entry.delay); - showCandyRtChange(entry); - return 0; - } - return 0; - } - return 0; +void Client::setDiscoveryInterval(int interval) { + peer.setDiscoveryInterval(interval); } -int Client::sendDelayMessage(const PeerInfo &peer) { - PeerDelayMessage delayMessage; - delayMessage.type = PeerMessageType::DELAY; - delayMessage.src = Address::hostToNet(this->tun.getIP()); - delayMessage.dst = Address::hostToNet(peer.getTun()); - delayMessage.timestamp = Time::hostToNet(Time::bootTime()); - return sendDelayMessage(peer, delayMessage); +void Client::setRouteCost(int cost) { + peer.setForwardCost(cost); } -int Client::sendDelayMessage(const PeerInfo &peer, const PeerDelayMessage &delay) { - UdpMessage message; - message.ip = peer.real.ip; - message.port = peer.real.port; - message.buffer = encrypt(peer.getKey(), std::string((char *)&delay, sizeof(delay))); - this->udpHolder.write(message); - return 0; +void Client::setPort(int port) { + peer.setPort(port); } -int Client::sendCandyRtMessage(uint32_t dst, int32_t delay) { - PeerRouteMessage routeMessage; - routeMessage.type = PeerMessageType::ROUTE; - routeMessage.dst = Address::hostToNet(dst); - routeMessage.next = Address::hostToNet(this->tun.getIP()); - routeMessage.delay = Time::hostToNet(delay == DELAY_LIMIT ? DELAY_LIMIT : delay + routeCost); - - for (auto &[_, peer] : this->ipPeerMap) { - if (peer.getState() == PeerState::CONNECTED) { - UdpMessage message; - message.ip = peer.real.ip; - message.port = peer.real.port; - message.buffer = encrypt(peer.getKey(), std::string((char *)&routeMessage, sizeof(routeMessage))); - this->udpHolder.write(message); - } - } - - return 0; +void Client::setLocalhost(std::string ip) { + peer.setLocalhost(ip); } -bool Client::isDelayMessage(const UdpMessage &message) { - return message.buffer.front() == PeerMessageType::DELAY; +void Client::setMtu(int mtu) { + tun.setMTU(mtu); } -bool Client::isCandyRtMessage(const UdpMessage &message) { - return message.buffer.front() == PeerMessageType::ROUTE; +void Client::setTunUpdateCallback(std::function callback) { + this->ws.setTunUpdateCallback(callback); } -int Client::handleDelayMessage(const UdpMessage &message) { - if (message.buffer.length() < sizeof(PeerDelayMessage)) { - spdlog::debug("invalid delay message length: {}", message.buffer.length()); - return -1; - } - - PeerDelayMessage *delayMessage = (PeerDelayMessage *)message.buffer.c_str(); - uint32_t src = Address::netToHost(delayMessage->src); - uint32_t dst = Address::netToHost(delayMessage->dst); - int64_t timestamp = Time::netToHost(delayMessage->timestamp); - - std::shared_lock lock(this->ipPeerMutex); - - if (src == this->tun.getIP()) { - auto it = this->ipPeerMap.find(dst); - if (it != this->ipPeerMap.end()) { - int32_t delay = Time::bootTime() - timestamp; - it->second.delay = delay; - updateCandyRtTable(CandyRouteEntry(dst, dst, delay)); - } - return 0; - } - - if (dst == this->tun.getIP()) { - auto it = this->ipPeerMap.find(src); - if (it != this->ipPeerMap.end()) { - sendDelayMessage(it->second, *delayMessage); - } - return 0; - } - - return 0; +void Client::run() { + this->running = true; + ws.run(this); + tun.run(this); + peer.run(this); } -int Client::handleCandyRtMessage(const UdpMessage &message) { - if (message.buffer.length() < sizeof(PeerRouteMessage)) { - spdlog::debug("invalid route message length: {}", message.buffer.length()); - return -1; - } - - PeerRouteMessage *routeMessage = (PeerRouteMessage *)message.buffer.c_str(); - uint32_t dst = Address::netToHost(routeMessage->dst); - uint32_t next = Address::netToHost(routeMessage->next); - int32_t delay = Time::netToHost(routeMessage->delay); - - if (dst != this->tun.getIP()) { - updateCandyRtTable(CandyRouteEntry(dst, next, delay)); - } - - return 0; +void Client::shutdown() { + this->running = false; + ws.shutdown(); + tun.shutdown(); + peer.shutdown(); } } // namespace Candy diff --git a/src/core/client.h b/src/core/client.h index 63c59bce..7b7e11b8 100644 --- a/src/core/client.h +++ b/src/core/client.h @@ -3,199 +3,76 @@ #define CANDY_CORE_CLIENT_H #include "core/message.h" -#include "peer/udp.h" +#include "peer/peer.h" #include "tun/tun.h" -#include "utility/random.h" #include "websocket/client.h" -#include #include -#include -#include #include -#include #include -#include -#include namespace Candy { -struct StunCache { - uint32_t ip; - uint16_t port; - std::string uri; -}; - -struct CandyRouteEntry { - uint32_t dst; - uint32_t next; - int32_t delay; - - CandyRouteEntry(uint32_t dst = 0, uint32_t next = 0, int32_t delay = DELAY_LIMIT) : dst(dst), next(next), delay(delay) {} -}; +void shutdown(Client *client); -struct SysRouteEntry { - uint32_t dst; - uint32_t mask; - uint32_t next; +/* 各模块之间通过消息队列通信 */ +class MsgQueue { +public: + // 阻塞读 + Msg read(); + // 向队列中写入消息 + void write(Msg msg); - SysRouteEntry(uint32_t dst = 0, uint32_t mask = 0, uint32_t next = 0) : dst(dst), mask(mask), next(next) {} +private: + std::queue msgQueue; + std::mutex msgMutex; + std::condition_variable msgCondition; }; +/* 客户端只负责维护模块间通信的消息队列,以及进程启动时的参数透传,本身不提供实质性的功能 */ class Client { public: - // 设置客户端名称,用于设置 TUN 设备的名称,格式为 candy-name, 如果 name 为空将被命名为 candy. - int setName(const std::string &name); - std::string getName() const; - - // 客户端工作线程数量 - int setWorkers(int number); - - // 连接 websocket 服务端时身份认证的口令 - int setPassword(const std::string &password); - - // 用于数据转发和对等连接控制的服务端地址 - int setWebSocketServer(const std::string &server); - - // TUN 地址,向服务端要求强制使用这个地址,使用相同地址的前一个设备会被踢出网络 - int setTunAddress(const std::string &cidr); - - // 向服务端请求时期望获得的地址,地址不可用时服务端返回新地址 - int setExpectedAddress(const std::string &cidr); - - // 虚拟 Mac 地址 - int setVirtualMac(const std::string &vmac); - - // STUN 服务端,用于开启对等连接 - int setStun(const std::string &stun); - - // 主动发现时间间隔 - int setDiscoveryInterval(int interval); - - // 通过本节点路由的代价 - int setRouteCost(int cost); + // 通过配置文件或命令行设置的参数 + void setName(const std::string &name); + void setPassword(const std::string &password); + void setWebSocket(const std::string &uri); + void setTunAddress(const std::string &cidr); + void setStun(const std::string &stun); + void setDiscoveryInterval(int interval); + void setRouteCost(int cost); + void setPort(int port); + void setLocalhost(std::string ip); + void setMtu(int mtu); + void setTunUpdateCallback(std::function callback); + + // 期望使用的地址,当地址可用时服务端优先分配这个地址 + void setExptTunAddress(const std::string &cidr); + // 虚拟的硬件地址,在程序第一次启动时随机生成的 16 位随机字符串,用于和最近一次使用的地址绑定 + // 当相同虚拟硬件地址的设备登录时,判定为前一个客户端已断开,踢出前一个客户端并分配为与前一个客户端相同的 IP + void setVirtualMac(const std::string &vmac); + + // 启动客户端,非阻塞 + void run(); + // 关闭客户端,阻塞,直到所有子模块退出 + void shutdown(); - // 本地地址更新时执行的回调函数 - int setAddressUpdateCallback(std::function callback); - - // 绑定用于 P2P 连接的 UDP 端口, 0 表示由操作系统分配 - int setUdpBindPort(int port); - - // 用于局域网连接的地址 - int setLocalhost(std::string ip); + bool running = false; - // 设置最大传输单元 - int setMtu(int mtu); + std::string getName() const; - // 启停客户端用于处理任务的线程 - int run(); - int shutdown(); +public: + // 三个消息队列,子模块使用这些队列通信 + MsgQueue tunMsgQueue, peerMsgQueue, wsMsgQueue; private: - // Common - int workers = 0; - int mtu = 1400; - bool running = false; - std::string password; - std::mutex runningMutex; - std::function addressUpdateCallback; - int startWorkerThreads(); - int stopWorkerThreads(); - - // WebSocket - int startWsThread(); - void handleWebSocketMessage(); - void sendForwardMessage(const std::string &buffer); - void sendVirtualMacMessage(); - void sendDynamicAddressMessage(); - void sendAuthMessage(); - void sendPeerConnMessage(const PeerInfo &peer); - void sendDiscoveryMessage(uint32_t dst); - void sendLocalPeerConnMessage(const PeerInfo &peer); - void handleForwardMessage(WebSocketMessage &message); - void handleExpectedAddressMessage(WebSocketMessage &message); - void handlePeerConnMessage(WebSocketMessage &message); - void handleDiscoveryMessage(WebSocketMessage &message); - void handleSysRtMessage(WebSocketMessage &message); - void handleGeneralMessage(WebSocketMessage &message); - void handleLocalPeerConnMessage(WebSocketMessage &message); - void tryDirectConnection(uint32_t ip); - std::string hostName(); - + // TUN 模块,与本机通信 + Tun tun; + // PEER 模块,用于建立 P2P 连接,以及在 P2P 连接之上的客户端中继功能 + Peer peer; + // WS 模块,主要处理与服务端之间的控制信息,在 P2P 无法使用时提供服务端中继 WebSocketClient ws; - std::string wsUri; - std::thread wsThread; - - // TUN - int startTunThread(); - void recvTunMessage(); - void handleTunMessage(std::string message); - Tun tun; +private: std::string tunName; - std::string tunAddress; - std::string expectedAddress; - std::string realAddress; - std::string virtualMac; - std::thread tunThread; - std::vector tunMsgWorkerThreads; - std::mutex tunMsgQueueMutex; - std::queue tunMsgQueue; - std::condition_variable tunMsgQueueCondition; - - // P2P - int startUdpThread(); - int startTickThread(); - void recvUdpMessage(); - void handleUdpMessage(UdpMessage message); - void tick(); - std::string encrypt(const std::string &key, const std::string &plaintext); - std::string decrypt(const std::string &key, const std::string &ciphertext); - std::string encryptHelper(const std::string &key, const std::string &plaintext); - std::string decryptHelper(const std::string &key, const std::string &ciphertext); - int sendStunRequest(); - int sendHeartbeatMessage(const PeerInfo &peer); - int sendPeerForwardMessage(const std::string &buffer); - int sendPeerForwardMessage(const std::string &buffer, uint32_t nextHop); - bool isStunResponse(const UdpMessage &message); - bool isHeartbeatMessage(const UdpMessage &message); - bool isPeerForwardMessage(const UdpMessage &message); - int handleStunResponse(const std::string &buffer); - int handleHeartbeatMessage(const UdpMessage &message); - int handlePeerForwardMessage(const UdpMessage &message); - static bool isLocalIp(uint32_t ip); - - UdpHolder udpHolder; - StunCache stun; - PeerInfo selfInfo; - std::shared_mutex ipPeerMutex; - std::unordered_map ipPeerMap; - std::thread udpThread; - std::thread tickThread; - uint64_t tickTick = randomUint32(); - uint32_t discoveryInterval; - std::mutex cryptMutex; - std::atomic localP2PDisabled; - std::vector udpMsgWorkerThreads; - std::mutex udpMsgQueueMutex; - std::queue udpMsgQueue; - std::condition_variable udpMsgQueueCondition; - - // Route - void showCandyRtChange(const CandyRouteEntry &entry); - int updateCandyRtTable(CandyRouteEntry entry); - int sendDelayMessage(const PeerInfo &peer); - int sendDelayMessage(const PeerInfo &peer, const PeerDelayMessage &delay); - int sendCandyRtMessage(uint32_t dst, int32_t delay); - bool isDelayMessage(const UdpMessage &message); - bool isCandyRtMessage(const UdpMessage &message); - int handleDelayMessage(const UdpMessage &message); - int handleCandyRtMessage(const UdpMessage &message); - - std::shared_mutex sysRtTableMutex; - std::list sysRtTable; - std::shared_mutex candyRtTableMutex; - std::unordered_map candyRtTable; - int32_t routeCost; }; } // namespace Candy diff --git a/src/core/message.cc b/src/core/message.cc index 3b56f0da..57030a88 100644 --- a/src/core/message.cc +++ b/src/core/message.cc @@ -1,134 +1,22 @@ // SPDX-License-Identifier: MIT #include "core/message.h" -#include "utility/address.h" -#include "utility/time.h" -#include namespace Candy { -AuthHeader::AuthHeader(uint32_t ip) { - this->type = MessageType::AUTH; - this->ip = Address::hostToNet(ip); - this->timestamp = Time::hostToNet(Time::unixTime()); +Msg::Msg(MsgKind kind, std::string data) { + this->kind = kind; + this->data = std::move(data); } -void AuthHeader::updateHash(const std::string &password) { - std::string data; - data.append(password); - data.append((char *)&ip, sizeof(ip)); - data.append((char *)×tamp, sizeof(timestamp)); - SHA256((unsigned char *)data.data(), data.size(), this->hash); +Msg::Msg(Msg &&packet) { + kind = packet.kind; + data = std::move(packet.data); } -bool AuthHeader::check(const std::string &password) { - // 检查时间 - int64_t localTime = Time::unixTime(); - int64_t remoteTime = Time::netToHost(this->timestamp); - if (std::abs(localTime - remoteTime) > 30) { - spdlog::warn("auth header timestamp check failed: server {} client {}", localTime, remoteTime); - return false; - } - - // 备份上报的数据 - uint8_t reported[SHA256_DIGEST_LENGTH]; - std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); - - // 用口令计算正确的哈希并填充 - updateHash(password); - - // 检查上报的哈希和填充的哈希是否相等 - if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { - spdlog::warn("auth header hash check failed"); - return false; - } - return true; -} - -ForwardHeader::ForwardHeader() { - this->type = MessageType::FORWARD; -} - -ExpectedAddressMessage::ExpectedAddressMessage(const std::string &cidr) { - this->type = MessageType::EXPECTED; - this->timestamp = Time::hostToNet(Time::unixTime()); - std::strcpy(this->cidr, cidr.c_str()); -} - -void ExpectedAddressMessage::updateHash(const std::string &password) { - std::string data; - data.append(password); - data.append((char *)&this->timestamp, sizeof(this->timestamp)); - SHA256((unsigned char *)data.data(), data.size(), this->hash); -} - -bool ExpectedAddressMessage::check(const std::string &password) { - int64_t localTime = Time::unixTime(); - int64_t remoteTime = Time::netToHost(this->timestamp); - if (std::abs(localTime - remoteTime) > 30) { - spdlog::warn("expected address header timestamp check failed: server {} client {}", localTime, remoteTime); - return false; - } - - uint8_t reported[SHA256_DIGEST_LENGTH]; - std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); - - updateHash(password); - - if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { - spdlog::warn("expected address header hash check failed"); - return false; - } - return true; -} - -VMacMessage::VMacMessage(const std::string &vmac) { - this->type = MessageType::VMAC; - this->timestamp = Time::hostToNet(Time::unixTime()); - if (vmac.length() >= sizeof(this->vmac)) { - memcpy(this->vmac, vmac.c_str(), sizeof(this->vmac)); - } else { - memset(this->vmac, 0, sizeof(this->vmac)); - } -} - -DiscoveryMessage::DiscoveryMessage() { - this->type = MessageType::DISCOVERY; -} - -GeneralHeader::GeneralHeader() { - this->type = MessageType::GENERAL; -} - -void VMacMessage::updateHash(const std::string &password) { - std::string data; - data.append(password); - data.append((char *)&this->vmac, sizeof(this->vmac)); - data.append((char *)&this->timestamp, sizeof(this->timestamp)); - SHA256((unsigned char *)data.data(), data.size(), this->hash); -} - -bool VMacMessage::check(const std::string &password) { - int64_t localTime = Time::unixTime(); - int64_t remoteTime = Time::netToHost(this->timestamp); - if (std::abs(localTime - remoteTime) > 30) { - spdlog::warn("vmac message timestamp check failed: server {} client {}", localTime, remoteTime); - return false; - } - - uint8_t reported[SHA256_DIGEST_LENGTH]; - std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); - - updateHash(password); - - if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { - spdlog::warn("vmac message hash check failed"); - return false; - } - return true; -} - -PeerConnMessage::PeerConnMessage() { - this->type = MessageType::PEER; +Msg &Msg::operator=(Msg &&packet) { + kind = packet.kind; + data = std::move(packet.data); + return *this; } } // namespace Candy diff --git a/src/core/message.h b/src/core/message.h index 01b872c2..edcf0f6f 100644 --- a/src/core/message.h +++ b/src/core/message.h @@ -2,178 +2,34 @@ #ifndef CANDY_CORE_MESSAGE_H #define CANDY_CORE_MESSAGE_H -#include "utility/address.h" -#include #include -#include #include namespace Candy { -namespace MessageType { - -constexpr uint8_t AUTH = 0; -constexpr uint8_t FORWARD = 1; -constexpr uint8_t EXPECTED = 2; -constexpr uint8_t PEER = 3; -constexpr uint8_t VMAC = 4; -constexpr uint8_t DISCOVERY = 5; -constexpr uint8_t ROUTE = 6; -constexpr uint8_t GENERAL = 255; - -} // namespace MessageType - -struct AuthHeader { - uint8_t type; - uint32_t ip; - int64_t timestamp; - uint8_t hash[SHA256_DIGEST_LENGTH]; - - AuthHeader(uint32_t ip); - void updateHash(const std::string &password); - bool check(const std::string &password); -} __attribute__((packed)); - -struct ForwardHeader { - uint8_t type; - IPv4Header iph; - - ForwardHeader(); -} __attribute__((packed)); - -struct ExpectedAddressMessage { - uint8_t type; - int64_t timestamp; - char cidr[32]; - uint8_t hash[SHA256_DIGEST_LENGTH]; - - ExpectedAddressMessage(const std::string &cidr); - void updateHash(const std::string &password); - bool check(const std::string &password); -} __attribute__((packed)); - -struct PeerConnMessage { - uint8_t type; - uint32_t src; - uint32_t dst; - uint32_t ip; - uint16_t port; - - PeerConnMessage(); -} __attribute__((packed)); - -struct VMacMessage { - uint8_t type; - uint8_t vmac[16]; - int64_t timestamp; - uint8_t hash[SHA256_DIGEST_LENGTH]; - - VMacMessage(const std::string &vmac); - void updateHash(const std::string &password); - bool check(const std::string &password); -} __attribute__((packed)); - -struct DiscoveryMessage { - uint8_t type; - uint32_t src; - uint32_t dst; - - DiscoveryMessage(); -} __attribute__((packed)); - -struct SysRouteItem { - uint32_t dest; - uint32_t mask; - uint32_t nexthop; -} __attribute__((packed)); - -struct SysRouteMessage { - uint8_t type; - uint8_t size; - uint16_t reserved; - SysRouteItem rtTable[0]; -} __attribute__((packed)); - -struct GeneralHeader { - uint8_t type; - uint8_t subtype; - uint16_t extra; - uint32_t src; - uint32_t dst; - - GeneralHeader(); -} __attribute__((packed)); - -namespace GeSubType { - -constexpr uint8_t LOCAL_PEER_CONN = 0; - -} // namespace GeSubType - -struct LocalPeerConnMessage { - GeneralHeader ge; - uint32_t ip; - uint16_t port; -} __attribute__((packed)); - -struct StunRequest { - uint8_t type[2] = {0x00, 0x01}; - uint8_t length[2] = {0x00, 0x08}; - uint8_t cookie[4] = {0x21, 0x12, 0xa4, 0x42}; - uint8_t id[12] = {0x00}; - struct { - uint8_t type[2] = {0x00, 0x03}; - uint8_t length[2] = {0x00, 0x04}; - uint8_t notset[4] = {0x00}; - } attr; +/* 内部模块之间通过消息通信.消息类型在这里定义 */ +enum class MsgKind { + TIMEOUT, // 读操作超时 + PACKET, // 模块间转发 IP 报文 + TUNADDR, // 通知 TUN 模块设置地址 + SYSRT, // 设置系统路由 + TRYP2P, // 尝试建立对等连接 }; -struct StunResponse { - uint16_t type; - uint16_t length; - uint32_t cookie; - uint8_t id[12]; - uint8_t attr[0]; -}; +/* 内部模块间的消息只包含类型和可选的数据,模块之间传输信息只允许移动,不允许复制 */ +struct Msg { + MsgKind kind; + std::string data; -namespace PeerMessageType { + // 禁用拷贝构造和拷贝赋值 + Msg(const Msg &) = delete; + Msg &operator=(const Msg &) = delete; -constexpr uint8_t HEARTBEAT = 0; -constexpr uint8_t FORWARD = 1; -constexpr uint8_t DELAY = 2; -// TODO: 遗漏了 3, 新功能时使用 -constexpr uint8_t ROUTE = 4; - -} // namespace PeerMessageType - -struct PeerHeartbeatMessage { - uint8_t type; - uint32_t tun; - uint32_t ip; - uint16_t port; - uint8_t ack; -} __attribute__((packed)); - -struct PeerForwardMessage { - uint8_t type; - IPv4Header iph; -} __attribute__((packed)); - -struct PeerDelayMessage { - uint8_t type; - uint32_t src; - uint32_t dst; - int64_t timestamp; -} __attribute__((packed)); - -struct PeerRouteMessage { - uint8_t type; - uint32_t dst; - uint32_t next; - int32_t delay; -} __attribute__((packed)); - -constexpr uint32_t BROADCAST_IP = UINT32_MAX; + // 默认构造,移动构造和移动赋值 + Msg(MsgKind kind = MsgKind::TIMEOUT, std::string = ""); + Msg(Msg &&packet); + Msg &operator=(Msg &&packet); +}; } // namespace Candy diff --git a/src/core/net.cc b/src/core/net.cc new file mode 100644 index 00000000..87eac062 --- /dev/null +++ b/src/core/net.cc @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +#include "core/net.h" +#include +#include +#include + +namespace Candy { + +IP4::IP4(const std::string &ip) { + fromString(ip); +} + +IP4 IP4::operator=(const std::string &ip) { + fromString(ip); + return *this; +} + +IP4::operator std::string() const { + return toString(); +} + +IP4::operator uint32_t() const { + uint32_t val = 0; + std::memcpy(&val, raw.data(), sizeof(val)); + return val; +} + +IP4 IP4::operator|(IP4 another) const { + for (int i = 0; i < raw.size(); ++i) { + another.raw[i] |= raw[i]; + } + return another; +} + +IP4 IP4::operator^(IP4 another) const { + for (int i = 0; i < raw.size(); ++i) { + another.raw[i] ^= raw[i]; + } + return another; +} + +IP4 IP4::operator~() const { + IP4 retval; + for (int i = 0; i < raw.size(); ++i) { + retval.raw[i] |= ~raw[i]; + } + return retval; +} + +bool IP4::operator==(IP4 another) const { + return raw == another.raw; +} + +IP4 IP4::operator&(IP4 another) const { + for (int i = 0; i < raw.size(); ++i) { + another.raw[i] &= raw[i]; + } + return another; +} + +IP4 IP4::next() const { + IP4 ip; + uint32_t t = hton(ntoh(uint32_t(*this)) + 1); + std::memcpy(&ip, &t, sizeof(ip)); + return ip; +} + +int IP4::fromString(const std::string &ip) { + memcpy(raw.data(), Poco::Net::IPAddress(ip).addr(), 4); + return 0; +} + +std::string IP4::toString() const { + return Poco::Net::IPAddress(raw.data(), sizeof(raw)).toString(); +} + +int IP4::fromPrefix(int prefix) { + std::memset(raw.data(), 0, sizeof(raw)); + for (int i = 0; i < prefix; ++i) { + raw[i / 8] |= (0x80 >> (i % 8)); + } + return 0; +} + +int IP4::toPrefix() { + int i; + for (i = 0; i < 32; ++i) { + if (!(raw[i / 8] & (0x80 >> (i % 8)))) { + break; + } + } + return i; +} + +bool IP4::empty() const { + return raw[0] == 0 && raw[1] == 0 && raw[2] == 0 && raw[3] == 0; +} + +bool IP4Header::isIPv4() { + return true; +} + +bool IP4Header::isIPIP() { + return false; +} + +Address::Address() {} + +Address::Address(const std::string &cidr) { + fromCidr(cidr); +} + +IP4 &Address::Host() { + return this->host; +} + +IP4 &Address::Mask() { + return this->mask; +} + +IP4 Address::Net() { + return Host() & Mask(); +} + +Address Address::Next() { + Address next; + next.mask = this->mask; + next.host = (Net() | (~Mask() & this->host.next())); + return next; +} + +bool Address::isValid() { + // 主机号全为 0 + if ((~mask & host) == 0) { + return false; + } + // 主机号全为 1 + if (~(mask | host) == 0) { + return false; + } + return true; +} + +int Address::fromCidr(const std::string &cidr) { + try { + std::size_t pos = cidr.find('/'); + host.fromString(cidr.substr(0UL, pos)); + mask.fromPrefix(std::stoi(cidr.substr(pos + 1))); + } catch (std::exception &e) { + spdlog::warn("address parse cidr failed: {}: {}", e.what(), cidr); + return -1; + } + return 0; +} + +std::string Address::toCidr() { + return host.toString() + "/" + std::to_string(mask.toPrefix()); +} + +} // namespace Candy diff --git a/src/core/net.h b/src/core/net.h new file mode 100644 index 00000000..b51ccc7d --- /dev/null +++ b/src/core/net.h @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_CORE_NET_H +#define CANDY_CORE_NET_H + +#include "utility/byteswap.h" +#include +#include +#include +#include +#include + +namespace Candy { + +// 统一函数处理网络序与主机序之间的转换 +template T ntoh(T v) { + if (std::endian::native == std::endian::little) { + return byteswap(v); + } + return v; +} + +template T hton(T v) { + return ntoh(v); +} + +/* 用于表示 IPv4 地址,数据在内部以网络序的形式存储,并提供与之对应的字符串操作 */ +class __attribute__((packed)) IP4 { +public: + IP4(const std::string &ip = "0.0.0.0"); + IP4 operator=(const std::string &ip); + IP4 operator&(IP4 another) const; + IP4 operator|(IP4 another) const; + IP4 operator^(IP4 another) const; + IP4 operator~() const; + bool operator==(IP4 another) const; + operator std::string() const; + operator uint32_t() const; + IP4 next() const; + int fromString(const std::string &ip); + std::string toString() const; + int fromPrefix(int prefix); + int toPrefix(); + bool empty() const; + +private: + std::array raw; +}; + +/* IPv4 头,分装于 IPv4 相关的操作 */ +struct __attribute__((packed)) IP4Header { + uint8_t version_ihl; + uint8_t tos; + uint16_t tot_len; + uint16_t id; + uint16_t frag_off; + uint8_t ttl; + uint8_t protocol; + uint16_t check; + IP4 saddr; + IP4 daddr; + + // 判断是否为 IPv4, 对于 TUN 设备需要丢弃所有非 IPv4 的报文 + bool isIPv4(); + // 判断是否为模拟的 IPIP 协议 + bool isIPIP(); +}; + +struct __attribute__((packed)) SysRouteEntry { + IP4 dst; + IP4 mask; + IP4 nexthop; +}; + +/* 用于表示地址和掩码的组合,用于判断主机是否属于某个网络 */ +class Address { +public: + Address(); + Address(const std::string &cidr); + + IP4 &Host(); + IP4 &Mask(); + IP4 Net(); + + // 当前网络内的下一个地址 + Address Next(); + + // 判断是否是有效的主机地址 + bool isValid(); + + int fromCidr(const std::string &cidr); + std::string toCidr(); + + bool empty() const { + return host.empty() && mask.empty(); + } + +private: + IP4 host; + IP4 mask; +}; + +} // namespace Candy + +namespace std { +template <> struct hash { + size_t operator()(const Candy::IP4 &ip) const noexcept { + return hash{}(ip); + } +}; +} // namespace std + +#endif diff --git a/src/core/server.cc b/src/core/server.cc index f9634bc9..3b15317b 100644 --- a/src/core/server.cc +++ b/src/core/server.cc @@ -1,531 +1,30 @@ // SPDX-License-Identifier: MIT #include "core/server.h" -#include "core/common.h" -#include "core/message.h" -#include "utility/address.h" -#include "utility/random.h" -#include "utility/time.h" -#include -#include -#include -#include namespace Candy { -int Server::setWebSocketServer(const std::string &uri) { - try { - Poco::URI parser(uri); - if (parser.getScheme() != "ws") { - spdlog::critical("websocket server only support ws"); - return -1; - } - this->host = parser.getHost(); - this->port = parser.getPort(); - return 0; - } catch (std::exception &e) { - spdlog::critical("invalid websocket uri: {}: {}", uri, e.what()); - return -1; - } +void Server::setWebSocket(const std::string &uri) { + ws.setWebSocket(uri); } -int Server::setPassword(const std::string &password) { - this->password = password; - return 0; +void Server::setPassword(const std::string &password) { + ws.setPassword(password); } -int Server::setDynamicAddressRange(const std::string &cidr) { - if (cidr.empty()) { - return 0; - } - if (this->dynamic.cidrUpdate(cidr)) { - spdlog::critical("dynamic address generator init failed"); - return -1; - } - uint32_t randomHost = (~this->dynamic.getMask()) & randomUint32(); - if (this->dynamic.ipMaskUpdate(this->dynamic.getNet() | randomHost, this->dynamic.getMask())) { - return -1; - } - this->dynamicAddrEnabled = true; - return 0; +void Server::setDHCP(const std::string &cidr) { + ws.setDHCP(cidr); } -int Server::setSdwan(const std::string &sdwan) { - if (sdwan.empty()) { - return 0; - } - std::string route; - std::stringstream stream(sdwan); - while (std::getline(stream, route, ';')) { - std::string addr; - SysRoute rt; - std::stringstream ss(route); - if (!std::getline(ss, addr, ',') || rt.dev.cidrUpdate(addr) || rt.dev.getIp() != rt.dev.getNet()) { - spdlog::critical("invalid route device: {}", route); - return -1; - } - if (!std::getline(ss, addr, ',') || rt.dst.cidrUpdate(addr) || rt.dst.getIp() != rt.dst.getNet()) { - spdlog::critical("invalid route dest: {}", route); - return -1; - } - if (!std::getline(ss, addr, ',') || rt.next.ipStrUpdate(addr)) { - spdlog::critical("invalid route nexthop: {}", route); - return -1; - } - spdlog::info("route: dev={} dst={} next={}", rt.dev.getCidr(), rt.dst.getCidr(), rt.next.getIpStr()); - this->routes.push_back(rt); - } - return 0; +void Server::setSdwan(const std::string &sdwan) { + ws.setSdwan(sdwan); } -int Server::run() { - this->running = true; - if (startWsThread()) { - spdlog::critical("start websocket server thread failed"); - Candy::shutdown(this); - return -1; - } - return 0; +void Server::run() { + ws.run(); } -int Server::shutdown() { - if (!this->running) { - return 0; - } - - this->running = false; - this->dynamicAddrEnabled = false; - if (this->wsThread.joinable()) { - this->wsThread.join(); - } - - this->ws.stop(); - this->routes.clear(); - return 0; -} - -int Server::startWsThread() { - if (this->ws.listen(this->host, this->port)) { - spdlog::critical("websocket server listen failed"); - return -1; - } - - if (this->ws.setTimeout(1)) { - spdlog::critical("websocket server set read write timeout failed"); - return -1; - } - - this->wsThread = std::thread([&] { - this->handleWebSocketMessage(); - spdlog::debug("websocket server thread exit"); - }); - return 0; -} - -void Server::handleWebSocketMessage() { - int error; - WebSocketMessage message; - - spdlog::info("listen: {}:{}", this->host, this->port); - - while (this->running) { - error = this->ws.read(message); - if (error == 0) { - continue; - } - if (error < 0) { - spdlog::error("websocket server read failed: error {}", error); - Candy::shutdown(this); - break; - } - - if (message.type == WebSocketMessageType::Message) { - uint8_t msgType = message.buffer.front(); - switch (msgType) { - case MessageType::EXPECTED: - handleExpectedAddressMessage(message); - break; - case MessageType::VMAC: - handleVirtualMacMessage(message); - break; - case MessageType::AUTH: - handleAuthMessage(message); - break; - case MessageType::FORWARD: - handleForwardMessage(message); - break; - case MessageType::PEER: - handlePeerConnMessage(message); - break; - case MessageType::DISCOVERY: - handleDiscoveryMessage(message); - break; - case MessageType::GENERAL: - handleGeneralMessage(message); - break; - default: - spdlog::debug("unknown message: type {}", msgType); - break; - } - } - - if (message.type == WebSocketMessageType::Close) { - handleCloseMessage(message); - continue; - } - - if (message.type == WebSocketMessageType::Error) { - spdlog::critical("server websocket error: {}", message.buffer); - Candy::shutdown(this); - break; - } - } - return; -} - -void Server::handleAuthMessage(WebSocketMessage &message) { - if (message.buffer.length() < sizeof(AuthHeader)) { - spdlog::warn("invalid auth message: len {}", message.buffer.length()); - this->ws.close(message.conn); - return; - } - - AuthHeader *header = (AuthHeader *)message.buffer.data(); - if (!header->check(this->password)) { - spdlog::warn("auth header check failed: buffer {:n}", spdlog::to_hex(message.buffer)); - this->ws.close(message.conn); - return; - } - - Address address; - if (address.ipUpdate(Address::netToHost(header->ip))) { - spdlog::warn("invalid auth ip: buffer {:n}", spdlog::to_hex(message.buffer)); - this->ws.close(message.conn); - return; - } - if (this->ipWsMap.contains(address.getIp())) { - this->ws.close(this->ipWsMap[address.getIp()]); - spdlog::info("reconnect: {}", address.getIpStr()); - } else { - spdlog::info("connect: {}", address.getIpStr()); - } - - this->ipWsMap[address.getIp()] = message.conn; - this->wsIpMap[message.conn] = address.getIp(); - updateClientRoute(message, address.getIp()); -} - -void Server::handleForwardMessage(WebSocketMessage &message) { - if (!this->wsIpMap.contains(message.conn)) { - spdlog::debug("unauthorized forward websocket client"); - return; - } - - if (message.buffer.length() < sizeof(ForwardHeader)) { - spdlog::debug("invalid forawrd message: len {}", message.buffer.length()); - return; - } - - ForwardHeader *header = (ForwardHeader *)message.buffer.data(); - uint32_t saddr = Address::netToHost(header->iph.saddr); - uint32_t daddr = Address::netToHost(header->iph.daddr); - Address source; - source.ipUpdate(saddr); - - if (this->wsIpMap[message.conn] != saddr) { - Address auth; - auth.ipUpdate(this->wsIpMap[message.conn]); - spdlog::debug("forward source address does not match: auth {} source {}", auth.getIpStr(), source.getIpStr()); - return; - } - - if (this->ipWsMap.contains(daddr)) { - message.conn = this->ipWsMap[daddr]; - this->ws.write(message); - return; - } - - bool broadcast = [&] { - // 多播地址 - if ((daddr & 0xF0000000) == 0xE0000000) { - return true; - } - // 广播 - if (daddr == UINT32_MAX) { - return true; - } - // 服务端没有配置动态分配地址的范围,没法检查是否为定向广播 - if (!this->dynamicAddrEnabled) { - return false; - } - // 网络号不同,不是定向广播 - if ((this->dynamic.getMask() & daddr) != this->dynamic.getNet()) { - return false; - } - // 主机号部分不全为 1,不是定向广播 - if ((~this->dynamic.getMask()) & (daddr + 1)) { - return false; - } - return true; - }(); - - if (broadcast) { - for (auto conn : this->ipWsMap) { - if (conn.first != saddr) { - message.conn = conn.second; - this->ws.write(message); - } - } - return; - } - - Address destination; - destination.ipUpdate(daddr); - spdlog::debug("forward failed: source {} dest {}", source.getIpStr(), destination.getIpStr()); - return; -} - -void Server::handleExpectedAddressMessage(WebSocketMessage &message) { - if (message.buffer.length() < sizeof(ExpectedAddressMessage)) { - spdlog::warn("invalid dynamic address message: len {}", message.buffer.length()); - this->ws.close(message.conn); - return; - } - - ExpectedAddressMessage *header = (ExpectedAddressMessage *)message.buffer.data(); - if (!header->check(this->password)) { - spdlog::warn("dynamic address header check failed: buffer {:n}", spdlog::to_hex(message.buffer)); - this->ws.close(message.conn); - return; - } - - if (!this->dynamicAddrEnabled) { - spdlog::warn("the client requests a dynamic address, but the server does not enable this function"); - this->ws.close(message.conn); - return; - } - - Address address; - if (address.cidrUpdate(header->cidr)) { - spdlog::warn("dynamic address header cidr invalid: buffer {:n}", spdlog::to_hex(message.buffer)); - this->ws.close(message.conn); - return; - } - - bool needGenNewAddr = [&]() { - if (!dynamic.inSameNetwork(address)) { - return true; - } - auto oldWs = this->ipWsMap.find(address.getIp()); - if (oldWs == this->ipWsMap.end()) { - return false; - } - auto newMac = this->wsMacMap.find(message.conn); - if (newMac == this->wsMacMap.end()) { - return true; - } - auto oldMac = this->wsMacMap.find(oldWs->second); - if (oldMac == this->wsMacMap.end()) { - return true; - } - if (newMac->second == oldMac->second) { - return false; - } - return true; - }(); - - if (needGenNewAddr) { - uint32_t oldip = dynamic.getIp(); - uint32_t newip = 0; - do { - if (this->dynamic.next()) { - spdlog::error("unable to get next available address"); - this->ws.close(message.conn); - return; - } - newip = dynamic.getIp(); - if (oldip == newip) { - spdlog::warn("all addresses in the network are assigned"); - this->ws.close(message.conn); - return; - } - } while (this->ipWsMap.contains(newip)); - address.ipMaskUpdate(dynamic.getIp(), dynamic.getMask()); - } - - header->timestamp = Time::hostToNet(Time::unixTime()); - std::strcpy(header->cidr, address.getCidr().c_str()); - header->updateHash(this->password); - - this->ws.write(message); -} - -void Server::handlePeerConnMessage(WebSocketMessage &message) { - if (!this->wsIpMap.contains(message.conn)) { - spdlog::debug("unauthorized peer websocket client"); - return; - } - - if (message.buffer.length() < sizeof(PeerConnMessage)) { - spdlog::warn("invalid peer conn message: len {}", message.buffer.length()); - return; - } - - PeerConnMessage *header = (PeerConnMessage *)message.buffer.data(); - Address auth, source, destination; - auth.ipUpdate(this->wsIpMap[message.conn]); - source.ipUpdate(Address::netToHost(header->src)); - if (this->wsIpMap[message.conn] != Address::netToHost(header->src)) { - spdlog::debug("peer source address does not match: auth {} source {}", auth.getIpStr(), source.getIpStr()); - return; - } - if (!this->ipWsMap.contains(Address::netToHost(header->dst))) { - spdlog::debug("peer dest address not logged in: source {} dest {}", source.getIpStr(), destination.getIpStr()); - return; - } - message.conn = this->ipWsMap[Address::netToHost(header->dst)]; - this->ws.write(message); - return; -} - -void Server::handleVirtualMacMessage(WebSocketMessage &message) { - if (message.buffer.length() < sizeof(VMacMessage)) { - spdlog::warn("invalid vmac message: len {}", message.buffer.length()); - return; - } - - VMacMessage *header = (VMacMessage *)message.buffer.data(); - if (!header->check(this->password)) { - spdlog::warn("vmac message check failed: buffer {:n}", spdlog::to_hex(message.buffer)); - this->ws.close(message.conn); - return; - } - std::string vmac((char *)header->vmac, sizeof(header->vmac)); - - this->wsMacMap[message.conn] = vmac; - return; -} - -void Server::handleDiscoveryMessage(WebSocketMessage &message) { - if (!this->wsIpMap.contains(message.conn)) { - spdlog::debug("unauthorized discovery websocket client"); - return; - } - - if (message.buffer.length() < sizeof(DiscoveryMessage)) { - spdlog::debug("invalid discovery message: len {}", message.buffer.length()); - return; - } - - DiscoveryMessage *header = (DiscoveryMessage *)message.buffer.data(); - uint32_t saddr = Address::netToHost(header->src); - uint32_t daddr = Address::netToHost(header->dst); - - if (this->wsIpMap[message.conn] != saddr) { - Address auth, source; - auth.ipUpdate(this->wsIpMap[message.conn]); - source.ipUpdate(saddr); - spdlog::debug("discovery source address does not match: auth {} source {}", auth.getIpStr(), source.getIpStr()); - return; - } - - if (daddr == BROADCAST_IP) { - for (auto conn : this->ipWsMap) { - if (conn.first != saddr) { - message.conn = conn.second; - this->ws.write(message); - } - } - return; - } - - if (this->ipWsMap.contains(daddr)) { - message.conn = this->ipWsMap[daddr]; - this->ws.write(message); - return; - } -} - -void Server::handleGeneralMessage(WebSocketMessage &message) { - if (!this->wsIpMap.contains(message.conn)) { - spdlog::debug("unauthorized general websocket client"); - return; - } - - if (message.buffer.length() < sizeof(GeneralHeader)) { - spdlog::debug("invalid general message: len {}", message.buffer.length()); - return; - } - - GeneralHeader *header = (GeneralHeader *)message.buffer.data(); - uint32_t saddr = Address::netToHost(header->src); - uint32_t daddr = Address::netToHost(header->dst); - - if (this->wsIpMap[message.conn] != saddr) { - Address auth, source; - auth.ipUpdate(this->wsIpMap[message.conn]); - source.ipUpdate(saddr); - spdlog::debug("general source address does not match: auth {} source {}", auth.getIpStr(), source.getIpStr()); - return; - } - - if (daddr == BROADCAST_IP) { - for (auto conn : this->ipWsMap) { - if (conn.first != saddr) { - message.conn = conn.second; - this->ws.write(message); - } - } - return; - } - - if (this->ipWsMap.contains(daddr)) { - message.conn = this->ipWsMap[daddr]; - this->ws.write(message); - return; - } -} - -void Server::handleCloseMessage(WebSocketMessage &message) { - auto it = this->wsIpMap.find(message.conn); - if (it != this->wsIpMap.end()) { - if (this->ipWsMap[it->second] == message.conn) { - Address address; - if (!address.ipUpdate(it->second)) { - spdlog::info("disconnect: {}", address.getIpStr()); - } - this->ipWsMap.erase(it->second); - } - this->wsIpMap.erase(it); - } - this->wsMacMap.erase(message.conn); -} - -void Server::updateClientRoute(WebSocketMessage &message, uint32_t client) { - message.buffer.resize(sizeof(SysRouteMessage)); - SysRouteMessage *header = (SysRouteMessage *)message.buffer.data(); - memset(header, 0, sizeof(SysRouteMessage)); - header->type = MessageType::ROUTE; - - for (auto rt : this->routes) { - if ((rt.dev.getMask() & client) == rt.dev.getIp()) { - SysRouteItem item; - item.dest = Address::hostToNet(rt.dst.getNet()); - item.mask = Address::hostToNet(rt.dst.getMask()); - item.nexthop = Address::hostToNet(rt.next.getIp()); - message.buffer.append((char *)(&item), sizeof(item)); - header->size += 1; - } - // 100 条路由报文大小是 1204 字节,超过 100 条后分批发送 - if (header->size > 100) { - this->ws.write(message); - message.buffer.resize(sizeof(SysRouteMessage)); - header->size = 0; - } - } - - if (header->size > 0) { - this->ws.write(message); - } +void Server::shutdown() { + ws.shutdown(); } } // namespace Candy diff --git a/src/core/server.h b/src/core/server.h index 2696c50f..9ca234d3 100644 --- a/src/core/server.h +++ b/src/core/server.h @@ -2,61 +2,29 @@ #ifndef CANDY_CORE_SERVER_H #define CANDY_CORE_SERVER_H -#include "utility/address.h" #include "websocket/server.h" -#include -#include #include -#include -#include namespace Candy { -struct SysRoute { - Address dev; - Address dst; - Address next; -}; - class Server { public: - int setWebSocketServer(const std::string &uri); - int setPassword(const std::string &password); - int setDynamicAddressRange(const std::string &cidr); - int setSdwan(const std::string &sdwan); + // 通过配置文件或命令行设置的参数 + void setWebSocket(const std::string &uri); + void setPassword(const std::string &password); + void setDHCP(const std::string &cidr); + void setSdwan(const std::string &sdwan); - int run(); - int shutdown(); + // 启动服务端,非阻塞 + void run(); + // 关闭客户端,阻塞,直到所有子模块退出 + void shutdown(); private: - int startWsThread(); - void handleWebSocketMessage(); - - void handleAuthMessage(WebSocketMessage &message); - void handleForwardMessage(WebSocketMessage &message); - void handleExpectedAddressMessage(WebSocketMessage &message); - void handlePeerConnMessage(WebSocketMessage &message); - void handleVirtualMacMessage(WebSocketMessage &message); - void handleDiscoveryMessage(WebSocketMessage &message); - void handleGeneralMessage(WebSocketMessage &message); - void handleCloseMessage(WebSocketMessage &message); - - void updateClientRoute(WebSocketMessage &message, uint32_t client); - - bool running = false; - uint16_t port; - std::string host; - std::string password; - std::thread wsThread; + // 目前只有一个 WebSocket 服务端的子模块 WebSocketServer ws; - Address dynamic; - bool dynamicAddrEnabled = false; - - std::unordered_map ipWsMap; - std::map wsIpMap; - std::map wsMacMap; - std::list routes; + // TODO: 添加 STUN 服务端的支持 }; } // namespace Candy diff --git a/src/core/common.h b/src/core/version.h similarity index 55% rename from src/core/common.h rename to src/core/version.h index 79f76294..703fbbe9 100644 --- a/src/core/common.h +++ b/src/core/version.h @@ -1,21 +1,16 @@ // SPDX-License-Identifier: MIT -#ifndef CANDY_CORE_COMMON_H -#define CANDY_CORE_COMMON_H +#ifndef CANDY_CORE_VERSION_H +#define CANDY_CORE_VERSION_H #include #if POCO_OS == POCO_OS_LINUX -#include #define CANDY_SYSTEM "linux" #elif POCO_OS == POCO_OS_MAC_OS_X -#include #define CANDY_SYSTEM "macos" #elif POCO_OS == POCO_OS_ANDROID -#include #define CANDY_SYSTEM "android" #elif POCO_OS == POCO_OS_WINDOWS_NT -#include -#include #define CANDY_SYSTEM "windows" #else #define CANDY_SYSTEM "unknown" @@ -25,15 +20,4 @@ #define CANDY_VERSION "unknown" #endif -#include "core/client.h" -#include "core/server.h" - -namespace Candy { - -// 出现内部异常时调用 -void shutdown(Client *client); -void shutdown(Server *client); - -} // namespace Candy - #endif diff --git a/src/main/CMakeLists.txt b/src/main/CMakeLists.txt index 4c273761..ca80ead5 100644 --- a/src/main/CMakeLists.txt +++ b/src/main/CMakeLists.txt @@ -7,14 +7,16 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(${CANDY_EXECUTE_NAME} PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(${CANDY_EXECUTE_NAME} PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(${CANDY_EXECUTE_NAME} PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(${CANDY_EXECUTE_NAME} PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/main/config.cc b/src/main/config.cc new file mode 100644 index 00000000..6915eeb2 --- /dev/null +++ b/src/main/config.cc @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +#include "main/config.h" +#include "core/version.h" +#include "main/config.h" +#include "utility/argparse.h" +#include "utility/random.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void arguments::dump(const std::string &key, const std::string &value) { + if (!value.empty()) { + spdlog::debug("--{}={}", key, value); + } +} + +void arguments::dump(const std::string &key, int value) { + if (value) { + spdlog::debug("--{}={}", key, value); + } +} + +void arguments::dump() { + spdlog::debug("================================"); + dump("mode", this->mode); + dump("websocket", this->websocket); + dump("password", this->password); + dump("ntp", this->ntp); + dump("restart", this->restart); + dump("dhcp", this->dhcp); + dump("sdwan", this->sdwan); + dump("name", this->name); + dump("tun", this->tun); + dump("stun", this->stun); + dump("localhost", this->localhost); + dump("discovery", this->discovery); + dump("route", this->routeCost); + dump("mtu", this->mtu); + dump("workers", this->workers); + dump("port", this->udpPort); + spdlog::debug("================================"); +} + +int arguments::parse(int argc, char *argv[]) { + argparse::ArgumentParser program("candy", CANDY_VERSION); + + program.add_argument("-m", "--mode").help("working mode").metavar("TEXT"); + program.add_argument("-w", "--websocket").help("websocket address").metavar("URI"); + program.add_argument("-p", "--password").help("authorization password").metavar("TEXT"); + program.add_argument("--ntp").help("ntp server").metavar("HOST"); + program.add_argument("--restart").help("restart interval").scan<'i', int>().metavar("SECONDS"); + program.add_argument("-d", "--dhcp").help("dhcp address range").metavar("CIDR"); + program.add_argument("--sdwan").help("software-defined wide area network").metavar("ROUTES"); + program.add_argument("-n", "--name").help("network interface name").metavar("TEXT"); + program.add_argument("--workers").help("workers number").scan<'i', int>().metavar("NUM"); + program.add_argument("-t", "--tun").help("static address").metavar("CIDR"); + program.add_argument("-s", "--stun").help("stun address").metavar("URI"); + program.add_argument("--port").help("udp port").scan<'i', int>().metavar("NUMBER"); + program.add_argument("--mtu").help("maximum transmission unit").scan<'i', int>().metavar("NUMBER"); + program.add_argument("-r", "--route").help("routing cost").scan<'i', int>().metavar("COST"); + program.add_argument("--discovery").help("discovery interval").scan<'i', int>().metavar("SECONDS"); + program.add_argument("--localhost").help("local ip").metavar("IP"); + program.add_argument("-c", "--config").help("config file path").metavar("PATH"); + program.add_argument("--no-timestamp").implicit_value(true).help("disable log time"); + program.add_argument("--debug").implicit_value(true).help("show debug log"); + + try { + program.parse_args(argc, argv); + if (program.is_used("--config")) { + parseFile(program.get("--config")); + } + + this->mode = program.is_used("--mode") ? program.get("--mode") : this->mode; + this->websocket = program.is_used("--websocket") ? program.get("--websocket") : this->websocket; + this->password = program.is_used("--password") ? program.get("--password") : this->password; + this->ntp = program.is_used("--ntp") ? program.get("--ntp") : this->ntp; + this->restart = program.is_used("--restart") ? program.get("--restart") : this->restart; + this->noTimestamp = program.is_used("--no-timestamp") ? program.get("--no-timestamp") : this->noTimestamp; + this->debug = program.is_used("--debug") ? program.get("--debug") : this->debug; + this->dhcp = program.is_used("--dhcp") ? program.get("--dhcp") : this->dhcp; + this->sdwan = program.is_used("--sdwan") ? program.get("--sdwan") : this->sdwan; + this->name = program.is_used("--name") ? program.get("--name") : this->name; + this->workers = program.is_used("--workers") ? program.get("--workers") : this->workers; + this->tun = program.is_used("--tun") ? program.get("--tun") : this->tun; + this->stun = program.is_used("--stun") ? program.get("--stun") : this->stun; + this->localhost = program.is_used("--localhost") ? program.get("--localhost") : this->localhost; + this->udpPort = program.is_used("--port") ? program.get("--port") : this->udpPort; + this->mtu = program.is_used("--mtu") ? program.get("--mtu") : this->mtu; + this->discovery = program.is_used("--discovery") ? program.get("--discovery") : this->discovery; + this->routeCost = program.is_used("--route") ? program.get("--route") : this->routeCost; + + bool needShowUsage = [&]() { + if (this->mode != "client" && this->mode != "server") + return true; + if (this->websocket.empty()) + return true; + + return false; + }(); + + if (needShowUsage) { + std::cout << program.usage() << std::endl; + exit(1); + } + + if (this->noTimestamp) { + spdlog::set_pattern("[%^%l%$] %v"); + } + if (this->debug) { + spdlog::set_level(spdlog::level::debug); + this->dump(); + } + return 0; + } catch (const std::exception &e) { + std::cout << program.usage() << std::endl; + exit(1); + } +} + +void arguments::parseFile(std::string cfgFile) { + try { + std::map> cfgHandlers = { + {"mode", [&](const std::string &value) { this->mode = value; }}, + {"websocket", [&](const std::string &value) { this->websocket = value; }}, + {"password", [&](const std::string &value) { this->password = value; }}, + {"ntp", [&](const std::string &value) { this->ntp = value; }}, + {"debug", [&](const std::string &value) { this->debug = (value == "true"); }}, + {"restart", [&](const std::string &value) { this->restart = std::stoi(value); }}, + {"dhcp", [&](const std::string &value) { this->dhcp = value; }}, + {"sdwan", [&](const std::string &value) { this->sdwan = value; }}, + {"tun", [&](const std::string &value) { this->tun = value; }}, + {"stun", [&](const std::string &value) { this->stun = value; }}, + {"name", [&](const std::string &value) { this->name = value; }}, + {"workers", [&](const std::string &value) { this->workers = std::stoi(value); }}, + {"discovery", [&](const std::string &value) { this->discovery = std::stoi(value); }}, + {"route", [&](const std::string &value) { this->routeCost = std::stoi(value); }}, + {"port", [&](const std::string &value) { this->udpPort = std::stoi(value); }}, + {"mtu", [&](const std::string &value) { this->mtu = std::stoi(value); }}, + {"localhost", [&](const std::string &value) { this->localhost = value; }}, + }; + auto trim = [](std::string str) { + if (str.length() >= 2 && str.front() == '\"' && str.back() == '\"') { + return str.substr(1, str.length() - 2); + } + return str; + }; + auto configs = fileToKvMap(cfgFile); + for (auto cfg : configs) { + auto handler = cfgHandlers.find(cfg.first); + if (handler != cfgHandlers.end()) { + handler->second(trim(cfg.second)); + } else { + spdlog::warn("unknown config: {}={}", cfg.first, cfg.second); + } + } + } catch (std::exception &e) { + spdlog::error("parse config file failed: {}", e.what()); + exit(1); + } +} + +std::map arguments::fileToKvMap(const std::string &filename) { + std::map config; + std::ifstream file(filename); + std::string line; + + while (std::getline(file, line)) { + line = Poco::trimLeft(line); + if (line.empty() || line.front() == '#') + continue; + line.erase(line.find_last_not_of(" \t;") + 1); + std::size_t delimiterPos = line.find('='); + if (delimiterPos != std::string::npos) { + std::string key = Poco::trim(line.substr(0, delimiterPos)); + std::string value = Poco::trim(line.substr(delimiterPos + 1)); + config[key] = value; + } + } + return config; +} + +int saveTunAddress(const std::string &name, const std::string &cidr) { + try { + std::string cache = storageDirectory("address/"); + cache += name.empty() ? "__noname__" : name; + std::filesystem::create_directories(std::filesystem::path(cache).parent_path()); + std::ofstream ofs(cache); + if (ofs.is_open()) { + ofs << cidr; + ofs.close(); + } + return 0; + } catch (std::exception &e) { + spdlog::critical("save latest address failed: {}", e.what()); + return -1; + } +} + +std::string loadTunAddress(const std::string &name) { + std::string cache = storageDirectory("address/"); + cache += name.empty() ? "__noname__" : name; + std::ifstream ifs(cache); + if (ifs.is_open()) { + std::stringstream ss; + ss << ifs.rdbuf(); + ifs.close(); + return ss.str(); + } + return "0.0.0.0/0"; +} + +std::string virtualMac(const std::string &name) { + try { + std::string cache = storageDirectory("vmac/"); + cache += name.empty() ? "__noname__" : name; + std::filesystem::create_directories(std::filesystem::path(cache).parent_path()); + + char buffer[16]; + std::stringstream ss; + + std::ifstream ifs(cache); + if (ifs.is_open()) { + ifs.read(buffer, sizeof(buffer)); + if (ifs) { + for (int i = 0; i < (int)sizeof(buffer); i++) { + ss << std::hex << buffer[i]; + } + } + ifs.close(); + } else { + ss << Candy::randomHexString(sizeof(buffer)); + std::ofstream ofs(cache); + if (ofs.is_open()) { + ofs << ss.str(); + ofs.close(); + } + } + return ss.str(); + } catch (std::exception &e) { + spdlog::critical("vmac failed: {}", e.what()); + return ""; + } +} + +bool hasContainerVolume(const arguments &args) { + if (args.mode != "client") { + return true; + } + if (!std::filesystem::exists(storageDirectory("lost"))) { + return true; + } + if (args.websocket.starts_with("wss://canets.org")) { + return false; + } + if (!args.tun.empty()) { + return true; + } + return false; +} + +#if POCO_OS == POCO_OS_WINDOWS_NT +std::string storageDirectory(std::string subdir) { + return "C:/ProgramData/Candy/" + subdir; +} +#else +std::string storageDirectory(std::string subdir) { + return "/var/lib/candy/" + subdir; +} +#endif diff --git a/src/main/config.h b/src/main/config.h new file mode 100644 index 00000000..79a4a3cf --- /dev/null +++ b/src/main/config.h @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_MAIN_CONFIG_H +#define CANDY_MAIN_CONFIG_H + +#include +#include +#include + +struct arguments { + // 通用配置 + std::string mode; + std::string websocket; + std::string password; + std::string ntp; + int restart = 0; + bool noTimestamp = false; + bool debug = false; + + // 服务端配置 + std::string dhcp; + std::string sdwan; + + // 客户端配置 + std::string name; + std::string tun; + std::string stun; + std::string localhost; + int workers = 0; + int udpPort = 0; + int discovery = 0; + int routeCost = 0; + int mtu = 1400; + + int parse(int argc, char *argv[]); + +private: + void dump(const std::string &key, const std::string &value); + void dump(const std::string &key, int value); + void dump(); + void parseFile(std::string cfgFile); + std::map fileToKvMap(const std::string &filename); +}; + +// 保存虚拟地址 +int saveTunAddress(const std::string &name, const std::string &cidr); + +// 获取虚拟地址 +std::string loadTunAddress(const std::string &name); + +// 获取或生成虚拟硬件地址 +std::string virtualMac(const std::string &name); + +// 检查是否能成功保存虚拟硬件地址,虚拟硬件地址不能持久化会导致 +// 1. 址申请动态 IP 时会重复获取地址资源 +// 2. 使用静态地址时可能会冲突 +bool hasContainerVolume(const arguments &args); + +// 获取数据存储目录,默认参数非空是追加为子目录或目录下的文件 +std::string storageDirectory(std::string subdir = ""); + +#endif diff --git a/src/main/main.cc b/src/main/main.cc index 93a09c20..fd7d7970 100644 --- a/src/main/main.cc +++ b/src/main/main.cc @@ -1,247 +1,15 @@ // SPDX-License-Identifier: MIT #include "core/client.h" -#include "core/common.h" #include "core/server.h" -#include "utility/argparse.h" -#include "utility/random.h" +#include "main/config.h" #include "utility/time.h" -#include -#include #include -#include -#include -#include -#include -#include #include #include -#include #include -namespace { - -struct arguments { - // 通用配置 - std::string mode; - std::string websocket; - std::string password; - std::string ntp; - int restart = 0; - bool noTimestamp = false; - bool debug = false; - - // 服务端配置 - std::string dhcp; - std::string sdwan; - - // 客户端配置 - std::string name; - std::string tun; - std::string stun; - std::string localhost; - int workers = 0; - int udpPort = 0; - int discovery = 0; - int routeCost = 0; - int mtu = 1400; - - void dump(const std::string &key, const std::string &value) { - if (!value.empty()) { - spdlog::debug("--{}={}", key, value); - } - } - void dump(const std::string &key, int value) { - if (value) { - spdlog::debug("--{}={}", key, value); - } - } - void dump() { - spdlog::debug("================================"); - dump("mode", this->mode); - dump("websocket", this->websocket); - dump("password", this->password); - dump("ntp", this->ntp); - dump("restart", this->restart); - dump("dhcp", this->dhcp); - dump("sdwan", this->sdwan); - dump("name", this->name); - dump("tun", this->tun); - dump("stun", this->stun); - dump("localhost", this->localhost); - dump("discovery", this->discovery); - dump("route", this->routeCost); - dump("mtu", this->mtu); - dump("workers", this->workers); - dump("port", this->udpPort); - spdlog::debug("================================"); - } -}; - -int disableLogTimestamp() { - spdlog::set_pattern("[%^%l%$] %v"); - return 0; -} - -int setLogLevelDebug() { - spdlog::set_level(spdlog::level::debug); - return 0; -} - -std::map parseConfig(const std::string &filename) { - std::map config; - std::ifstream file(filename); - std::string line; - - while (std::getline(file, line)) { - line = Poco::trimLeft(line); - if (line.empty() || line.front() == '#') - continue; - line.erase(line.find_last_not_of(" \t;") + 1); - std::size_t delimiterPos = line.find('='); - if (delimiterPos != std::string::npos) { - std::string key = Poco::trim(line.substr(0, delimiterPos)); - std::string value = Poco::trim(line.substr(delimiterPos + 1)); - config[key] = value; - } - } - return config; -} - -void parseConfig(std::string cfgFile, arguments &args) { - try { - std::map> cfgHandlers = { - {"mode", [&](const std::string &value) { args.mode = value; }}, - {"websocket", [&](const std::string &value) { args.websocket = value; }}, - {"password", [&](const std::string &value) { args.password = value; }}, - {"ntp", [&](const std::string &value) { args.ntp = value; }}, - {"debug", [&](const std::string &value) { args.debug = (value == "true"); }}, - {"restart", [&](const std::string &value) { args.restart = std::stoi(value); }}, - {"dhcp", [&](const std::string &value) { args.dhcp = value; }}, - {"sdwan", [&](const std::string &value) { args.sdwan = value; }}, - {"tun", [&](const std::string &value) { args.tun = value; }}, - {"stun", [&](const std::string &value) { args.stun = value; }}, - {"name", [&](const std::string &value) { args.name = value; }}, - {"workers", [&](const std::string &value) { args.workers = std::stoi(value); }}, - {"discovery", [&](const std::string &value) { args.discovery = std::stoi(value); }}, - {"route", [&](const std::string &value) { args.routeCost = std::stoi(value); }}, - {"port", [&](const std::string &value) { args.udpPort = std::stoi(value); }}, - {"mtu", [&](const std::string &value) { args.mtu = std::stoi(value); }}, - {"localhost", [&](const std::string &value) { args.localhost = value; }}, - }; - auto trim = [](std::string str) { - if (str.length() >= 2 && str.front() == '\"' && str.back() == '\"') { - return str.substr(1, str.length() - 2); - } - return str; - }; - auto configs = parseConfig(cfgFile); - for (auto cfg : configs) { - auto handler = cfgHandlers.find(cfg.first); - if (handler != cfgHandlers.end()) { - handler->second(trim(cfg.second)); - } else { - spdlog::warn("unknown config: {}={}", cfg.first, cfg.second); - } - } - } catch (std::exception &e) { - spdlog::error("parse config file failed: {}", e.what()); - exit(1); - } -} - -#if POCO_OS == POCO_OS_WINDOWS_NT - -std::string storageDirectory = "C:/ProgramData/Candy/"; - -#else - -std::string storageDirectory = "/var/lib/candy/"; - -#endif - -int saveLatestAddress(const std::string &name, const std::string &cidr) { - try { - std::string cache = storageDirectory + "address/"; - cache += name.empty() ? "__noname__" : name; - std::filesystem::create_directories(std::filesystem::path(cache).parent_path()); - std::ofstream ofs(cache); - if (ofs.is_open()) { - ofs << cidr; - ofs.close(); - } - return 0; - } catch (std::exception &e) { - spdlog::critical("save latest address failed: {}", e.what()); - return -1; - } -} - -std::string getLastestAddress(const std::string &name) { - std::string cache = storageDirectory + "address/"; - cache += name.empty() ? "__noname__" : name; - std::ifstream ifs(cache); - if (ifs.is_open()) { - std::stringstream ss; - ss << ifs.rdbuf(); - ifs.close(); - return ss.str(); - } - return ""; -} - -std::string virtualMac(const std::string &name) { - try { - std::string cache = storageDirectory + "vmac/"; - cache += name.empty() ? "__noname__" : name; - std::filesystem::create_directories(std::filesystem::path(cache).parent_path()); - - char buffer[16]; - std::stringstream ss; - - std::ifstream ifs(cache); - if (ifs.is_open()) { - ifs.read(buffer, sizeof(buffer)); - if (ifs) { - for (int i = 0; i < (int)sizeof(buffer); i++) { - ss << std::hex << buffer[i]; - } - } - ifs.close(); - } else { - ss << Candy::randomHexString(sizeof(buffer)); - std::ofstream ofs(cache); - if (ofs.is_open()) { - ofs << ss.str(); - ofs.close(); - } - } - return ss.str(); - } catch (std::exception &e) { - spdlog::critical("vmac failed: {}", e.what()); - return ""; - } -} - -bool checkStorageDirectory(const arguments &args) { - if (args.mode != "client") { - return true; - } - if (!std::filesystem::exists(storageDirectory + "lost")) { - return true; - } - if (args.websocket.starts_with("wss://canets.org")) { - return false; - } - if (!args.tun.empty()) { - return true; - } - return false; -} - std::atomic running = true; -} // namespace - namespace Candy { void shutdown(Client *client) { @@ -256,54 +24,43 @@ void shutdown(Server *server) { } // namespace Candy -namespace { - -std::atomic exitCode = 1; - -void signalHandler(int signal) { - exitCode = 0; - running = false; - running.notify_one(); -} +int exitCode = 1; int serve(const arguments &args) { Poco::Net::initializeNetwork(); - Candy::Server server; - Candy::Client client; - if (args.mode == "server") { + Candy::Server server; server.setPassword(args.password); - server.setWebSocketServer(args.websocket); - server.setDynamicAddressRange(args.dhcp); + server.setWebSocket(args.websocket); + server.setDHCP(args.dhcp); server.setSdwan(args.sdwan); server.run(); + running.wait(true); + server.shutdown(); } if (args.mode == "client") { - client.setAddressUpdateCallback([&](const std::string &cidr) { return saveLatestAddress(args.name, cidr); }); + Candy::Client client; client.setDiscoveryInterval(args.discovery); client.setRouteCost(args.routeCost); - client.setUdpBindPort(args.udpPort); + client.setPort(args.udpPort); client.setLocalhost(args.localhost); client.setPassword(args.password); - client.setWebSocketServer(args.websocket); + client.setWebSocket(args.websocket); client.setStun(args.stun); client.setTunAddress(args.tun); - client.setExpectedAddress(getLastestAddress(args.name)); + client.setExptTunAddress(loadTunAddress(args.name)); client.setVirtualMac(virtualMac(args.name)); client.setMtu(args.mtu); - client.setWorkers(args.workers); client.setName(args.name); + client.setTunUpdateCallback([&](auto tunCidr) { return saveTunAddress(args.name, tunCidr); }); client.run(); + running.wait(true); + client.shutdown(); } - running.wait(true); - - server.shutdown(); - client.shutdown(); - if (exitCode == 0) { spdlog::info("service exit: normal"); } else { @@ -313,104 +70,35 @@ int serve(const arguments &args) { Poco::Net::uninitializeNetwork(); return exitCode; } -} // namespace - -int parseConfig(int argc, char *argv[], arguments &args) { - argparse::ArgumentParser program("candy", CANDY_VERSION); - - program.add_argument("-m", "--mode").help("working mode").metavar("TEXT"); - program.add_argument("-w", "--websocket").help("websocket address").metavar("URI"); - program.add_argument("-p", "--password").help("authorization password").metavar("TEXT"); - program.add_argument("--ntp").help("ntp server").metavar("HOST"); - program.add_argument("--restart").help("restart interval").scan<'i', int>().metavar("SECONDS"); - program.add_argument("-d", "--dhcp").help("dhcp address range").metavar("CIDR"); - program.add_argument("--sdwan").help("software-defined wide area network").metavar("ROUTES"); - program.add_argument("-n", "--name").help("network interface name").metavar("TEXT"); - program.add_argument("--workers").help("workers number").scan<'i', int>().metavar("NUM"); - program.add_argument("-t", "--tun").help("static address").metavar("CIDR"); - program.add_argument("-s", "--stun").help("stun address").metavar("URI"); - program.add_argument("--port").help("udp port").scan<'i', int>().metavar("NUMBER"); - program.add_argument("--mtu").help("maximum transmission unit").scan<'i', int>().metavar("NUMBER"); - program.add_argument("-r", "--route").help("routing cost").scan<'i', int>().metavar("COST"); - program.add_argument("--discovery").help("discovery interval").scan<'i', int>().metavar("SECONDS"); - program.add_argument("--localhost").help("local ip").metavar("IP"); - program.add_argument("-c", "--config").help("config file path").metavar("PATH"); - program.add_argument("--no-timestamp").implicit_value(true).help("disable log time"); - program.add_argument("--debug").implicit_value(true).help("show debug log"); - - try { - program.parse_args(argc, argv); - if (program.is_used("--config")) { - parseConfig(program.get("--config"), args); - } - - args.mode = program.is_used("--mode") ? program.get("--mode") : args.mode; - args.websocket = program.is_used("--websocket") ? program.get("--websocket") : args.websocket; - args.password = program.is_used("--password") ? program.get("--password") : args.password; - args.ntp = program.is_used("--ntp") ? program.get("--ntp") : args.ntp; - args.restart = program.is_used("--restart") ? program.get("--restart") : args.restart; - args.noTimestamp = program.is_used("--no-timestamp") ? program.get("--no-timestamp") : args.noTimestamp; - args.debug = program.is_used("--debug") ? program.get("--debug") : args.debug; - args.dhcp = program.is_used("--dhcp") ? program.get("--dhcp") : args.dhcp; - args.sdwan = program.is_used("--sdwan") ? program.get("--sdwan") : args.sdwan; - args.name = program.is_used("--name") ? program.get("--name") : args.name; - args.workers = program.is_used("--workers") ? program.get("--workers") : args.workers; - args.tun = program.is_used("--tun") ? program.get("--tun") : args.tun; - args.stun = program.is_used("--stun") ? program.get("--stun") : args.stun; - args.localhost = program.is_used("--localhost") ? program.get("--localhost") : args.localhost; - args.udpPort = program.is_used("--port") ? program.get("--port") : args.udpPort; - args.mtu = program.is_used("--mtu") ? program.get("--mtu") : args.mtu; - args.discovery = program.is_used("--discovery") ? program.get("--discovery") : args.discovery; - args.routeCost = program.is_used("--route") ? program.get("--route") : args.routeCost; - bool needShowUsage = [&]() { - if (args.mode != "client" && args.mode != "server") - return true; - if (args.websocket.empty()) - return true; - - return false; - }(); - - if (needShowUsage) { - std::cout << program.usage() << std::endl; - exit(1); - } - - if (args.noTimestamp) { - disableLogTimestamp(); - } - if (args.debug) { - setLogLevelDebug(); - args.dump(); - } - return 0; - } catch (const std::exception &e) { - std::cout << program.usage() << std::endl; - exit(1); - } +void signalHandler(int signal) { + exitCode = 0; + running = false; + running.notify_one(); } int main(int argc, char *argv[]) { arguments args; - parseConfig(argc, argv, args); + args.parse(argc, argv); signal(SIGINT, signalHandler); signal(SIGTERM, signalHandler); - if (!checkStorageDirectory(args)) { - spdlog::critical("the container needs to add a storage volume: {}", storageDirectory); + if (!hasContainerVolume(args)) { + spdlog::critical("the container needs to add a storage volume: {}", storageDirectory()); running = false; } - Candy::Time::ntpServer = args.ntp; + Candy::ntpServer = args.ntp; while (running && serve(args) && args.restart) { running = true; - Candy::Time::useSystemTime = false; + Candy::useSystemTime = false; spdlog::info("service will restart in {} seconds", args.restart); std::this_thread::sleep_for(std::chrono::seconds(args.restart)); } + spdlog::drop_all(); + spdlog::shutdown(); return exitCode; } diff --git a/src/peer/CMakeLists.txt b/src/peer/CMakeLists.txt index 34b6d6a1..1260cfd5 100644 --- a/src/peer/CMakeLists.txt +++ b/src/peer/CMakeLists.txt @@ -11,14 +11,16 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(peer PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(peer PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(peer PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(peer PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/peer/info.cc b/src/peer/info.cc new file mode 100644 index 00000000..11f31b71 --- /dev/null +++ b/src/peer/info.cc @@ -0,0 +1,13 @@ +#include "peer/info.h" + +namespace Candy { + +bool PeerInfo::isConnected() const { + return this->state == PeerState::CONNECTED; +} + +PeerState PeerInfo::getState() const { + return this->state; +} + +} // namespace Candy diff --git a/src/peer/info.h b/src/peer/info.h new file mode 100644 index 00000000..53cc2206 --- /dev/null +++ b/src/peer/info.h @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_PEER_INFO_H +#define CANDY_PEER_INFO_H + +#include +namespace Candy { + +enum class PeerState { + INIT, // 默认状态 + PREPARING, // 开始尝试建立对等连接 + SYNCHRONIZING, // 本机已经将建立对等连接所需的信息发送给了对端,但还没有收到对方的信息 + CONNECTING, // 已经收到了对端的对等连接信息,且将自己的信息发送给了对方 + CONNECTED, // 连接成功,持续发送心跳 + WAITING, // 连接失败,一段时间后重新进入 INIT + FAILED, // 连接失败,且不会再主动进入其他状态,除非收到对端的对等连接信息 +}; + +constexpr int32_t DELAY_LIMIT = INT32_MAX; +constexpr int32_t RETRY_MIN = 30; +constexpr int32_t RETRY_MAX = 3600; + +class PeerInfo { +public: + bool isConnected() const; + PeerState getState() const; + +private: + PeerState state; +}; + +} // namespace Candy + +#endif diff --git a/src/peer/message.h b/src/peer/message.h new file mode 100644 index 00000000..5a1befba --- /dev/null +++ b/src/peer/message.h @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_PEER_MESSAGE_H +#define CANDY_PEER_MESSAGE_H + +#include + +namespace Candy { +namespace PeerMsgKind { +constexpr uint8_t HEARTBEAT = 0; +constexpr uint8_t FORWARD = 1; +constexpr uint8_t DELAY = 2; +// TODO: 遗漏了 3, 新功能时使用 +constexpr uint8_t ROUTE = 4; +} // namespace PeerMsgKind +} // namespace Candy + +#endif diff --git a/src/peer/peer.cc b/src/peer/peer.cc index 61fd16f0..a149a5ee 100644 --- a/src/peer/peer.cc +++ b/src/peer/peer.cc @@ -1,90 +1,112 @@ // SPDX-License-Identifier: MIT #include "peer/peer.h" -#include "utility/address.h" -#include "utility/byteswap.h" -#include -#include -#include +#include "core/client.h" +#include "core/message.h" +#include "core/net.h" +#include #include namespace Candy { -int PeerInfo::setTun(uint32_t tun, const std::string &password) { - this->tun = tun; +int Peer::setPassword(const std::string &password) { + return 0; +} - if (std::endian::native == std::endian::big) { - tun = byteswap(tun); - } +int Peer::setStun(const std::string &stun) { + return 0; +} - std::string data; - data.append(password); - data.append((char *)&tun, sizeof(tun)); +int Peer::setDiscoveryInterval(int interval) { + return 0; +} - this->key.resize(SHA256_DIGEST_LENGTH); - SHA256((unsigned char *)data.data(), data.size(), (unsigned char *)this->key.data()); +int Peer::setForwardCost(int cost) { return 0; } -std::string PeerInfo::getKey() const { - return this->key; +int Peer::setPort(int port) { + return 0; } -uint32_t PeerInfo::getTun() const { - return this->tun; +int Peer::setLocalhost(const std::string &ip) { + return 0; } -void PeerInfo::updateState(PeerState state) { - this->count = 0; - if (this->state == state) { - return; - } +int Peer::run(Client *client) { + this->client = client; + this->msgThread = std::thread([&] { + while (this->client->running) { + handlePeerQueue(); + } + }); + return 0; +} - spdlog::debug("conn state: {} {} => {}", Address::ipToStr(this->tun), getStateStr(this->state), getStateStr(state)); - if (state == PeerState::INIT || state == PeerState::WAITING || state == PeerState::FAILED) { - this->wide.ip = 0; - this->wide.port = 0; - this->local.ip = 0; - this->local.port = 0; - this->real.ip = 0; - this->real.port = 0; - this->ack = 0; - this->delay = DELAY_LIMIT; +int Peer::shutdown() { + if (this->msgThread.joinable()) { + this->msgThread.join(); } - if (this->state == PeerState::WAITING && state == PeerState::INIT) { - this->retry = std::min(this->retry * 2, RETRY_MAX); - } else if (state == PeerState::INIT || state == PeerState::FAILED) { - this->retry = RETRY_MIN; + return 0; +} + +void Peer::handlePeerQueue() { + Msg msg = this->client->peerMsgQueue.read(); + switch (msg.kind) { + case MsgKind::TIMEOUT: + break; + case MsgKind::PACKET: + handlePacket(std::move(msg)); + break; + case MsgKind::TRYP2P: + handleTryP2P(std::move(msg)); + break; + default: + spdlog::warn("unexcepted peer message type: {}", static_cast(msg.kind)); + break; } - this->state = state; } -PeerState PeerInfo::getState() const { - return this->state; +int Peer::sendTo(IP4 dst, const Msg &msg) { + // 这两个锁同时使用时先给 ipPeerMap 加锁,避免死锁 + std::shared_lock ipPeerLock(this->ipPeerMutex); + std::shared_lock rtTableLock(this->rtTableMutex); + + auto rt = this->rtTableMap.find(dst); + if (rt == this->rtTableMap.end()) { + return -1; + } + auto it = this->ipPeerMap.find(rt->second); + if (it == this->ipPeerMap.end()) { + return -1; + } + auto &info = it->second; + if (!info.isConnected()) { + return -1; + } + std::string x; + x.push_back(1); + x += msg.data; + return sendTo(info, x); } -std::string PeerInfo::getStateStr() const { - return getStateStr(this->state); +int Peer::sendTo(PeerInfo &info, const std::string &data) { + return -1; } -std::string PeerInfo::getStateStr(PeerState state) { - switch (state) { - case PeerState::INIT: - return "INIT"; - case PeerState::PREPARING: - return "PREPARING"; - case PeerState::SYNCHRONIZING: - return "SYNCHRONIZING"; - case PeerState::CONNECTING: - return "CONNECTING"; - case PeerState::CONNECTED: - return "CONNECTED"; - case PeerState::WAITING: - return "WAITING"; - case PeerState::FAILED: - return "FAILED"; - default: - return "UNKNOWN"; +void Peer::handlePacket(Msg msg) { + IP4Header *header = (IP4Header *)msg.data.data(); + // 尝试 P2P 转发流量 + if (!sendTo(header->daddr, msg)) { + return; } + // 无法通过 P2P 转发流量,交给 WS 模块通过服务端转发 + this->client->wsMsgQueue.write(std::move(msg)); +} + +void Peer::handleTryP2P(Msg msg) { + // TODO: 尝试与特定对端建立直连 + IP4 src(msg.data); + spdlog::debug("TRYP2P: {}", src.toString()); } } // namespace Candy diff --git a/src/peer/peer.h b/src/peer/peer.h index 66c291a1..93ab56a3 100644 --- a/src/peer/peer.h +++ b/src/peer/peer.h @@ -2,58 +2,52 @@ #ifndef CANDY_PEER_PEER_H #define CANDY_PEER_PEER_H -#include "utility/random.h" -#include +#include "core/message.h" +#include "core/net.h" +#include "peer/info.h" +#include #include +#include +#include namespace Candy { -enum class PeerState { - INIT, - PREPARING, - SYNCHRONIZING, - CONNECTING, - CONNECTED, - WAITING, - FAILED, -}; - -constexpr int32_t DELAY_LIMIT = INT32_MAX; -constexpr uint32_t RETRY_MIN = 30U; -constexpr uint32_t RETRY_MAX = 3600U; +class Client; -class PeerInfo { +class Peer { public: - struct { - uint32_t ip = 0; - uint16_t port = 0; - } wide, local, real; - uint8_t ack = 0; - uint32_t count = 0; - uint32_t tick = randomUint32(); - uint32_t retry = RETRY_MIN; - int32_t delay = DELAY_LIMIT; + int setPassword(const std::string &password); + int setStun(const std::string &stun); + int setDiscoveryInterval(int interval); + int setForwardCost(int cost); + int setPort(int port); + int setLocalhost(const std::string &ip); -public: - int setTun(uint32_t tun, const std::string &password); - std::string getKey() const; - uint32_t getTun() const; - void updateState(PeerState state); - PeerState getState() const; - std::string getStateStr() const; + int run(Client *client); + int shutdown(); private: - static std::string getStateStr(PeerState state); - PeerState state = PeerState::INIT; - uint32_t tun = 0; - std::string key; -}; + // 处理来自消息队列的数据 + void handlePeerQueue(); + void handlePacket(Msg msg); + void handleTryP2P(Msg msg); -class UdpMessage { -public: - uint32_t ip; - uint16_t port; - std::string buffer; + std::thread msgThread; + + // 处理 PACKET 报文,并判断目标是否可达 + int sendTo(IP4 dst, const Msg &msg); + // 通过指定的 info 向对端发送 data, 此时的 data 是明文 + int sendTo(PeerInfo &info, const std::string &data); + +private: + std::shared_mutex ipPeerMutex; + std::unordered_map ipPeerMap; + + std::shared_mutex rtTableMutex; + std::unordered_map rtTableMap; + +private: + Client *client; }; } // namespace Candy diff --git a/src/peer/tcp.h b/src/peer/tcp.h new file mode 100644 index 00000000..8d7647b0 --- /dev/null +++ b/src/peer/tcp.h @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_PEER_TCP_H +#define CANDY_PEER_TCP_H + +namespace Candy { + +class TCP {}; + +class TCP4 : public TCP {}; + +class TCP6 : public TCP {}; + +} // namespace Candy + +#endif diff --git a/src/peer/udp.cc b/src/peer/udp.cc deleted file mode 100644 index 911b02a0..00000000 --- a/src/peer/udp.cc +++ /dev/null @@ -1,97 +0,0 @@ -#include "peer/udp.h" -#include "utility/address.h" -#include -#include -#include -#include -#include - -namespace Candy { - -int UdpHolder::init() { - try { - this->socket = Poco::Net::DatagramSocket(Poco::Net::SocketAddress(this->port), true, true); - this->socket.setBlocking(false); - this->address = socket.address(); - } catch (std::exception &e) { - spdlog::critical("udp holder init failed: {}", e.what()); - return -1; - } - return 0; -} - -void UdpHolder::reset() { - try { - this->socket.close(); - this->port = 0; - this->ip = 0; - } catch (std::exception &e) { - spdlog::warn("udp holder reset failed: {}", e.what()); - } -} - -void UdpHolder::setPort(uint16_t port) { - this->port = port; -} - -void UdpHolder::setIP(uint32_t ip) { - this->ip = ip; -} - -uint16_t UdpHolder::Port() { - return this->address.port(); -} - -uint32_t UdpHolder::IP() { - if (!this->ip) { - try { - for (const auto &iface : Poco::Net::NetworkInterface::list()) { - if (iface.supportsIPv4() && !iface.isLoopback() && !iface.isPointToPoint() && - iface.type() != iface.NI_TYPE_OTHER) { - auto firstAddress = iface.firstAddress(Poco::Net::IPAddress::IPv4); - memcpy(&this->ip, firstAddress.addr(), sizeof(this->ip)); - this->ip = ntohl(this->ip); - break; - } - } - } catch (std::exception &e) { - spdlog::warn("local ip failed: {}", e.what()); - } - } - return this->ip; -} - -size_t UdpHolder::read(UdpMessage &message) { - try { - if (this->socket.available()) { - std::string buffer(1500, 0); - Poco::Net::SocketAddress address; - int size = this->socket.receiveFrom(buffer.data(), buffer.size(), address); - if (size >= 0) { - buffer.resize(size); - message.buffer = std::move(buffer); - message.port = address.port(); - memcpy(&message.ip, address.host().addr(), sizeof(message.ip)); - message.ip = Address::netToHost(message.ip); - return size; - } - } - - this->socket.poll(Poco::Timespan(1, 0), Poco::Net::Socket::SELECT_READ); - } catch (std::exception &e) { - spdlog::debug("udp holder read failed: {}", e.what()); - } - return 0; -} - -size_t UdpHolder::write(const UdpMessage &message) { - try { - Poco::Net::SocketAddress address(Address::ipToStr(message.ip), message.port); - return this->socket.sendTo(message.buffer.data(), message.buffer.size(), address); - } catch (std::exception &e) { - spdlog::debug("udp holder write failed: {}", e.what()); - } - return 0; -} - -} // namespace Candy diff --git a/src/peer/udp.h b/src/peer/udp.h index 65bae131..6cbcc99d 100644 --- a/src/peer/udp.h +++ b/src/peer/udp.h @@ -1,29 +1,15 @@ - -#include "peer/peer.h" -#include +// SPDX-License-Identifier: MIT +#ifndef CANDY_PEER_UDP_H +#define CANDY_PEER_UDP_H namespace Candy { -class UdpHolder { -public: - int init(); - void reset(); - - void setPort(uint16_t port); - void setIP(uint32_t ip); - - uint16_t Port(); - uint32_t IP(); +class UDP {}; - size_t read(UdpMessage &message); - size_t write(const UdpMessage &message); +class UDP4 : public UDP {}; -private: - Poco::Net::SocketAddress address; - Poco::Net::DatagramSocket socket; - - uint16_t port = 0; - uint32_t ip = 0; -}; +class UDP6 : public UDP {}; } // namespace Candy + +#endif diff --git a/src/tun/CMakeLists.txt b/src/tun/CMakeLists.txt index bb2dff86..4586b243 100644 --- a/src/tun/CMakeLists.txt +++ b/src/tun/CMakeLists.txt @@ -34,19 +34,21 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(tun PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(tun PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(tun PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(tun PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) - target_link_libraries(tun PRIVATE Poco::Foundation) + target_link_libraries(tun PRIVATE Poco::Foundation Poco::Net Poco::NetSSL) else() - find_package(Poco REQUIRED Net) - target_link_libraries(tun PRIVATE Poco::Foundation) + find_package(Poco REQUIRED Foundation Net NetSSL) + target_link_libraries(tun PRIVATE Poco::Foundation Poco::Net Poco::NetSSL) endif() diff --git a/src/tun/linux.cc b/src/tun/linux.cc index a95ff2ed..8b7abcee 100644 --- a/src/tun/linux.cc +++ b/src/tun/linux.cc @@ -2,8 +2,8 @@ #include #if POCO_OS == POCO_OS_LINUX +#include "core/net.h" #include "tun/tun.h" -#include "utility/address.h" #include #include #include @@ -16,6 +16,7 @@ #include namespace { + class LinuxTun { public: int setName(const std::string &name) { @@ -23,16 +24,16 @@ class LinuxTun { return 0; } - int setIP(uint32_t ip) { + int setIP(Candy::IP4 ip) { this->ip = ip; return 0; } - int getIP() { + Candy::IP4 getIP() { return this->ip; } - int setMask(uint32_t mask) { + int setMask(Candy::IP4 mask) { this->mask = mask; return 0; } @@ -42,11 +43,6 @@ class LinuxTun { return 0; } - int setTimeout(int timeout) { - this->timeout = timeout; - return 0; - } - // 配置网卡,设置路由 int up() { this->tunFd = open("/dev/net/tun", O_RDWR); @@ -86,17 +82,17 @@ class LinuxTun { } // 设置地址 - addr->sin_addr.s_addr = Candy::Address::hostToNet(this->ip); + addr->sin_addr.s_addr = this->ip; if (ioctl(sockfd, SIOCSIFADDR, (caddr_t)&ifr) == -1) { - spdlog::critical("set ip address failed: ip {:08x}", this->ip); + spdlog::critical("set ip address failed: ip {}", this->ip.toString()); close(sockfd); return -1; } // 设置掩码 - addr->sin_addr.s_addr = Candy::Address::hostToNet(this->mask); + addr->sin_addr.s_addr = this->mask; if (ioctl(sockfd, SIOCSIFNETMASK, (caddr_t)&ifr) == -1) { - spdlog::critical("set mask failed: mask {:08x}", this->mask); + spdlog::critical("set mask failed: mask {}", this->mask.toString()); close(sockfd); return -1; } @@ -128,11 +124,11 @@ class LinuxTun { addr = (struct sockaddr_in *)&route.rt_dst; addr->sin_family = AF_INET; - addr->sin_addr.s_addr = Candy::Address::hostToNet(this->ip); + addr->sin_addr.s_addr = this->ip; addr = (struct sockaddr_in *)&route.rt_genmask; addr->sin_family = AF_INET; - addr->sin_addr.s_addr = Candy::Address::hostToNet(this->mask); + addr->sin_addr.s_addr = this->mask; route.rt_dev = (char *)this->name.c_str(); route.rt_flags = RTF_UP | RTF_HOST; @@ -161,7 +157,7 @@ class LinuxTun { } if (errno == EAGAIN || errno == EWOULDBLOCK) { - struct timeval timeout = {.tv_sec = this->timeout}; + struct timeval timeout = {.tv_sec = 1}; fd_set set; FD_ZERO(&set); @@ -178,7 +174,7 @@ class LinuxTun { return ::write(this->tunFd, buffer.c_str(), buffer.size()); } - int setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { + int setSysRtTable(Candy::IP4 dst, Candy::IP4 mask, Candy::IP4 nexthop) { int sockfd = socket(AF_INET, SOCK_DGRAM, 0); if (sockfd == -1) { spdlog::error("set route failed: create socket failed"); @@ -191,15 +187,15 @@ class LinuxTun { addr = (struct sockaddr_in *)&route.rt_dst; addr->sin_family = AF_INET; - addr->sin_addr.s_addr = Candy::Address::hostToNet(dst); + addr->sin_addr.s_addr = dst; addr = (struct sockaddr_in *)&route.rt_genmask; addr->sin_family = AF_INET; - addr->sin_addr.s_addr = Candy::Address::hostToNet(mask); + addr->sin_addr.s_addr = mask; addr = (struct sockaddr_in *)&route.rt_gateway; addr->sin_family = AF_INET; - addr->sin_addr.s_addr = Candy::Address::hostToNet(nexthop); + addr->sin_addr.s_addr = nexthop; route.rt_flags = RTF_UP | RTF_GATEWAY; if (ioctl(sockfd, SIOCADDRT, &route) == -1) { @@ -214,12 +210,13 @@ class LinuxTun { private: std::string name; - uint32_t ip; - uint32_t mask; + Candy::IP4 ip; + Candy::IP4 mask; int mtu; int timeout; int tunFd; }; + } // namespace namespace Candy { @@ -244,21 +241,22 @@ int Tun::setAddress(const std::string &cidr) { std::shared_ptr tun; Address address; - if (address.cidrUpdate(cidr)) { + if (address.fromCidr(cidr)) { return -1; } - spdlog::info("client address: {}", address.getCidr()); + spdlog::info("client address: {}", address.toCidr()); tun = std::any_cast>(this->impl); - if (tun->setIP(address.getIp())) { + if (tun->setIP(address.Host())) { return -1; } - if (tun->setMask(address.getMask())) { + if (tun->setMask(address.Mask())) { return -1; } + this->tunAddress = cidr; return 0; } -uint32_t Tun::getIP() { +IP4 Tun::getIP() { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->getIP(); @@ -273,15 +271,6 @@ int Tun::setMTU(int mtu) { return 0; } -int Tun::setTimeout(int timeout) { - std::shared_ptr tun; - tun = std::any_cast>(this->impl); - if (tun->setTimeout(timeout)) { - return -1; - } - return 0; -} - int Tun::up() { std::shared_ptr tun; tun = std::any_cast>(this->impl); @@ -306,7 +295,7 @@ int Tun::write(const std::string &buffer) { return tun->write(buffer); } -int Tun::setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { +int Tun::setSysRtTable(IP4 dst, IP4 mask, IP4 nexthop) { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->setSysRtTable(dst, mask, nexthop); diff --git a/src/tun/macos.cc b/src/tun/macos.cc index ab634809..d88a85e9 100644 --- a/src/tun/macos.cc +++ b/src/tun/macos.cc @@ -2,8 +2,8 @@ #include #if POCO_OS == POCO_OS_MAC_OS_X +#include "core/net.h" #include "tun/tun.h" -#include "utility/address.h" #include #include #include @@ -23,6 +23,7 @@ #include namespace { + class MacTun { public: int setName(const std::string &name) { @@ -30,16 +31,16 @@ class MacTun { return 0; } - int setIP(uint32_t ip) { + int setIP(Candy::IP4 ip) { this->ip = ip; return 0; } - int getIP() { + Candy::IP4 getIP() { return this->ip; } - int setMask(uint32_t mask) { + int setMask(Candy::IP4 mask) { this->mask = mask; return 0; } @@ -49,11 +50,6 @@ class MacTun { return 0; } - int setTimeout(int timeout) { - this->timeout = timeout; - return 0; - } - int up() { // 创建设备,操作系统不允许自定义设备名,只能由内核分配 this->tunFd = socket(PF_SYSTEM, SOCK_DGRAM, SYSPROTO_CONTROL); @@ -120,18 +116,19 @@ class MacTun { strncpy(areq.ifra_name, ifname, IFNAMSIZ); ((struct sockaddr_in *)&areq.ifra_addr)->sin_family = AF_INET; ((struct sockaddr_in *)&areq.ifra_addr)->sin_len = sizeof(areq.ifra_addr); - ((struct sockaddr_in *)&areq.ifra_addr)->sin_addr.s_addr = Candy::Address::hostToNet(this->ip); + ((struct sockaddr_in *)&areq.ifra_addr)->sin_addr.s_addr = this->ip; ((struct sockaddr_in *)&areq.ifra_mask)->sin_family = AF_INET; ((struct sockaddr_in *)&areq.ifra_mask)->sin_len = sizeof(areq.ifra_mask); - ((struct sockaddr_in *)&areq.ifra_mask)->sin_addr.s_addr = Candy::Address::hostToNet(this->mask); + ((struct sockaddr_in *)&areq.ifra_mask)->sin_addr.s_addr = this->mask; ((struct sockaddr_in *)&areq.ifra_broadaddr)->sin_family = AF_INET; ((struct sockaddr_in *)&areq.ifra_broadaddr)->sin_len = sizeof(areq.ifra_broadaddr); - ((struct sockaddr_in *)&areq.ifra_broadaddr)->sin_addr.s_addr = Candy::Address::hostToNet((this->ip & this->mask)); + ((struct sockaddr_in *)&areq.ifra_broadaddr)->sin_addr.s_addr = (this->ip & this->mask); if (ioctl(sockfd, SIOCAIFADDR, (void *)&areq) == -1) { - spdlog::critical("set ip mask failed: {}: ip {:08x} mask {:08x}", strerror(errno), this->ip, this->mask); + spdlog::critical("set ip mask failed: {}: ip {} mask {}", strerror(errno), this->ip.toString(), + this->mask.toString()); close(sockfd); return -1; } @@ -185,7 +182,7 @@ class MacTun { } if (errno == EAGAIN || errno == EWOULDBLOCK) { - struct timeval timeout = {.tv_sec = this->timeout}; + struct timeval timeout = {.tv_sec = 1}; fd_set set; FD_ZERO(&set); @@ -208,7 +205,7 @@ class MacTun { return ::writev(this->tunFd, iov, sizeof(iov) / sizeof(iov[0])) - sizeof(sizeof(this->packetinfo)); } - int setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { + int setSysRtTable(Candy::IP4 dst, Candy::IP4 mask, Candy::IP4 nexthop) { struct { struct rt_msghdr msghdr; struct sockaddr_in addr[3]; @@ -224,9 +221,9 @@ class MacTun { msg.addr[idx].sin_len = sizeof(msg.addr[0]); msg.addr[idx].sin_family = AF_INET; } - msg.addr[0].sin_addr.s_addr = Candy::Address::hostToNet(dst); - msg.addr[1].sin_addr.s_addr = Candy::Address::hostToNet(nexthop); - msg.addr[2].sin_addr.s_addr = Candy::Address::hostToNet(mask); + msg.addr[0].sin_addr.s_addr = dst; + msg.addr[1].sin_addr.s_addr = nexthop; + msg.addr[2].sin_addr.s_addr = mask; int routefd = socket(AF_ROUTE, SOCK_RAW, 0); if (routefd < 0) { @@ -245,14 +242,15 @@ class MacTun { private: std::string name; char ifname[IFNAMSIZ] = {0}; - uint32_t ip; - uint32_t mask; + Candy::IP4 ip; + Candy::IP4 mask; int mtu; int timeout; int tunFd; uint8_t packetinfo[4] = {0x00, 0x00, 0x00, 0x02}; }; + } // namespace namespace Candy { @@ -277,21 +275,21 @@ int Tun::setAddress(const std::string &cidr) { std::shared_ptr tun; Address address; - if (address.cidrUpdate(cidr)) { + if (address.fromCidr(cidr)) { return -1; } - spdlog::info("client address: {}", address.getCidr()); + spdlog::info("client address: {}", address.toCidr()); tun = std::any_cast>(this->impl); - if (tun->setIP(address.getIp())) { + if (tun->setIP(address.Host())) { return -1; } - if (tun->setMask(address.getMask())) { + if (tun->setMask(address.Mask())) { return -1; } return 0; } -uint32_t Tun::getIP() { +IP4 Tun::getIP() { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->getIP(); @@ -306,15 +304,6 @@ int Tun::setMTU(int mtu) { return 0; } -int Tun::setTimeout(int timeout) { - std::shared_ptr tun; - tun = std::any_cast>(this->impl); - if (tun->setTimeout(timeout)) { - return -1; - } - return 0; -} - int Tun::up() { std::shared_ptr tun; tun = std::any_cast>(this->impl); @@ -339,7 +328,7 @@ int Tun::write(const std::string &buffer) { return tun->write(buffer); } -int Tun::setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { +int Tun::setSysRtTable(IP4 dst, IP4 mask, IP4 nexthop) { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->setSysRtTable(dst, mask, nexthop); diff --git a/src/tun/tun.cc b/src/tun/tun.cc index e7bf9529..eef2f88c 100644 --- a/src/tun/tun.cc +++ b/src/tun/tun.cc @@ -1,54 +1,143 @@ -#include -#if POCO_OS != POCO_OS_LINUX && POCO_OS != POCO_OS_MAC_OS_X && POCO_OS != POCO_OS_WINDOWS_NT - +// SPDX-License-Identifier: MIT #include "tun/tun.h" +#include "core/client.h" +#include "core/message.h" +#include "core/net.h" +#include +#include +#include namespace Candy { -Tun::Tun() {} - -Tun::~Tun() {} - -int Tun::setName(const std::string &name) { - return -1; -} - -int Tun::setAddress(const std::string &cidr) { - return -1; -} - -uint32_t Tun::getIP() { - return -1; +int Tun::run(Client *client) { + this->client = client; + this->msgThread = std::thread([&] { + while (this->client->running) { + handleTunQueue(); + } + }); + return 0; } -int Tun::setMTU(int mtu) { - return -1; +int Tun::shutdown() { + if (this->tunThread.joinable()) { + this->tunThread.join(); + } + if (this->msgThread.joinable()) { + this->msgThread.join(); + } + { + std::unique_lock lock(this->sysRtMutex); + this->sysRtTable.clear(); + } + return 0; } -int Tun::setTimeout(int timeout) { - return -1; +void Tun::handleTunDevice() { + std::string buffer; + int error = read(buffer); + if (error <= 0) { + return; + } + if (buffer.length() < sizeof(IP4Header)) { + return; + } + IP4Header *header = (IP4Header *)buffer.data(); + if ((header->version_ihl >> 4) != 4) { + return; + } + + IP4 nextHop = [&]() { + std::shared_lock lock(this->sysRtMutex); + for (auto const &rt : sysRtTable) { + if ((header->daddr & rt.mask) == rt.dst) { + return rt.nexthop; + } + } + return IP4(); + }(); + if (!nextHop.empty()) { + buffer.insert(0, sizeof(IP4Header), 0); + header = (IP4Header *)buffer.data(); + header->protocol = 0x04; + header->saddr = getIP(); + header->daddr = nextHop; + } + + if (header->daddr == getIP()) { + write(buffer); + return; + } + + // 流量给 P2P 模块,如果 P2P 模块无法处理,由 P2P 模块转发给 WS 模块 + this->client->peerMsgQueue.write(Msg(MsgKind::PACKET, std::move(buffer))); } -int Tun::up() { - return -1; +void Tun::handleTunQueue() { + Msg msg = this->client->tunMsgQueue.read(); + switch (msg.kind) { + case MsgKind::TIMEOUT: + break; + case MsgKind::PACKET: + handlePacket(std::move(msg)); + break; + case MsgKind::TUNADDR: + handleTunAddr(std::move(msg)); + break; + case MsgKind::SYSRT: + handleSysRt(std::move(msg)); + break; + default: + spdlog::warn("unexcepted tun message type: {}", static_cast(msg.kind)); + break; + } } -int Tun::down() { - return -1; +void Tun::handlePacket(Msg msg) { + if (msg.data.size() < sizeof(IP4Header)) { + spdlog::warn("invalid IPv4 packet: {:n}", spdlog::to_hex(msg.data)); + return; + } + IP4Header *header = (IP4Header *)msg.data.data(); + if (header->protocol == 0x04) { + msg.data.erase(0, sizeof(IP4Header)); + header = (IP4Header *)msg.data.data(); + } + write(msg.data); } -int Tun::read(std::string &buffer) { - return -1; +void Tun::handleTunAddr(Msg msg) { + if (setAddress(msg.data)) { + Candy::shutdown(this->client); + } + + this->tunThread = std::thread([&] { + if (up()) { + Candy::shutdown(this->client); + return; + } + while (this->client->running) { + handleTunDevice(); + } + if (down()) { + Candy::shutdown(this->client); + return; + } + }); } -int Tun::write(const std::string &buffer) { - return -1; +void Tun::handleSysRt(Msg msg) { + SysRouteEntry *rt = (SysRouteEntry *)msg.data.data(); + if (rt->nexthop != getIP()) { + spdlog::info("route: {}/{} via {}", rt->dst.toString(), rt->mask.toPrefix(), rt->nexthop.toString()); + setSysRtTable(*rt); + } } -int Tun::setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { - return -1; +int Tun::setSysRtTable(const SysRouteEntry &entry) { + std::unique_lock lock(this->sysRtMutex); + this->sysRtTable.push_back(entry); + return setSysRtTable(entry.dst, entry.mask, entry.nexthop); } } // namespace Candy - -#endif diff --git a/src/tun/tun.h b/src/tun/tun.h index 3b6be79f..fc5b7105 100644 --- a/src/tun/tun.h +++ b/src/tun/tun.h @@ -2,45 +2,64 @@ #ifndef CANDY_TUN_TUN_H #define CANDY_TUN_TUN_H +#include "core/message.h" +#include "core/net.h" #include -#include +#include +#include #include +#include namespace Candy { +class Client; + class Tun { public: Tun(); ~Tun(); - // 为了支持一台设备接入多个 VPN 网络.用名称区分 TUN 设备. int setName(const std::string &name); + int setMTU(int mtu); + + int run(Client *client); + int shutdown(); - // 设置 TUN 设备的地址和网络,以及由网络引入的路由.设置相同网络的流量路由到本设备. +private: + IP4 getIP(); int setAddress(const std::string &cidr); - // 获取 IP 地址,用于发包前校验源 IP 是否相同 - uint32_t getIP(); + // 处理来自 TUN 设备的数据 + void handleTunDevice(); - // 设置 MTU, 这个数值应该略小于网络实际 MTU, 这样即使添加了 VPN 的包头也能一次发包. - int setMTU(int mtu); + // 处理来自消息队列的数据 + void handleTunQueue(); + void handlePacket(Msg msg); + void handleTunAddr(Msg msg); + void handleSysRt(Msg msg); - // 设置读超时时间. - int setTimeout(int timeout); + std::string tunAddress; + std::thread tunThread; + std::thread msgThread; - // 网卡 up/down +private: int up(); int down(); - // 阻塞的从 TUN 设备读写数据.读操作返回 0 表示超时. int read(std::string &buffer); int write(const std::string &buffer); - // 设置系统路由表 - int setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop); + int setSysRtTable(const SysRouteEntry &entry); + int setSysRtTable(IP4 dst, IP4 mask, IP4 nexthop); + + std::shared_mutex sysRtMutex; + std::list sysRtTable; private: std::any impl; + +private: + Client *client; }; } // namespace Candy diff --git a/src/tun/unknown.cc b/src/tun/unknown.cc new file mode 100644 index 00000000..b493e65a --- /dev/null +++ b/src/tun/unknown.cc @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +#include + +#if POCO_OS != POCO_OS_LINUX && POCO_OS != POCO_OS_MAC_OS_X && POCO_OS != POCO_OS_WINDOWS_NT + +#include "tun/tun.h" + +namespace Candy { + +Tun::Tun() {} + +Tun::~Tun() {} + +int Tun::setName(const std::string &name) { + return -1; +} + +int Tun::setAddress(const std::string &cidr) { + return -1; +} + +int Tun::setMTU(int mtu) { + return -1; +} + +int Tun::up() { + return -1; +} + +int Tun::down() { + return -1; +} + +int Tun::read(std::string &buffer) { + return -1; +} + +int Tun::write(const std::string &buffer) { + return -1; +} + +int Tun::setSysRtTable(IP4 dst, IP4 mask, IP4 nexthop) { + return -1; +} + +} // namespace Candy + +#endif diff --git a/src/tun/windows.cc b/src/tun/windows.cc index 6056d238..4099e954 100644 --- a/src/tun/windows.cc +++ b/src/tun/windows.cc @@ -2,8 +2,8 @@ #include #if POCO_OS == POCO_OS_WINDOWS_NT +#include "core/net.h" #include "tun/tun.h" -#include "utility/address.h" #include #include #include @@ -95,12 +95,12 @@ class WindowsTun { return 0; } - int setIP(uint32_t ip) { + int setIP(Candy::IP4 ip) { this->ip = ip; return 0; } - int getIP() { + Candy::IP4 getIP() { return this->ip; } @@ -114,11 +114,6 @@ class WindowsTun { return 0; } - int setTimeout(int timeout) { - this->timeout = timeout; - return 0; - } - int up() { if (!Holder::Ok()) { spdlog::critical("init wintun failed"); @@ -141,7 +136,7 @@ class WindowsTun { InitializeUnicastIpAddressEntry(&AddressRow); WintunGetAdapterLUID(this->adapter, &AddressRow.InterfaceLuid); AddressRow.Address.Ipv4.sin_family = AF_INET; - AddressRow.Address.Ipv4.sin_addr.S_un.S_addr = Candy::Address::hostToNet(this->ip); + AddressRow.Address.Ipv4.sin_addr.S_un.S_addr = this->ip; AddressRow.OnLinkPrefixLength = this->prefix; AddressRow.DadState = IpDadStatePreferred; Error = CreateUnicastIpAddressEntry(&AddressRow); @@ -201,7 +196,7 @@ class WindowsTun { return size; } if (GetLastError() == ERROR_NO_MORE_ITEMS) { - WaitForSingleObject(WintunGetReadWaitEvent(this->session), this->timeout * 1000); + WaitForSingleObject(WintunGetReadWaitEvent(this->session), 1000); return 0; } spdlog::error("wintun read failed: {}", GetLastError()); @@ -222,12 +217,12 @@ class WindowsTun { return -1; } - int setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { + int setSysRtTable(Candy::IP4 dst, Candy::IP4 mask, Candy::IP4 nexthop) { MIB_IPFORWARDROW route; - route.dwForwardDest = Candy::Address::hostToNet(dst); - route.dwForwardMask = Candy::Address::hostToNet(mask); - route.dwForwardNextHop = Candy::Address::hostToNet(nexthop); + route.dwForwardDest = dst; + route.dwForwardMask = mask; + route.dwForwardNextHop = nexthop; route.dwForwardIfIndex = this->ifindex; route.dwForwardProto = MIB_IPPROTO_NETMGMT; @@ -252,7 +247,7 @@ class WindowsTun { private: std::string name; - uint32_t ip; + Candy::IP4 ip; uint32_t prefix; int mtu; int timeout; @@ -262,6 +257,7 @@ class WindowsTun { WINTUN_ADAPTER_HANDLE adapter = NULL; WINTUN_SESSION_HANDLE session = NULL; }; + } // namespace namespace Candy { @@ -286,21 +282,21 @@ int Tun::setAddress(const std::string &cidr) { std::shared_ptr tun; Address address; - if (address.cidrUpdate(cidr)) { + if (address.fromCidr(cidr)) { return -1; } - spdlog::info("client address: {}", address.getCidr()); + spdlog::info("client address: {}", address.toCidr()); tun = std::any_cast>(this->impl); - if (tun->setIP(address.getIp())) { + if (tun->setIP(address.Host())) { return -1; } - if (tun->setPrefix(address.getPrefix())) { + if (tun->setPrefix(address.Mask().toPrefix())) { return -1; } return 0; } -uint32_t Tun::getIP() { +IP4 Tun::getIP() { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->getIP(); @@ -315,15 +311,6 @@ int Tun::setMTU(int mtu) { return 0; } -int Tun::setTimeout(int timeout) { - std::shared_ptr tun; - tun = std::any_cast>(this->impl); - if (tun->setTimeout(timeout)) { - return -1; - } - return 0; -} - int Tun::up() { std::shared_ptr tun; tun = std::any_cast>(this->impl); @@ -348,7 +335,7 @@ int Tun::write(const std::string &buffer) { return tun->write(buffer); } -int Tun::setSysRtTable(uint32_t dst, uint32_t mask, uint32_t nexthop) { +int Tun::setSysRtTable(IP4 dst, IP4 mask, IP4 nexthop) { std::shared_ptr tun; tun = std::any_cast>(this->impl); return tun->setSysRtTable(dst, mask, nexthop); diff --git a/src/utility/CMakeLists.txt b/src/utility/CMakeLists.txt index e33012ab..46f21526 100644 --- a/src/utility/CMakeLists.txt +++ b/src/utility/CMakeLists.txt @@ -11,10 +11,10 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(utility PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(utility PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(utility PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/utility/address.cc b/src/utility/address.cc deleted file mode 100644 index d6ee1df6..00000000 --- a/src/utility/address.cc +++ /dev/null @@ -1,235 +0,0 @@ -// SPDX-License-Identifier: MIT -#include "utility/address.h" -#include "utility/byteswap.h" -#include -#include - -#if defined(POCO_OS_FAMILY_UNIX) -#include -#include -#elif defined(POCO_OS_FAMILY_WINDOWS) -#include -#endif - -namespace Candy { - -int Address::cidrUpdate(const std::string &cidr) { - if (cidr.empty()) { - spdlog::error("cidr is empty"); - return -1; - } - std::size_t pos = cidr.find('/'); - if (pos == std::string::npos) { - spdlog::error("invalid cidr format: {}", cidr); - return -1; - } - - std::string ipStr = cidr.substr(0UL, pos); - std::string maskStr; - prefixStrToMaskStr(cidr.substr(pos + 1), maskStr); - return ipMaskStrUpdate(ipStr, maskStr); -} - -int Address::ipMaskStrUpdate(const std::string &ipStr, const std::string &maskStr) { - uint32_t ip, mask; - if (inet_pton(AF_INET, ipStr.c_str(), &ip) != 1) { - spdlog::error("invalid ip format: ip {}", ipStr); - return -1; - } - if (inet_pton(AF_INET, maskStr.c_str(), &mask) != 1) { - spdlog::error("invalid mask format: mask {}", maskStr); - return -1; - } - ip = netToHost(ip); - mask = netToHost(mask); - return this->ipMaskUpdate(ip, mask); -} - -int Address::ipMaskUpdate(uint32_t ip, uint32_t mask) { - this->ip = ip; - this->mask = mask; - ip = hostToNet(ip); - mask = hostToNet(mask); - char buffer[16]; - if (!inet_ntop(AF_INET, &ip, buffer, sizeof(buffer))) { - spdlog::error("invalid ip format: ip {:08x}", ip); - return -1; - } - this->ipStr = buffer; - if (!inet_ntop(AF_INET, &mask, buffer, sizeof(buffer))) { - spdlog::error("invalid mask format: mask {:08x}", mask); - return -1; - } - this->maskStr = buffer; - - if (maskToPrefix(this->mask, this->prefix)) { - spdlog::error("mask to prefix failed: mask {:08x}", mask); - return -1; - } - this->prefixStr = std::to_string(this->prefix); - this->net = this->ip & this->mask; - this->host = this->ip & (~this->mask); - this->cidr = this->ipStr + "/" + this->prefixStr; - return 0; -} - -int Address::ipStrUpdate(const std::string &ipStr) { - return ipMaskStrUpdate(ipStr, "255.255.255.255"); -} - -int Address::ipUpdate(uint32_t ip) { - return ipMaskUpdate(ip, UINT32_MAX); -} - -bool Address::inSameNetwork(const Address &address) { - if (getMask() != address.getMask()) { - return false; - } - if (getNet() != address.getNet()) { - return false; - } - if (address.getHost() == 0) { - return false; - } - if (address.getHost() == (~address.getMask())) { - return false; - } - return true; -} - -int Address::next() { - if (this->prefix >= 31) { - spdlog::error("unable to generate next available address: prefix {}", this->prefix); - return -1; - } - - do { - this->host = (this->host + 1) & (~this->mask); - } while (this->host == (~this->mask) || this->host == 0); - - uint32_t ip = this->net | this->host; - uint32_t mask = this->mask; - - return ipMaskUpdate(ip, mask); -} - -int Address::dump() const { - spdlog::debug("cidr={}", this->cidr); - spdlog::debug("ipStr={} ip=0x{:0>8x}", this->ipStr, this->ip); - spdlog::debug("maskStr={} mask=0x{:0>8x}", this->maskStr, this->mask); - spdlog::debug("prefixStr={} prefix={}", this->prefixStr, this->prefix); - spdlog::debug("net=0x{:0>8x} host=0x{:0>8x}", this->net, this->host); - return 0; -} - -uint32_t Address::getIp() const { - return this->ip; -} - -std::string Address::getIpStr() const { - return this->ipStr; -} - -uint32_t Address::getMask() const { - return this->mask; -} - -uint32_t Address::getPrefix() const { - return this->prefix; -} - -uint32_t Address::getNet() const { - return this->net; -} - -uint32_t Address::getHost() const { - return this->host; -} - -std::string Address::getMaskStr() const { - return this->maskStr; -} - -std::string Address::getCidr() const { - return this->cidr; -} - -uint32_t Address::netToHost(uint32_t address) { - if (std::endian::native == std::endian::little) { - return byteswap(address); - } - return address; -} - -uint32_t Address::hostToNet(uint32_t address) { - return netToHost(address); -} - -uint16_t Address::netToHost(uint16_t port) { - if (std::endian::native == std::endian::little) { - return byteswap(port); - } - return port; -} - -uint16_t Address::hostToNet(uint16_t port) { - return netToHost(port); -} - -std::string Address::ipToStr(uint32_t ip) { - Address address; - address.ipUpdate(ip); - return address.getIpStr(); -} - -int Address::prefixStrToMaskStr(const std::string &prefixStr, std::string &maskStr) { - uint32_t prefix = std::stoi(prefixStr); - uint32_t mask = 0; - - if (prefixToMask(prefix, mask) != 0) { - return -1; - } - - char buffer[16]; - mask = hostToNet(mask); - if (!inet_ntop(AF_INET, &mask, buffer, sizeof(buffer))) { - spdlog::error("invalid mask format: mask {:08x}", mask); - return -1; - } - maskStr = buffer; - return 0; -} - -int Address::prefixToMask(uint32_t prefix, uint32_t &mask) { - if (prefix > 32 || prefix < 0) { - spdlog::critical("cidr prefix exception: prefix {}", prefix); - return -1; - } - - mask = 0; - for (uint32_t idx = 0; idx < prefix; ++idx) { - mask |= 0x80000000 >> idx; - } - - return 0; -} - -int Address::maskToPrefix(uint32_t mask, uint32_t &prefix) { - prefix = 0; - for (uint32_t idx = 0; idx < 32; ++idx) { - if ((0x80000000 >> idx) & (mask)) { - ++prefix; - continue; - } - break; - } - for (uint32_t idx = prefix; idx < 32; ++idx) { - if ((0x80000000 >> idx) & (mask)) { - spdlog::error("invalid mask: mask {}", mask); - return -1; - } - } - return 0; -} - -} // namespace Candy diff --git a/src/utility/address.h b/src/utility/address.h deleted file mode 100644 index f656ecbb..00000000 --- a/src/utility/address.h +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-License-Identifier: MIT -#ifndef CANDY_UTILITY_ADDRESS_H -#define CANDY_UTILITY_ADDRESS_H - -#include -#include - -namespace Candy { - -struct IPv4Header { - uint8_t version_ihl; // 版本号和首部长度 - uint8_t tos; // 服务类型 - uint16_t tot_len; // 总长度 - uint16_t id; // 标识 - uint16_t frag_off; // 分片偏移 - uint8_t ttl; // 生存时间 - uint8_t protocol; // 协议类型 - uint16_t check; // 校验和 - uint32_t saddr; // 源地址 - uint32_t daddr; // 目的地址 -}; - -class Address { -public: - // 以不同的形式更新地址 - int cidrUpdate(const std::string &cidr); - int ipMaskStrUpdate(const std::string &ip, const std::string &mask); - int ipMaskUpdate(uint32_t ip, uint32_t mask); - // 以单个地址更新,掩码为 255.255.255.255 - int ipStrUpdate(const std::string &ip); - int ipUpdate(uint32_t ip); - - // 获取地址里的参数 - uint32_t getIp() const; - uint32_t getMask() const; - uint32_t getPrefix() const; - uint32_t getNet() const; - uint32_t getHost() const; - std::string getIpStr() const; - std::string getMaskStr() const; - std::string getCidr() const; - - // 地址在这个网络且主机地址有效 - bool inSameNetwork(const Address &address); - - // 地址更新为同网络的下一个地址,动态分配 IP 地址时使用 - int next(); - - // 显示地址信息,用于调试 - int dump() const; - - static uint32_t netToHost(uint32_t address); - static uint32_t hostToNet(uint32_t address); - static uint16_t netToHost(uint16_t port); - static uint16_t hostToNet(uint16_t port); - - static std::string ipToStr(uint32_t ip); - -private: - int prefixStrToMaskStr(const std::string &netPrefixStr, std::string &maskStr); - int prefixToMask(uint32_t prefix, uint32_t &mask); - int maskToPrefix(uint32_t mask, uint32_t &prefix); - - // 原始数据首先转换成地址和掩码 - uint32_t ip; - uint32_t mask; - - // 根据地址和掩码计算网络号和主机号 - uint32_t net; - uint32_t host; - // 根据掩码计算网络前缀 - uint32_t prefix; - // 把上面的数据转换为字符串格式 - std::string ipStr; - std::string maskStr; - std::string prefixStr; - // 根据地址和网络前缀获取 CIDR - std::string cidr; -}; - -} // namespace Candy - -#endif diff --git a/src/utility/argparse.h b/src/utility/argparse.h index e768e2ef..7d2fb1a8 100644 --- a/src/utility/argparse.h +++ b/src/utility/argparse.h @@ -589,8 +589,8 @@ class Argument { } template - auto action(F &&callable, Args &&...bound_args) - -> std::enable_if_t, Argument &> { + auto action(F &&callable, + Args &&...bound_args) -> std::enable_if_t, Argument &> { using action_type = std::conditional_t>, void_action, valued_action>; if constexpr (sizeof...(Args) == 0) { diff --git a/src/utility/random.cc b/src/utility/random.cc index 0f309d29..3dfe2cdf 100644 --- a/src/utility/random.cc +++ b/src/utility/random.cc @@ -4,7 +4,7 @@ #include #include -namespace Candy { +namespace { uint32_t randomUint32() { std::random_device device; @@ -19,7 +19,9 @@ int randomHex() { std::uniform_int_distribution distrib(0, 15); return distrib(engine); } +} // namespace +namespace Candy { std::string randomHexString(int length) { std::stringstream ss; for (int i = 0; i < length; i++) { diff --git a/src/utility/random.h b/src/utility/random.h index f7763738..8345ed2f 100644 --- a/src/utility/random.h +++ b/src/utility/random.h @@ -1,11 +1,13 @@ // SPDX-License-Identifier: MIT -#include +#ifndef CANDY_UTILITY_RANDOM_H +#define CANDY_UTILITY_RANDOM_H + #include namespace Candy { -uint32_t randomUint32(); -int randomHex(); std::string randomHexString(int length); } // namespace Candy + +#endif diff --git a/src/utility/time.cc b/src/utility/time.cc index 04795a00..510cdb18 100644 --- a/src/utility/time.cc +++ b/src/utility/time.cc @@ -1,10 +1,7 @@ // SPDX-License-Identifier: MIT #include "utility/time.h" -#include "utility/address.h" -#include "utility/byteswap.h" +#include "core/net.h" #include -#include -#include #include #include #include @@ -13,8 +10,8 @@ namespace Candy { -bool Time::useSystemTime = false; -std::string Time::ntpServer; +bool useSystemTime = false; +std::string ntpServer; struct ntp_packet { uint8_t li_vn_mode = 0x23; @@ -40,10 +37,10 @@ struct ntp_packet { uint32_t txTm_f; }; -static int64_t ntpTime() { +int64_t ntpTime() { try { Poco::Net::DatagramSocket socket; - socket.connect(Poco::Net::SocketAddress(Time::ntpServer, 123)); + socket.connect(Poco::Net::SocketAddress(ntpServer, 123)); struct ntp_packet packet = {}; socket.sendBytes(&packet, sizeof(packet)); @@ -56,7 +53,7 @@ static int64_t ntpTime() { return 0; } - int64_t retval = (int64_t)(Candy::Address::netToHost(packet.rxTm_s)); + int64_t retval = (int64_t)(ntoh(packet.rxTm_s)); if (retval == 0) { spdlog::warn("invalid ntp response buffer: {:n}", spdlog::to_hex(std::string((char *)(&packet), sizeof(packet)))); return 0; @@ -75,7 +72,7 @@ static int64_t ntpTime() { } } -int64_t Time::unixTime() { +int64_t unixTime() { using namespace std::chrono; int64_t sysTime; @@ -107,32 +104,10 @@ int64_t Time::unixTime() { return sysTime; } -int64_t Time::bootTime() { +int64_t bootTime() { using namespace std::chrono; auto now = steady_clock::now(); return duration_cast(now.time_since_epoch()).count(); } -int64_t Time::hostToNet(int64_t host) { - if (std::endian::native == std::endian::little) { - return byteswap(host); - } - return host; -} - -int64_t Time::netToHost(int64_t net) { - return Time::hostToNet(net); -} - -int32_t Time::hostToNet(int32_t host) { - if (std::endian::native == std::endian::little) { - return byteswap(host); - } - return host; -} - -int32_t Time::netToHost(int32_t net) { - return Time::hostToNet(net); -} - } // namespace Candy diff --git a/src/utility/time.h b/src/utility/time.h index 0ba88f17..20bd9edf 100644 --- a/src/utility/time.h +++ b/src/utility/time.h @@ -7,19 +7,11 @@ namespace Candy { -class Time { -public: - // 秒级的 Unix 时间戳,优先使用从互联网获取的时间 - static int64_t unixTime(); - // 毫秒级别的系统启动时间戳,不受时间回滚影响,用于计算网络延迟 - static int64_t bootTime(); - static int64_t hostToNet(int64_t host); - static int64_t netToHost(int64_t net); - static int32_t hostToNet(int32_t host); - static int32_t netToHost(int32_t net); - static bool useSystemTime; - static std::string ntpServer; -}; +extern bool useSystemTime; +extern std::string ntpServer; + +int64_t unixTime(); +int64_t bootTime(); } // namespace Candy diff --git a/src/websocket/CMakeLists.txt b/src/websocket/CMakeLists.txt index baa8ab1d..5475e7a6 100644 --- a/src/websocket/CMakeLists.txt +++ b/src/websocket/CMakeLists.txt @@ -7,14 +7,16 @@ if (${CANDY_STATIC_SPDLOG}) target_link_libraries(websocket PRIVATE spdlog::spdlog) else() find_package(PkgConfig REQUIRED) - pkg_check_modules(DEPS REQUIRED spdlog) - add_definitions(${DEPS_CFLAGS}) - include_directories(${DEPS_INCLUDEDIR}) - target_link_libraries(websocket PRIVATE ${DEPS_LIBRARIES}) + pkg_check_modules(SPDLOG REQUIRED spdlog) + add_definitions(${SPDLOG_CFLAGS}) + include_directories(${SPDLOG_INCLUDEDIR}) + target_link_libraries(websocket PRIVATE ${SPDLOG_LIBRARIES}) endif() if (${CANDY_STATIC_OPENSSL}) target_link_libraries(websocket PRIVATE ${OPENSSL_LIB_CRYPTO} ${OPENSSL_LIB_SSL}) +else() + find_package(OpenSSL REQUIRED) endif() if (${CANDY_STATIC_POCO}) diff --git a/src/websocket/client.cc b/src/websocket/client.cc index 32c0c136..e120dc8f 100644 --- a/src/websocket/client.cc +++ b/src/websocket/client.cc @@ -1,7 +1,11 @@ // SPDX-License-Identifier: MIT #include "websocket/client.h" +#include "core/client.h" +#include "core/message.h" +#include "core/net.h" +#include "core/version.h" #include "utility/time.h" -#include +#include "websocket/message.h" #include #include #include @@ -9,153 +13,329 @@ #include #include #include +#include #include namespace Candy { -int WebSocketClient::connect(const std::string &address) { - std::shared_ptr uri; - try { - uri = std::make_shared(address); - } catch (std::exception &e) { - spdlog::critical("invalid websocket server: {}: {}", address, e.what()); - return -1; - } +int WebSocketClient::setPassword(const std::string &password) { + this->password = password; + return 0; +} - try { - const std::string path = uri->getPath().empty() ? "/" : uri->getPath(); - Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_GET, path, Poco::Net::HTTPMessage::HTTP_1_1); - Poco::Net::HTTPResponse response; - if (uri->getScheme() == "wss") { - using Poco::Net::Context; - Context::Ptr context = new Context(Context::TLS_CLIENT_USE, "", "", "", Context::VERIFY_NONE); - Poco::Net::HTTPSClientSession cs(uri->getHost(), uri->getPort(), context); - this->ws = std::make_shared(cs, request, response); - } else if (uri->getScheme() == "ws") { - Poco::Net::HTTPClientSession cs(uri->getHost(), uri->getPort()); - this->ws = std::make_shared(cs, request, response); - } else { - spdlog::critical("invalid websocket scheme: {}", address); - return -1; - } - this->timestamp = Time::bootTime(); - return 0; - } catch (std::exception &e) { - spdlog::critical("websocket connect failed: {}", e.what()); - return -1; - } +int WebSocketClient::setWsServerUri(const std::string &uri) { + this->wsServerUri = uri; + return 0; } -int WebSocketClient::disconnect() { - try { - if (this->ws) { - this->ws->shutdown(); - this->ws->close(); - this->ws.reset(); +int WebSocketClient::setExptTunAddress(const std::string &cidr) { + this->exptTunCidr = cidr; + return 0; +} + +int WebSocketClient::setAddress(const std::string &cidr) { + this->tunCidr = cidr; + return 0; +} + +int WebSocketClient::setVirtualMac(const std::string &vmac) { + this->vmac = vmac; + return 0; +} + +int WebSocketClient::setTunUpdateCallback(std::function callback) { + this->addressUpdateCallback = callback; + return 0; +} + +int WebSocketClient::run(Client *client) { + this->client = client; + this->msgThread = std::thread([&] { + while (this->client->running) { + handleWsQueue(); } - } catch (std::exception &e) { - spdlog::debug("websocket disconnect failed: {}", e.what()); + }); + + if (connect()) { + spdlog::critical("websocket client connect failed"); + Candy::shutdown(this->client); + } + + sendVirtualMacMsg(); + if (this->tunCidr.empty()) { + sendExptTunMsg(); + } else { + sendAuthMsg(); } + + this->wsThread = std::thread([&] { + while (this->client->running) { + handleWsConn(); + } + spdlog::debug("websocket client thread exit"); + }); + return 0; } -int WebSocketClient::setTimeout(int timeout) { - this->timeout = timeout; +int WebSocketClient::shutdown() { + if (this->msgThread.joinable()) { + this->msgThread.join(); + } + if (this->wsThread.joinable()) { + this->wsThread.join(); + } return 0; } -int WebSocketClient::read(WebSocketMessage &message) { - if (!this->ws) { - spdlog::critical("websocket read before connected"); - return -1; +void WebSocketClient::handleWsQueue() { + Msg msg = this->client->wsMsgQueue.read(); + switch (msg.kind) { + case MsgKind::TIMEOUT: + break; + case MsgKind::PACKET: + handlePacket(std::move(msg)); + break; + default: + spdlog::warn("unexcepted websocket message type: {}", static_cast(msg.kind)); + break; } +} + +void WebSocketClient::handlePacket(Msg msg) { + IP4Header *header = (IP4Header *)msg.data.data(); + msg.data.insert(0, 1, WsMsgKind::FORWARD); + sendFrame(msg.data, Poco::Net::WebSocket::FRAME_BINARY); +} + +void WebSocketClient::handleWsConn() { try { - if (!this->ws->poll(Poco::Timespan(this->timeout, 0), Poco::Net::Socket::SELECT_READ)) { - if (Time::bootTime() - this->timestamp > 30000) { - message.type = WebSocketMessageType::Error; - message.buffer = "websocket pong timeout"; - return 1; + std::string buffer; + int flags = 0; + + // receiveFrame 会对 ws 加锁,影响写操作,需要先确定可读 + if (!this->ws->poll(Poco::Timespan(1, 0), Poco::Net::Socket::SELECT_READ)) { + if (bootTime() - this->timestamp > 30000) { + spdlog::warn("websocket pong timeout"); + Candy::shutdown(this->client); + return; } - if (Time::bootTime() - this->timestamp > 15000) { - return sendPingMessage(message); + if (bootTime() - this->timestamp > 15000) { + sendPingMessage(); } - return 0; + return; } - char buffer[1500] = {0}; - int flags = 0; - int length = this->ws->receiveFrame(buffer, sizeof(buffer), flags); + buffer.resize(1500); + int length = this->ws->receiveFrame(buffer.data(), buffer.size(), flags); if (length == 0 && flags == 0) { - message.type = WebSocketMessageType::Error; - message.buffer = "abnormal disconnect"; - return 1; + spdlog::info("abnormal disconnect"); + Candy::shutdown(this->client); + return; } if ((flags & Poco::Net::WebSocket::FRAME_OP_BITMASK) == Poco::Net::WebSocket::FRAME_OP_PING) { + this->timestamp = bootTime(); flags = (int)Poco::Net::WebSocket::FRAME_FLAG_FIN | (int)Poco::Net::WebSocket::FRAME_OP_PONG; - this->ws->sendFrame(buffer, length, flags); - this->timestamp = Time::bootTime(); - return 0; + sendFrame(buffer.data(), length, flags); + return; } if ((flags & Poco::Net::WebSocket::FRAME_OP_BITMASK) == Poco::Net::WebSocket::FRAME_OP_PONG) { - this->timestamp = Time::bootTime(); - return 0; + this->timestamp = bootTime(); + return; } if ((flags & Poco::Net::WebSocket::FRAME_OP_BITMASK) == Poco::Net::WebSocket::FRAME_OP_CLOSE) { - message.type = WebSocketMessageType::Close; - message.buffer.assign(buffer, length); - return 1; + spdlog::info("websocket close: {}", buffer); + Candy::shutdown(this->client); + return; } if (length > 0) { - message.type = WebSocketMessageType::Message; - message.buffer.assign(buffer, length); - this->timestamp = Time::bootTime(); - return 1; + this->timestamp = bootTime(); + buffer.resize(length); + handleWsMsg(std::move(buffer)); + return; } - return 0; } catch (std::exception &e) { - message.type = WebSocketMessageType::Error; - message.buffer = e.what(); - return 1; + spdlog::warn("handle ws conn failed: {}", e.what()); + Candy::shutdown(this->client); + return; } } -int WebSocketClient::write(const WebSocketMessage &message) { - if (!this->ws) { - spdlog::critical("websocket write before connected"); - return -1; +void WebSocketClient::handleWsMsg(std::string buffer) { + uint8_t msgKind = buffer.front(); + switch (msgKind) { + case WsMsgKind::FORWARD: + handleForwardMsg(std::move(buffer)); + break; + case WsMsgKind::EXPTTUN: + handleExptTunMsg(std::move(buffer)); + break; + case WsMsgKind::UDP4CONN: + break; + case WsMsgKind::DISCOVERY: + handleDiscoveryMsg(std::move(buffer)); + break; + case WsMsgKind::ROUTE: + handleRouteMsg(std::move(buffer)); + break; + case WsMsgKind::GENERAL: + break; + default: + spdlog::debug("unknown websocket message kind: {}", msgKind); + break; + } +} + +void WebSocketClient::handleForwardMsg(std::string buffer) { + if (buffer.size() < sizeof(WsMsg::Forward)) { + spdlog::warn("invalid forward message: {:n}", spdlog::to_hex(buffer)); + return; } + // 移除一个字节的类型 + buffer.erase(0, 1); + // 尝试与源地址建立对等连接 + IP4Header *header = (IP4Header *)buffer.data(); + this->client->peerMsgQueue.write(Msg(MsgKind::TRYP2P, header->saddr.toString())); + // 最后把报文移动到 TUN 模块,因为有移动操作所以必须在最后执行 + this->client->tunMsgQueue.write(Msg(MsgKind::PACKET, std::move(buffer))); +} - try { - this->ws->sendFrame(message.buffer.c_str(), message.buffer.length(), Poco::Net::WebSocket::FRAME_BINARY); - return 0; - } catch (std::exception &e) { - spdlog::critical("websocket write failed: {}", e.what()); - return -1; +void WebSocketClient::handleExptTunMsg(std::string buffer) { + if (buffer.size() < sizeof(WsMsg::ExptTun)) { + spdlog::warn("invalid expt tun message: {:n}", spdlog::to_hex(buffer)); + return; } + WsMsg::ExptTun *header = (WsMsg::ExptTun *)buffer.data(); + Address exptTun(header->cidr); + this->tunCidr = exptTun.toCidr(); + sendAuthMsg(); } -int WebSocketClient::setPingMessage(const std::string &message) { - this->pingMessage = message; - spdlog::debug("set ping message: {}", this->pingMessage); - return 0; +void WebSocketClient::handleDiscoveryMsg(std::string buffer) { + if (buffer.size() < sizeof(WsMsg::Discovery)) { + spdlog::warn("invalid discovery message: {:n}", spdlog::to_hex(buffer)); + return; + } + WsMsg::Discovery *header = (WsMsg::Discovery *)buffer.data(); + if (header->dst == IP4("255.255.255.255")) { + sendDiscoveryMsg(header->src); + } + this->client->peerMsgQueue.write(Msg(MsgKind::TRYP2P, header->src.toString())); +} + +void WebSocketClient::handleRouteMsg(std::string buffer) { + if (buffer.size() < sizeof(WsMsg::SysRoute)) { + spdlog::warn("invalid expt tun message: {:n}", spdlog::to_hex(buffer)); + return; + } + WsMsg::SysRoute *header = (WsMsg::SysRoute *)buffer.data(); + SysRouteEntry *rt = header->rtTable; + for (uint8_t idx = 0; idx < header->size; ++idx) { + this->client->tunMsgQueue.write(Msg(MsgKind::SYSRT, std::string((char *)(rt + idx), sizeof(SysRouteEntry)))); + } +} + +void WebSocketClient::sendFrame(const std::string &buffer, int flags) { + sendFrame(buffer.c_str(), buffer.size(), flags); +} + +void WebSocketClient::sendFrame(const void *buffer, int length, int flags) { + this->ws->sendFrame(buffer, length, flags); +} + +void WebSocketClient::sendVirtualMacMsg() { + WsMsg::VMac buffer(this->vmac); + buffer.updateHash(this->password); + sendFrame(&buffer, sizeof(buffer)); } -int WebSocketClient::sendPingMessage() { - WebSocketMessage wsMessage; - return sendPingMessage(wsMessage); +void WebSocketClient::sendExptTunMsg() { + Address exptTun(this->exptTunCidr); + WsMsg::ExptTun buffer(exptTun.toCidr()); + buffer.updateHash(this->password); + sendFrame(&buffer, sizeof(buffer)); } -int WebSocketClient::sendPingMessage(WebSocketMessage &message) { +void WebSocketClient::sendAuthMsg() { + Address address(this->tunCidr); + WsMsg::Auth buffer(address.Host()); + buffer.updateHash(this->password); + sendFrame(&buffer, sizeof(buffer)); + this->client->tunMsgQueue.write(Msg(MsgKind::TUNADDR, address.toCidr())); + if (addressUpdateCallback) { + addressUpdateCallback(address.toCidr()); + } +} + +void WebSocketClient::sendDiscoveryMsg(IP4 dst) { + Address address(this->tunCidr); + + WsMsg::Discovery buffer; + buffer.dst = dst; + buffer.src = address.Host(); + + sendFrame(&buffer, sizeof(buffer)); +} + +std::string WebSocketClient::hostName() { + char hostname[64] = {0}; + if (!gethostname(hostname, sizeof(hostname))) { + return std::string(hostname, strnlen(hostname, sizeof(hostname))); + } + return ""; +} + +void WebSocketClient::sendPingMessage() { + int flags = (int)Poco::Net::WebSocket::FRAME_FLAG_FIN | (int)Poco::Net::WebSocket::FRAME_OP_PING; + sendFrame(pingMessage, flags); +} + +int WebSocketClient::connect() { + std::shared_ptr uri; try { - int flags = (int)Poco::Net::WebSocket::FRAME_FLAG_FIN | (int)Poco::Net::WebSocket::FRAME_OP_PING; - this->ws->sendFrame(this->pingMessage.c_str(), this->pingMessage.size(), flags); + uri = std::make_shared(wsServerUri); + } catch (std::exception &e) { + spdlog::critical("invalid websocket server: {}: {}", wsServerUri, e.what()); + return -1; + } + + try { + const std::string path = uri->getPath().empty() ? "/" : uri->getPath(); + Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_GET, path, Poco::Net::HTTPMessage::HTTP_1_1); + Poco::Net::HTTPResponse response; + if (uri->getScheme() == "wss") { + using Poco::Net::Context; + Context::Ptr context = new Context(Context::TLS_CLIENT_USE, "", "", "", Context::VERIFY_NONE); + Poco::Net::HTTPSClientSession cs(uri->getHost(), uri->getPort(), context); + this->ws = std::make_shared(cs, request, response); + } else if (uri->getScheme() == "ws") { + Poco::Net::HTTPClientSession cs(uri->getHost(), uri->getPort()); + this->ws = std::make_shared(cs, request, response); + } else { + spdlog::critical("invalid websocket scheme: {}", wsServerUri); + return -1; + } + this->timestamp = bootTime(); + this->pingMessage = fmt::format("candy::{}::{}::{}", CANDY_SYSTEM, CANDY_VERSION, hostName()); return 0; } catch (std::exception &e) { - message.type = WebSocketMessageType::Error; - message.buffer = e.what(); - return 1; + spdlog::critical("websocket connect failed: {}", e.what()); + return -1; + } +} + +int WebSocketClient::disconnect() { + try { + if (this->ws) { + this->ws->shutdown(); + this->ws->close(); + this->ws.reset(); + } + } catch (std::exception &e) { + spdlog::debug("websocket disconnect failed: {}", e.what()); } + return 0; } } // namespace Candy diff --git a/src/websocket/client.h b/src/websocket/client.h index 7642962a..a962d777 100644 --- a/src/websocket/client.h +++ b/src/websocket/client.h @@ -2,36 +2,73 @@ #ifndef CANDY_WEBSOCKET_CLIENT_H #define CANDY_WEBSOCKET_CLIENT_H -#include "websocket/common.h" +#include "core/message.h" +#include "core/net.h" #include +#include +#include #include +#include namespace Candy { +class Client; + class WebSocketClient { public: - // 连接或断开与服务端的连接 - int connect(const std::string &address); - int disconnect(); + int setPassword(const std::string &password); + int setWsServerUri(const std::string &uri); + int setExptTunAddress(const std::string &cidr); + int setAddress(const std::string &cidr); + int setVirtualMac(const std::string &vmac); + int setTunUpdateCallback(std::function callback); + + int run(Client *client); + int shutdown(); + +private: + void handleWsQueue(); + void handlePacket(Msg msg); + + std::thread msgThread; + + void handleWsConn(); + void handleWsMsg(std::string buffer); + void handleForwardMsg(std::string buffer); + void handleExptTunMsg(std::string buffer); + void handleDiscoveryMsg(std::string buffer); + void handleRouteMsg(std::string buffer); + std::thread wsThread; - // 设置读超时时间 - int setTimeout(int timeout); + void sendFrame(const std::string &buffer, int flags = Poco::Net::WebSocket::FRAME_BINARY); + void sendFrame(const void *buffer, int length, int flags = Poco::Net::WebSocket::FRAME_BINARY); - // 读操作返回 0 表示超时.由于客户端只与一个服务端通信,事实上只需要操作的 buffer, - // 为了和服务端操作的数据结构保持一直,使用了相同的参数. - int read(WebSocketMessage &message); - int write(const WebSocketMessage &message); + void sendVirtualMacMsg(); + void sendExptTunMsg(); + void sendAuthMsg(); + void sendDiscoveryMsg(IP4 dst); - int setPingMessage(const std::string &message); - int sendPingMessage(); + std::function addressUpdateCallback; private: - int sendPingMessage(WebSocketMessage &message); + std::string hostName(); + void sendPingMessage(); + +private: + int connect(); + int disconnect(); - int timeout; std::shared_ptr ws; - int64_t timestamp; std::string pingMessage; + int64_t timestamp; + +private: + std::string wsServerUri; + std::string exptTunCidr; + std::string tunCidr; + std::string vmac; + std::string password; + Client *client; }; } // namespace Candy diff --git a/src/websocket/common.cc b/src/websocket/common.cc deleted file mode 100644 index d6339cd5..00000000 --- a/src/websocket/common.cc +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-License-Identifier: MIT -#include "websocket/common.h" -#include -#include - -namespace Candy { - -bool WebSocketConn::operator<(const WebSocketConn &other) const { - return std::owner_less>()(this->ws, other.ws); -} - -bool WebSocketConn::operator==(const WebSocketConn &other) const { - return this->ws.lock() == other.ws.lock(); -} - -} // namespace Candy diff --git a/src/websocket/common.h b/src/websocket/common.h deleted file mode 100644 index e3c5baa3..00000000 --- a/src/websocket/common.h +++ /dev/null @@ -1,34 +0,0 @@ -// SPDX-License-Identifier: MIT -#ifndef CANDY_WEBSOCKET_COMMON_H -#define CANDY_WEBSOCKET_COMMON_H - -#include -#include -#include - -namespace Candy { - -enum class WebSocketMessageType { Message, Open, Close, Error }; - -class WebSocketConn { -public: - // 重载小于号,用于作为 std::map 的 key - bool operator<(const WebSocketConn &other) const; - - // 重载等于号,用于判断是否是相同的连接 - bool operator==(const WebSocketConn &other) const; - - std::weak_ptr ws; -}; - -// 消息会被放到消息队列里,从消息队列里取出来的时候至少要包含消息的类型和来源, -// 对于客户端,消息的来源只能是服务端,所以客户端的消息队列可以不填充 conn 字段 -struct WebSocketMessage { - WebSocketMessageType type; - std::string buffer; - WebSocketConn conn; -}; - -} // namespace Candy - -#endif diff --git a/src/websocket/message.cc b/src/websocket/message.cc new file mode 100644 index 00000000..558d536b --- /dev/null +++ b/src/websocket/message.cc @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +#include "websocket/message.h" +#include "utility/time.h" + +namespace Candy { +namespace WsMsg { + +Auth::Auth(IP4 ip) { + this->type = WsMsgKind::AUTH; + this->ip = ip; + this->timestamp = hton(unixTime()); +} + +void Auth::updateHash(const std::string &password) { + std::string data; + data.append(password); + data.append((char *)&ip, sizeof(ip)); + data.append((char *)×tamp, sizeof(timestamp)); + SHA256((unsigned char *)data.data(), data.size(), this->hash); +} + +bool Auth::check(const std::string &password) { + // 检查时间 + int64_t localTime = unixTime(); + int64_t remoteTime = ntoh(this->timestamp); + if (std::abs(localTime - remoteTime) > 30) { + spdlog::warn("auth header timestamp check failed: server {} client {}", localTime, remoteTime); + } + + // 备份上报的数据 + uint8_t reported[SHA256_DIGEST_LENGTH]; + std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); + + // 用口令计算正确的哈希并填充 + updateHash(password); + + // 检查上报的哈希和填充的哈希是否相等 + if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { + spdlog::warn("auth header hash check failed"); + return false; + } + return true; +} + +Forward::Forward() { + this->type = WsMsgKind::FORWARD; +} + +ExptTun::ExptTun(const std::string &cidr) { + this->type = WsMsgKind::EXPTTUN; + this->timestamp = hton(unixTime()); + std::strcpy(this->cidr, cidr.c_str()); +} + +void ExptTun::updateHash(const std::string &password) { + std::string data; + data.append(password); + data.append((char *)&this->timestamp, sizeof(this->timestamp)); + SHA256((unsigned char *)data.data(), data.size(), this->hash); +} + +bool ExptTun::check(const std::string &password) { + int64_t localTime = unixTime(); + int64_t remoteTime = ntoh(this->timestamp); + if (std::abs(localTime - remoteTime) > 30) { + spdlog::warn("expected address header timestamp check failed: server {} client {}", localTime, remoteTime); + return false; + } + + uint8_t reported[SHA256_DIGEST_LENGTH]; + std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); + + updateHash(password); + + if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { + spdlog::warn("expected address header hash check failed"); + return false; + } + return true; +} + +Udp4Conn::Udp4Conn() { + this->type = WsMsgKind::UDP4CONN; +} + +VMac::VMac(const std::string &vmac) { + this->type = WsMsgKind::VMAC; + this->timestamp = hton(unixTime()); + if (vmac.length() >= sizeof(this->vmac)) { + memcpy(this->vmac, vmac.c_str(), sizeof(this->vmac)); + } else { + memset(this->vmac, 0, sizeof(this->vmac)); + } +} + +void VMac::updateHash(const std::string &password) { + std::string data; + data.append(password); + data.append((char *)&this->vmac, sizeof(this->vmac)); + data.append((char *)&this->timestamp, sizeof(this->timestamp)); + SHA256((unsigned char *)data.data(), data.size(), this->hash); +} + +bool VMac::check(const std::string &password) { + int64_t localTime = unixTime(); + int64_t remoteTime = ntoh(this->timestamp); + if (std::abs(localTime - remoteTime) > 30) { + spdlog::warn("vmac message timestamp check failed: server {} client {}", localTime, remoteTime); + return false; + } + + uint8_t reported[SHA256_DIGEST_LENGTH]; + std::memcpy(reported, this->hash, SHA256_DIGEST_LENGTH); + + updateHash(password); + + if (std::memcmp(reported, this->hash, SHA256_DIGEST_LENGTH)) { + spdlog::warn("vmac message hash check failed"); + return false; + } + return true; +} + +Discovery::Discovery() { + this->type = WsMsgKind::DISCOVERY; +} + +General::General() { + this->type = WsMsgKind::GENERAL; +} + +} // namespace WsMsg +} // namespace Candy diff --git a/src/websocket/message.h b/src/websocket/message.h new file mode 100644 index 00000000..4cc4903a --- /dev/null +++ b/src/websocket/message.h @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +#ifndef CANDY_WEBSOCKET_MESSAGE_H +#define CANDY_WEBSOCKET_MESSAGE_H + +#include "core/net.h" +#include + +namespace Candy { + +namespace WsMsgKind { +constexpr uint8_t AUTH = 0; +constexpr uint8_t FORWARD = 1; +constexpr uint8_t EXPTTUN = 2; +constexpr uint8_t UDP4CONN = 3; +constexpr uint8_t VMAC = 4; +constexpr uint8_t DISCOVERY = 5; +constexpr uint8_t ROUTE = 6; +constexpr uint8_t GENERAL = 255; +} // namespace WsMsgKind + +namespace GeSubType { +constexpr uint8_t LOCALUDP4CONN = 0; +} + +namespace WsMsg { + +struct __attribute__((packed)) Auth { + uint8_t type; + IP4 ip; + int64_t timestamp; + uint8_t hash[SHA256_DIGEST_LENGTH]; + + Auth(IP4 ip); + void updateHash(const std::string &password); + bool check(const std::string &password); +}; + +struct __attribute__((packed)) Forward { + uint8_t type; + IP4Header iph; + + Forward(); +}; + +struct __attribute__((packed)) ExptTun { + uint8_t type; + int64_t timestamp; + char cidr[32]; + uint8_t hash[SHA256_DIGEST_LENGTH]; + + ExptTun(const std::string &cidr); + void updateHash(const std::string &password); + bool check(const std::string &password); +}; + +struct __attribute__((packed)) Udp4Conn { + uint8_t type; + IP4 src; + IP4 dst; + IP4 ip; + uint16_t port; + + Udp4Conn(); +}; + +struct __attribute__((packed)) VMac { + uint8_t type; + uint8_t vmac[16]; + int64_t timestamp; + uint8_t hash[SHA256_DIGEST_LENGTH]; + + VMac(const std::string &vmac); + void updateHash(const std::string &password); + bool check(const std::string &password); +}; + +struct __attribute__((packed)) Discovery { + uint8_t type; + IP4 src; + IP4 dst; + + Discovery(); +}; + +struct __attribute__((packed)) SysRoute { + uint8_t type; + uint8_t size; + uint16_t reserved; + SysRouteEntry rtTable[0]; +}; + +struct __attribute__((packed)) General { + uint8_t type; + uint8_t subtype; + uint16_t extra; + IP4 src; + IP4 dst; + + General(); +}; + +struct __attribute__((packed)) LocalUDP4 { + General ge; + IP4 ip; + uint16_t port; +}; + +} // namespace WsMsg +} // namespace Candy + +#endif diff --git a/src/websocket/server.cc b/src/websocket/server.cc index fa25e74e..a6724523 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT #include "websocket/server.h" -#include "websocket/common.h" +#include "core/net.h" +#include "utility/time.h" +#include "websocket/message.h" #include #include #include @@ -8,168 +10,539 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include +#include +/** + * Poco 的 WebSocket 服务端接口有点难用,简单封装一下,并对外提供一个回调函数,回调函数的参数表示独立的 + * WebSocket客户端,函数返回会释放连接 + */ namespace { -using namespace Candy; +using WebSocketHandler = std::function; -class WebSocketHandler : public Poco::Net::HTTPRequestHandler { +class HTTPRequestHandler : public Poco::Net::HTTPRequestHandler { public: - WebSocketHandler(WebSocketServer *server) { - this->server = server; - } + HTTPRequestHandler(WebSocketHandler wsHandler) : wsHandler(wsHandler) {} void handleRequest(Poco::Net::HTTPServerRequest &request, Poco::Net::HTTPServerResponse &response) { - std::shared_ptr ws = std::make_shared(request, response); - ws->setReceiveTimeout(Poco::Timespan(this->server->timeout, 0)); - - char buffer[1500] = {0}; - int length = 0; - int flags = 0; - - while (this->server->running) { - try { - length = ws->receiveFrame(buffer, sizeof(buffer), flags); - int frameOp = flags & Poco::Net::WebSocket::FRAME_OP_BITMASK; - - if (frameOp == Poco::Net::WebSocket::FRAME_OP_PING) { - flags = (int)Poco::Net::WebSocket::FRAME_FLAG_FIN | (int)Poco::Net::WebSocket::FRAME_OP_PONG; - ws->sendFrame(buffer, length, flags); - continue; - } - - if ((length == 0 && flags == 0) || frameOp == Poco::Net::WebSocket::FRAME_OP_CLOSE) { - WebSocketMessage msg; - msg.type = WebSocketMessageType::Close; - msg.buffer.assign(buffer, length); - msg.conn.ws = std::weak_ptr(ws); - this->server->push(msg); - break; - } - - if (frameOp == Poco::Net::WebSocket::FRAME_OP_BINARY && length > 0) { - WebSocketMessage msg; - msg.type = WebSocketMessageType::Message; - msg.buffer.assign(buffer, length); - msg.conn.ws = std::weak_ptr(ws); - this->server->push(msg); - continue; - } - } catch (Poco::TimeoutException const &e) { - continue; - } catch (std::exception &e) { - WebSocketMessage msg; - msg.type = WebSocketMessageType::Close; - msg.buffer = e.what(); - msg.conn.ws = std::weak_ptr(ws); - this->server->push(msg); - break; - } - spdlog::debug("unknown websocket request: length {} flags {}", length, flags); + try { + Poco::Net::WebSocket ws(request, response); + wsHandler(ws); + ws.close(); + } catch (const std::exception &e) { + response.setStatus(Poco::Net::HTTPResponse::HTTP_FORBIDDEN); + response.setReason("Forbidden"); + response.setContentLength(0); + response.send(); } - ws->close(); } private: - WebSocketServer *server = nullptr; + WebSocketHandler wsHandler; }; -class ForbiddenHandler : public Poco::Net::HTTPRequestHandler { +class HTTPRequestHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory { public: - void handleRequest(Poco::Net::HTTPServerRequest &request, Poco::Net::HTTPServerResponse &response) { - response.setStatus(Poco::Net::HTTPResponse::HTTP_FORBIDDEN); - response.setReason("Forbidden"); - response.setContentLength(0); - response.send(); - } -}; + HTTPRequestHandlerFactory(WebSocketHandler wsHandler) : wsHandler(wsHandler) {} -class WebSocketHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory { -public: - WebSocketHandlerFactory(WebSocketServer *server) { - this->server = server; - } Poco::Net::HTTPRequestHandler *createRequestHandler(const Poco::Net::HTTPServerRequest &request) { - if (request.get("Upgrade", "") == "websocket") { - return new WebSocketHandler(this->server); - } else { - return new ForbiddenHandler(); - } + return new HTTPRequestHandler(wsHandler); } private: - WebSocketServer *server = nullptr; + WebSocketHandler wsHandler; }; -} // namespace +}; // namespace namespace Candy { -int WebSocketServer::listen(const std::string &host, uint16_t port) { +void WsCtx::sendFrame(const std::string &frame, int flags) { + this->ws->sendFrame(frame.data(), frame.size(), flags); +} + +int WebSocketServer::setWebSocket(const std::string &uri) { try { - Poco::Net::ServerSocket socket(Poco::Net::SocketAddress(host, port)); - Poco::Net::HTTPServerParams *params = new Poco::Net::HTTPServerParams(); - params->setMaxThreads(0x00FFFFFF); - this->server = std::make_shared(new WebSocketHandlerFactory(this), socket, params); - this->running = true; - this->server->start(); + Poco::URI parser(uri); + if (parser.getScheme() != "ws") { + spdlog::critical("websocket server only support ws"); + return -1; + } + this->host = parser.getHost(); + this->port = parser.getPort(); return 0; } catch (std::exception &e) { - spdlog::critical("listen failed: {}", e.what()); + spdlog::critical("invalid websocket uri: {}: {}", uri, e.what()); return -1; } } -int WebSocketServer::stop() { - this->running = false; - if (this->server) { - this->server->stop(); - this->server->stopAll(); +int WebSocketServer::setPassword(const std::string &password) { + this->password = password; + return 0; +} + +int WebSocketServer::setDHCP(const std::string &cidr) { + if (cidr.empty()) { + return 0; + } + return this->dhcp.fromCidr(cidr); +} + +int WebSocketServer::setSdwan(const std::string &sdwan) { + if (sdwan.empty()) { + return 0; + } + std::string route; + std::stringstream stream(sdwan); + while (std::getline(stream, route, ';')) { + std::string addr; + SysRoute rt; + std::stringstream ss(route); + // dev + if (!std::getline(ss, addr, ',') || rt.dev.fromCidr(addr) || rt.dev.Host() != rt.dev.Net()) { + spdlog::critical("invalid route device: {}", route); + return -1; + } + // dst + if (!std::getline(ss, addr, ',') || rt.dst.fromCidr(addr) || rt.dst.Host() != rt.dst.Net()) { + spdlog::critical("invalid route dest: {}", route); + return -1; + } + // next + if (!std::getline(ss, addr, ',') || rt.next.fromString(addr)) { + spdlog::critical("invalid route nexthop: {}", route); + return -1; + } + spdlog::info("route: dev={} dst={} next={}", rt.dev.toCidr(), rt.dst.toCidr(), rt.next.toString()); + this->routes.push_back(rt); } return 0; } -int WebSocketServer::setTimeout(int timeout) { - this->timeout = timeout; +int WebSocketServer::run() { + listen(); return 0; } -int WebSocketServer::read(WebSocketMessage &message) { - std::unique_lock lock(this->mutex); - if (this->condition.wait_for(lock, std::chrono::seconds(this->timeout), [&] { return !this->queue.empty(); })) { - message = this->queue.front(); - this->queue.pop(); - return 1; +int WebSocketServer::shutdown() { + this->running = false; + if (this->httpServer) { + this->httpServer->stopAll(); } + this->routes.clear(); return 0; } -int WebSocketServer::write(const WebSocketMessage &message) { - auto ws = message.conn.ws.lock(); - if (ws) { - try { - ws->sendFrame(message.buffer.c_str(), message.buffer.size(), Poco::Net::WebSocket::FRAME_BINARY); - } catch (std::exception &e) { - spdlog::warn("websocket server write failed: {}", e.what()); +void WebSocketServer::handleMsg(WsCtx &ctx) { + uint8_t msgKind = ctx.buffer.front(); + switch (msgKind) { + case WsMsgKind::AUTH: + handleAuthMsg(ctx); + break; + case WsMsgKind::FORWARD: + handleForwardMsg(ctx); + break; + case WsMsgKind::EXPTTUN: + handleExptTunMsg(ctx); + break; + case WsMsgKind::UDP4CONN: + handleUdp4ConnMsg(ctx); + break; + case WsMsgKind::VMAC: + handleVMacMsg(ctx); + break; + case WsMsgKind::DISCOVERY: + handleDiscoveryMsg(ctx); + break; + case WsMsgKind::GENERAL: + HandleGeneralMsg(ctx); + break; + } +} + +void WebSocketServer::handleAuthMsg(WsCtx &ctx) { + if (ctx.buffer.length() < sizeof(WsMsg::Auth)) { + spdlog::warn("invalid auth message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::Auth *header = (WsMsg::Auth *)ctx.buffer.data(); + if (!header->check(this->password)) { + spdlog::warn("auth header check failed: buffer {:n}", spdlog::to_hex(ctx.buffer)); + ctx.status = -1; + return; + } + + ctx.ip = header->ip; + + { + std::unique_lock lock(ipCtxMutex); + auto it = ipCtxMap.find(header->ip); + if (it != ipCtxMap.end()) { + it->second->status = -1; + spdlog::info("reconnect: {}", it->second->ip.toString()); + } else { + spdlog::info("connect: {}", ctx.ip.toString()); } + ipCtxMap[header->ip] = &ctx; } - return 0; + + updateSysRoute(ctx); } -int WebSocketServer::close(WebSocketConn conn) { - auto ws = conn.ws.lock(); - if (ws) { - ws->close(); +void WebSocketServer::handleForwardMsg(WsCtx &ctx) { + if (ctx.ip.empty()) { + spdlog::debug("unauthorized forward websocket client"); + ctx.status = -1; + return; } - return 0; + + if (ctx.buffer.length() < sizeof(WsMsg::Forward)) { + spdlog::debug("invalid forawrd message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::Forward *header = (WsMsg::Forward *)ctx.buffer.data(); + if (ctx.ip != header->iph.saddr) { + spdlog::debug("forward failed: auth {} source {}", ctx.ip.toString(), header->iph.saddr.toString()); + ctx.status = -1; + return; + } + + { + std::shared_lock lock(this->ipCtxMutex); + auto it = this->ipCtxMap.find(header->iph.daddr); + if (it != this->ipCtxMap.end()) { + it->second->sendFrame(ctx.buffer); + return; + } + } + + bool broadcast = [&] { + // 多播地址 + if ((header->iph.daddr & IP4("240.0.0.0")) == IP4("224.0.0.0")) { + return true; + } + // 广播 + if (header->iph.daddr == IP4("255.255.255.255")) { + return true; + } + // 服务端没有配置动态分配地址的范围,没法检查是否为定向广播 + if (this->dhcp.empty()) { + return false; + } + // 网络号不同,不是定向广播 + if ((this->dhcp.Mask() & header->iph.daddr) != this->dhcp.Net()) { + return false; + } + // 主机号部分不全为 1,不是定向广播 + if (~((header->iph.daddr & ~this->dhcp.Mask()) ^ this->dhcp.Mask())) { + return false; + } + return true; + }(); + + if (broadcast) { + std::shared_lock lock(this->ipCtxMutex); + for (auto c : this->ipCtxMap) { + if (c.second->ip != ctx.ip) { + c.second->sendFrame(ctx.buffer); + } + } + return; + } + + spdlog::debug("forward failed: source {} dest {}", header->iph.saddr.toString(), header->iph.daddr.toString()); + return; +} + +void WebSocketServer::handleExptTunMsg(WsCtx &ctx) { + if (ctx.buffer.length() < sizeof(WsMsg::ExptTun)) { + spdlog::warn("invalid dynamic address message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + WsMsg::ExptTun *header = (WsMsg::ExptTun *)ctx.buffer.data(); + if (!header->check(this->password)) { + spdlog::warn("dynamic address header check failed: buffer {:n}", spdlog::to_hex(ctx.buffer)); + ctx.status = -1; + return; + } + if (this->dhcp.empty()) { + spdlog::warn("unable to allocate dynamic address"); + ctx.status = -1; + return; + } + Address exptTun; + if (exptTun.fromCidr(header->cidr)) { + spdlog::warn("dynamic address header cidr invalid: buffer {:n}", spdlog::to_hex(ctx.buffer)); + ctx.status = -1; + return; + } + // 判断能否直接使用申请的地址 + bool direct = [&]() { + if (dhcp.Net() != exptTun.Net()) { + return false; + } + std::shared_lock lock(this->ipCtxMutex); + auto oldCtx = this->ipCtxMap.find(exptTun.Host()); + if (oldCtx == this->ipCtxMap.end()) { + return true; + } + return ctx.vmac == oldCtx->second->vmac; + }(); + if (!direct) { + exptTun = this->dhcp; + std::shared_lock lock(this->ipCtxMutex); + do { + exptTun = exptTun.Next(); + if (exptTun.Host() == this->dhcp.Host()) { + spdlog::warn("all addresses in the network are assigned"); + ctx.status = -1; + return; + } + } while (!exptTun.isValid() && this->ipCtxMap.contains(exptTun.Host())); + this->dhcp = exptTun; + } + header->timestamp = hton(unixTime()); + std::strcpy(header->cidr, exptTun.toCidr().c_str()); + header->updateHash(this->password); + ctx.sendFrame(ctx.buffer.data()); } -void WebSocketServer::push(const WebSocketMessage &msg) { +void WebSocketServer::handleUdp4ConnMsg(WsCtx &ctx) { + if (ctx.ip.empty()) { + spdlog::debug("unauthorized peer websocket client"); + ctx.status = -1; + return; + } + + if (ctx.buffer.length() < sizeof(WsMsg::Udp4Conn)) { + spdlog::warn("invalid peer conn message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::Udp4Conn *header = (WsMsg::Udp4Conn *)ctx.buffer.data(); + if (ctx.ip != header->src) { + spdlog::debug("peer source address does not match: auth {} source {}", ctx.ip.toString(), header->src.toString()); + ctx.status = -1; + return; + } + std::shared_lock lock(this->ipCtxMutex); + auto it = this->ipCtxMap.find(header->dst); + if (it == this->ipCtxMap.end()) { + spdlog::debug("peer dest address not logged in: source {} dst {}", header->src.toString(), header->dst.toString()); + return; + } + it->second->sendFrame(ctx.buffer); + return; +} + +void WebSocketServer::handleVMacMsg(WsCtx &ctx) { + if (ctx.buffer.length() < sizeof(WsMsg::VMac)) { + spdlog::warn("invalid vmac message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::VMac *header = (WsMsg::VMac *)ctx.buffer.data(); + if (!header->check(this->password)) { + spdlog::warn("vmac message check failed: buffer {:n}", spdlog::to_hex(ctx.buffer)); + ctx.status = -1; + return; + } + + ctx.vmac.assign((char *)header->vmac, sizeof(header->vmac)); + return; +} + +void WebSocketServer::handleDiscoveryMsg(WsCtx &ctx) { + if (ctx.ip.empty()) { + spdlog::debug("unauthorized discovery websocket client"); + ctx.status = -1; + return; + } + + if (ctx.buffer.length() < sizeof(WsMsg::Discovery)) { + spdlog::debug("invalid discovery message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::Discovery *header = (WsMsg::Discovery *)ctx.buffer.data(); + if (ctx.ip != header->src) { + spdlog::debug("discovery source address does not match: auth {} source {}", ctx.ip.toString(), header->src.toString()); + ctx.status = -1; + return; + } + + std::shared_lock lock(this->ipCtxMutex); + if (header->dst == IP4("255.255.255.255")) { + for (auto c : this->ipCtxMap) { + if (c.first != header->src) { + c.second->sendFrame(ctx.buffer); + } + } + return; + } + auto it = this->ipCtxMap.find(header->dst); + if (it != this->ipCtxMap.end()) { + it->second->sendFrame(ctx.buffer); + return; + } +} + +void WebSocketServer::HandleGeneralMsg(WsCtx &ctx) { + if (ctx.ip.empty()) { + spdlog::debug("unauthorized general websocket client"); + ctx.status = -1; + return; + } + + if (ctx.buffer.length() < sizeof(WsMsg::General)) { + spdlog::debug("invalid general message: len {}", ctx.buffer.length()); + ctx.status = -1; + return; + } + + WsMsg::General *header = (WsMsg::General *)ctx.buffer.data(); + + if (ctx.ip != header->src) { + spdlog::debug("general source address does not match: auth {} source {}", ctx.ip.toString(), header->src.toString()); + ctx.status = -1; + return; + } + + std::shared_lock lock(this->ipCtxMutex); + if (header->dst == IP4("255.255.255.255")) { + for (auto c : this->ipCtxMap) { + if (c.first != header->src) { + c.second->sendFrame(ctx.buffer); + } + } + return; + } + auto it = this->ipCtxMap.find(header->dst); + if (it != this->ipCtxMap.end()) { + it->second->sendFrame(ctx.buffer); + return; + } +} + +void WebSocketServer::updateSysRoute(WsCtx &ctx) { + ctx.buffer.resize(sizeof(WsMsg::SysRoute)); + WsMsg::SysRoute *header = (WsMsg::SysRoute *)ctx.buffer.data(); + memset(header, 0, sizeof(WsMsg::SysRoute)); + header->type = WsMsgKind::ROUTE; + + for (auto rt : this->routes) { + if ((rt.dev.Mask() & ctx.ip) == rt.dev.Host()) { + SysRouteEntry item; + item.dst = rt.dst.Net(); + item.mask = rt.dst.Mask(); + item.nexthop = rt.next; + ctx.buffer.append((char *)(&item), sizeof(item)); + header->size += 1; + } + // 100 条路由报文大小是 1204 字节,超过 100 条后分批发送 + if (header->size > 100) { + ctx.sendFrame(ctx.buffer); + ctx.buffer.resize(sizeof(WsMsg::SysRoute)); + header->size = 0; + } + } + + if (header->size > 0) { + ctx.sendFrame(ctx.buffer); + } +} + +int WebSocketServer::listen() { + try { + // 设置监听的地址和端口 + Poco::Net::ServerSocket socket(Poco::Net::SocketAddress(host, port)); + + // 设置最多同时可以处理的客户端数为局域网最大主机数 + Poco::Net::HTTPServerParams *params = new Poco::Net::HTTPServerParams(); + params->setMaxThreads(0x00FFFFFF); + + // 创建 HTTP 服务端并启动 + this->running = true; + WebSocketHandler wsHandler = [this](Poco::Net::WebSocket &ws) { handleWebsocket(ws); }; + this->httpServer = std::make_shared(new HTTPRequestHandlerFactory(wsHandler), socket, params); + this->httpServer->start(); + return 0; + } catch (std::exception &e) { + spdlog::critical("listen failed: {}", e.what()); + return -1; + } +} + +void WebSocketServer::handleWebsocket(Poco::Net::WebSocket &ws) { + ws.setReceiveTimeout(Poco::Timespan(1, 0)); + WsCtx ctx = {.ws = &ws}; + + int flags = 0; + int length = 0; + std::string buffer; + while (this->running && ctx.status == 0) { + try { + buffer.resize(1500); + length = ws.receiveFrame(buffer.data(), buffer.size(), flags); + int frameOp = flags & Poco::Net::WebSocket::FRAME_OP_BITMASK; + + // 响应 Ping 报文 + if (frameOp == Poco::Net::WebSocket::FRAME_OP_PING) { + flags = (int)Poco::Net::WebSocket::FRAME_FLAG_FIN | (int)Poco::Net::WebSocket::FRAME_OP_PONG; + ws.sendFrame(buffer.data(), buffer.size(), flags); + continue; + } + + // 客户端主动关闭连接 + if ((length == 0 && flags == 0) || frameOp == Poco::Net::WebSocket::FRAME_OP_CLOSE) { + break; + } + + if (frameOp == Poco::Net::WebSocket::FRAME_OP_BINARY && length > 0) { + // 调整 buffer 为真实大小并移动到 ctx + buffer.resize(length); + ctx.buffer = std::move(buffer); + + // 处理客户端请求 + handleMsg(ctx); + + // 重新初始化 buffer + buffer = std::string(); + } + } catch (Poco::TimeoutException const &e) { + // 超时异常,不做处理 + continue; + } catch (std::exception &e) { + // 未知异常,退出这个客户端 + spdlog::debug("handle websocket failed: {}", e.what()); + break; + } + } + { - std::lock_guard lock(this->mutex); - this->queue.push(msg); + std::unique_lock lock(ipCtxMutex); + auto it = ipCtxMap.find(ctx.ip); + if (it != ipCtxMap.end() && it->second == &ctx) { + ipCtxMap.erase(it); + spdlog::info("disconnect: {}", ctx.ip.toString()); + } } - this->condition.notify_all(); } } // namespace Candy diff --git a/src/websocket/server.h b/src/websocket/server.h index 1c965b06..cbbf4bee 100644 --- a/src/websocket/server.h +++ b/src/websocket/server.h @@ -2,42 +2,79 @@ #ifndef CANDY_WEBSOCKET_SERVER_H #define CANDY_WEBSOCKET_SERVER_H -#include "websocket/common.h" +#include "core/net.h" #include -#include -#include +#include +#include #include -#include +#include #include namespace Candy { +struct WsCtx { + Poco::Net::WebSocket *ws; + + std::string buffer; + int status; + + IP4 ip; + std::string vmac; + + void sendFrame(const std::string &frame, int flags = Poco::Net::WebSocket::FRAME_BINARY); +}; + +struct SysRoute { + // 通过地址和掩码确定策略下发给哪些客户端 + Address dev; + // 系统路由策略中的地址掩码和下一跳 + Address dst; + IP4 next; +}; + class WebSocketServer { public: - // 开始监听和停止监听 - int listen(const std::string &ipStr, uint16_t port); - int stop(); + int setWebSocket(const std::string &uri); + int setPassword(const std::string &password); + int setDHCP(const std::string &cidr); + int setSdwan(const std::string &sdwan); + int run(); + int shutdown(); - // 设置读操作超时时间 - int setTimeout(int timeout); +private: + std::string host; + uint16_t port; + std::string password; + Address dhcp; + std::list routes; - // 阻塞的读写操作 - int read(WebSocketMessage &message); - int write(const WebSocketMessage &message); +private: + void handleMsg(WsCtx &ctx); + void handleAuthMsg(WsCtx &ctx); + void handleForwardMsg(WsCtx &ctx); + void handleExptTunMsg(WsCtx &ctx); + void handleUdp4ConnMsg(WsCtx &ctx); + void handleVMacMsg(WsCtx &ctx); + void handleDiscoveryMsg(WsCtx &ctx); + void HandleGeneralMsg(WsCtx &ctx); - // 关闭单个客户端连接 - int close(WebSocketConn conn); + // 更新客户端系统路由 + void updateSysRoute(WsCtx &ctx); - void push(const WebSocketMessage &msg); + // 保存 IP 到对应连接指针的映射 + std::unordered_map ipCtxMap; + // 操作 map 时需要加锁,以确保操作时指针有效 + std::shared_mutex ipCtxMutex; bool running; - int timeout; private: - std::mutex mutex; - std::condition_variable condition; - std::queue queue; - std::shared_ptr server; + // 开始监听,新的请求将调用 handleWebsocket + int listen(); + // 同步的处理每个客户独的请求,函数返回后连接将断开 + void handleWebsocket(Poco::Net::WebSocket &ws); + + std::shared_ptr httpServer; }; } // namespace Candy