Skip to content

Commit

Permalink
Function to change the remote command in an ssh kitten cmdline
Browse files Browse the repository at this point in the history
  • Loading branch information
kovidgoyal committed Feb 8, 2023
1 parent 237a5d1 commit 2445073
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
26 changes: 2 additions & 24 deletions kittens/ssh/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from contextlib import contextmanager, suppress
from getpass import getuser
from select import select
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple, Union, cast
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Sequence, Tuple, Union, cast

from kitty.constants import cache_dir, runtime_dir, shell_integration_dir, ssh_control_master_template, str_version, terminfo_dir
from kitty.shell_integration import as_str_literal
Expand All @@ -37,7 +37,7 @@
from .copy import CopyInstruction
from .options.types import Options as SSHOptions
from .options.utils import DELETE_ENV_VAR
from .utils import create_shared_memory, ssh_options
from .utils import create_shared_memory, get_ssh_cli, is_extra_arg, passthrough_args


@run_once
Expand Down Expand Up @@ -291,25 +291,6 @@ def bootstrap_script(
return prepare_script(ans, sd, script_type), replacements, shm_name


def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args


def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''


def get_connection_data(args: List[str], cwd: str = '', extra_args: Tuple[str, ...] = ()) -> Optional[SSHConnectionData]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
port: Optional[int] = None
Expand Down Expand Up @@ -405,9 +386,6 @@ def system_exit(self) -> None:
os.execlp(ssh_exe(), 'ssh')


passthrough_args = {f'-{x}' for x in 'NnfGT'}


def parse_ssh_args(args: List[str], extra_args: Tuple[str, ...] = ()) -> Tuple[List[str], List[str], bool, Tuple[str, ...]]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []
Expand Down
88 changes: 86 additions & 2 deletions kittens/ssh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import subprocess
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Sequence, Set, Tuple

from kitty.types import run_once

Expand Down Expand Up @@ -51,9 +51,14 @@ def ssh_options() -> Dict[str, str]:


def is_kitten_cmdline(q: Sequence[str]) -> bool:
if not q:
return False
exe_name = os.path.basename(q[0]).lower()
if exe_name == 'kitten' and q[1:2] == ['ssh']:
return True
if len(q) < 4:
return False
if os.path.basename(q[0]).lower() != 'kitty':
if exe_name != 'kitty':
return False
if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
return True
Expand Down Expand Up @@ -91,3 +96,82 @@ def create_shared_memory(data: Any, prefix: str) -> str:

def set_env_in_cmdline(env: Dict[str, str], argv: List[str]) -> None:
patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)



def get_ssh_cli() -> Tuple[Set[str], Set[str]]:
other_ssh_args: Set[str] = set()
boolean_ssh_args: Set[str] = set()
for k, v in ssh_options().items():
k = f'-{k}'
if v:
other_ssh_args.add(k)
else:
boolean_ssh_args.add(k)
return boolean_ssh_args, other_ssh_args


def is_extra_arg(arg: str, extra_args: Tuple[str, ...]) -> str:
for x in extra_args:
if arg == x or arg.startswith(f'{x}='):
return x
return ''


passthrough_args = {f'-{x}' for x in 'NnfGT'}


def set_server_args_in_cmdline(server_args: List[str], argv: List[str], extra_args: Tuple[str, ...] = ('--kitten',)) -> None:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
ssh_args = []
expecting_option_val = False
found_extra_args: List[str] = []
expecting_extra_val = ''
ans = list(argv)
found_ssh = False
for i, argument in enumerate(argv):
if not found_ssh:
found_ssh = argument == 'ssh'
continue
if argument.startswith('-') and not expecting_option_val:
if argument == '--':
del ans[i+2:]
break
if extra_args:
matching_ex = is_extra_arg(argument, extra_args)
if matching_ex:
if '=' in argument:
exval = argument.partition('=')[-1]
found_extra_args.extend((matching_ex, exval))
else:
expecting_extra_val = matching_ex
expecting_option_val = True
continue
# could be a multi-character option
all_args = argument[1:]
for i, arg in enumerate(all_args):
arg = f'-{arg}'
if arg in boolean_ssh_args:
ssh_args.append(arg)
continue
if arg in other_ssh_args:
ssh_args.append(arg)
rest = all_args[i+1:]
if rest:
ssh_args.append(rest)
else:
expecting_option_val = True
break
raise KeyError(f'unknown option -- {arg[1:]}')
continue
if expecting_option_val:
if expecting_extra_val:
found_extra_args.extend((expecting_extra_val, argument))
expecting_extra_val = ''
else:
ssh_args.append(argument)
expecting_option_val = False
continue
del ans[i+1:]
break
argv[:] = ans + server_args

0 comments on commit 2445073

Please sign in to comment.