Verify uid/gid of connection from a prewarm client

This commit is contained in:
Kovid Goyal 2022-07-05 18:20:32 +05:30
parent 7b7f1ecc54
commit de9263a117
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 43 additions and 7 deletions

View File

@ -13,6 +13,9 @@
#endif
#include "data-types.h"
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "cleanup.h"
#include "safe-wrappers.h"
#include "control-codes.h"
@ -189,8 +192,28 @@ locale_is_valid(PyObject *self UNUSED, PyObject *args) {
Py_RETURN_TRUE;
}
static PyObject*
py_getpeereid(PyObject *self UNUSED, PyObject *args) {
int fd;
if (!PyArg_ParseTuple(args, "i", &fd)) return NULL;
uid_t euid = 0; gid_t egid = 0;
#ifdef __linux__
struct ucred cr;
socklen_t sz = sizeof(cr);
if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &sz) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
euid = cr.uid; egid = cr.gid;
#else
if (getpeereid(fd, &euid, &egid) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
#endif
int u = euid, g = egid;
return Py_BuildValue("ii", u, g);
}
static PyMethodDef module_methods[] = {
{"wcwidth", (PyCFunction)wcwidth_wrap, METH_O, ""},
{"getpeereid", (PyCFunction)py_getpeereid, METH_VARARGS, ""},
{"wcswidth", (PyCFunction)wcswidth_std, METH_O, ""},
{"open_tty", open_tty, METH_VARARGS, ""},
{"normal_tty", normal_tty, METH_VARARGS, ""},
@ -269,7 +292,6 @@ shift_to_first_set_bit(CellAttrs x) {
return ans;
}
EXPORTED PyMODINIT_FUNC
PyInit_fast_data_types(void) {
PyObject *m;

View File

@ -1413,3 +1413,7 @@ def install_signal_handlers(*signals: int) -> Tuple[int, int]:
def remove_signal_handlers() -> None:
pass
def getpeereid(fd: int) -> Tuple[int, int]:
pass

View File

@ -22,8 +22,9 @@ from typing import (
from kitty.constants import kitty_exe, running_in_kitty
from kitty.entry_points import main as main_entry_point
from kitty.fast_data_types import (
CLD_EXITED, CLD_KILLED, get_options, install_signal_handlers, read_signals,
remove_signal_handlers, safe_pipe, set_options
CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
set_options, getpeereid
)
from kitty.options.types import Options
from kitty.shm import SharedMemory
@ -338,7 +339,6 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl
os.setsid()
tty_name = cmd.get('tty_name')
if tty_name:
from kitty.fast_data_types import establish_controlling_tty
sys.__stdout__.flush()
sys.__stderr__.flush()
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
@ -361,6 +361,11 @@ class SocketClosed(Exception):
pass
def verify_socket_creds(conn: socket.socket) -> bool:
uid, gid = getpeereid(conn.fileno())
return uid == os.geteuid() and gid == os.getegid()
class SocketChild:
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
@ -389,8 +394,7 @@ class SocketChild:
def read(self) -> bool:
import array
fds = array.array("i") # Array of ints
maxfds = 3
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize))
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
# Append data, ignoring any truncated integers at the end.
@ -670,6 +674,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
def handle_socket_client(event: int) -> None:
check_event(event, 'UNIX socket fd listener failed')
conn, addr = unix_socket.accept()
if not verify_socket_creds(conn):
print_error('Connection attempted with invalid credentials ignoring')
conn.close()
return
sc = SocketChild(conn, addr, poll)
socket_children[sc.conn.fileno()] = sc

View File

@ -24,6 +24,8 @@ from . import BaseTest
def socket_child_main(exit_code=0):
import json
import os
import sys
from kitty.fast_data_types import get_options
from kitty.utils import read_screen_size
output = {
@ -34,7 +36,7 @@ def socket_child_main(exit_code=0):
'done': 'hello',
}
print(json.dumps(output, indent=2))
print(json.dumps(output, indent=2), file=sys.stderr, flush=True)
raise SystemExit(exit_code)
# END_socket_child_main