From de9263a117fe72d495be8d8e975c27f4c82d63e3 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 5 Jul 2022 18:20:32 +0530 Subject: [PATCH] Verify uid/gid of connection from a prewarm client --- kitty/data-types.c | 24 +++++++++++++++++++++++- kitty/fast_data_types.pyi | 4 ++++ kitty/prewarm.py | 18 +++++++++++++----- kitty_tests/prewarm.py | 4 +++- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/kitty/data-types.c b/kitty/data-types.c index 702c21109..202aafaaa 100644 --- a/kitty/data-types.c +++ b/kitty/data-types.c @@ -13,6 +13,9 @@ #endif #include "data-types.h" +#include +#include +#include #include "cleanup.h" #include "safe-wrappers.h" #include "control-codes.h" @@ -189,8 +192,28 @@ locale_is_valid(PyObject *self UNUSED, PyObject *args) { Py_RETURN_TRUE; } +static PyObject* +py_getpeereid(PyObject *self UNUSED, PyObject *args) { + int fd; + if (!PyArg_ParseTuple(args, "i", &fd)) return NULL; + uid_t euid = 0; gid_t egid = 0; +#ifdef __linux__ + struct ucred cr; + socklen_t sz = sizeof(cr); + if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &sz) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; } + euid = cr.uid; egid = cr.gid; +#else + if (getpeereid(fd, &euid, &egid) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; } +#endif + int u = euid, g = egid; + return Py_BuildValue("ii", u, g); +} + + + static PyMethodDef module_methods[] = { {"wcwidth", (PyCFunction)wcwidth_wrap, METH_O, ""}, + {"getpeereid", (PyCFunction)py_getpeereid, METH_VARARGS, ""}, {"wcswidth", (PyCFunction)wcswidth_std, METH_O, ""}, {"open_tty", open_tty, METH_VARARGS, ""}, {"normal_tty", normal_tty, METH_VARARGS, ""}, @@ -269,7 +292,6 @@ shift_to_first_set_bit(CellAttrs x) { return ans; } - EXPORTED PyMODINIT_FUNC PyInit_fast_data_types(void) { PyObject *m; diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 863f4fa4b..d096283e8 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1413,3 +1413,7 @@ def install_signal_handlers(*signals: int) -> Tuple[int, int]: def remove_signal_handlers() -> None: pass + + +def getpeereid(fd: int) -> Tuple[int, int]: + pass diff --git a/kitty/prewarm.py b/kitty/prewarm.py index fdaa11388..5ccde0e2a 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -22,8 +22,9 @@ from typing import ( from kitty.constants import kitty_exe, running_in_kitty from kitty.entry_points import main as main_entry_point from kitty.fast_data_types import ( - CLD_EXITED, CLD_KILLED, get_options, install_signal_handlers, read_signals, - remove_signal_handlers, safe_pipe, set_options + CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options, + install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe, + set_options, getpeereid ) from kitty.options.types import Options from kitty.shm import SharedMemory @@ -338,7 +339,6 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl os.setsid() tty_name = cmd.get('tty_name') if tty_name: - from kitty.fast_data_types import establish_controlling_tty sys.__stdout__.flush() sys.__stderr__.flush() establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno()) @@ -361,6 +361,11 @@ class SocketClosed(Exception): pass +def verify_socket_creds(conn: socket.socket) -> bool: + uid, gid = getpeereid(conn.fileno()) + return uid == os.geteuid() and gid == os.getegid() + + class SocketChild: def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll): @@ -389,8 +394,7 @@ class SocketChild: def read(self) -> bool: import array fds = array.array("i") # Array of ints - maxfds = 3 - msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize)) + msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024) for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: # Append data, ignoring any truncated integers at the end. @@ -670,6 +674,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket: def handle_socket_client(event: int) -> None: check_event(event, 'UNIX socket fd listener failed') conn, addr = unix_socket.accept() + if not verify_socket_creds(conn): + print_error('Connection attempted with invalid credentials ignoring') + conn.close() + return sc = SocketChild(conn, addr, poll) socket_children[sc.conn.fileno()] = sc diff --git a/kitty_tests/prewarm.py b/kitty_tests/prewarm.py index 75fcb7c8e..ef31531de 100644 --- a/kitty_tests/prewarm.py +++ b/kitty_tests/prewarm.py @@ -24,6 +24,8 @@ from . import BaseTest def socket_child_main(exit_code=0): import json import os + import sys + from kitty.fast_data_types import get_options from kitty.utils import read_screen_size output = { @@ -34,7 +36,7 @@ def socket_child_main(exit_code=0): 'done': 'hello', } - print(json.dumps(output, indent=2)) + print(json.dumps(output, indent=2), file=sys.stderr, flush=True) raise SystemExit(exit_code) # END_socket_child_main