Verify uid/gid of connection from a prewarm client
This commit is contained in:
parent
7b7f1ecc54
commit
de9263a117
@ -13,6 +13,9 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "data-types.h"
|
#include "data-types.h"
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <unistd.h>
|
||||||
#include "cleanup.h"
|
#include "cleanup.h"
|
||||||
#include "safe-wrappers.h"
|
#include "safe-wrappers.h"
|
||||||
#include "control-codes.h"
|
#include "control-codes.h"
|
||||||
@ -189,8 +192,28 @@ locale_is_valid(PyObject *self UNUSED, PyObject *args) {
|
|||||||
Py_RETURN_TRUE;
|
Py_RETURN_TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject*
|
||||||
|
py_getpeereid(PyObject *self UNUSED, PyObject *args) {
|
||||||
|
int fd;
|
||||||
|
if (!PyArg_ParseTuple(args, "i", &fd)) return NULL;
|
||||||
|
uid_t euid = 0; gid_t egid = 0;
|
||||||
|
#ifdef __linux__
|
||||||
|
struct ucred cr;
|
||||||
|
socklen_t sz = sizeof(cr);
|
||||||
|
if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &sz) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
|
||||||
|
euid = cr.uid; egid = cr.gid;
|
||||||
|
#else
|
||||||
|
if (getpeereid(fd, &euid, &egid) != 0) { PyErr_SetFromErrno(PyExc_OSError); return NULL; }
|
||||||
|
#endif
|
||||||
|
int u = euid, g = egid;
|
||||||
|
return Py_BuildValue("ii", u, g);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static PyMethodDef module_methods[] = {
|
static PyMethodDef module_methods[] = {
|
||||||
{"wcwidth", (PyCFunction)wcwidth_wrap, METH_O, ""},
|
{"wcwidth", (PyCFunction)wcwidth_wrap, METH_O, ""},
|
||||||
|
{"getpeereid", (PyCFunction)py_getpeereid, METH_VARARGS, ""},
|
||||||
{"wcswidth", (PyCFunction)wcswidth_std, METH_O, ""},
|
{"wcswidth", (PyCFunction)wcswidth_std, METH_O, ""},
|
||||||
{"open_tty", open_tty, METH_VARARGS, ""},
|
{"open_tty", open_tty, METH_VARARGS, ""},
|
||||||
{"normal_tty", normal_tty, METH_VARARGS, ""},
|
{"normal_tty", normal_tty, METH_VARARGS, ""},
|
||||||
@ -269,7 +292,6 @@ shift_to_first_set_bit(CellAttrs x) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
EXPORTED PyMODINIT_FUNC
|
EXPORTED PyMODINIT_FUNC
|
||||||
PyInit_fast_data_types(void) {
|
PyInit_fast_data_types(void) {
|
||||||
PyObject *m;
|
PyObject *m;
|
||||||
|
|||||||
@ -1413,3 +1413,7 @@ def install_signal_handlers(*signals: int) -> Tuple[int, int]:
|
|||||||
|
|
||||||
def remove_signal_handlers() -> None:
|
def remove_signal_handlers() -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def getpeereid(fd: int) -> Tuple[int, int]:
|
||||||
|
pass
|
||||||
|
|||||||
@ -22,8 +22,9 @@ from typing import (
|
|||||||
from kitty.constants import kitty_exe, running_in_kitty
|
from kitty.constants import kitty_exe, running_in_kitty
|
||||||
from kitty.entry_points import main as main_entry_point
|
from kitty.entry_points import main as main_entry_point
|
||||||
from kitty.fast_data_types import (
|
from kitty.fast_data_types import (
|
||||||
CLD_EXITED, CLD_KILLED, get_options, install_signal_handlers, read_signals,
|
CLD_EXITED, CLD_KILLED, establish_controlling_tty, get_options,
|
||||||
remove_signal_handlers, safe_pipe, set_options
|
install_signal_handlers, read_signals, remove_signal_handlers, safe_pipe,
|
||||||
|
set_options, getpeereid
|
||||||
)
|
)
|
||||||
from kitty.options.types import Options
|
from kitty.options.types import Options
|
||||||
from kitty.shm import SharedMemory
|
from kitty.shm import SharedMemory
|
||||||
@ -338,7 +339,6 @@ def fork(shm_address: str, free_non_child_resources: Callable[[], None]) -> Tupl
|
|||||||
os.setsid()
|
os.setsid()
|
||||||
tty_name = cmd.get('tty_name')
|
tty_name = cmd.get('tty_name')
|
||||||
if tty_name:
|
if tty_name:
|
||||||
from kitty.fast_data_types import establish_controlling_tty
|
|
||||||
sys.__stdout__.flush()
|
sys.__stdout__.flush()
|
||||||
sys.__stderr__.flush()
|
sys.__stderr__.flush()
|
||||||
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
establish_controlling_tty(tty_name, sys.__stdin__.fileno(), sys.__stdout__.fileno(), sys.__stderr__.fileno())
|
||||||
@ -361,6 +361,11 @@ class SocketClosed(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def verify_socket_creds(conn: socket.socket) -> bool:
|
||||||
|
uid, gid = getpeereid(conn.fileno())
|
||||||
|
return uid == os.geteuid() and gid == os.getegid()
|
||||||
|
|
||||||
|
|
||||||
class SocketChild:
|
class SocketChild:
|
||||||
|
|
||||||
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
|
def __init__(self, conn: socket.socket, addr: bytes, poll: select.poll):
|
||||||
@ -389,8 +394,7 @@ class SocketChild:
|
|||||||
def read(self) -> bool:
|
def read(self) -> bool:
|
||||||
import array
|
import array
|
||||||
fds = array.array("i") # Array of ints
|
fds = array.array("i") # Array of ints
|
||||||
maxfds = 3
|
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, 1024)
|
||||||
msg, ancdata, flags, addr = self.conn.recvmsg(io.DEFAULT_BUFFER_SIZE, socket.CMSG_LEN(maxfds * fds.itemsize))
|
|
||||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||||
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
|
if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
|
||||||
# Append data, ignoring any truncated integers at the end.
|
# Append data, ignoring any truncated integers at the end.
|
||||||
@ -670,6 +674,10 @@ def main(stdin_fd: int, stdout_fd: int, notify_child_death_fd: int, unix_socket:
|
|||||||
def handle_socket_client(event: int) -> None:
|
def handle_socket_client(event: int) -> None:
|
||||||
check_event(event, 'UNIX socket fd listener failed')
|
check_event(event, 'UNIX socket fd listener failed')
|
||||||
conn, addr = unix_socket.accept()
|
conn, addr = unix_socket.accept()
|
||||||
|
if not verify_socket_creds(conn):
|
||||||
|
print_error('Connection attempted with invalid credentials ignoring')
|
||||||
|
conn.close()
|
||||||
|
return
|
||||||
sc = SocketChild(conn, addr, poll)
|
sc = SocketChild(conn, addr, poll)
|
||||||
socket_children[sc.conn.fileno()] = sc
|
socket_children[sc.conn.fileno()] = sc
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,8 @@ from . import BaseTest
|
|||||||
def socket_child_main(exit_code=0):
|
def socket_child_main(exit_code=0):
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from kitty.fast_data_types import get_options
|
from kitty.fast_data_types import get_options
|
||||||
from kitty.utils import read_screen_size
|
from kitty.utils import read_screen_size
|
||||||
output = {
|
output = {
|
||||||
@ -34,7 +36,7 @@ def socket_child_main(exit_code=0):
|
|||||||
|
|
||||||
'done': 'hello',
|
'done': 'hello',
|
||||||
}
|
}
|
||||||
print(json.dumps(output, indent=2))
|
print(json.dumps(output, indent=2), file=sys.stderr, flush=True)
|
||||||
raise SystemExit(exit_code)
|
raise SystemExit(exit_code)
|
||||||
|
|
||||||
# END_socket_child_main
|
# END_socket_child_main
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user