This commit is contained in:
Kovid Goyal 2021-07-22 17:56:21 +05:30
parent 21a2768ec3
commit 075fb2eaf2
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -175,6 +175,18 @@ def get_connection_data(args: List[str]) -> Optional[SSHConnectionData]:
return SSHConnectionData(found_ssh, arg, port) return SSHConnectionData(found_ssh, arg, port)
class InvalidSSHArgs(ValueError):
def __init__(self, msg: str = ''):
super().__init__(msg)
self.err_msg = msg
def system_exit(self) -> None:
if self.err_msg:
print(self.err_msg, file=sys.stderr)
os.execlp('ssh', 'ssh')
def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]: def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]:
boolean_ssh_args, other_ssh_args = get_ssh_cli() boolean_ssh_args, other_ssh_args = get_ssh_cli()
passthrough_args = {'-' + x for x in 'Nnf'} passthrough_args = {'-' + x for x in 'Nnf'}
@ -191,6 +203,7 @@ def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]:
if arg == '--': if arg == '--':
stop_option_processing = True stop_option_processing = True
continue continue
# could be a multi-character option
all_args = arg[1:] all_args = arg[1:]
for i, arg in enumerate(all_args): for i, arg in enumerate(all_args):
arg = '-' + arg arg = '-' + arg
@ -207,9 +220,7 @@ def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]:
else: else:
expecting_option_val = True expecting_option_val = True
break break
print('unknown option -- {}'.format(arg[1:]), file=sys.stderr) raise InvalidSSHArgs('unknown option -- {}'.format(arg[1:]))
subprocess.Popen(['ssh']).wait()
raise SystemExit(255)
continue continue
if expecting_option_val: if expecting_option_val:
ssh_args.append(arg) ssh_args.append(arg)
@ -217,7 +228,7 @@ def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]:
continue continue
server_args.append(arg) server_args.append(arg)
if not server_args: if not server_args:
raise SystemExit('Must specify server to connect to') raise InvalidSSHArgs()
return ssh_args, server_args, passthrough return ssh_args, server_args, passthrough
@ -261,7 +272,10 @@ def main(args: List[str]) -> NoReturn:
if args and args[0] == 'use-python': if args and args[0] == 'use-python':
args = args[1:] args = args[1:]
use_posix = False use_posix = False
ssh_args, server_args, passthrough = parse_ssh_args(args) try:
ssh_args, server_args, passthrough = parse_ssh_args(args)
except InvalidSSHArgs as e:
e.system_exit()
cmd = ['ssh'] + ssh_args cmd = ['ssh'] + ssh_args
if passthrough: if passthrough:
cmd += server_args cmd += server_args