diff --git a/setup.py b/setup.py index 292d9305b..84c4db9bb 100755 --- a/setup.py +++ b/setup.py @@ -19,8 +19,8 @@ from contextlib import suppress from functools import lru_cache, partial from pathlib import Path from typing import ( - Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional, Sequence, - Set, Tuple, Union + Callable, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, + Union ) from glfw import glfw @@ -500,20 +500,14 @@ def get_vcs_rev_defines(env: Env, src: str) -> List[str]: return ans -def get_library_defines(env: Env, src: str) -> Optional[List[str]]: - try: - return env.library_paths[src] - except KeyError: - return None - - -SPECIAL_SOURCES: Dict[str, Tuple[str, Union[List[str], Callable[[Env, str], Union[Optional[List[str]], Iterator[str]]]]]] = { - 'glfw/egl_context.c': ('glfw/egl_context.c', get_library_defines), - 'kitty/desktop.c': ('kitty/desktop.c', get_library_defines), - 'kitty/fontconfig.c': ('kitty/fontconfig.c', get_library_defines), - 'kitty/parser_dump.c': ('kitty/parser.c', ['DUMP_COMMANDS']), - 'kitty/data-types.c': ('kitty/data-types.c', get_vcs_rev_defines), -} +def get_source_specific_defines(env: Env, src: str) -> Tuple[str, Optional[List[str]]]: + if src == 'kitty/parser_dump.c': + return 'kitty/parser.c', ['DUMP_COMMANDS'] + if src == 'kitty/data-types.c': + return src, get_vcs_rev_defines(env, src) + with suppress(KeyError): + return src, env.library_paths[src] + return src, None def newer(dest: str, *sources: str) -> bool: @@ -699,13 +693,9 @@ def compile_c_extension( for original_src, dest in zip(sources, objects): src = original_src cppflags = kenv.cppflags[:] - is_special = src in SPECIAL_SOURCES - if is_special: - src, defines_ = SPECIAL_SOURCES[src] - defines = defines_(kenv, src) if callable(defines_) else defines_ - if defines is not None: - cppflags.extend(map(define, defines)) - + src, defines = get_source_specific_defines(kenv, src) + if defines is not None: + cppflags.extend(map(define, defines)) cmd = kenv.cc + ['-MMD'] + cppflags + kenv.cflags cmd += ['-c', src] + ['-o', dest] key = CompileKey(original_src, os.path.basename(dest))