-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Coll] Implement get host address in libxgboost.
- Port `xgboost.tracker.get_host_ip` in C++.
- Loading branch information
1 parent
db8d117
commit e6e30ad
Showing
6 changed files
with
199 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#if defined(__unix__) || defined(__APPLE__) | ||
#include <netdb.h> // gethostbyname | ||
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname | ||
#endif // defined(__unix__) || defined(__APPLE__) | ||
|
||
#if !defined(NOMINMAX) && defined(_WIN32) | ||
#define NOMINMAX | ||
#endif // !defined(NOMINMAX) | ||
|
||
#if defined(_WIN32) | ||
#include <winsock2.h> | ||
#include <ws2tcpip.h> | ||
#endif // defined(_WIN32) | ||
|
||
#include <string> // for string | ||
|
||
#include "xgboost/collective/result.h" // for Result, Fail, Success | ||
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ... | ||
|
||
namespace xgboost::collective { | ||
[[nodiscard]] Result GetHostAddress(std::string* out) { | ||
auto rc = GetHostName(out); | ||
if (!rc.OK()) { | ||
return rc; | ||
} | ||
auto host = gethostbyname(out->c_str()); | ||
|
||
// get ip address from host | ||
std::string ip; | ||
rc = INetNToP(host, &ip); | ||
if (!rc.OK()) { | ||
return rc; | ||
} | ||
|
||
if (!(ip.size() >= 4 && ip.substr(0, 4) == "127.")) { | ||
// return if this is a public IP address. | ||
// not entirely accurate, we have other reserved IPs | ||
*out = ip; | ||
return Success(); | ||
} | ||
|
||
// Create an UDP socket to prob the public IP address, it's fine even if it's | ||
// unreachable. | ||
auto sock = socket(AF_INET, SOCK_DGRAM, 0); | ||
if (sock == -1) { | ||
return Fail("Failed to create socket."); | ||
} | ||
|
||
auto paddr = MakeSockAddress(StringView{"10.255.255.255"}, 1); | ||
sockaddr const* addr_handle = reinterpret_cast<const sockaddr*>(&paddr.V4().Handle()); | ||
socklen_t addr_len{sizeof(paddr.V4().Handle())}; | ||
auto err = connect(sock, addr_handle, addr_len); | ||
if (err != 0) { | ||
return system::FailWithCode("Failed to find IP address."); | ||
} | ||
|
||
// get the IP address from socket desrciptor | ||
struct sockaddr_in addr; | ||
socklen_t len = sizeof(addr); | ||
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) { | ||
return Fail("Failed to get sock name."); | ||
} | ||
ip = inet_ntoa(addr.sin_addr); | ||
|
||
err = system::CloseSocket(sock); | ||
if (err != 0) { | ||
return system::FailWithCode("Failed to close socket."); | ||
} | ||
|
||
*out = ip; | ||
return Success(); | ||
} | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#pragma once | ||
#include <string> // for string | ||
|
||
#include "xgboost/collective/result.h" // for Result | ||
|
||
namespace xgboost::collective { | ||
// Prob the public IP address of the host, need a better method. | ||
// | ||
// This is directly translated from the previous Python implementation, we should find a | ||
// more riguous approach, can use some expertise in network programming. | ||
[[nodiscard]] Result GetHostAddress(std::string* out); | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/** | ||
* Copyright 2022-2023, XGBoost Contributors | ||
*/ | ||
#pragma once | ||
|
||
#include <gtest/gtest.h> | ||
#include <xgboost/collective/socket.h> | ||
|
||
#include <fstream> // ifstream | ||
|
||
#include "../helpers.h" // for FileExists | ||
|
||
namespace xgboost::collective { | ||
class SocketTest : public ::testing::Test { | ||
protected: | ||
std::string skip_msg_{"Skipping IPv6 test"}; | ||
|
||
bool SkipTest() { | ||
std::string path{"/sys/module/ipv6/parameters/disable"}; | ||
if (FileExists(path)) { | ||
std::ifstream fin(path); | ||
if (!fin) { | ||
return true; | ||
} | ||
std::string s_value; | ||
fin >> s_value; | ||
auto value = std::stoi(s_value); | ||
if (value != 0) { | ||
return true; | ||
} | ||
} else { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
protected: | ||
void SetUp() override { system::SocketStartup(); } | ||
void TearDown() override { system::SocketFinalize(); } | ||
}; | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include "../../../src/collective/tracker.h" // for GetHostAddress | ||
#include "net_test.h" // for SocketTest | ||
|
||
namespace xgboost::collective { | ||
namespace { | ||
class TrackerTest : public SocketTest {}; | ||
} // namespace | ||
|
||
TEST_F(TrackerTest, GetHostAddress) { | ||
std::string host; | ||
auto rc = GetHostAddress(&host); | ||
ASSERT_TRUE(rc.OK()); | ||
ASSERT_TRUE(host.find("127.") == std::string::npos); | ||
} | ||
} // namespace xgboost::collective |