Skip to content

Commit

Permalink
[#674] Support env var in ssh config (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
yohamta authored Sep 11, 2024
1 parent 87a33e6 commit 4065af3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 8 deletions.
48 changes: 40 additions & 8 deletions internal/dag/executor/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"errors"
"fmt"
"io"
"net"
"os"
"reflect"
"strings"

"github.com/mitchellh/mapstructure"
Expand All @@ -37,15 +39,23 @@ type sshExec struct {
session *ssh.Session
}

type sshExecConfig struct {
type sshExecConfigDefinition struct {
User string
IP string
Port int
Port any
Key string
Password string
StrictHostKeyChecking bool
}

type sshExecConfig struct {
User string
IP string
Port string
Key string
Password string
}

// selectSSHAuthMethod selects the authentication method based on the configuration.
// If the key is provided, it will use the public key authentication method.
// Otherwise, it will use the password authentication method.
Expand All @@ -67,10 +77,21 @@ func selectSSHAuthMethod(cfg *sshExecConfig) (ssh.AuthMethod, error) {
return ssh.Password(cfg.Password), nil
}

// expandEnvHook is a mapstructure decode hook that expands environment variables in string fields
func expandEnvHook(f reflect.Type, t reflect.Type, data any) (any, error) {
if f.Kind() != reflect.String || t.Kind() != reflect.String {
return data, nil
}
return os.ExpandEnv(data.(string)), nil
}

func newSSHExec(_ context.Context, step dag.Step) (Executor, error) {
cfg := new(sshExecConfig)
def := new(sshExecConfigDefinition)
md, err := mapstructure.NewDecoder(
&mapstructure.DecoderConfig{Result: cfg},
&mapstructure.DecoderConfig{
Result: def,
DecodeHook: expandEnvHook,
},
)

if err != nil {
Expand All @@ -81,11 +102,22 @@ func newSSHExec(_ context.Context, step dag.Step) (Executor, error) {
return nil, err
}

if cfg.Port == 0 {
cfg.Port = 22
cfg := &sshExecConfig{
User: def.User,
IP: def.IP,
Key: def.Key,
Password: def.Password,
}

// Handle Port as either string or int
port := os.ExpandEnv(fmt.Sprintf("%v", def.Port))
if port == "" {
port = "22"
}
cfg.Port = port

if cfg.StrictHostKeyChecking {
// StrictHostKeyChecking is not supported yet.
if def.StrictHostKeyChecking {
return nil, errStrictHostKey
}

Expand Down Expand Up @@ -130,7 +162,7 @@ func (e *sshExec) Kill(_ os.Signal) error {
}

func (e *sshExec) Run() error {
addr := fmt.Sprintf("%s:%d", e.config.IP, e.config.Port)
addr := net.JoinHostPort(e.config.IP, e.config.Port)
conn, err := ssh.Dial("tcp", addr, e.sshConfig)
if err != nil {
return err
Expand Down
87 changes: 87 additions & 0 deletions internal/dag/executor/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (C) 2024 The Dagu Authors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package executor

import (
"context"
"os"
"testing"

"github.com/dagu-org/dagu/internal/dag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSSHExecutor(t *testing.T) {
t.Parallel()

t.Run("Basic", func(t *testing.T) {
step := dag.Step{
Name: "ssh-exec",
ExecutorConfig: dag.ExecutorConfig{
Type: "ssh",
Config: map[string]any{
"User": "testuser",
"IP": "testip",
"Port": 25,
"Password": "testpassword",
},
},
}
ctx := context.Background()
exec, err := newSSHExec(ctx, step)
require.NoError(t, err)

sshExec, ok := exec.(*sshExec)
require.True(t, ok)

assert.Equal(t, "testuser", sshExec.config.User)
assert.Equal(t, "testip", sshExec.config.IP)
assert.Equal(t, "25", sshExec.config.Port)
assert.Equal(t, "testpassword", sshExec.config.Password)
})

t.Run("ExpandEnv", func(t *testing.T) {
os.Setenv("TEST_SSH_EXEC_USER", "testuser")
os.Setenv("TEST_SSH_EXEC_IP", "testip")
os.Setenv("TEST_SSH_EXEC_PORT", "23")
os.Setenv("TEST_SSH_EXEC_PASSWORD", "testpassword")

step := dag.Step{
Name: "ssh-exec",
ExecutorConfig: dag.ExecutorConfig{
Type: "ssh",
Config: map[string]any{
"User": "${TEST_SSH_EXEC_USER}",
"IP": "${TEST_SSH_EXEC_IP}",
"Port": "${TEST_SSH_EXEC_PORT}",
"Password": "${TEST_SSH_EXEC_PASSWORD}",
},
},
}
ctx := context.Background()
exec, err := newSSHExec(ctx, step)
require.NoError(t, err)

sshExec, ok := exec.(*sshExec)
require.True(t, ok)

assert.Equal(t, "testuser", sshExec.config.User)
assert.Equal(t, "testip", sshExec.config.IP)
assert.Equal(t, "23", sshExec.config.Port)
assert.Equal(t, "testpassword", sshExec.config.Password)
})
}

0 comments on commit 4065af3

Please sign in to comment.