Use an enum for state as well

This commit is contained in:
Kovid Goyal 2022-01-18 13:58:55 +05:30
parent f1fbfe297d
commit f9a4b6bb0d
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -61,7 +61,12 @@ except ImportError:
yield -1 yield -1
INITIALIZING, COLLECTED, DIFFED, COMMAND, MESSAGE = range(5) class State(Enum):
initializing = auto()
collected = auto()
diffed = auto()
command = auto()
message = auto()
class BackgroundWork(Enum): class BackgroundWork(Enum):
@ -89,7 +94,7 @@ class DiffHandler(Handler):
image_manager_class = ImageManager image_manager_class = ImageManager
def __init__(self, args: DiffCLIOptions, opts: DiffOptions, left: str, right: str) -> None: def __init__(self, args: DiffCLIOptions, opts: DiffOptions, left: str, right: str) -> None:
self.state = INITIALIZING self.state = State.initializing
self.message = '' self.message = ''
self.current_search_is_regex = True self.current_search_is_regex = True
self.current_search: Optional[Search] = None self.current_search: Optional[Search] = None
@ -116,7 +121,7 @@ class DiffHandler(Handler):
if func == 'quit': if func == 'quit':
self.terminate() self.terminate()
return return
if self.state <= DIFFED: if self.state.value <= State.diffed.value:
if func == 'scroll_by': if func == 'scroll_by':
return self.scroll_lines(int(args[0] or 0)) return self.scroll_lines(int(args[0] or 0))
if func == 'scroll_to': if func == 'scroll_to':
@ -149,7 +154,7 @@ class DiffHandler(Handler):
def collect_done(collection: Collection) -> None: def collect_done(collection: Collection) -> None:
self.doing_background_work = BackgroundWork.none self.doing_background_work = BackgroundWork.none
self.collection = collection self.collection = collection
self.state = COLLECTED self.state = State.collected
self.generate_diff() self.generate_diff()
def collect(left: str, right: str) -> None: def collect(left: str, right: str) -> None:
@ -167,7 +172,7 @@ class DiffHandler(Handler):
self.report_traceback_on_exit = diff_map self.report_traceback_on_exit = diff_map
self.terminate(1) self.terminate(1)
return return
self.state = DIFFED self.state = State.diffed
self.diff_map = diff_map self.diff_map = diff_map
self.calculate_statistics() self.calculate_statistics()
self.render_diff() self.render_diff()
@ -327,7 +332,7 @@ class DiffHandler(Handler):
self.create_collection() self.create_collection()
def enforce_cursor_state(self) -> None: def enforce_cursor_state(self) -> None:
self.cmd.set_cursor_visible(self.state == COMMAND) self.cmd.set_cursor_visible(self.state is State.command)
def draw_lines(self, num: int, offset: int = 0) -> None: def draw_lines(self, num: int, offset: int = 0) -> None:
offset += self.scroll_pos offset += self.scroll_pos
@ -430,7 +435,7 @@ class DiffHandler(Handler):
def draw_screen(self) -> None: def draw_screen(self) -> None:
self.enforce_cursor_state() self.enforce_cursor_state()
if self.state < DIFFED: if self.state.value < State.diffed.value:
self.cmd.clear_screen() self.cmd.clear_screen()
self.write(_('Calculating diff, please wait...')) self.write(_('Calculating diff, please wait...'))
return return
@ -440,14 +445,14 @@ class DiffHandler(Handler):
self.draw_status_line() self.draw_status_line()
def draw_status_line(self) -> None: def draw_status_line(self) -> None:
if self.state < DIFFED: if self.state.value < State.diffed.value:
return return
self.enforce_cursor_state() self.enforce_cursor_state()
self.cmd.set_cursor_position(0, self.num_lines) self.cmd.set_cursor_position(0, self.num_lines)
self.cmd.clear_to_eol() self.cmd.clear_to_eol()
if self.state is COMMAND: if self.state is State.command:
self.line_edit.write(self.write) self.line_edit.write(self.write)
elif self.state is MESSAGE: elif self.state is State.message:
self.cmd.styled(self.message, reverse=True) self.cmd.styled(self.message, reverse=True)
else: else:
sp = f'{self.scroll_pos/self.max_scroll_pos:.0%}' if self.scroll_pos and self.max_scroll_pos else '0%' sp = f'{self.scroll_pos/self.max_scroll_pos:.0%}' if self.scroll_pos and self.max_scroll_pos else '0%'
@ -470,16 +475,16 @@ class DiffHandler(Handler):
new_ctx = max(0, new_ctx) new_ctx = max(0, new_ctx)
if new_ctx != self.current_context_count: if new_ctx != self.current_context_count:
self.current_context_count = new_ctx self.current_context_count = new_ctx
self.state = COLLECTED self.state = State.collected
self.generate_diff() self.generate_diff()
self.restore_position = self.current_position self.restore_position = self.current_position
self.draw_screen() self.draw_screen()
def start_search(self, is_regex: bool, is_backward: bool) -> None: def start_search(self, is_regex: bool, is_backward: bool) -> None:
if self.state != DIFFED: if self.state is not State.diffed:
self.cmd.bell() self.cmd.bell()
return return
self.state = COMMAND self.state = State.command
self.line_edit.clear() self.line_edit.clear()
self.line_edit.add_text('?' if is_backward else '/') self.line_edit.add_text('?' if is_backward else '/')
self.current_search_is_regex = is_regex self.current_search_is_regex = is_regex
@ -493,50 +498,50 @@ class DiffHandler(Handler):
try: try:
self.current_search = Search(self.opts, query[1:], self.current_search_is_regex, query[0] == '?') self.current_search = Search(self.opts, query[1:], self.current_search_is_regex, query[0] == '?')
except BadRegex: except BadRegex:
self.state = MESSAGE self.state = State.message
self.message = sanitize(_('Bad regex: {}').format(query[1:])) self.message = sanitize(_('Bad regex: {}').format(query[1:]))
self.cmd.bell() self.cmd.bell()
else: else:
if self.current_search(self.diff_lines, self.margin_size, self.screen_size.cols): if self.current_search(self.diff_lines, self.margin_size, self.screen_size.cols):
self.scroll_to_next_match(include_current=True) self.scroll_to_next_match(include_current=True)
else: else:
self.state = MESSAGE self.state = State.message
self.message = sanitize(_('No matches found')) self.message = sanitize(_('No matches found'))
self.cmd.bell() self.cmd.bell()
def on_key_event(self, key_event: KeyEvent, in_bracketed_paste: bool = False) -> None: def on_key_event(self, key_event: KeyEvent, in_bracketed_paste: bool = False) -> None:
if key_event.text: if key_event.text:
if self.state is COMMAND: if self.state is State.command:
self.line_edit.on_text(key_event.text, in_bracketed_paste) self.line_edit.on_text(key_event.text, in_bracketed_paste)
self.draw_status_line() self.draw_status_line()
return return
if self.state is MESSAGE: if self.state is State.message:
self.state = DIFFED self.state = State.diffed
self.draw_status_line() self.draw_status_line()
return return
else: else:
if self.state is MESSAGE: if self.state is State.message:
if key_event.type is not EventType.RELEASE: if key_event.type is not EventType.RELEASE:
self.state = DIFFED self.state = State.diffed
self.draw_status_line() self.draw_status_line()
return return
if self.state is COMMAND: if self.state is State.command:
if self.line_edit.on_key(key_event): if self.line_edit.on_key(key_event):
if not self.line_edit.current_input: if not self.line_edit.current_input:
self.state = DIFFED self.state = State.diffed
self.draw_status_line() self.draw_status_line()
return return
if key_event.matches('enter'): if key_event.matches('enter'):
self.state = DIFFED self.state = State.diffed
self.do_search() self.do_search()
self.line_edit.clear() self.line_edit.clear()
self.draw_screen() self.draw_screen()
return return
if key_event.matches('esc'): if key_event.matches('esc'):
self.state = DIFFED self.state = State.diffed
self.draw_status_line() self.draw_status_line()
return return
if self.state >= DIFFED and self.current_search is not None and key_event.matches('esc'): if self.state.value >= State.diffed.value and self.current_search is not None and key_event.matches('esc'):
self.current_search = None self.current_search = None
self.draw_screen() self.draw_screen()
return return
@ -549,7 +554,7 @@ class DiffHandler(Handler):
def on_resize(self, screen_size: ScreenSize) -> None: def on_resize(self, screen_size: ScreenSize) -> None:
self.screen_size = screen_size self.screen_size = screen_size
self.set_scrolling_region() self.set_scrolling_region()
if self.state > COLLECTED: if self.state.value > State.collected.value:
self.image_manager.delete_all_sent_images() self.image_manager.delete_all_sent_images()
self.render_diff() self.render_diff()
self.draw_screen() self.draw_screen()