Ensure sys.path is preserved even if there are errors importing a custom kitten

This commit is contained in:
Kovid Goyal 2021-06-21 04:35:10 +05:30
parent 1438c64b9e
commit 6e77345263
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 29 additions and 14 deletions

View File

@ -6,8 +6,9 @@
import importlib import importlib
import os import os
import sys import sys
from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Any, Dict, FrozenSet, List, TYPE_CHECKING, cast from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Generator, List, cast
from kitty.types import run_once from kitty.types import run_once
@ -30,21 +31,29 @@ def path_to_custom_kitten(config_dir: str, kitten: str) -> str:
return path return path
@contextmanager
def preserve_sys_path() -> Generator[None, None, None]:
orig = sys.path[:]
try:
yield
finally:
if sys.path != orig:
del sys.path[:]
sys.path.extend(orig)
def import_kitten_main_module(config_dir: str, kitten: str) -> Dict[str, Any]: def import_kitten_main_module(config_dir: str, kitten: str) -> Dict[str, Any]:
if kitten.endswith('.py'): if kitten.endswith('.py'):
path_modified = False with preserve_sys_path():
path = path_to_custom_kitten(config_dir, kitten) path = path_to_custom_kitten(config_dir, kitten)
if os.path.dirname(path): if os.path.dirname(path):
sys.path.insert(0, os.path.dirname(path)) sys.path.insert(0, os.path.dirname(path))
path_modified = True with open(path) as f:
with open(path) as f: src = f.read()
src = f.read() code = compile(src, path, 'exec')
code = compile(src, path, 'exec') g = {'__name__': 'kitten'}
g = {'__name__': 'kitten'} exec(code, g)
exec(code, g) hr = g.get('handle_result', lambda *a, **kw: None)
hr = g.get('handle_result', lambda *a, **kw: None)
if path_modified:
del sys.path[0]
return {'start': g['main'], 'end': hr} return {'start': g['main'], 'end': hr}
kitten = resolved_kitten(kitten) kitten = resolved_kitten(kitten)

View File

@ -423,3 +423,9 @@ def write_output(loc: str, defn: Definition) -> None:
c = generate_c_conversion(loc, ctypes) c = generate_c_conversion(loc, ctypes)
with open(os.path.join(*loc.split('.'), 'options', 'to-c-generated.h'), 'w') as f: with open(os.path.join(*loc.split('.'), 'options', 'to-c-generated.h'), 'w') as f:
f.write(c + '\n') f.write(c + '\n')
def main() -> None:
import sys
kitten = sys.argv[-1]
if not kitten.endswith(