Add tests for the subseq matcher
This commit is contained in:
parent
edb25314c5
commit
60b64dadfe
@ -34,7 +34,6 @@ typedef uint8_t len_t;
|
||||
typedef uint32_t text_t;
|
||||
|
||||
#define LEN_MAX UINT8_MAX
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define UTF8_ACCEPT 0
|
||||
#define UTF8_REJECT 1
|
||||
#define IS_LOWERCASE(x) (x) >= 'a' && (x) <= 'z'
|
||||
@ -66,8 +65,8 @@ typedef struct {
|
||||
bool output_positions;
|
||||
size_t limit;
|
||||
int num_threads;
|
||||
text_t mark_before[128], mark_after[128];
|
||||
size_t mark_before_sz, mark_after_sz;
|
||||
text_t mark_before[128], mark_after[128], delimiter[128];
|
||||
size_t mark_before_sz, mark_after_sz, delimiter_sz;
|
||||
} Options;
|
||||
|
||||
VECTOR_OF(len_t, Positions)
|
||||
@ -75,7 +74,7 @@ VECTOR_OF(text_t, Chars)
|
||||
VECTOR_OF(Candidate, Candidates)
|
||||
|
||||
|
||||
void output_results(GlobalData *, Candidate *haystack, size_t count, Options *opts, len_t needle_len, text_t delim);
|
||||
void output_results(GlobalData *, Candidate *haystack, size_t count, Options *opts, len_t needle_len);
|
||||
void* alloc_workspace(len_t max_haystack_len, GlobalData*);
|
||||
void* free_workspace(void *v);
|
||||
double score_item(void *v, text_t *haystack, len_t haystack_len, len_t *match_positions);
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
* Distributed under terms of the GPL3 license.
|
||||
*/
|
||||
|
||||
#include "data-types.h"
|
||||
#include "choose-data-types.h"
|
||||
#include "charsets.h"
|
||||
|
||||
@ -160,7 +161,7 @@ run_search(Options *opts, GlobalData *global, const char * const *lines, const s
|
||||
global->haystack = haystack;
|
||||
global->haystack_count = SIZE(candidates);
|
||||
ret = run_threaded(opts->num_threads, global);
|
||||
if (ret == 0) output_results(global, haystack, SIZE(candidates), opts, global->needle_len, '\n');
|
||||
if (ret == 0) output_results(global, haystack, SIZE(candidates), opts, global->needle_len);
|
||||
else { REPORT_OOM; }
|
||||
} else { ret = 1; REPORT_OOM; }
|
||||
|
||||
@ -185,13 +186,13 @@ match(PyObject *self, PyObject *args) {
|
||||
(void)(self);
|
||||
int output_positions;
|
||||
unsigned long limit;
|
||||
PyObject *lines, *levels, *needle, *mark_before, *mark_after;
|
||||
PyObject *lines, *levels, *needle, *mark_before, *mark_after, *delimiter;
|
||||
Options opts = {0};
|
||||
GlobalData global = {0};
|
||||
if (!PyArg_ParseTuple(args, "O!O!O!pkiO!O!",
|
||||
&lines, &PyList_Type, &levels, &PyTuple_Type, &needle, &PyUnicode_Type,
|
||||
if (!PyArg_ParseTuple(args, "O!O!UpkiUUU",
|
||||
&PyList_Type, &lines, &PyTuple_Type, &levels, &needle,
|
||||
&output_positions, &limit, &opts.num_threads,
|
||||
&mark_before, &PyUnicode_Type, &mark_after, &PyUnicode_Type
|
||||
&mark_before, &mark_after, &delimiter
|
||||
)) return NULL;
|
||||
opts.output_positions = output_positions ? true : false;
|
||||
opts.limit = limit;
|
||||
@ -201,13 +202,14 @@ match(PyObject *self, PyObject *args) {
|
||||
global.needle_len = copy_unicode_object(needle, global.needle, arraysz(global.needle));
|
||||
opts.mark_before_sz = copy_unicode_object(mark_before, opts.mark_before, arraysz(opts.mark_before));
|
||||
opts.mark_after_sz = copy_unicode_object(mark_after, opts.mark_after, arraysz(opts.mark_after));
|
||||
opts.delimiter_sz = copy_unicode_object(delimiter, opts.delimiter, arraysz(opts.delimiter));
|
||||
size_t num_lines = PyList_GET_SIZE(lines);
|
||||
char **clines = malloc(sizeof(char*) * num_lines);
|
||||
size_t *sizes = malloc(sizeof(size_t) * num_lines);
|
||||
if (!lines || !sizes) { PyErr_NoMemory(); return NULL; }
|
||||
if (!lines || !sizes) { return PyErr_NoMemory(); }
|
||||
for (size_t i = 0; i < num_lines; i++) {
|
||||
clines[i] = PyBytes_AS_STRING(PyTuple_GET_ITEM(lines, i));
|
||||
sizes[i] = PyBytes_GET_SIZE(PyTuple_GET_ITEM(lines, i));
|
||||
clines[i] = PyBytes_AS_STRING(PyList_GET_ITEM(lines, i));
|
||||
sizes[i] = PyBytes_GET_SIZE(PyList_GET_ITEM(lines, i));
|
||||
}
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
run_search(&opts, &global, (const char* const *)clines, sizes, num_lines);
|
||||
@ -235,11 +237,7 @@ static struct PyModuleDef module = {
|
||||
.m_methods = module_methods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
EXPORTED PyMODINIT_FUNC
|
||||
PyInit_subseq_matcher(void) {
|
||||
PyObject *m;
|
||||
|
||||
m = PyModule_Create(&module);
|
||||
if (m == NULL) return NULL;
|
||||
return m;
|
||||
return PyModule_Create(&module);
|
||||
}
|
||||
|
||||
@ -8,6 +8,40 @@ import sys
|
||||
from ..tui.handler import Handler
|
||||
from ..tui.loop import Loop
|
||||
|
||||
from . import subseq_matcher
|
||||
|
||||
|
||||
def match(
|
||||
input_data,
|
||||
query,
|
||||
threads=0,
|
||||
positions=False,
|
||||
level1='/',
|
||||
level2='-_0123456789',
|
||||
level3='.',
|
||||
limit=0,
|
||||
mark_before='',
|
||||
mark_after='',
|
||||
delimiter='\n'
|
||||
):
|
||||
if isinstance(input_data, str):
|
||||
input_data = input_data.encode('utf-8')
|
||||
if isinstance(input_data, bytes):
|
||||
input_data = input_data.split(delimiter.encode('utf-8'))
|
||||
else:
|
||||
input_data = [x.encode('utf-8') if isinstance(x, str) else x for x in input_data]
|
||||
query = query.lower()
|
||||
level1 = level1.lower()
|
||||
level2 = level2.lower()
|
||||
level3 = level3.lower()
|
||||
data = subseq_matcher.match(
|
||||
input_data, (level1, level2, level3), query,
|
||||
positions, limit, threads,
|
||||
mark_before, mark_after, delimiter)
|
||||
if data is None:
|
||||
return []
|
||||
return list(filter(None, data.split(delimiter or '\n')))
|
||||
|
||||
|
||||
class ChooseHandler(Handler):
|
||||
|
||||
|
||||
@ -75,31 +75,31 @@ output_positions(GlobalData *global, len_t *positions, len_t num) {
|
||||
int num = swprintf(buf, sizeof(buf)/sizeof(buf[0]), L"%u", positions[i]);
|
||||
if (num > 0 && ensure_space(global, num + 1)) {
|
||||
for (int i = 0; i < num; i++) global->output[global->output_pos++] = buf[i];
|
||||
global->output[global->output_pos++] = (i == num - 1) ? ':' : ',';
|
||||
global->output[global->output_pos++] = (i == num - 1) ? ',' : ':';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
output_result(GlobalData *global, Candidate *c, Options *opts, len_t needle_len, text_t delim) {
|
||||
output_result(GlobalData *global, Candidate *c, Options *opts, len_t needle_len) {
|
||||
if (opts->output_positions) output_positions(global, c->positions, needle_len);
|
||||
if (opts->mark_before_sz > 0 || opts->mark_after_sz > 0) {
|
||||
output_with_marks(global, opts, c->src, c->src_sz, c->positions, needle_len);
|
||||
} else {
|
||||
output_text(global, c->src, c->src_sz);
|
||||
}
|
||||
output_text(global, &delim, 1);
|
||||
output_text(global, opts->delimiter, opts->delimiter_sz);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
output_results(GlobalData *global, Candidate *haystack, size_t count, Options *opts, len_t needle_len, text_t delim) {
|
||||
output_results(GlobalData *global, Candidate *haystack, size_t count, Options *opts, len_t needle_len) {
|
||||
Candidate *c;
|
||||
qsort(haystack, count, sizeof(*haystack), cmpscore);
|
||||
size_t left = opts->limit > 0 ? opts->limit : count;
|
||||
for (size_t i = 0; i < left; i++) {
|
||||
c = haystack + i;
|
||||
if (c->score > 0) output_result(global, c, opts, needle_len, delim);
|
||||
if (c->score > 0) output_result(global, c, opts, needle_len);
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,10 +6,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
#include "data-types.h"
|
||||
|
||||
#define REPORT_OOM global->oom = 1;
|
||||
|
||||
|
||||
100
kitty_tests/choose.py
Normal file
100
kitty_tests/choose.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python
|
||||
# vim:fileencoding=utf-8
|
||||
# License: GPLv3 Copyright: 2019, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
from . import BaseTest
|
||||
|
||||
|
||||
def run(input_data, query, **kw):
|
||||
kw['threads'] = kw.get('threads', 1)
|
||||
mark = kw.pop('mark', False)
|
||||
from kittens.choose.main import match
|
||||
mark_before = mark_after = ''
|
||||
if mark:
|
||||
if mark is True:
|
||||
mark_before, mark_after = '\033[32m', '\033[39m'
|
||||
else:
|
||||
mark_before = mark_after = mark
|
||||
kw['mark_before'], kw['mark_after'] = mark_before, mark_after
|
||||
return match(input_data, query, **kw)
|
||||
|
||||
|
||||
class TestMatcher(BaseTest):
|
||||
|
||||
def run_matcher(self, *args, **kwargs):
|
||||
result = run(*args, **kwargs)
|
||||
return result
|
||||
|
||||
def basic_test(self, inp, query, out, **k):
|
||||
result = self.run_matcher(inp, query, **k)
|
||||
if out is not None:
|
||||
if hasattr(out, 'splitlines'):
|
||||
out = list(filter(None, out.split(k.get('delimiter', '\n'))))
|
||||
self.assertEqual(list(out), result)
|
||||
return out
|
||||
|
||||
def test_filtering(self):
|
||||
' Non matching entries must be removed '
|
||||
self.basic_test('test\nxyz', 'te', 'test')
|
||||
self.basic_test('abc\nxyz', 'ba', '')
|
||||
self.basic_test('abc\n123', 'abc', 'abc')
|
||||
|
||||
def test_case_insensitive(self):
|
||||
self.basic_test('test\nxyz', 'Te', 'test')
|
||||
self.basic_test('test\nxyz', 'XY', 'xyz')
|
||||
self.basic_test('test\nXYZ', 'xy', 'XYZ')
|
||||
self.basic_test('test\nXYZ', 'mn', '')
|
||||
|
||||
def test_marking(self):
|
||||
' Marking of matched characters '
|
||||
self.basic_test(
|
||||
'test\nxyz',
|
||||
'ts',
|
||||
'\x1b[32mt\x1b[39me\x1b[32ms\x1b[39mt',
|
||||
mark=True)
|
||||
|
||||
def test_positions(self):
|
||||
' Output of positions '
|
||||
self.basic_test('abc\nac', 'ac', '0,1:ac\n0,2:abc', positions=True)
|
||||
|
||||
def test_delimiter(self):
|
||||
' Test using a custom line delimiter '
|
||||
self.basic_test('abc\n21ac', 'ac', 'ac1abc\n2', delimiter='1')
|
||||
|
||||
def test_scoring(self):
|
||||
' Scoring algorithm '
|
||||
# Match at start
|
||||
self.basic_test('archer\nelementary', 'e', 'elementary\narcher')
|
||||
# Match at level factor
|
||||
self.basic_test('xxxy\nxx/y', 'y', 'xx/y\nxxxy')
|
||||
# CamelCase
|
||||
self.basic_test('xxxy\nxxxY', 'y', 'xxxY\nxxxy')
|
||||
# Total length
|
||||
self.basic_test('xxxya\nxxxy', 'y', 'xxxy\nxxxya')
|
||||
# Distance
|
||||
self.basic_test('abbc\nabc', 'ac', 'abc\nabbc')
|
||||
# Extreme chars
|
||||
self.basic_test('xxa\naxx', 'a', 'axx\nxxa')
|
||||
# Highest score
|
||||
self.basic_test('xa/a', 'a', 'xa/|a|', mark='|')
|
||||
|
||||
def test_threading(self):
|
||||
' Test matching on a large data set with different number of threads '
|
||||
alphabet = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
||||
|
||||
def random_word():
|
||||
sz = random.randint(2, 10)
|
||||
return ''.join(random.choice(alphabet) for x in range(sz))
|
||||
words = [random_word() for i in range(400)]
|
||||
|
||||
def random_item():
|
||||
num = random.randint(2, 7)
|
||||
return '/'.join(random.choice(words) for w in range(num))
|
||||
|
||||
data = '\n'.join(random_item() for x in range(25123))
|
||||
|
||||
for threads in range(4):
|
||||
self.basic_test(data, 'foo', None, threads=threads)
|
||||
Loading…
x
Reference in New Issue
Block a user