Skip to content

Commit

Permalink
support ssh protocol config choose if reuse connection (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsun28 authored Jul 28, 2023
1 parent 0b47919 commit 6102d02
Show file tree
Hide file tree
Showing 22 changed files with 369 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ private void cleanTimeoutCache() {
* @param timeDiff 缓存对象保存时间 millis
*/
public void addCache(Object key, Object value, Long timeDiff) {
removeCache(key);
if (timeDiff == null) {
timeDiff = DEFAULT_CACHE_TIMEOUT;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,25 @@
@Slf4j
public class CommonSshClient {

private static SshClient sshClient;
private static final SshClient SSH_CLIENT;


static {
sshClient = SshClient.setUpDefaultClient();
SSH_CLIENT = SshClient.setUpDefaultClient();
// 接受所有服务端公钥校验,会打印warn日志 Server at {} presented unverified {} key: {}
AcceptAllServerKeyVerifier verifier = AcceptAllServerKeyVerifier.INSTANCE;
sshClient.setServerKeyVerifier(verifier);
// 设置链接保活心跳2000毫秒一次, 客户端等待保活心跳响应超时时间300_0000毫秒
SSH_CLIENT.setServerKeyVerifier(verifier);
// 设置链接保活心跳2000毫秒一次, 客户端等待保活心跳响应超时时间300_000毫秒
PropertyResolverUtils.updateProperty(
sshClient, CoreModuleProperties.HEARTBEAT_INTERVAL.getName(), 2000);
SSH_CLIENT, CoreModuleProperties.HEARTBEAT_INTERVAL.getName(), 2000);
PropertyResolverUtils.updateProperty(
sshClient, CoreModuleProperties.HEARTBEAT_REPLY_WAIT.getName(), 300_000);
sshClient.start();
SSH_CLIENT, CoreModuleProperties.HEARTBEAT_REPLY_WAIT.getName(), 300_000);
PropertyResolverUtils.updateProperty(
SSH_CLIENT, CoreModuleProperties.SOCKET_KEEPALIVE.getName(), true);
SSH_CLIENT.start();
}

public static SshClient getSshClient() {
return sshClient;
return SSH_CLIENT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.dromara.hertzbeat.collector.collect.mongodb;

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Optional;

Expand Down Expand Up @@ -47,20 +47,18 @@
* Mongodb 单机指标收集器
*
* @author <a href="mailto:liudonghua123@gmail.com">liudonghua</a>
* @version 1.0
* Created by liudonghua on 2023/01/01
* see also https://www.mongodb.com/languages/java,
* https://www.mongodb.com/docs/manual/reference/command/serverStatus/#metrics
* see also https://www.mongodb.com/languages/java,
* https://www.mongodb.com/docs/manual/reference/command/serverStatus/#metrics
*/
@Slf4j
public class MongodbSingleCollectImpl extends AbstractCollect {

/**
* 支持的 mongodb diagnostic 命令,排除internal/deprecated相关的命令
* 可参考 https://www.mongodb.com/docs/manual/reference/command/nav-diagnostic/,
* https://www.mongodb.com/docs/mongodb-shell/run-commands/
* 可参考 <a href="https://www.mongodb.com/docs/manual/reference/command/nav-diagnostic/">...</a>,
* <a href="https://www.mongodb.com/docs/mongodb-shell/run-commands/">...</a>
* 注意:一些命令需要相应的权限才能执行,否则执行虽然不会报错,但是返回的结果是空的,
* 详见 https://www.mongodb.com/docs/manual/reference/built-in-roles/
* 详见 <a href="https://www.mongodb.com/docs/manual/reference/built-in-roles/">...</a>
*/
private static final String[] SUPPORTED_MONGODB_DIAGNOSTIC_COMMANDS = {
"buildInfo",
Expand Down Expand Up @@ -199,14 +197,10 @@ private MongoClient getClient(Metrics metrics) {
}
// 复用失败则新建连接 connect to mongodb
String url;
try {
// 密码可能包含特殊字符,需要使用类似js的encodeURIComponent进行编码,这里使用java的URLEncoder
url = String.format("mongodb://%s:%s@%s:%s/%s?authSource=%s", mongodbProtocol.getUsername(),
URLEncoder.encode(mongodbProtocol.getPassword(), "UTF-8"), mongodbProtocol.getHost(), mongodbProtocol.getPort(),
mongodbProtocol.getDatabase(), mongodbProtocol.getAuthenticationDatabase());
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
// 密码可能包含特殊字符,需要使用类似js的encodeURIComponent进行编码,这里使用java的URLEncoder
url = String.format("mongodb://%s:%s@%s:%s/%s?authSource=%s", mongodbProtocol.getUsername(),
URLEncoder.encode(mongodbProtocol.getPassword(), StandardCharsets.UTF_8), mongodbProtocol.getHost(), mongodbProtocol.getPort(),
mongodbProtocol.getDatabase(), mongodbProtocol.getAuthenticationDatabase());
mongoClient = MongoClients.create(url);
MongodbConnect mongodbConnect = new MongodbConnect(mongoClient);
CommonCache.getInstance().addCache(identifier, mongodbConnect);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.dromara.hertzbeat.collector.collect.ssh;

import org.apache.sshd.common.SshException;
import org.apache.sshd.common.channel.exception.SshChannelOpenException;
import org.apache.sshd.common.util.io.output.NoCloseOutputStream;
import org.apache.sshd.common.util.security.SecurityUtils;
import org.dromara.hertzbeat.collector.collect.AbstractCollect;
import org.dromara.hertzbeat.collector.collect.common.cache.CacheIdentifier;
Expand All @@ -43,9 +46,11 @@
import java.io.FileInputStream;
import java.io.IOException;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -83,24 +88,24 @@ public void collect(CollectRep.MetricsData.Builder builder, long appId, String a
return;
}
SshProtocol sshProtocol = metrics.getSsh();
boolean reuseConnection = Boolean.parseBoolean(sshProtocol.getReuseConnection());
int timeout = CollectUtil.getTimeout(sshProtocol.getTimeout(), DEFAULT_TIMEOUT);
ClientChannel channel = null;
ClientSession clientSession = null;
try {
ClientSession clientSession = getConnectSession(sshProtocol, timeout);
clientSession = getConnectSession(sshProtocol, timeout, reuseConnection);
channel = clientSession.createExecChannel(sshProtocol.getScript());
ByteArrayOutputStream response = new ByteArrayOutputStream();
channel.setOut(response);
if (!channel.open().verify(timeout).isOpened()) {
removeConnectSessionCache(sshProtocol);
channel.close();
clientSession.close();
throw new Exception("ssh channel open failed");
}
channel.setErr(new NoCloseOutputStream(System.err));
channel.open().verify(timeout);
List<ClientChannelEvent> list = new ArrayList<>();
list.add(ClientChannelEvent.CLOSED);
channel.waitFor(list, timeout);
Collection<ClientChannelEvent> waitEvents = channel.waitFor(list, timeout);
if (waitEvents.contains(ClientChannelEvent.TIMEOUT)) {
throw new SocketTimeoutException("Failed to retrieve command result in time: " + sshProtocol.getScript());
}
Long responseTime = System.currentTimeMillis() - startTime;
channel.close();
String result = response.toString();
if (!StringUtils.hasText(result)) {
builder.setCode(CollectRep.Code.FAIL);
Expand All @@ -126,11 +131,19 @@ public void collect(CollectRep.MetricsData.Builder builder, long appId, String a
log.info(errorMsg);
builder.setCode(CollectRep.Code.UN_CONNECTABLE);
builder.setMsg("The peer refused to connect: service port does not listening or firewall: " + errorMsg);
} catch (SshException sshException) {
Throwable throwable = sshException.getCause();
if (throwable instanceof SshChannelOpenException) {
log.warn("Remote ssh server no more session channel, please increase sshd_config MaxSessions.");
}
String errorMsg = CommonUtil.getMessageFromThrowable(sshException);
builder.setCode(CollectRep.Code.UN_CONNECTABLE);
builder.setMsg("Peer ssh connection failed: " + errorMsg);
} catch (IOException ioException) {
String errorMsg = CommonUtil.getMessageFromThrowable(ioException);
log.info(errorMsg);
builder.setCode(CollectRep.Code.UN_CONNECTABLE);
builder.setMsg("Peer connection failed: " + errorMsg);
builder.setMsg("Peer io connection failed: " + errorMsg);
} catch (Exception exception) {
String errorMsg = CommonUtil.getMessageFromThrowable(exception);
log.warn(errorMsg, exception);
Expand All @@ -144,6 +157,13 @@ public void collect(CollectRep.MetricsData.Builder builder, long appId, String a
log.error(e.getMessage(), e);
}
}
if (clientSession != null && !reuseConnection) {
try {
clientSession.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
}

Expand Down Expand Up @@ -247,28 +267,31 @@ private void removeConnectSessionCache(SshProtocol sshProtocol) {
CommonCache.getInstance().removeCache(identifier);
}

private ClientSession getConnectSession(SshProtocol sshProtocol, int timeout) throws IOException, GeneralSecurityException {
private ClientSession getConnectSession(SshProtocol sshProtocol, int timeout, boolean reuseConnection)
throws IOException, GeneralSecurityException {
CacheIdentifier identifier = CacheIdentifier.builder()
.ip(sshProtocol.getHost()).port(sshProtocol.getPort())
.username(sshProtocol.getUsername()).password(sshProtocol.getPassword())
.build();
Optional<Object> cacheOption = CommonCache.getInstance().getCache(identifier, true);
.ip(sshProtocol.getHost()).port(sshProtocol.getPort())
.username(sshProtocol.getUsername()).password(sshProtocol.getPassword())
.build();
ClientSession clientSession = null;
if (cacheOption.isPresent()) {
clientSession = ((SshConnect) cacheOption.get()).getConnection();
try {
if (clientSession == null || clientSession.isClosed() || clientSession.isClosing()) {
if (reuseConnection) {
Optional<Object> cacheOption = CommonCache.getInstance().getCache(identifier, true);
if (cacheOption.isPresent()) {
clientSession = ((SshConnect) cacheOption.get()).getConnection();
try {
if (clientSession == null || clientSession.isClosed() || clientSession.isClosing()) {
clientSession = null;
CommonCache.getInstance().removeCache(identifier);
}
} catch (Exception e) {
log.warn(e.getMessage());
clientSession = null;
CommonCache.getInstance().removeCache(identifier);
}
} catch (Exception e) {
log.warn(e.getMessage());
clientSession = null;
CommonCache.getInstance().removeCache(identifier);
}
}
if (clientSession != null) {
return clientSession;
if (clientSession != null) {
return clientSession;
}
}
SshClient sshClient = CommonSshClient.getSshClient();
clientSession = sshClient.connect(sshProtocol.getUsername(), sshProtocol.getHost(), Integer.parseInt(sshProtocol.getPort()))
Expand All @@ -286,8 +309,10 @@ private ClientSession getConnectSession(SshProtocol sshProtocol, int timeout) th
clientSession.close();
throw new IllegalArgumentException("ssh auth failed.");
}
SshConnect sshConnect = new SshConnect(clientSession);
CommonCache.getInstance().addCache(identifier, sshConnect);
if (reuseConnection) {
SshConnect sshConnect = new SshConnect(clientSession);
CommonCache.getInstance().addCache(identifier, sshConnect);
}
return clientSession;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import org.dromara.hertzbeat.collector.collect.AbstractCollect;
import org.springframework.boot.CommandLineRunner;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;

import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -14,6 +16,7 @@
*
*/
@Configuration
@Order(value = Ordered.HIGHEST_PRECEDENCE)
public class CollectStrategyFactory implements CommandLineRunner {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ public class SshProtocol {
* 私钥(可选)
*/
private String privateKey;

/**
* reuse connection session
*/
private String reuseConnection = "true";

/**
* SSH执行脚本
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.dromara.hertzbeat.common.entity.job.Job;
import org.dromara.hertzbeat.common.entity.manager.Monitor;
import org.dromara.hertzbeat.common.entity.manager.Param;
import org.dromara.hertzbeat.common.entity.manager.ParamDefine;
import org.dromara.hertzbeat.common.util.JsonUtil;
import org.dromara.hertzbeat.manager.dao.MonitorDao;
import org.dromara.hertzbeat.manager.dao.ParamDao;
Expand All @@ -30,6 +31,7 @@
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -65,8 +67,6 @@ public void run(String... args) throws Exception {
try {
// 构造采集任务Job实体
Job appDefine = appService.getAppDefine(monitor.getApp());
// todo 这里暂时是深拷贝处理
appDefine = JsonUtil.fromJson(JsonUtil.toJson(appDefine), Job.class);
appDefine.setId(monitor.getJobId());
appDefine.setMonitorId(monitor.getId());
appDefine.setInterval(monitor.getIntervals());
Expand All @@ -75,6 +75,16 @@ public void run(String... args) throws Exception {
List<Param> params = paramDao.findParamsByMonitorId(monitor.getId());
List<Configmap> configmaps = params.stream().map(param ->
new Configmap(param.getField(), param.getValue(), param.getType())).collect(Collectors.toList());
List<ParamDefine> paramDefaultValue = appDefine.getParams().stream()
.filter(item -> StringUtils.hasText(item.getDefaultValue()))
.collect(Collectors.toList());
paramDefaultValue.forEach(defaultVar -> {
if (configmaps.stream().noneMatch(item -> item.getKey().equals(defaultVar.getField()))) {
// todo type
Configmap configmap = new Configmap(defaultVar.getField(), defaultVar.getDefaultValue(), (byte) 1);
configmaps.add(configmap);
}
});
appDefine.setConfigmap(configmaps);
// 下发采集任务
long jobId = collectJobService.addAsyncCollectJob(appDefine);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,16 @@ public void enableManageMonitors(HashSet<Long> ids) {
List<Param> params = paramDao.findParamsByMonitorId(monitor.getId());
List<Configmap> configmaps = params.stream().map(param ->
new Configmap(param.getField(), param.getValue(), param.getType())).collect(Collectors.toList());
List<ParamDefine> paramDefaultValue = appDefine.getParams().stream()
.filter(item -> StringUtils.hasText(item.getDefaultValue()))
.collect(Collectors.toList());
paramDefaultValue.forEach(defaultVar -> {
if (configmaps.stream().noneMatch(item -> item.getKey().equals(defaultVar.getField()))) {
// todo type
Configmap configmap = new Configmap(defaultVar.getField(), defaultVar.getDefaultValue(), (byte) 1);
configmaps.add(configmap);
}
});
appDefine.setConfigmap(configmaps);
// Issue collection tasks 下发采集任务
long newJobId = collectJobService.addAsyncCollectJob(appDefine);
Expand Down
Loading

0 comments on commit 6102d02

Please sign in to comment.