Skip to content

Commit

Permalink
feat: Add support for SSH key passphrase
Browse files Browse the repository at this point in the history
This commit adds support for handling password-protected SSH keys in both
interactive and CI environments. Key features include:

1. New SSHConfig class for managing SSH configuration
2. Support for loading SSH key passphrase from config or environment variables
3. Enhanced _start_ssh_agent method to handle protected keys
4. Added comprehensive test coverage for the new functionality

Signed-off-by: longhao <hal.long@outlook.com>
  • Loading branch information
loonghao committed Dec 27, 2024
1 parent 189c2d6 commit e51717d
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 36 deletions.
21 changes: 21 additions & 0 deletions persistent_ssh_agent/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Configuration management for SSH agent."""
from dataclasses import dataclass
from typing import Dict, Optional


@dataclass
class SSHConfig:
"""SSH configuration class."""
# Path to identity file
identity_file: Optional[str] = None
# Identity file content (for CI environments)
identity_content: Optional[str] = None
# Identity file passphrase
identity_passphrase: Optional[str] = None
# Additional SSH options
ssh_options: Dict[str, str] = None

def __post_init__(self):
"""Initialize default values."""
if self.ssh_options is None:
self.ssh_options = {}
153 changes: 117 additions & 36 deletions persistent_ssh_agent/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
import json
import logging
import os
from pathlib import Path
import re
import subprocess
import time
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from .config import SSHConfig


logger = logging.getLogger(__name__)

Expand All @@ -31,14 +34,20 @@ class PersistentSSHAgent:
authentication for various operations including Git.
"""

def __init__(self, expiration_time: int = 86400):
"""Initialize SSH manager."""
def __init__(self, expiration_time: int = 86400, config: Optional[SSHConfig] = None):
"""Initialize SSH manager.
Args:
expiration_time: Time in seconds before agent info expires
config: Optional SSH configuration
"""
self._ensure_home_env()
self._ssh_dir = Path.home() / ".ssh"
self._agent_info_file = self._ssh_dir / "agent_info.json"
self._ssh_config_cache: Dict[str, Dict[str, Any]] = {}
self._ssh_agent_started = False
self._expiration_time = expiration_time
self._config = config

def _ensure_home_env(self) -> None:
"""Ensure HOME environment variable is set correctly.
Expand Down Expand Up @@ -153,6 +162,7 @@ def _start_ssh_agent(self, identity_file: str) -> bool:
if self._load_agent_info():
logger.debug("Reused existing SSH agent")
return True

# Kill any existing SSH agents
if os.name == "nt":
self._run_command(["taskkill", "/F", "/IM", "ssh-agent.exe"])
Expand Down Expand Up @@ -182,10 +192,25 @@ def _start_ssh_agent(self, identity_file: str) -> bool:
if auth_sock and agent_pid:
self._save_agent_info(auth_sock, agent_pid)

# Get passphrase from config or environment
passphrase = None
if self._config and self._config.identity_passphrase:
passphrase = self._config.identity_passphrase
elif "SSH_KEY_PASSPHRASE" in os.environ:
passphrase = os.environ["SSH_KEY_PASSPHRASE"]

# Add the key
result = self._run_command(
["ssh-add", identity_file],
)
if passphrase:
# Use sshpass for keys with passphrase
result = self._run_command(
["ssh-add", identity_file],
input=passphrase.encode()
)
else:
# Try without passphrase
result = self._run_command(
["ssh-add", identity_file],
)

if result is None or result.returncode != 0:
logger.error("Failed to add key: %s", result.stderr if result else "Command failed")
Expand Down Expand Up @@ -292,11 +317,12 @@ def setup_ssh(self, hostname: str) -> bool:
return False

def _run_command(
self,
cmd: Union[str, List[str]],
check_output: bool = True,
shell: bool = False,
env: Optional[Dict[str, str]] = None
self,
cmd: Union[str, List[str]],
check_output: bool = True,
shell: bool = False,
env: Optional[Dict[str, str]] = None,
input: Optional[bytes] = None
) -> Optional[subprocess.CompletedProcess]:
"""Run a command and handle its output.
Expand All @@ -305,6 +331,7 @@ def _run_command(
check_output: Whether to check command output
shell: Whether to run command through shell
env: Environment variables to use
input: Input to pass to the command
Returns:
CompletedProcess instance or None if command failed
Expand All @@ -314,22 +341,37 @@ def _run_command(
If check_output is False, command result is returned regardless of exit code
"""
try:
logger.debug("Running command: %s", cmd)
if isinstance(cmd, str):
cmd = cmd.split()

logger.debug("Running command: %s", " ".join(cmd))

# Merge current environment with provided environment
merged_env = os.environ.copy()
if env:
merged_env.update(env)

result = subprocess.run(
cmd,
capture_output=True,
text=True,
env= env or os.environ.copy(),
shell=shell
shell=shell,
env=merged_env,
input=input
)

if check_output and result.returncode != 0:
logger.debug("Command failed with code %d", result.returncode)
logger.debug("stdout: %s", result.stdout)
logger.debug("stderr: %s", result.stderr)
logger.error(
"Command failed with exit code %d: %s",
result.returncode,
result.stderr.strip()
)
return None

return result

except Exception as e:
logger.error("Command execution failed: %s", e)
logger.error("Failed to run command: %s", e)
return None

def _get_identity_file(self, hostname: str) -> Optional[str]:
Expand All @@ -343,32 +385,71 @@ def _get_identity_file(self, hostname: str) -> Optional[str]:
Note:
The search order is:
1. Exact match in SSH config
2. Pattern match in SSH config (e.g. *.example.com)
3. Default key files (id_ed25519, id_rsa)
1. Config provided identity file/content
2. Environment variable SSH_IDENTITY_FILE
3. Environment variable SSH_IDENTITY_CONTENT
4. Exact match in SSH config
5. Pattern match in SSH config (e.g. *.example.com)
6. Default key files (id_ed25519, id_rsa)
"""
# Check config first
if self._config and (self._config.identity_file or self._config.identity_content):
if self._config.identity_file and os.path.exists(self._config.identity_file):
return self._config.identity_file
elif self._config.identity_content:
return self._write_temp_key(self._config.identity_content)

# Check environment variables
env_file = os.environ.get("SSH_IDENTITY_FILE")
if env_file and os.path.exists(env_file):
return env_file

env_content = os.environ.get("SSH_IDENTITY_CONTENT")
if env_content:
return self._write_temp_key(env_content)

# Check SSH config
config = self._parse_ssh_config()

# Try exact hostname match first
if hostname in config and "identityfile" in config[hostname]:
identity_file = os.path.expanduser(config[hostname]["identityfile"])
return str(Path(identity_file))

# Try exact hostname match
if hostname in config and "IdentityFile" in config[hostname]:
identity_file = os.path.expanduser(config[hostname]["IdentityFile"])
if os.path.exists(identity_file):
return identity_file

# Try pattern matching
for pattern, settings in config.items():
if "*" in pattern and "identityfile" in settings and fnmatch.fnmatch(hostname, pattern):
# Convert SSH config pattern to fnmatch pattern
# SSH uses * and ? as wildcards, which is compatible with fnmatch
identity_file = os.path.expanduser(settings["identityfile"])
return str(Path(identity_file))
if "IdentityFile" in settings and fnmatch.fnmatch(hostname, pattern):
identity_file = os.path.expanduser(settings["IdentityFile"])
if os.path.exists(identity_file):
return identity_file

# Default to standard key files if no match in config
# Try default key files
for key_name in ["id_ed25519", "id_rsa"]:
key_path = self._ssh_dir / key_name
if key_path.exists():
return str(key_path)
identity_file = os.path.join(str(self._ssh_dir), key_name)
if os.path.exists(identity_file):
return identity_file

return None

return str(self._ssh_dir / "id_rsa") # Return default path even if it doesn't exist
def _write_temp_key(self, key_content: str) -> str:
"""Write key content to a temporary file.
Args:
key_content: SSH key content to write
Returns:
str: Path to temporary key file
"""
try:
with NamedTemporaryFile(mode='w', delete=False) as temp_file:
temp_file.write(key_content)
# Set correct permissions
os.chmod(temp_file.name, 0o600)
return temp_file.name
except Exception as e:
logger.error("Failed to write temporary key file: %s", e)
return None

def _parse_ssh_config(self) -> Dict[str, Dict[str, str]]:
"""Parse SSH config file to get host-specific configurations.
Expand Down
125 changes: 125 additions & 0 deletions tests/test_ssh_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Test SSH configuration functionality."""
import os
import tempfile
from pathlib import Path

import pytest

from persistent_ssh_agent.config import SSHConfig
from persistent_ssh_agent.core import PersistentSSHAgent


@pytest.fixture
def ssh_key_content():
"""Sample SSH key content for testing."""
return """-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
-----END OPENSSH PRIVATE KEY-----"""


@pytest.fixture
def protected_ssh_key_content():
"""Sample password-protected SSH key content for testing."""
return """-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABD+UwXz5w
-----END OPENSSH PRIVATE KEY-----"""


def test_config_identity_file():
"""Test loading identity file from config."""
with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
temp_file.write("test key content")
temp_file.flush()

config = SSHConfig(identity_file=temp_file.name)
agent = PersistentSSHAgent(config=config)

identity_file = agent._get_identity_file("github.com")
assert identity_file == temp_file.name

os.unlink(temp_file.name)


def test_config_identity_content(ssh_key_content):
"""Test loading identity content from config."""
config = SSHConfig(identity_content=ssh_key_content)
agent = PersistentSSHAgent(config=config)

identity_file = agent._get_identity_file("github.com")
assert identity_file is not None
assert os.path.exists(identity_file)

with open(identity_file, 'r') as f:
content = f.read()
assert content == ssh_key_content

os.unlink(identity_file)


def test_config_with_passphrase(protected_ssh_key_content):
"""Test loading identity with passphrase from config."""
config = SSHConfig(
identity_content=protected_ssh_key_content,
identity_passphrase="testpass"
)
agent = PersistentSSHAgent(config=config)

identity_file = agent._get_identity_file("github.com")
assert identity_file is not None
assert os.path.exists(identity_file)

# Start SSH agent with the protected key
assert agent._start_ssh_agent(identity_file)

os.unlink(identity_file)


def test_env_identity_file():
"""Test loading identity file from environment variable."""
with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
temp_file.write("test key content")
temp_file.flush()

os.environ["SSH_IDENTITY_FILE"] = temp_file.name
agent = PersistentSSHAgent()

identity_file = agent._get_identity_file("github.com")
assert identity_file == temp_file.name

del os.environ["SSH_IDENTITY_FILE"]
os.unlink(temp_file.name)


def test_env_identity_content(ssh_key_content):
"""Test loading identity content from environment variable."""
os.environ["SSH_IDENTITY_CONTENT"] = ssh_key_content
agent = PersistentSSHAgent()

identity_file = agent._get_identity_file("github.com")
assert identity_file is not None
assert os.path.exists(identity_file)

with open(identity_file, 'r') as f:
content = f.read()
assert content == ssh_key_content

del os.environ["SSH_IDENTITY_CONTENT"]
os.unlink(identity_file)


def test_env_with_passphrase(protected_ssh_key_content):
"""Test loading identity with passphrase from environment."""
os.environ["SSH_IDENTITY_CONTENT"] = protected_ssh_key_content
os.environ["SSH_KEY_PASSPHRASE"] = "testpass"
agent = PersistentSSHAgent()

identity_file = agent._get_identity_file("github.com")
assert identity_file is not None
assert os.path.exists(identity_file)

# Start SSH agent with the protected key
assert agent._start_ssh_agent(identity_file)

del os.environ["SSH_IDENTITY_CONTENT"]
del os.environ["SSH_KEY_PASSPHRASE"]
os.unlink(identity_file)

0 comments on commit e51717d

Please sign in to comment.