diff --git a/kittens/choose/choose-data-types.h b/kittens/choose/choose-data-types.h index 4f67f5941..590245a73 100644 --- a/kittens/choose/choose-data-types.h +++ b/kittens/choose/choose-data-types.h @@ -34,7 +34,6 @@ typedef uint8_t len_t; typedef uint32_t text_t; #define LEN_MAX UINT8_MAX -#define UNUSED(x) (void)(x) #define UTF8_ACCEPT 0 #define UTF8_REJECT 1 #define IS_LOWERCASE(x) (x) >= 'a' && (x) <= 'z' @@ -66,8 +65,8 @@ typedef struct { bool output_positions; size_t limit; int num_threads; - text_t mark_before[128], mark_after[128]; - size_t mark_before_sz, mark_after_sz; + text_t mark_before[128], mark_after[128], delimiter[128]; + size_t mark_before_sz, mark_after_sz, delimiter_sz; } Options; VECTOR_OF(len_t, Positions) @@ -75,7 +74,7 @@ VECTOR_OF(text_t, Chars) VECTOR_OF(Candidate, Candidates) -void output_results(GlobalData *, Candidate *haystack, size_t count, Options *opts, len_t needle_len, text_t delim); +void output_results(GlobalData *, Candidate *haystack, size_t count, Options *opts, len_t needle_len); void* alloc_workspace(len_t max_haystack_len, GlobalData*); void* free_workspace(void *v); double score_item(void *v, text_t *haystack, len_t haystack_len, len_t *match_positions); diff --git a/kittens/choose/main.c b/kittens/choose/main.c index ca4907a00..ce7de2556 100644 --- a/kittens/choose/main.c +++ b/kittens/choose/main.c @@ -5,6 +5,7 @@ * Distributed under terms of the GPL3 license. */ +#include "data-types.h" #include "choose-data-types.h" #include "charsets.h" @@ -160,7 +161,7 @@ run_search(Options *opts, GlobalData *global, const char * const *lines, const s global->haystack = haystack; global->haystack_count = SIZE(candidates); ret = run_threaded(opts->num_threads, global); - if (ret == 0) output_results(global, haystack, SIZE(candidates), opts, global->needle_len, '\n'); + if (ret == 0) output_results(global, haystack, SIZE(candidates), opts, global->needle_len); else { REPORT_OOM; } } else { ret = 1; REPORT_OOM; } @@ -185,13 +186,13 @@ match(PyObject *self, PyObject *args) { (void)(self); int output_positions; unsigned long limit; - PyObject *lines, *levels, *needle, *mark_before, *mark_after; + PyObject *lines, *levels, *needle, *mark_before, *mark_after, *delimiter; Options opts = {0}; GlobalData global = {0}; - if (!PyArg_ParseTuple(args, "O!O!O!pkiO!O!", - &lines, &PyList_Type, &levels, &PyTuple_Type, &needle, &PyUnicode_Type, + if (!PyArg_ParseTuple(args, "O!O!UpkiUUU", + &PyList_Type, &lines, &PyTuple_Type, &levels, &needle, &output_positions, &limit, &opts.num_threads, - &mark_before, &PyUnicode_Type, &mark_after, &PyUnicode_Type + &mark_before, &mark_after, &delimiter )) return NULL; opts.output_positions = output_positions ? true : false; opts.limit = limit; @@ -201,13 +202,14 @@ match(PyObject *self, PyObject *args) { global.needle_len = copy_unicode_object(needle, global.needle, arraysz(global.needle)); opts.mark_before_sz = copy_unicode_object(mark_before, opts.mark_before, arraysz(opts.mark_before)); opts.mark_after_sz = copy_unicode_object(mark_after, opts.mark_after, arraysz(opts.mark_after)); + opts.delimiter_sz = copy_unicode_object(delimiter, opts.delimiter, arraysz(opts.delimiter)); size_t num_lines = PyList_GET_SIZE(lines); char **clines = malloc(sizeof(char*) * num_lines); size_t *sizes = malloc(sizeof(size_t) * num_lines); - if (!lines || !sizes) { PyErr_NoMemory(); return NULL; } + if (!lines || !sizes) { return PyErr_NoMemory(); } for (size_t i = 0; i < num_lines; i++) { - clines[i] = PyBytes_AS_STRING(PyTuple_GET_ITEM(lines, i)); - sizes[i] = PyBytes_GET_SIZE(PyTuple_GET_ITEM(lines, i)); + clines[i] = PyBytes_AS_STRING(PyList_GET_ITEM(lines, i)); + sizes[i] = PyBytes_GET_SIZE(PyList_GET_ITEM(lines, i)); } Py_BEGIN_ALLOW_THREADS; run_search(&opts, &global, (const char* const *)clines, sizes, num_lines); @@ -235,11 +237,7 @@ static struct PyModuleDef module = { .m_methods = module_methods }; -PyMODINIT_FUNC +EXPORTED PyMODINIT_FUNC PyInit_subseq_matcher(void) { - PyObject *m; - - m = PyModule_Create(&module); - if (m == NULL) return NULL; - return m; + return PyModule_Create(&module); } diff --git a/kittens/choose/main.py b/kittens/choose/main.py index d6067d728..5a104c76f 100644 --- a/kittens/choose/main.py +++ b/kittens/choose/main.py @@ -8,6 +8,40 @@ import sys from ..tui.handler import Handler from ..tui.loop import Loop +from . import subseq_matcher + + +def match( + input_data, + query, + threads=0, + positions=False, + level1='/', + level2='-_0123456789', + level3='.', + limit=0, + mark_before='', + mark_after='', + delimiter='\n' +): + if isinstance(input_data, str): + input_data = input_data.encode('utf-8') + if isinstance(input_data, bytes): + input_data = input_data.split(delimiter.encode('utf-8')) + else: + input_data = [x.encode('utf-8') if isinstance(x, str) else x for x in input_data] + query = query.lower() + level1 = level1.lower() + level2 = level2.lower() + level3 = level3.lower() + data = subseq_matcher.match( + input_data, (level1, level2, level3), query, + positions, limit, threads, + mark_before, mark_after, delimiter) + if data is None: + return [] + return list(filter(None, data.split(delimiter or '\n'))) + class ChooseHandler(Handler): diff --git a/kittens/choose/output.c b/kittens/choose/output.c index 2f8fed166..eac45383b 100644 --- a/kittens/choose/output.c +++ b/kittens/choose/output.c @@ -75,31 +75,31 @@ output_positions(GlobalData *global, len_t *positions, len_t num) { int num = swprintf(buf, sizeof(buf)/sizeof(buf[0]), L"%u", positions[i]); if (num > 0 && ensure_space(global, num + 1)) { for (int i = 0; i < num; i++) global->output[global->output_pos++] = buf[i]; - global->output[global->output_pos++] = (i == num - 1) ? ':' : ','; + global->output[global->output_pos++] = (i == num - 1) ? ',' : ':'; } } } static void -output_result(GlobalData *global, Candidate *c, Options *opts, len_t needle_len, text_t delim) { +output_result(GlobalData *global, Candidate *c, Options *opts, len_t needle_len) { if (opts->output_positions) output_positions(global, c->positions, needle_len); if (opts->mark_before_sz > 0 || opts->mark_after_sz > 0) { output_with_marks(global, opts, c->src, c->src_sz, c->positions, needle_len); } else { output_text(global, c->src, c->src_sz); } - output_text(global, &delim, 1); + output_text(global, opts->delimiter, opts->delimiter_sz); } void -output_results(GlobalData *global, Candidate *haystack, size_t count, Options *opts, len_t needle_len, text_t delim) { +output_results(GlobalData *global, Candidate *haystack, size_t count, Options *opts, len_t needle_len) { Candidate *c; qsort(haystack, count, sizeof(*haystack), cmpscore); size_t left = opts->limit > 0 ? opts->limit : count; for (size_t i = 0; i < left; i++) { c = haystack + i; - if (c->score > 0) output_result(global, c, opts, needle_len, delim); + if (c->score > 0) output_result(global, c, opts, needle_len); } } diff --git a/kittens/choose/vector.h b/kittens/choose/vector.h index 33e13b5f1..72413c731 100644 --- a/kittens/choose/vector.h +++ b/kittens/choose/vector.h @@ -6,10 +6,7 @@ #pragma once -#include - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#include "data-types.h" #define REPORT_OOM global->oom = 1; diff --git a/kitty_tests/choose.py b/kitty_tests/choose.py new file mode 100644 index 000000000..5571b7777 --- /dev/null +++ b/kitty_tests/choose.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# vim:fileencoding=utf-8 +# License: GPLv3 Copyright: 2019, Kovid Goyal + +import random +import string + +from . import BaseTest + + +def run(input_data, query, **kw): + kw['threads'] = kw.get('threads', 1) + mark = kw.pop('mark', False) + from kittens.choose.main import match + mark_before = mark_after = '' + if mark: + if mark is True: + mark_before, mark_after = '\033[32m', '\033[39m' + else: + mark_before = mark_after = mark + kw['mark_before'], kw['mark_after'] = mark_before, mark_after + return match(input_data, query, **kw) + + +class TestMatcher(BaseTest): + + def run_matcher(self, *args, **kwargs): + result = run(*args, **kwargs) + return result + + def basic_test(self, inp, query, out, **k): + result = self.run_matcher(inp, query, **k) + if out is not None: + if hasattr(out, 'splitlines'): + out = list(filter(None, out.split(k.get('delimiter', '\n')))) + self.assertEqual(list(out), result) + return out + + def test_filtering(self): + ' Non matching entries must be removed ' + self.basic_test('test\nxyz', 'te', 'test') + self.basic_test('abc\nxyz', 'ba', '') + self.basic_test('abc\n123', 'abc', 'abc') + + def test_case_insensitive(self): + self.basic_test('test\nxyz', 'Te', 'test') + self.basic_test('test\nxyz', 'XY', 'xyz') + self.basic_test('test\nXYZ', 'xy', 'XYZ') + self.basic_test('test\nXYZ', 'mn', '') + + def test_marking(self): + ' Marking of matched characters ' + self.basic_test( + 'test\nxyz', + 'ts', + '\x1b[32mt\x1b[39me\x1b[32ms\x1b[39mt', + mark=True) + + def test_positions(self): + ' Output of positions ' + self.basic_test('abc\nac', 'ac', '0,1:ac\n0,2:abc', positions=True) + + def test_delimiter(self): + ' Test using a custom line delimiter ' + self.basic_test('abc\n21ac', 'ac', 'ac1abc\n2', delimiter='1') + + def test_scoring(self): + ' Scoring algorithm ' + # Match at start + self.basic_test('archer\nelementary', 'e', 'elementary\narcher') + # Match at level factor + self.basic_test('xxxy\nxx/y', 'y', 'xx/y\nxxxy') + # CamelCase + self.basic_test('xxxy\nxxxY', 'y', 'xxxY\nxxxy') + # Total length + self.basic_test('xxxya\nxxxy', 'y', 'xxxy\nxxxya') + # Distance + self.basic_test('abbc\nabc', 'ac', 'abc\nabbc') + # Extreme chars + self.basic_test('xxa\naxx', 'a', 'axx\nxxa') + # Highest score + self.basic_test('xa/a', 'a', 'xa/|a|', mark='|') + + def test_threading(self): + ' Test matching on a large data set with different number of threads ' + alphabet = string.ascii_lowercase + string.ascii_uppercase + string.digits + + def random_word(): + sz = random.randint(2, 10) + return ''.join(random.choice(alphabet) for x in range(sz)) + words = [random_word() for i in range(400)] + + def random_item(): + num = random.randint(2, 7) + return '/'.join(random.choice(words) for w in range(num)) + + data = '\n'.join(random_item() for x in range(25123)) + + for threads in range(4): + self.basic_test(data, 'foo', None, threads=threads)