Verify uid/gid of connection from a prewarm client
This commit is contained in:
parent
7b7f1ecc54
commit
de9263a117
@ -13,6 +13,9 @@
|
||||
#endif
|
||||
|
||||
#include "data-types.h"
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include "cleanup.h"
|
||||
#include "safe-wrappers.h"
|
||||
#include "control-codes.h"
|
||||
@ -189,8 +192,28 @@ locale_is_valid(PyObject *self UNUSED, PyObject *args) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
py_getpeereid(PyObject *self UNUSED, PyObject *args) {
|
||||
int fd;
|
||||
if (!PyArg_ParseTuple(args, "i", &fd)) return NULL;
|
||||
uid_t euid = 0; gid_t egid = 0;
|
||||
#ifdef __linux__
|
||||
struct ucred cr;
|
||||
socklen_t sz = sizeof(cr);
|
||||
if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &sz) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
|
||||
euid = cr.uid; egid = cr.gid;
|
||||
#else
|
||||
if (getpeereid(fd, &euid, &egid) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
|
||||
#endif
|
||||
int u = euid, g = egid;
|
||||
return Py_BuildValue("ii", u, g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"wcwidth", (PyCFunction)wcwidth_wrap, METH_O, ""},
|
||||
{"getpeereid", (PyCFunction)py_getpeereid, METH_VARARGS, ""},
|
||||
{"wcswidth", (PyCFunction)wcswidth_std, METH_O, ""},
|
||||
{"open_tty", open_tty, METH_VARARGS, ""},
|
||||
{"normal_tty", normal_tty, METH_VARARGS, ""},
|
||||
@ -269,7 +292,6 @@ shift_to_first_set_bit(CellAttrs x) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
EXPORTED PyMODINIT_FUNC
|
||||
PyInit_fast_data_types(void) {
|
||||
PyObject *m;
|
||||
|
||||
@ -1413,3 +1413,7 @@ def install_signal_handlers(*signals: int) -> Tuple[int, int]:
|
||||
|
||||
def remove_signal_handlers() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def getpeereid(fd: int) -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
@ -22,8 +22,9 @@ from typing import (
|
||||
from kitty.constants import kitty_exe, running_in_kitty
|
||||
from kitty.entry_points import main as main_entry_point
|
||||
from kitty.fast_data_types import (
|
||||
CLD_EXITED, CLD_KILLED, get_options, install_signal_handlers, read_signals,
|
||||
remove_signal_handlers, safe_pipe, set_options
|
||||
CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
|
||||
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
|
||||
set_options, getpeereid
|
||||
)
|
||||
from kitty.options.types import Options
|
||||
from kitty.shm import SharedMemory
|
||||
@ -338,7 +339,6 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl
|
||||
os.setsid()
|
||||
tty_name = cmd.get('tty_name')
|
||||
if tty_name:
|
||||
from kitty.fast_data_types import establish_controlling_tty
|
||||
sys.__stdout__.flush()
|
||||
sys.__stderr__.flush()
|
||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||
@ -361,6 +361,11 @@ class SocketClosed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify_socket_creds(conn: socket.socket) -> bool:
|
||||
uid, gid = getpeereid(conn.fileno())
|
||||
return uid == os.geteuid() and gid == os.getegid()
|
||||
|
||||
|
||||
class SocketChild:
|
||||
|
||||
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
|
||||
@ -389,8 +394,7 @@ class SocketChild:
|
||||
def read(self) -> bool:
|
||||
import array
|
||||
fds = array.array("i") # Array of ints
|
||||
maxfds = 3
|
||||
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize))
|
||||
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
|
||||
# Append data, ignoring any truncated integers at the end.
|
||||
@ -670,6 +674,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
||||
def handle_socket_client(event: int) -> None:
|
||||
check_event(event, 'UNIX socket fd listener failed')
|
||||
conn, addr = unix_socket.accept()
|
||||
if not verify_socket_creds(conn):
|
||||
print_error('Connection attempted with invalid credentials ignoring')
|
||||
conn.close()
|
||||
return
|
||||
sc = SocketChild(conn, addr, poll)
|
||||
socket_children[sc.conn.fileno()] = sc
|
||||
|
||||
|
||||
@ -24,6 +24,8 @@ from . import BaseTest
|
||||
def socket_child_main(exit_code=0):
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from kitty.fast_data_types import get_options
|
||||
from kitty.utils import read_screen_size
|
||||
output = {
|
||||
@ -34,7 +36,7 @@ def socket_child_main(exit_code=0):
|
||||
|
||||
'done': 'hello',
|
||||
}
|
||||
print(json.dumps(output, indent=2))
|
||||
print(json.dumps(output, indent=2), file=sys.stderr, flush=True)
|
||||
raise SystemExit(exit_code)
|
||||
|
||||
# END_socket_child_main
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user