Skip to content

Commit

Permalink
[Coll] Implement get host address in libxgboost.
Browse files Browse the repository at this point in the history
- Port `xgboost.tracker.get_host_ip` in C++.
  • Loading branch information
trivialfis committed Oct 9, 2023
1 parent db8d117 commit e6e30ad
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 31 deletions.
28 changes: 28 additions & 0 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,34 @@ class TCPSocket {
* @brief Get the local host name.
*/
[[nodiscard]] Result GetHostName(std::string *p_out);

/**
* @brief inet_ntop
*/
template <typename H>
Result INetNToP(H const &host, std::string *p_out) {
std::string &ip = *p_out;
switch (host->h_addrtype) {
case AF_INET: {
auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
char str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
ip = str;
break;
}
case AF_INET6: {
auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
char str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
ip = str;
break;
}
default: {
return Fail("Invalid address type.");
}
}
return Success();
}
} // namespace collective
} // namespace xgboost

Expand Down
76 changes: 76 additions & 0 deletions src/collective/tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
#endif // defined(__unix__) || defined(__APPLE__)

#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)

#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#endif // defined(_WIN32)

#include <string> // for string

#include "xgboost/collective/result.h" // for Result, Fail, Success
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...

namespace xgboost::collective {
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {
return rc;
}
auto host = gethostbyname(out->c_str());

// get ip address from host
std::string ip;
rc = INetNToP(host, &ip);
if (!rc.OK()) {
return rc;
}

if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) {
// return if this is a public IP address.
// not entirely accurate, we have other reserved IPs
*out = ip;
return Success();
}

// Create an UDP socket to prob the public IP address, it's fine even if it's
// unreachable.
auto sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock == -1) {
return Fail("Failed to create socket.");
}

auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1);
sockaddr const* addr_handle = reinterpret_cast<const sockaddr*>(&paddr.V4().Handle());
socklen_t addr_len{sizeof(paddr.V4().Handle())};
auto err = connect(sock, addr_handle, addr_len);
if (err != 0) {
return system::FailWithCode("Failed to find IP address.");
}

// get the IP address from socket desrciptor
struct sockaddr_in addr;
socklen_t len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) {
return Fail("Failed to get sock name.");
}
ip = inet_ntoa(addr.sin_addr);

err = system::CloseSocket(sock);
if (err != 0) {
return system::FailWithCode("Failed to close socket.");
}

*out = ip;
return Success();
}
} // namespace xgboost::collective
15 changes: 15 additions & 0 deletions src/collective/tracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <string> // for string

#include "xgboost/collective/result.h" // for Result

namespace xgboost::collective {
// Prob the public IP address of the host, need a better method.
//
// This is directly translated from the previous Python implementation, we should find a
// more riguous approach, can use some expertise in network programming.
[[nodiscard]] Result GetHostAddress(std::string* out);
} // namespace xgboost::collective
41 changes: 41 additions & 0 deletions tests/cpp/collective/net_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost Contributors
*/
#pragma once

#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>

#include <fstream> // ifstream

#include "../helpers.h" // for FileExists

namespace xgboost::collective {
class SocketTest : public ::testing::Test {
protected:
std::string skip_msg_{"Skipping IPv6 test"};

bool SkipTest() {
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
return true;
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
return true;
}
} else {
return true;
}
return false;
}

protected:
void SetUp() override { system::SocketStartup(); }
void TearDown() override { system::SocketFinalize(); }
};
} // namespace xgboost::collective
52 changes: 21 additions & 31 deletions tests/cpp/collective/test_socket.cc
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>

#include <cerrno> // EADDRNOTAVAIL
#include <fstream> // ifstream
#include <system_error> // std::error_code, std::system_category

#include "../helpers.h"
#include "net_test.h" // for SocketTest

namespace xgboost::collective {
TEST(Socket, Basic) {
system::SocketStartup();

TEST_F(SocketTest, Basic) {
SockAddress addr{SockAddrV6::Loopback()};
ASSERT_TRUE(addr.IsV6());
addr = SockAddress{SockAddrV4::Loopback()};
Expand Down Expand Up @@ -54,34 +51,27 @@ TEST(Socket, Basic) {

run_test(SockDomain::kV4);

std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
GTEST_SKIP_(msg.c_str());
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
GTEST_SKIP_(msg.c_str());
}
} else {
GTEST_SKIP_(msg.c_str());
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run_test(SockDomain::kV6);

system::SocketFinalize();
}

TEST(Socket, Bind) {
system::SocketStartup();
auto any = SockAddrV4::InaddrAny().Addr();
auto sock = TCPSocket::Create(SockDomain::kV4);
std::int32_t port{0};
auto rc = sock.Bind(any, &port);
ASSERT_TRUE(rc.OK());
ASSERT_NE(port, 0);
system::SocketFinalize();
TEST_F(SocketTest, Bind) {
auto run = [](SockDomain domain) {
auto any =
domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
auto sock = TCPSocket::Create(domain);
std::int32_t port{0};
auto rc = sock.Bind(any, &port);
ASSERT_TRUE(rc.OK());
ASSERT_NE(port, 0);
};

run(SockDomain::kV4);
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run(SockDomain::kV6);
}
} // namespace xgboost::collective
18 changes: 18 additions & 0 deletions tests/cpp/collective/test_tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "../../../src/collective/tracker.h" // for GetHostAddress
#include "net_test.h" // for SocketTest

namespace xgboost::collective {
namespace {
class TrackerTest : public SocketTest {};
} // namespace

TEST_F(TrackerTest, GetHostAddress) {
std::string host;
auto rc = GetHostAddress(&host);
ASSERT_TRUE(rc.OK());
ASSERT_TRUE(host.find("127.") == std::string::npos);
}
} // namespace xgboost::collective

0 comments on commit e6e30ad

Please sign in to comment.