From 90bc3ab770a892e66dc4a6625bac832048fdf2d7 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 7 Jun 2022 14:07:39 +0530 Subject: [PATCH] Function to create a randomly named UNIX domain socket --- kitty/child-monitor.c | 17 ++++++++++++++- kitty/fast_data_types.pyi | 4 ++++ kitty/utils.py | 44 ++++++++++++++++++++++++++++++++++----- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/kitty/child-monitor.c b/kitty/child-monitor.c index d9faef67e..e38c02d85 100644 --- a/kitty/child-monitor.c +++ b/kitty/child-monitor.c @@ -20,6 +20,7 @@ #include #include #include +#include #include extern PyTypeObject Screen_Type; @@ -1756,7 +1757,6 @@ cocoa_set_menubar_title(PyObject *self UNUSED, PyObject *args UNUSED) { } static PyObject* - send_data_to_peer(PyObject *self UNUSED, PyObject *args) { char * msg; Py_ssize_t sz; unsigned long long peer_id; @@ -1765,8 +1765,23 @@ send_data_to_peer(PyObject *self UNUSED, PyObject *args) { Py_RETURN_NONE; } +static PyObject * +random_unix_socket(PyObject *self UNUSED, PyObject *args UNUSED) { + int fd, optval = 1; + struct sockaddr_un bind_addr = {.sun_family=AF_UNIX}; + fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd < 0) return PyErr_SetFromErrno(PyExc_OSError); + if (setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &optval, sizeof optval) != 0) goto fail; + if (bind(fd, (struct sockaddr *)&bind_addr, sizeof(sa_family_t)) != 0) goto fail; + return PyLong_FromLong((long)fd); +fail: + safe_close(fd, __FILE__, __LINE__); + return PyErr_SetFromErrno(PyExc_OSError); +} + static PyMethodDef module_methods[] = { METHODB(safe_pipe, METH_VARARGS), + METHODB(random_unix_socket, METH_NOARGS), {"add_timer", (PyCFunction)add_python_timer, METH_VARARGS, ""}, {"remove_timer", (PyCFunction)remove_python_timer, METH_VARARGS, ""}, METHODB(monitor_pid, METH_VARARGS), diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 6418a44fe..93c229951 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1387,3 +1387,7 @@ def sigqueue(pid: int, signal: int, value: int) -> None: def establish_controlling_tty(ttyname: str, stdin: int, stdout: int, stderr: int) -> None: pass + + +def random_unix_socket() -> int: + pass diff --git a/kitty/utils.py b/kitty/utils.py index 2cbb34568..7280d91be 100644 --- a/kitty/utils.py +++ b/kitty/utils.py @@ -372,15 +372,18 @@ class startup_notification_handler: end_startup_notification(self.ctx) -def remove_socket_file(s: 'Socket', path: Optional[str] = None) -> None: +def remove_socket_file(s: 'Socket', path: Optional[str] = None, is_dir: Optional[Callable[[str], None]] = None) -> None: with suppress(OSError): s.close() if path: with suppress(OSError): - os.unlink(path) + if is_dir: + is_dir(path) + else: + os.unlink(path) -def unix_socket_paths(name: str, ext: str = '.lock') -> Generator[str, None, None]: +def unix_socket_directories() -> Iterator[str]: import tempfile home = os.path.expanduser('~') candidates = [tempfile.gettempdir(), home] @@ -389,8 +392,39 @@ def unix_socket_paths(name: str, ext: str = '.lock') -> Generator[str, None, Non candidates = [user_cache_dir(), '/Library/Caches'] for loc in candidates: if os.access(loc, os.W_OK | os.R_OK | os.X_OK): - filename = ('.' if loc == home else '') + name + ext - yield os.path.join(loc, filename) + yield loc + + +def unix_socket_paths(name: str, ext: str = '.lock') -> Generator[str, None, None]: + home = os.path.expanduser('~') + for loc in unix_socket_directories(): + filename = ('.' if loc == home else '') + name + ext + yield os.path.join(loc, filename) + + +def random_unix_socket() -> 'Socket': + import shutil + import socket + import stat + import tempfile + + from kitty.fast_data_types import random_unix_socket as rus + try: + fd = rus() + except OSError: + for path in unix_socket_directories(): + ans = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0) + tdir = tempfile.mkdtemp(prefix='.kitty-', dir=path) + atexit.register(remove_socket_file, ans, tdir, shutil.rmtree) + path = os.path.join(tdir, 's') + ans.bind(path) + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) + break + else: + ans = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, fileno=fd) + ans.set_inheritable(False) + ans.setblocking(False) + return ans def single_instance_unix(name: str) -> bool: