diff --git a/gen-go-code.py b/gen-go-code.py index 402b5fe6a..3af7ad7cf 100755 --- a/gen-go-code.py +++ b/gen-go-code.py @@ -1,15 +1,17 @@ #!./kitty/launcher/kitty +launch # License: GPLv3 Copyright: 2022, Kovid Goyal +import bz2 import io import json import os import struct import subprocess import sys -import zlib +import tarfile from contextlib import contextmanager, suppress from functools import lru_cache +from itertools import chain from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Set, TextIO, Tuple, Union import kitty.constants as kc @@ -40,7 +42,7 @@ def newer(dest: str, *sources: str) -> bool: dtime = os.path.getmtime(dest) except OSError: return True - for s in sources: + for s in chain(sources, (__file__,)): with suppress(FileNotFoundError): if os.path.getmtime(s) >= dtime: return True @@ -442,6 +444,7 @@ def load_ref_map() -> Dict[str, Dict[str, str]]: def generate_constants() -> str: from kitty.options.types import Options + from kitty.options.utils import allowed_shell_integration_values ref_map = load_ref_map() dp = ", ".join(map(lambda x: f'"{serialize_as_go_string(x)}"', kc.default_pager_for_help)) return f'''\ @@ -465,6 +468,7 @@ var CharacterKeyNameAliases = map[string]string{serialize_go_dict(character_key_ var ConfigModMap = map[string]uint16{serialize_go_dict(config_mod_map)} var RefMap = map[string]string{serialize_go_dict(ref_map['ref'])} var DocTitleMap = map[string]string{serialize_go_dict(ref_map['doc'])} +var AllowedShellIntegrationValues = []string{{ {str(list(allowed_shell_integration_values))[1:-1].replace("'", '"')} }} var KittyConfigDefaults = struct {{ Term, Shell_integration string }}{{ @@ -649,7 +653,7 @@ def generate_textual_mimetypes() -> str: def write_compressed_data(data: bytes, d: BinaryIO) -> None: d.write(struct.pack(' None: @@ -678,20 +682,19 @@ def generate_ssh_kitten_data() -> None: path = os.path.join(dirpath, f) files.add(path.replace(os.sep, '/')) dest = 'tools/cmd/ssh/data_generated.bin' + + def normalize(t: tarfile.TarInfo) -> tarfile.TarInfo: + t.uid = t.gid = 0 + t.uname = t.gname = '' + return t + if newer(dest, *files): buf = io.BytesIO() - fmap = dict.fromkeys(files, (0, 0)) - for f in fmap: - with open(f, 'rb') as src: - data = src.read() - pos = buf.tell() - buf.write(data) - size = len(data) - fmap[f] = pos, size - mapping = ','.join(f'{name} {pos[0]} {pos[1]}' for name, pos in sorted(fmap.items())).encode('ascii') - data = struct.pack(' None: diff --git a/kittens/ssh/copy.py b/kittens/ssh/copy.py index 0ced93899..79c96e51d 100644 --- a/kittens/ssh/copy.py +++ b/kittens/ssh/copy.py @@ -36,7 +36,7 @@ type=list A glob pattern. Files with names matching this pattern are excluded from being transferred. Useful when adding directories. Can be specified multiple times, if any of the patterns match the file will be -excluded. To exclude a directory use a pattern like */directory_name/*. +excluded. To exclude a directory use a pattern like :code:`*/directory_name/*`. --symlink-strategy diff --git a/kittens/ssh/main.py b/kittens/ssh/main.py index 7f08d7dca..5185b6790 100644 --- a/kittens/ssh/main.py +++ b/kittens/ssh/main.py @@ -209,8 +209,6 @@ def get_ssh_data(msg: str, request_id: str) -> Iterator[bytes]: yield f'{e}\n'.encode('utf-8') else: yield b'OK\n' - ssh_opts = SSHOptions(env_data['opts']) - ssh_opts.copy = {k: CopyInstruction(*v) for k, v in ssh_opts.copy.items()} encoded_data = memoryview(env_data['tarfile'].encode('ascii')) # macOS has a 255 byte limit on its input queue as per man stty. # Not clear if that applies to canonical mode input as well, but diff --git a/kitty/options/utils.py b/kitty/options/utils.py index 50adaf0cc..2254f08b9 100644 --- a/kitty/options/utils.py +++ b/kitty/options/utils.py @@ -854,12 +854,14 @@ def store_multiple(val: str, current_val: Container[str]) -> Iterable[Tuple[str, yield val, val +allowed_shell_integration_values = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'}) + + def shell_integration(x: str) -> FrozenSet[str]: - s = frozenset({'enabled', 'disabled', 'no-rc', 'no-cursor', 'no-title', 'no-prompt-mark', 'no-complete', 'no-cwd'}) q = frozenset(x.lower().split()) - if not q.issubset(s): - log_error(f'Invalid shell integration options: {q - s}, ignoring') - return q & s or frozenset({'invalid'}) + if not q.issubset(allowed_shell_integration_values): + log_error(f'Invalid shell integration options: {q - allowed_shell_integration_values}, ignoring') + return q & allowed_shell_integration_values or frozenset({'invalid'}) return q diff --git a/tools/cmd/ssh/config.go b/tools/cmd/ssh/config.go index 10fb2114a..d1624ecbc 100644 --- a/tools/cmd/ssh/config.go +++ b/tools/cmd/ssh/config.go @@ -3,6 +3,7 @@ package ssh import ( + "archive/tar" "encoding/json" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "time" "kitty/tools/config" "kitty/tools/utils" @@ -86,10 +88,10 @@ func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string return export(self.val) } -func (self *Config) final_env_instructions(for_python bool, get_local_env func(string) (string, bool)) string { - seen := make(map[string]int, len(self.Env)) - ans := make([]string, 0, len(self.Env)) - for _, ei := range self.Env { +func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string { + seen := make(map[string]int, len(env)) + ans := make([]string, 0, len(env)) + for _, ei := range env { q := ei.Serialize(for_python, get_local_env) if q != "" { if pos, found := seen[ei.key]; found { @@ -222,6 +224,100 @@ func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) { return } +type file_unique_id struct { + dev, inode uint64 +} + +func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string, recurse bool) error { + s, err := os.Lstat(local_path) + if err != nil { + return err + } + u, ok := s.Sys().(unix.Stat_t) + cb := func(h *tar.Header, data []byte) error { + h.Name = arcname + h.Size = int64(len(data)) + h.Mode = int64(s.Mode()) + + h.ModTime = s.ModTime() + h.Uid, h.Gid = 0, 0 + h.Uname, h.Gname = "", "" + h.Format = tar.FormatPAX + if ok { + h.AccessTime = time.Unix(0, u.Atim.Nano()) + h.ChangeTime = time.Unix(0, u.Ctim.Nano()) + } + return callback(h, data) + } + // we only copy regular files, directories and symlinks + switch s.Mode().Type() { + case fs.ModeSymlink: + target, err := os.Readlink(local_path) + if err != nil { + return err + } + err = cb(&tar.Header{ + Typeflag: tar.TypeSymlink, + Linkname: target, + }, nil) + if err != nil { + return err + } + case fs.ModeDir: + err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil) + if err != nil { + return err + } + if recurse { + local_path = filepath.Clean(local_path) + return filepath.WalkDir(local_path, func(path string, d fs.DirEntry, werr error) error { + if filepath.Clean(path) == local_path { + return nil + } + for _, pat := range exclude_patterns { + if matched, err := filepath.Match(pat, path); matched && err == nil { + return nil + } + } + if werr == nil { + rel, err := filepath.Rel(local_path, path) + if err != nil { + aname := filepath.Join(arcname, rel) + return get_file_data(callback, seen, path, aname, nil, false) + } + } + return nil + }) + } + case 0: // Regular file + fid := file_unique_id{dev: u.Dev, inode: u.Ino} + if prev, ok := seen[fid]; ok { // Hard link + err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil) + if err != nil { + return err + } + } + seen[fid] = arcname + data, err := os.ReadFile(local_path) + if err != nil { + return err + } + err = cb(&tar.Header{Typeflag: tar.TypeReg}, data) + if err != nil { + return err + } + } + return nil +} + +func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) { + ep := ci.exclude_patterns + for _, folder_name := range []string{"__pycache__", ".DS_Store"} { + ep = append(ep, "*/"+folder_name, "*/"+folder_name+"/*") + } + return get_file_data(callback, seen, ci.local_path, ci.arcname, ep, true) +} + type ConfigSet struct { all_configs []*Config } diff --git a/tools/cmd/ssh/data.go b/tools/cmd/ssh/data.go index fa3278bc8..8b79793ea 100644 --- a/tools/cmd/ssh/data.go +++ b/tools/cmd/ssh/data.go @@ -3,13 +3,13 @@ package ssh import ( - "bytes" + "archive/tar" _ "embed" - "encoding/binary" + "errors" "fmt" + "io" "kitty/tools/utils" - "strconv" - "strings" + "path/filepath" ) var _ = fmt.Print @@ -17,21 +17,48 @@ var _ = fmt.Print //go:embed data_generated.bin var embedded_data string -type Container = map[string][]byte +type Entry struct { + metadata *tar.Header + data []byte +} + +type Container map[string]Entry var Data = (&utils.Once[Container]{Run: func() Container { - raw := utils.ReadCompressedEmbeddedData(embedded_data) - num_of_entries := binary.LittleEndian.Uint32(raw) - raw = raw[4:] - ans := make(Container, num_of_entries) - idx := bytes.IndexByte(raw, '\n') - text := utils.UnsafeBytesToString(raw[:idx]) - raw = raw[idx+1:] - for _, record := range strings.Split(text, ",") { - parts := strings.Split(record, " ") - offset, _ := strconv.Atoi(parts[1]) - size, _ := strconv.Atoi(parts[2]) - ans[parts[0]] = raw[offset : offset+size] + tr := tar.NewReader(utils.ReaderForCompressedEmbeddedData(embedded_data)) + ans := make(Container, 64) + for { + hdr, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + panic(err) + } + data, err := utils.ReadAll(tr, int(hdr.Size)) + if err != nil { + panic(err) + } + ans[hdr.Name] = Entry{hdr, data} } return ans }}).Get + +func (self Container) files_matching(include_pattern string, exclude_patterns ...string) []string { + ans := make([]string, 0, len(self)) + for name := range self { + if matched, err := filepath.Match(include_pattern, name); matched && err == nil { + excluded := false + for _, pat := range exclude_patterns { + if matched, err := filepath.Match(pat, name); matched && err == nil { + excluded = true + break + } + } + if !excluded { + ans = append(ans, name) + } + } + } + return ans +} diff --git a/tools/cmd/ssh/main.go b/tools/cmd/ssh/main.go index e7abaf5e9..81cc00710 100644 --- a/tools/cmd/ssh/main.go +++ b/tools/cmd/ssh/main.go @@ -3,6 +3,10 @@ package ssh import ( + "archive/tar" + "bytes" + "compress/gzip" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -12,14 +16,18 @@ import ( "os" "os/exec" "os/user" + "path" "path/filepath" + "regexp" "strconv" "strings" + "time" "kitty/tools/cli" "kitty/tools/tty" "kitty/tools/tui/loop" "kitty/tools/utils" + "kitty/tools/utils/secrets" "kitty/tools/utils/shm" "golang.org/x/exp/maps" @@ -167,11 +175,339 @@ type connection_data struct { echo_on bool request_data bool literal_env map[string]string + test_script string + + shm_name string + script_type string + rcmd []string + replacements map[string]string + request_id string + bootstrap_script string +} + +func get_effective_ksi_env_var(x string) string { + parts := strings.Split(strings.TrimSpace(strings.ToLower(x)), " ") + current := utils.NewSetWithItems(parts...) + if current.Has("disabled") { + return "" + } + allowed := utils.NewSetWithItems(kitty.AllowedShellIntegrationValues...) + if !current.IsSubsetOf(allowed) { + return RelevantKittyOpts().Shell_integration + } + return x +} + +func serialize_env(cd *connection_data, get_local_env func(string) (string, bool)) (string, string) { + ksi := "" + if cd.host_opts.Shell_integration == "inherited" { + ksi = get_effective_ksi_env_var(RelevantKittyOpts().Shell_integration) + } else { + ksi = get_effective_ksi_env_var(cd.host_opts.Shell_integration) + } + env := make([]*EnvInstruction, 0, 8) + add_env := func(key, val string, fallback ...string) *EnvInstruction { + if val == "" && len(fallback) > 0 { + val = fallback[0] + } + if val != "" { + env = append(env, &EnvInstruction{key: key, val: val, literal_quote: true}) + return env[len(env)-1] + } + return nil + } + for k, v := range cd.literal_env { + add_env(k, v) + } + add_env("TERM", os.Getenv("TERM"), RelevantKittyOpts().Term) + add_env("COLORTERM", "truecolor") + env = append(env, cd.host_opts.Env...) + add_env("KITTY_WINDOW_ID", os.Getenv("KITTY_WINDOW_ID")) + add_env("WINDOWID", os.Getenv("WINDOWID")) + if ksi != "" { + add_env("KITTY_SHELL_INTEGRATION", ksi) + } else { + env = append(env, &EnvInstruction{key: "KITTY_SHELL_INTEGRATION", delete_on_remote: true}) + } + add_env("KITTY_SSH_KITTEN_DATA_DIR", cd.host_opts.Remote_dir) + add_env("KITTY_LOGIN_SHELL", cd.host_opts.Login_shell) + add_env("KITTY_LOGIN_CWD", cd.host_opts.Cwd) + if cd.host_opts.Remote_kitty != Remote_kitty_no { + add_env("KITTY_REMOTE", cd.host_opts.Remote_kitty.String()) + } + add_env("KITTY_PUBLIC_KEY", os.Getenv("KITTY_PUBLIC_KEY")) + return final_env_instructions(cd.script_type == "py", get_local_env), ksi +} + +func make_tarfile(cd *connection_data, get_local_env func(string) (string, bool)) ([]byte, error) { + env_script, ksi := serialize_env(cd, get_local_env) + w := bytes.Buffer{} + w.Grow(64 * 1024) + gw, err := gzip.NewWriterLevel(&w, gzip.BestCompression) + if err != nil { + return nil, err + } + tw := tar.NewWriter(gw) + rd := strings.TrimRight(cd.host_opts.Remote_dir, "/") + seen := make(map[file_unique_id]string, 32) + add := func(h *tar.Header, data []byte) (err error) { + // some distro's like nix mess with installed file permissions so ensure + // files are at least readable and writable by owning user + h.Mode |= 0o600 + err = tw.WriteHeader(h) + if err != nil { + return + } + if data != nil { + _, err := tw.Write(data) + if err != nil { + return err + } + } + return + } + for _, ci := range cd.host_opts.Copy { + get_file_data(add, seen, ci.local_path, ci.arcname, ci.exclude_patterns, true) + } + type fe struct { + arcname string + data []byte + } + now := time.Now() + add_data := func(items ...fe) error { + for _, item := range items { + err := add( + &tar.Header{ + Typeflag: tar.TypeReg, Name: item.arcname, Format: tar.FormatPAX, Size: int64(len(item.data)), + Mode: 0o644, ModTime: now, ChangeTime: now, AccessTime: now, + }, item.data) + if err != nil { + return err + } + } + return nil + } + add_entries := func(prefix string, items ...Entry) error { + for _, item := range items { + err := add( + &tar.Header{ + Typeflag: item.metadata.Typeflag, Name: path.Join(prefix, path.Base(item.metadata.Name)), Format: tar.FormatPAX, + Size: int64(len(item.data)), Mode: item.metadata.Mode, ModTime: item.metadata.ModTime, + AccessTime: item.metadata.AccessTime, ChangeTime: item.metadata.ChangeTime, + }, item.data) + if err != nil { + return err + } + } + return nil + + } + add_data(fe{"data.sh", utils.UnsafeStringToBytes(env_script)}) + if ksi != "" { + for _, fname := range Data().files_matching( + "shell-integration/*", + "shell-integration/ssh/*", // bootstrap files are sent as command line args + "shell_integration/zsh/kitty.zsh", // backward compat file not needed by ssh kitten + ) { + arcname := path.Join("home/", rd, "/", path.Dir(fname)) + err = add_entries(arcname, Data()[fname]) + if err != nil { + return nil, err + } + } + } + if cd.host_opts.Remote_kitty != Remote_kitty_no { + arcname := path.Join("home/", rd, "/kitty") + err = add_data(fe{arcname + "/version", utils.UnsafeStringToBytes(kitty.VersionString)}) + if err != nil { + return nil, err + } + for _, x := range []string{"kitty", "kitten"} { + err = add_entries(path.Join(arcname, "bin"), Data()[path.Join("shell-integration", "ssh", x)]) + if err != nil { + return nil, err + } + } + } + err = add_entries(path.Join("home", ".terminfo"), Data()["terminfo/kitty.terminfo"]) + if err == nil { + err = add_entries(path.Join("home", ".terminfo", "x"), Data()["terminfo/x/xterm-kitty"]) + } + if err == nil { + err = tw.Close() + if err == nil { + err = gw.Close() + } + } + return w.Bytes(), err +} + +func prepare_home_command(cd *connection_data) string { + is_python := cd.script_type == "py" + homevar := "" + for _, ei := range cd.host_opts.Env { + if ei.key == "HOME" && !ei.delete_on_remote { + if ei.copy_from_local { + homevar = os.Getenv("HOME") + } else { + homevar = ei.val + } + } + } + export_home_cmd := "" + if homevar != "" { + if is_python { + export_home_cmd = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(homevar)) + } else { + export_home_cmd = fmt.Sprintf("export HOME=%s; cd \"$HOME\"", utils.QuoteStringForSH(homevar)) + } + } + return export_home_cmd +} + +func prepare_exec_cmd(cd *connection_data) string { + // ssh simply concatenates multiple commands using a space see + // line 1129 of ssh.c and on the remote side sshd.c runs the + // concatenated command as shell -c cmd + if cd.script_type == "py" { + return base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(strings.Join(cd.remote_args, " "))) + } + args := make([]string, len(cd.remote_args)) + for i, arg := range cd.remote_args { + args[i] = strings.ReplaceAll(arg, "'", "'\"'\"'") + } + return "unset KITTY_SHELL_INTEGRATION; exec \"$login_shell\" -c '" + strings.Join(args, " ") + "'" +} + +var data_shm shm.MMap + +func prepare_script(script string, replacements map[string]string) string { + if _, found := replacements["EXEC_CMD"]; !found { + replacements["EXEC_CMD"] = "" + } + if _, found := replacements["EXPORT_HOME_CMD"]; !found { + replacements["EXPORT_HOME_CMD"] = "" + } + keys := maps.Keys(replacements) + for i, key := range keys { + keys[i] = "\\b" + key + "\\b" + } + pat := regexp.MustCompile(strings.Join(keys, "|")) + return pat.ReplaceAllStringFunc(script, func(key string) string { return replacements[key] }) +} + +func bootstrap_script(cd *connection_data) (err error) { + if cd.request_id == "" { + cd.request_id = os.Getenv("KITTY_PID") + "-" + os.Getenv("KITTY_WINDOW_ID") + } + export_home_cmd := prepare_home_command(cd) + exec_cmd := "" + if len(cd.remote_args) > 0 { + exec_cmd = prepare_exec_cmd(cd) + } + pw, err := secrets.TokenHex() + if err != nil { + return err + } + tfd, err := make_tarfile(cd, os.LookupEnv) + if err != nil { + return err + } + data := map[string]string{ + "tarfile": base64.StdEncoding.EncodeToString(tfd), + "pw": pw, + "hostname": cd.hostname_for_match, "username": cd.username, + } + encoded_data, err := json.Marshal(data) + if err == nil { + data_shm, err = shm.CreateTemp(fmt.Sprintf("kssh-%d-", os.Getpid()), uint64(len(encoded_data)+8)) + if err == nil { + err = data_shm.WriteWithSize(encoded_data) + if err == nil { + err = data_shm.Flush() + } + } + } + if err != nil { + return err + } + cd.shm_name = data_shm.Name() + sensitive_data := map[string]string{"REQUEST_ID": cd.request_id, "DATA_PASSWORD": pw, "PASSWORD_FILENAME": cd.shm_name} + replacements := map[string]string{ + "EXPORT_HOME_CMD": export_home_cmd, + "EXEC_CMD": exec_cmd, + "TEST_SCRIPT": cd.test_script, + } + add_bool := func(ok bool, key string) { + if ok { + replacements[key] = "1" + } else { + replacements[key] = "0" + } + } + add_bool(cd.request_data, "REQUEST_DATA") + add_bool(cd.echo_on, "ECHO_ON") + sd := maps.Clone(replacements) + if cd.request_data { + maps.Copy(sd, sensitive_data) + } + maps.Copy(replacements, sensitive_data) + cd.replacements = replacements + cd.bootstrap_script = utils.UnsafeBytesToString(Data()["shell-integration/ssh/bootstrap."+cd.script_type].data) + cd.bootstrap_script = prepare_script(cd.bootstrap_script, sd) + return err +} + +func wrap_bootstrap_script(cd *connection_data) { + // sshd will execute the command we pass it by join all command line + // arguments with a space and passing it as a single argument to the users + // login shell with -c. If the user has a non POSIX login shell it might + // have different escaping semantics and syntax, so the command it should + // execute has to be as simple as possible, basically of the form + // interpreter -c unwrap_script escaped_bootstrap_script + // The unwrap_script is responsible for unescaping the bootstrap script and + // executing it. + encoded_script := "" + unwrap_script := "" + if cd.script_type == "py" { + encoded_script = base64.StdEncoding.EncodeToString(utils.UnsafeStringToBytes(cd.bootstrap_script)) + unwrap_script = `"import base64, sys; eval(compile(base64.standard_b64decode(sys.argv[-1]), 'bootstrap.py', 'exec'))"` + } else { + // We cant rely on base64 being available on the remote system, so instead + // we quote the bootstrap script by replacing ' and \ with \v and \f + // also replacing \n and ! with \r and \b for tcsh + // finally surrounding with ' + encoded_script = "'" + strings.NewReplacer("'", "\v", "\\", "\f", "\n", "\r", "!", "\b").Replace(cd.bootstrap_script) + "'" + unwrap_script = `'eval "$(echo "$0" | tr \\\v\\\f\\\r\\\b \\\047\\\134\\\n\\\041)"' ` + } + cd.rcmd = []string{"exec", cd.host_opts.Interpreter, "-c", unwrap_script, encoded_script} +} + +func get_remote_command(cd *connection_data) error { + interpreter := cd.host_opts.Interpreter + q := strings.ToLower(path.Base(interpreter)) + is_python := strings.Contains(q, "python") + cd.script_type = "sh" + if is_python { + cd.script_type = "py" + } + err := bootstrap_script(cd) + if err != nil { + return err + } + wrap_bootstrap_script(cd) + return nil } func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err error) { go Data() go RelevantKittyOpts() + defer func() { + if data_shm != nil { + data_shm.Close() + data_shm.Unlink() + } + }() cmd := append([]string{SSHExe()}, ssh_args...) cd := connection_data{remote_args: server_args[1:]} hostname := server_args[0] @@ -224,6 +560,10 @@ func run_ssh(ssh_args, server_args, found_extra_args []string) (rc int, err erro term.WriteString(loop.HANDLE_TERMIOS_SIGNALS.EscapeCodeToSet()) defer term.WriteString(loop.RESTORE_PRIVATE_MODE_VALUES) defer term.RestoreAndClose() + err = get_remote_command(&cd) + if err != nil { + return 1, err + } return 0, nil } diff --git a/tools/utils/embed.go b/tools/utils/embed.go index ad280f4f3..c485477ba 100644 --- a/tools/utils/embed.go +++ b/tools/utils/embed.go @@ -4,7 +4,7 @@ package utils import ( "bytes" - "compress/zlib" + "compress/bzip2" "encoding/binary" "fmt" "io" @@ -33,11 +33,14 @@ func ReadAll(r io.Reader, expected_size int) ([]byte, error) { func ReadCompressedEmbeddedData(raw string) []byte { compressed := UnsafeStringToBytes(raw) uncompressed_size := binary.LittleEndian.Uint32(compressed) - r, _ := zlib.NewReader(bytes.NewReader(compressed[4:])) - defer r.Close() + r := bzip2.NewReader(bytes.NewReader(compressed[4:])) ans, err := ReadAll(r, int(uncompressed_size)) if err != nil { panic(err) } return ans } + +func ReaderForCompressedEmbeddedData(raw string) io.Reader { + return bzip2.NewReader(bytes.NewReader(UnsafeStringToBytes(raw)[4:])) +} diff --git a/tools/utils/paths.go b/tools/utils/paths.go index b94995370..bb12d3dda 100644 --- a/tools/utils/paths.go +++ b/tools/utils/paths.go @@ -130,7 +130,7 @@ var CacheDir = (&Once[string]{Run: func() (cache_dir string) { }}).Get func macos_user_cache_dir() string { - // Sadly Go does not provide confstr() so we use this hack. We could + // Sadly Go does not provide confstr() so we use this hack. // Note that given a user generateduid and uid we can derive this by using // the algorithm at https://github.com/ydkhatri/MacForensics/blob/master/darwin_path_generator.py // but I cant find a good way to get the generateduid. Requires calling dscl in which case we might as well call getconf