diff --git a/src/shell/command_utils.cpp b/src/shell/command_utils.cpp new file mode 100644 index 0000000000..733d01fde4 --- /dev/null +++ b/src/shell/command_utils.cpp @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "command_utils.h" + +#include "command_executor.h" + +bool validate_ip(shell_context *sc, + const std::string &ip_str, + dsn::rpc_address &target_address, + std::string &err_info) +{ + if (!target_address.from_string_ipv4(ip_str.c_str())) { + err_info = fmt::format("invalid ip:port={}, can't transform it into rpc_address", ip_str); + return false; + } + + std::map nodes; + auto error = sc->ddl_client->list_nodes(dsn::replication::node_status::NS_INVALID, nodes); + if (error != dsn::ERR_OK) { + err_info = fmt::format("list nodes failed, error={}", error.to_string()); + return false; + } + + for (const auto &node : nodes) { + if (target_address == node.first) { + return true; + } + } + + err_info = fmt::format("invalid ip:port={}, can't find it in the cluster", ip_str); + return false; +} diff --git a/src/shell/command_utils.h b/src/shell/command_utils.h index 2aa480b9f3..1775169d49 100644 --- a/src/shell/command_utils.h +++ b/src/shell/command_utils.h @@ -11,6 +11,11 @@ #include "shell/argh.h" #include +namespace dsn { +class rpc_address; +} +class shell_context; + inline bool validate_cmd(const argh::parser &cmd, const std::set ¶ms, const std::set &flags) @@ -42,6 +47,11 @@ inline bool validate_cmd(const argh::parser &cmd, return true; } +bool validate_ip(shell_context *sc, + const std::string &ip_str, + /*out*/ dsn::rpc_address &target_address, + /*out*/ std::string &err_info); + #define verify_logged(exp, ...) \ do { \ if (!(exp)) { \ diff --git a/src/shell/commands/detect_hotkey.cpp b/src/shell/commands/detect_hotkey.cpp index 900e70f7cc..d2ffed7417 100644 --- a/src/shell/commands/detect_hotkey.cpp +++ b/src/shell/commands/detect_hotkey.cpp @@ -85,13 +85,13 @@ bool detect_hotkey(command_executor *e, shell_context *sc, arguments args) } dsn::rpc_address target_address; + std::string err_info; std::string ip_str = cmd({"-d", "--address"}).str(); - if (!target_address.from_string_ipv4(ip_str.c_str())) { - fmt::print("invalid ip, error={}\n", ip_str); + if (!validate_ip(sc, ip_str, target_address, err_info)) { + fmt::print(stderr, "{}\n", err_info); return false; } - std::string err_info; std::string hotkey_action = cmd({"-c", "--hotkey_action"}).str(); std::string hotkey_type = cmd({"-t", "--hotkey_type"}).str(); dsn::replication::detect_hotkey_request req;