From 6e77345263e153dca240e53e1abeafe640f8ac10 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 21 Jun 2021 04:35:10 +0530 Subject: [PATCH] Ensure sys.path is preserved even if there are errors importing a custom kitten --- kittens/runner.py | 37 +++++++++++++++++++++++-------------- kitty/conf/generate.py | 6 ++++++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/kittens/runner.py b/kittens/runner.py index 12c9b6e9e..1df26c473 100644 --- a/kittens/runner.py +++ b/kittens/runner.py @@ -6,8 +6,9 @@ import importlib import os import sys +from contextlib import contextmanager from functools import partial -from typing import Any, Dict, FrozenSet, List, TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, cast from kitty.types import run_once @@ -30,21 +31,29 @@ def path_to_custom_kitten(config_dir: str, kitten: str) -> str: return path +@contextmanager +def preserve_sys_path() -> Generator[None, None, None]: + orig = sys.path[:] + try: + yield + finally: + if sys.path != orig: + del sys.path[:] + sys.path.extend(orig) + + def import_kitten_main_module(config_dir: str, kitten: str) -> Dict[str, Any]: if kitten.endswith('.py'): - path_modified = False - path = path_to_custom_kitten(config_dir, kitten) - if os.path.dirname(path): - sys.path.insert(0, os.path.dirname(path)) - path_modified = True - with open(path) as f: - src = f.read() - code = compile(src, path, 'exec') - g = {'__name__': 'kitten'} - exec(code, g) - hr = g.get('handle_result', lambda *a, **kw: None) - if path_modified: - del sys.path[0] + with preserve_sys_path(): + path = path_to_custom_kitten(config_dir, kitten) + if os.path.dirname(path): + sys.path.insert(0, os.path.dirname(path)) + with open(path) as f: + src = f.read() + code = compile(src, path, 'exec') + g = {'__name__': 'kitten'} + exec(code, g) + hr = g.get('handle_result', lambda *a, **kw: None) return {'start': g['main'], 'end': hr} kitten = resolved_kitten(kitten) diff --git a/kitty/conf/generate.py b/kitty/conf/generate.py index fcbb4d53a..029ce0237 100644 --- a/kitty/conf/generate.py +++ b/kitty/conf/generate.py @@ -423,3 +423,9 @@ def write_output(loc: str, defn: Definition) -> None: c = generate_c_conversion(loc, ctypes) with open(os.path.join(*loc.split('.'), 'options', 'to-c-generated.h'), 'w') as f: f.write(c + '\n') + + +def main() -> None: + import sys + kitten = sys.argv[-1] + if not kitten.endswith(