Skip to content

Commit

Permalink
refactor(core.py): Improve SSH config parsing and error handling
Browse files Browse the repository at this point in the history
Signed-off-by: longhao <hal.long@outlook.com>
  • Loading branch information
loonghao committed Dec 28, 2024
1 parent 0345ada commit 7f921fe
Showing 1 changed file with 90 additions and 58 deletions.
148 changes: 90 additions & 58 deletions persistent_ssh_agent/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,12 @@ def _start_ssh_agent(self, identity_file: str) -> bool:
logger.debug("Key already loaded: %s", identity_file)
return True

# Start SSH agent based on platform
command = ["ssh-agent", "-s"] if os.name == "nt" else ["ssh-agent"]
result = subprocess.run(command, capture_output=True, text=True, check=False) if os.name == "nt" else self.run_command(command)
# Start SSH agent with platform-specific command
command = ["ssh-agent"]
if os.name == "nt":
command.append("-s")

result = self.run_command(command)
if not result or result.returncode != 0:
logger.error("Failed to start SSH agent")
return False
Expand Down Expand Up @@ -592,15 +594,15 @@ def _get_available_keys(self) -> List[str]:
available_keys = set() # Use set to avoid duplicates
for key_type in self.SSH_KEY_TYPES:
# Check for base key type (e.g., id_rsa)
key_path = os.path.join(self._ssh_dir, key_type)
pub_key_path = f"{key_path}.pub"
key_path = os.path.join(str(self._ssh_dir), key_type)
pub_key_path = key_path + ".pub"
if os.path.exists(key_path) and os.path.exists(pub_key_path):
available_keys.add(str(Path(key_path)).replace("\\", "/"))

# Check for keys with numeric suffixes (e.g., id_rsa2)
pattern = os.path.join(self._ssh_dir, f"{key_type}[0-9]*")
pattern = os.path.join(str(self._ssh_dir), f"{key_type}[0-9]*")
for numbered_key_path in glob.glob(pattern):
pub_key_path = f"{numbered_key_path}.pub"
pub_key_path = numbered_key_path + ".pub"
if os.path.exists(numbered_key_path) and os.path.exists(pub_key_path):
available_keys.add(str(Path(numbered_key_path)).replace("\\", "/"))

Expand Down Expand Up @@ -633,53 +635,77 @@ def _get_identity_file(self, hostname: str) -> Optional[str]:
return str(Path(os.path.join(self._ssh_dir, "id_rsa")))

def _parse_ssh_config(self) -> Dict[str, Dict[str, str]]:
"""Parse SSH config file to get host-specific configurations.
Returns:
dict: SSH configuration mapping hostnames to their settings
"""
if self._ssh_config_cache:
return self._ssh_config_cache

ssh_config_path = self._ssh_dir / "config"
if not ssh_config_path.exists():
logger.debug("No SSH config file found at: %s", ssh_config_path)
return {}

logger.debug("Parsing SSH config file: %s", ssh_config_path)
"""Parse SSH config file to get host-specific configurations."""
config = {}
current_host = None
current_match = None
ssh_config_path = self._ssh_dir / "config"

# Valid SSH config keys and their validation functions
if not ssh_config_path.exists():
return config

# Define valid keys and their validation functions
valid_keys = {
"host": lambda x: True, # Host patterns are handled separately
"hostname": lambda x: self.is_valid_hostname(x),
"port": lambda x: str(x).isdigit() and 1 <= int(x) <= 65535,
"user": lambda x: bool(x and not any(c in x for c in " \t\n\r")),
"identityfile": lambda x: True, # Path validation handled elsewhere
"identitiesonly": lambda x: x.lower() in ("yes", "no"),
"forwardagent": lambda x: x.lower() in ("yes", "no"),
"proxycommand": lambda x: bool(x),
"proxyhost": lambda x: self.is_valid_hostname(x),
"proxyport": lambda x: str(x).isdigit() and 1 <= int(x) <= 65535,
"proxyuser": lambda x: bool(x and not any(c in x for c in " \t\n\r")),
"stricthostkeychecking": lambda x: x.lower() in ("yes", "no", "accept-new", "off", "ask"),
"userknownhostsfile": lambda x: True, # Path validation handled elsewhere
"batchmode": lambda x: x.lower() in ("yes", "no"),
"compression": lambda x: x.lower() in ("yes", "no"),
"hostname": lambda x: True, # Any string is valid for hostname
"user": lambda x: True, # Any string is valid for user
"port": lambda x: x.isdigit(), # Must be numeric
"identityfile": lambda x: True, # Any path is valid
"stricthostkeychecking": lambda x: x.lower() in ["yes", "no"],
# Add more keys as needed
}

def is_valid_host_pattern(pattern: str) -> bool:
"""Check if a host pattern is valid."""
"""
Check if a host pattern is valid.
A valid host pattern can contain:
- Wildcards (* and ?)
- Negation (! at the start)
- Multiple patterns separated by spaces
- Most printable characters except control characters
- IPv6 addresses in square brackets
Args:
pattern (str): The host pattern to validate.
Returns:
bool: True if the pattern is valid, False otherwise.
"""
if not pattern:
return False
# Allow wildcards and negation
if pattern == "*" or pattern.startswith("!"):

# Special cases
if pattern == "*":
return True
# Check for invalid characters that shouldn't be in a hostname
invalid_chars = "|[]{}\\;"
return not any(c in pattern for c in invalid_chars)

# Split multiple patterns
patterns = pattern.split()
for p in patterns:
# Skip empty patterns
if not p:
continue

# Allow negation prefix
if p.startswith("!"):
p = p[1:]

# Skip empty patterns after removing prefix
if not p:
continue

# Check for control characters
if any(c in p for c in "\0\n\r\t"):
return False

# Allow IPv6 addresses in square brackets
if p.startswith("[") and p.endswith("]"):
# Basic IPv6 validation
p = p[1:-1]
if not all(c in "0123456789abcdefABCDEF:" for c in p):
return False
continue

return True

def process_config_line(line: str, config_file: str = str(ssh_config_path)) -> None:
"""Process a single line from SSH config file."""
Expand Down Expand Up @@ -713,19 +739,24 @@ def process_config_line(line: str, config_file: str = str(ssh_config_path)) -> N

# Handle Match blocks
if line.lower().startswith("match "):
current_match = line.split(None, 1)[1].lower()
parts = line.split(None, 2)
if len(parts) >= 3 and parts[1].lower() == "host":
current_match = parts[2]
current_host = current_match
if current_host not in config:
config[current_host] = {}
return

# Handle Host blocks
if line.lower().startswith("host "):
current_host = line.split(None, 1)[1]
if is_valid_host_pattern(current_host):
if is_valid_host_pattern(current_host) and current_host not in config:
config[current_host] = {}
current_match = None
return

# Parse key-value pairs
if current_host:
if current_host is not None:
# Split line into key and value
if "=" in line:
key, value = [x.strip() for x in line.split("=", 1)]
Expand All @@ -748,24 +779,25 @@ def process_config_line(line: str, config_file: str = str(ssh_config_path)) -> N
logger.debug(f"Skipping invalid config value in {config_file}: {key}={value}")
return

# Apply match block settings
if current_match:
if current_match == "all" or current_host in current_match:
config[current_host][key] = value
else:
config[current_host][key] = value
# Apply settings
if current_host not in config:
config[current_host] = {}
config[current_host][key] = value

try:
with open(ssh_config_path) as f:
# Reset config for each parse attempt
config.clear()
current_host = None
current_match = None

for line in f:
process_config_line(line)

self._ssh_config_cache = config
return config

except Exception as e:
logger.error(f"Failed to parse SSH config: {e}")
return {}
logger.debug(f"Failed to parse SSH config: {e}")
config.clear() # Clear config on error

return config

def _extract_hostname(self, url: str) -> Optional[str]:
"""Extract hostname from SSH URL.
Expand Down

0 comments on commit 7f921fe

Please sign in to comment.