Work on supporting streaming remote commands with passwords

This commit is contained in:
Kovid Goyal 2022-08-31 20:03:51 +05:30
parent d7985689c9
commit 364533b1ed
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
14 changed files with 206 additions and 85 deletions

View File

@ -94,7 +94,7 @@ def render_alias_map(alias_map: Dict[str, Tuple[str, ...]]) -> str:
def build_go_code(name: str, cmd: RemoteCommand, seq: OptionSpecSeq, template: str) -> str:
template = '\n' + template[len('//go:build exclude'):]
NO_RESPONSE_BASE = 'true' if cmd.no_response else 'false'
NO_RESPONSE_BASE = 'false'
af: List[str] = []
a = af.append
alias_map = {}
@ -169,6 +169,7 @@ def build_go_code(name: str, cmd: RemoteCommand, seq: OptionSpecSeq, template: s
JSON_DECLARATION_CODE='\n'.join(jd),
JSON_INIT_CODE='\n'.join(jc), ARGSPEC=argspec,
STRING_RESPONSE_IS_ERROR='true' if cmd.string_return_is_error else 'false',
STREAM_WANTED='true' if cmd.reads_streaming_data else 'false',
)
return ans

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2020, Kovid Goyal <kovid at kovidgoyal.net>
import tempfile
from contextlib import suppress
from dataclasses import dataclass
from typing import (
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Iterable, Iterator, List,
NoReturn, Optional, Set, Tuple, Type, Union, cast
TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Iterable, Iterator,
List, NoReturn, Optional, Set, Tuple, Type, Union, cast
)
from kitty.cli import get_defaults_from_seq, parse_args, parse_option_spec
@ -29,6 +30,16 @@ class NoResponse:
pass
class NamedTemporaryFile:
name: str = ''
def __enter__(self) -> None: ...
def __exit__(self, exc: Any, value: Any, tb: Any) -> None: ...
def close(self) -> None: ...
def write(self, data: bytes) -> None: ...
def flush(self) -> None: ...
class RemoteControlError(Exception):
pass
@ -51,6 +62,11 @@ class UnknownLayout(ValueError):
hide_traceback = True
class StreamError(ValueError):
hide_traceback = True
class PayloadGetter:
def __init__(self, cmd: 'RemoteCommand', payload: Dict[str, Any]):
@ -245,6 +261,35 @@ class ArgsHandling:
raise TypeError(f'Unknown args handling for cmd: {cmd_name}')
class StreamInFlight:
def __init__(self) -> None:
self.stream_id = ''
self.tempfile: Optional[NamedTemporaryFile] = None
def handle_data(self, stream_id: str, data: bytes) -> Union[AsyncResponse, NamedTemporaryFile]:
from ..remote_control import close_active_stream
if stream_id != self.stream_id:
close_active_stream(self.stream_id)
if self.tempfile is not None:
self.tempfile.close()
self.tempfile = None
self.stream_id = stream_id
if self.tempfile is None:
t: NamedTemporaryFile = cast(NamedTemporaryFile, tempfile.NamedTemporaryFile(suffix='.png'))
self.tempfile = t
else:
t = self.tempfile
if data:
t.write(data)
return AsyncResponse()
close_active_stream(self.stream_id)
self.stream_id = ''
self.tempfile = None
t.flush()
return t
class RemoteCommand:
Args = ArgsHandling
@ -253,7 +298,6 @@ class RemoteCommand:
desc: str = ''
args: ArgsHandling = ArgsHandling()
options_spec: Optional[str] = None
no_response: bool = False
response_timeout: float = 10. # seconds
string_return_is_error: bool = False
defaults: Optional[Dict[str, Any]] = None
@ -262,10 +306,12 @@ class RemoteCommand:
protocol_spec: str = ''
argspec = args_count = args_completion = ArgsHandling()
field_to_option_map: Optional[Dict[str, str]] = None
reads_streaming_data: bool = False
def __init__(self) -> None:
self.desc = self.desc or self.short_desc
self.name = self.__class__.__module__.split('.')[-1].replace('_', '-')
self.stream_in_flight = StreamInFlight()
def fatal(self, msg: str) -> NoReturn:
if running_in_kitty():
@ -342,6 +388,12 @@ class RemoteCommand:
def cancel_async_request(self, boss: 'Boss', window: Optional['Window'], payload_get: PayloadGetType) -> None:
pass
def handle_streamed_data(self, data: bytes, payload_get: PayloadGetType) -> Union[NamedTemporaryFile, AsyncResponse]:
stream_id = payload_get('stream_id')
if not stream_id or not isinstance(stream_id, str):
raise StreamError('No stream_id in rc payload')
return self.stream_in_flight.handle_data(stream_id, data)
def cli_params_for(command: RemoteCommand) -> Tuple[Callable[[], str], str, str, str]:
return (command.options_spec or '\n').format, command.args.spec, command.desc, f'{appname} @ {command.name}'

View File

@ -100,8 +100,9 @@ not interpreted for escapes. If stdin is a terminal, you can press :kbd:`Ctrl+D`
Path to a file whose contents you wish to send. Note that in this case the file contents
are sent as is, not interpreted for escapes.
'''
no_response = True
args = RemoteCommand.Args(spec='[TEXT TO SEND]', json_field='data', special_parse='+session_id:parse_send_text(io_data, args)')
is_asynchronous = True
reads_streaming_data = True
def message_to_kitty(self, global_opts: RCOptions, opts: 'CLIOptions', args: ArgsType) -> PayloadType:
limit = 1024

View File

@ -3,15 +3,14 @@
import imghdr
import os
import tempfile
from base64 import standard_b64decode, standard_b64encode
from typing import IO, TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
from kitty.types import AsyncResponse
from .base import (
MATCH_WINDOW_OPTION, ArgsType, Boss, CmdGenerator, PayloadGetType,
PayloadType, RCOptions, RemoteCommand, ResponseType, Window
MATCH_WINDOW_OPTION, ArgsType, Boss, CmdGenerator, NamedTemporaryFile,
PayloadGetType, PayloadType, RCOptions, RemoteCommand, ResponseType, Window
)
if TYPE_CHECKING:
@ -65,8 +64,7 @@ failed, the command will exit with a success code.
''' + '\n\n' + MATCH_WINDOW_OPTION
args = RemoteCommand.Args(spec='PATH_TO_PNG_IMAGE', count=1, json_field='data', special_parse='!read_window_logo(args[0])', completion={
'files': ('PNG Images', ('*.png',))})
images_in_flight: Dict[str, IO[bytes]] = {}
is_asynchronous = True
reads_streaming_data = True
def message_to_kitty(self, global_opts: RCOptions, opts: 'CLIOptions', args: ArgsType) -> PayloadType:
if len(args) != 1:
@ -98,34 +96,26 @@ failed, the command will exit with a success code.
def response_from_kitty(self, boss: Boss, window: Optional[Window], payload_get: PayloadGetType) -> ResponseType:
data = payload_get('data')
img_id = payload_get('async_id')
if data != '-':
if img_id not in self.images_in_flight:
self.images_in_flight[img_id] = tempfile.NamedTemporaryFile(suffix='.png')
if data:
self.images_in_flight[img_id].write(standard_b64decode(data))
return AsyncResponse()
windows = self.windows_for_payload(boss, window, payload_get)
os_windows = tuple({w.os_window_id for w in windows if w})
layout = payload_get('layout')
if data == '-':
path = None
tfile = NamedTemporaryFile()
else:
f = self.images_in_flight.pop(img_id)
path = f.name
f.flush()
q = self.handle_streamed_data(standard_b64decode(data) if data else b'', payload_get)
if isinstance(q, AsyncResponse):
return q
path = q.name
tfile = q
try:
boss.set_background_image(path, os_windows, payload_get('configured'), layout)
with tfile:
boss.set_background_image(path, os_windows, payload_get('configured'), layout)
except ValueError as err:
err.hide_traceback = True # type: ignore
raise
return None
def cancel_async_request(self, boss: 'Boss', window: Optional['Window'], payload_get: PayloadGetType) -> None:
async_id = payload_get('async_id')
self.images_in_flight.pop(async_id, None)
set_background_image = SetBackgroundImage()

View File

@ -4,15 +4,14 @@
import imghdr
import os
import tempfile
from base64 import standard_b64decode, standard_b64encode
from typing import IO, TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
from kitty.types import AsyncResponse
from .base import (
MATCH_WINDOW_OPTION, ArgsType, Boss, CmdGenerator, PayloadGetType,
PayloadType, RCOptions, RemoteCommand, ResponseType, Window
MATCH_WINDOW_OPTION, ArgsType, Boss, CmdGenerator, NamedTemporaryFile,
PayloadGetType, PayloadType, RCOptions, RemoteCommand, ResponseType, Window
)
if TYPE_CHECKING:
@ -61,8 +60,7 @@ failed, the command will exit with a success code.
'''
args = RemoteCommand.Args(spec='PATH_TO_PNG_IMAGE', count=1, json_field='data', special_parse='!read_window_logo(args[0])', completion={
'files': ('PNG Images', ('*.png',))})
images_in_flight: Dict[str, IO[bytes]] = {}
is_asynchronous = True
reads_streaming_data = True
def message_to_kitty(self, global_opts: RCOptions, opts: 'CLIOptions', args: ArgsType) -> PayloadType:
if len(args) != 1:
@ -94,26 +92,22 @@ failed, the command will exit with a success code.
def response_from_kitty(self, boss: Boss, window: Optional[Window], payload_get: PayloadGetType) -> ResponseType:
data = payload_get('data')
img_id = payload_get('async_id')
if data != '-':
if img_id not in self.images_in_flight:
self.images_in_flight[img_id] = tempfile.NamedTemporaryFile(suffix='.png')
if data:
self.images_in_flight[img_id].write(standard_b64decode(data))
return AsyncResponse()
if data == '-':
path = ''
else:
f = self.images_in_flight.pop(img_id)
path = f.name
f.flush()
alpha = float(payload_get('alpha', '-1'))
position = payload_get('position') or ''
for window in self.windows_for_match_payload(boss, window, payload_get):
if window:
window.set_logo(path, position, alpha)
if data == '-':
path = ''
tfile = NamedTemporaryFile()
else:
q = self.handle_streamed_data(standard_b64decode(data) if data else b'', payload_get)
if isinstance(q, AsyncResponse):
return q
path = q.name
tfile = q
with tfile:
for window in self.windows_for_match_payload(boss, window, payload_get):
if window:
window.set_logo(path, position, alpha)
return None

View File

@ -31,6 +31,7 @@ from .typing import BossType, WindowType
from .utils import TTYIO, log_error, parse_address_spec, resolve_custom_file
active_async_requests: Dict[str, float] = {}
active_streams: Dict[str, str] = {}
if TYPE_CHECKING:
from .window import Window
@ -136,6 +137,9 @@ user_password_allowed: Dict[str, bool] = {}
def is_cmd_allowed(pcmd: Dict[str, Any], window: Optional['Window'], from_socket: bool, extra_data: Dict[str, Any]) -> Optional[bool]:
sid = pcmd.get('stream_id', '')
if sid and active_streams.get(sid, '') == pcmd['cmd']:
return True
pw = pcmd.get('password', '')
if not pw:
auth_items = get_options().remote_control_password.get('')
@ -157,6 +161,10 @@ def set_user_password_allowed(pwd: str, allowed: bool = True) -> None:
user_password_allowed[pwd] = allowed
def close_active_stream(stream_id: str) -> None:
active_streams.pop(stream_id, None)
def handle_cmd(boss: BossType, window: Optional[WindowType], cmd: Dict[str, Any], peer_id: int) -> Union[Dict[str, Any], None, AsyncResponse]:
v = cmd['version']
no_response = cmd.get('no_response', False)
@ -168,6 +176,16 @@ def handle_cmd(boss: BossType, window: Optional[WindowType], cmd: Dict[str, Any]
payload = cmd.get('payload') or {}
payload['peer_id'] = peer_id
async_id = str(cmd.get('async', ''))
stream_id = str(cmd.get('stream_id', ''))
stream = bool(cmd.get('stream', False))
if (stream or stream_id) and not c.reads_streaming_data:
return {'ok': False, 'error': 'Streaming send of data is not supported for this command'}
if stream_id:
payload['stream_id'] = stream_id
active_streams[stream_id] = cmd['cmd']
if len(active_streams) > 32:
oldest = next(iter(active_streams))
del active_streams[oldest]
if async_id:
payload['async_id'] = async_id
if 'cancel_async' in cmd:
@ -187,11 +205,13 @@ def handle_cmd(boss: BossType, window: Optional[WindowType], cmd: Dict[str, Any]
if isinstance(ans, NoResponse):
return None
if isinstance(ans, AsyncResponse):
if stream:
return {'ok': True, 'stream': True}
return ans
response: Dict[str, Any] = {'ok': True}
if ans is not None:
response['data'] = ans
if not c.no_response and not no_response:
if not no_response:
return response
return None
@ -396,8 +416,7 @@ def create_basic_command(name: str, payload: Any = None, no_response: bool = Fal
def send_response_to_client(data: Any = None, error: str = '', peer_id: int = 0, window_id: int = 0, async_id: str = '') -> None:
ts = active_async_requests.pop(async_id, None)
if ts is None:
if active_async_requests.pop(async_id, None) is None:
return
if error:
response: Dict[str, Union[bool, int, str]] = {'ok': False, 'error': error}
@ -481,7 +500,7 @@ def main(args: List[str]) -> None:
payload = c.message_to_kitty(global_opts, opts, items)
except ParsingOfArgsFailed as err:
exit(str(err))
no_response = c.no_response
no_response = False
if hasattr(opts, 'no_response'):
no_response = opts.no_response
response_timeout = c.response_timeout

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"
"time"
@ -41,6 +42,14 @@ type GlobalOptions struct {
var global_options GlobalOptions
func set_payload_string_field(io_data *rc_io_data, field, data string) {
payload_interface := reflect.ValueOf(&io_data.rc.Payload).Elem()
struct_in_interface := reflect.New(payload_interface.Elem().Type()).Elem()
struct_in_interface.Set(payload_interface.Elem()) // copies the payload to struct_in_interface
struct_in_interface.FieldByName(field).SetString(data)
payload_interface.Set(struct_in_interface) // copies struct_in_interface back to payload
}
func get_pubkey(encoded_key string) (encryption_version string, pubkey []byte, err error) {
if encoded_key == "" {
encoded_key = os.Getenv("KITTY_PUBLIC_KEY")
@ -171,7 +180,8 @@ type rc_io_data struct {
multiple_payload_generator func(io_data *rc_io_data) (bool, error)
}
func (self *rc_io_data) next_chunk() (chunk []byte, err error) {
func (self *rc_io_data) next_chunk() (chunk []byte, one_escape_code_done bool, err error) {
one_escape_code_done = self.serializer.state == 2
block, err := self.serializer.next(self)
if err != nil && !errors.Is(err, io.EOF) {
return

View File

@ -18,9 +18,7 @@ func parse_send_text(io_data *rc_io_data, args []string) error {
if len(args) > 0 {
text := strings.Join(args, " ")
text_gen := func(io_data *rc_io_data) (bool, error) {
payload := io_data.rc.Payload.(send_text_json_type)
payload.Data = "text:" + text[:2048]
io_data.rc.Payload = payload
set_payload_data(io_data, "text:"+text[:2048])
text = text[2048:]
return len(text) == 0, nil
}
@ -38,9 +36,7 @@ func parse_send_text(io_data *rc_io_data, args []string) error {
if err != nil && !errors.Is(err, io.EOF) {
return false, err
}
payload := io_data.rc.Payload.(send_text_json_type)
payload.Data = "base64:" + base64.StdEncoding.EncodeToString(chunk[:n])
io_data.rc.Payload = payload
set_payload_data(io_data, "base64:"+base64.StdEncoding.EncodeToString(chunk[:n]))
return n == 0, nil
}
generators = append(generators, file_gen)
@ -48,9 +44,7 @@ func parse_send_text(io_data *rc_io_data, args []string) error {
io_data.multiple_payload_generator = func(io_data *rc_io_data) (bool, error) {
if len(generators) == 0 {
payload := io_data.rc.Payload.(send_text_json_type)
payload.Data = "text:"
io_data.rc.Payload = payload
set_payload_data(io_data, "text:")
return true, nil
}
finished, err := generators[0](io_data)

View File

@ -16,5 +16,6 @@ func parse_set_font_size(arg string, io_data *rc_io_data) error {
return err
}
payload.Size = val
io_data.rc.Payload = payload
return nil
}

View File

@ -11,10 +11,18 @@ import (
"strings"
)
type struct_with_data interface {
SetData(data string)
}
func set_payload_data(io_data *rc_io_data, data string) {
set_payload_string_field(io_data, "Data", data)
}
func read_window_logo(path string) (func(io_data *rc_io_data) (bool, error), error) {
if strings.ToLower(path) == "none" {
return func(io_data *rc_io_data) (bool, error) {
io_data.rc.Payload = "-"
set_payload_data(io_data, "-")
return true, nil
}, nil
}
@ -35,16 +43,20 @@ func read_window_logo(path string) (func(io_data *rc_io_data) (bool, error), err
f.Close()
return nil, fmt.Errorf("%s is not a PNG image", path)
}
is_first_call := true
return func(io_data *rc_io_data) (bool, error) {
payload := io_data.rc.Payload.(set_window_logo_json_type)
if is_first_call {
is_first_call = false
} else {
io_data.rc.Stream = false
}
if len(buf) == 0 {
payload.Data = ""
io_data.rc.Payload = payload
set_payload_data(io_data, "")
io_data.rc.Stream = false
return true, nil
}
payload.Data = base64.StdEncoding.EncodeToString(buf)
io_data.rc.Payload = payload
set_payload_data(io_data, base64.StdEncoding.EncodeToString(buf))
buf = buf[:cap(buf)]
n, err := f.Read(buf)
if err != nil && err != io.EOF {

View File

@ -57,7 +57,7 @@ func read_response_from_conn(conn *net.Conn, timeout time.Duration) (serialized_
func simple_socket_io(conn *net.Conn, io_data *rc_io_data) (serialized_response []byte, err error) {
for {
var chunk []byte
chunk, err = io_data.next_chunk()
chunk, _, err = io_data.next_chunk()
if err != nil {
return
}

View File

@ -43,6 +43,14 @@ func create_rc_CMD_NAME(args []string) (*utils.RemoteControlCmd, error) {
Cmd: "CLI_NAME",
Version: ProtocolVersion,
NoResponse: NO_RESPONSE_BASE,
Stream: STREAM_WANTED,
}
if rc.Stream {
stream_id, err := utils.HumanRandomId(128)
if err != nil {
return nil, err
}
rc.StreamId = stream_id
}
if IS_ASYNC {
async_id, err := utils.HumanRandomId(128)

View File

@ -3,12 +3,27 @@
package at
import (
"encoding/json"
"os"
"time"
"kitty/tools/tui/loop"
)
type stream_response struct {
Ok bool `json:"ok"`
Stream bool `json:"stream"`
}
func is_stream_response(serialized_response []byte) bool {
var response stream_response
if len(serialized_response) > 32 {
return false
}
err := json.Unmarshal(serialized_response, &response)
return err == nil && response.Stream
}
func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error) {
serialized_response = make([]byte, 0)
lp, err := loop.New()
@ -17,11 +32,21 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error)
return
}
const (
BEFORE_FIRST_ESCAPE_CODE_SENT = iota
WAITING_FOR_STREAMING_RESPONSE
SENDING
WAITING_FOR_RESPONSE
)
state := BEFORE_FIRST_ESCAPE_CODE_SENT
var last_received_data_at time.Time
var final_write_id loop.IdType
var check_for_timeout func(timer_id loop.IdType) error
wants_streaming := false
check_for_timeout = func(timer_id loop.IdType) error {
if state != WAITING_FOR_RESPONSE && state != WAITING_FOR_STREAMING_RESPONSE {
return nil
}
time_since_last_received_data := time.Now().Sub(last_received_data_at)
if time_since_last_received_data >= io_data.timeout {
return os.ErrDeadlineExceeded
@ -31,7 +56,7 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error)
}
transition_to_read := func() {
if io_data.rc.NoResponse {
if state == WAITING_FOR_RESPONSE && io_data.rc.NoResponse {
lp.Quit(0)
}
last_received_data_at = time.Now()
@ -44,36 +69,48 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error)
}
lp.OnInitialize = func() (string, error) {
chunk, err := io_data.next_chunk()
chunk, _, err := io_data.next_chunk()
wants_streaming = io_data.rc.Stream
if err != nil {
return "", err
}
write_id := lp.QueueWriteBytesDangerous(chunk)
lp.QueueWriteBytesDangerous(chunk)
if len(chunk) == 0 {
final_write_id = write_id
state = WAITING_FOR_RESPONSE
transition_to_read()
}
return "", nil
}
lp.OnWriteComplete = func(completed_write_id loop.IdType) error {
if final_write_id > 0 {
if completed_write_id == final_write_id {
transition_to_read()
}
if state == WAITING_FOR_STREAMING_RESPONSE || state == WAITING_FOR_RESPONSE {
return nil
}
chunk, err := io_data.next_chunk()
chunk, one_escape_code_done, err := io_data.next_chunk()
if err != nil {
return err
}
write_id := lp.QueueWriteBytesDangerous(chunk)
lp.QueueWriteBytesDangerous(chunk)
if len(chunk) == 0 {
final_write_id = write_id
state = WAITING_FOR_RESPONSE
transition_to_read()
}
if one_escape_code_done && state == BEFORE_FIRST_ESCAPE_CODE_SENT {
if wants_streaming {
state = WAITING_FOR_STREAMING_RESPONSE
transition_to_read()
} else {
state = SENDING
}
}
return nil
}
lp.OnRCResponse = func(raw []byte) error {
if state == WAITING_FOR_STREAMING_RESPONSE && is_stream_response(raw) {
state = SENDING
return lp.OnWriteComplete(0)
}
serialized_response = raw
lp.Quit(0)
return nil

View File

@ -6,11 +6,13 @@ type RemoteControlCmd struct {
Cmd string `json:"cmd"`
Version [3]int `json:"version"`
NoResponse bool `json:"no_response,omitempty"`
Payload interface{} `json:"payload,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
Password string `json:"password,omitempty"`
Async string `json:"async,omitempty"`
CancelAsync bool `json:"cancel_async,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamId string `json:"stream_id,omitempty"`
Payload interface{} `json:"payload,omitempty"`
}
type EncryptedRemoteControlCmd struct {