#!/usr/bin/env python3 # ################################################################ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under both the BSD-style license (found in the # LICENSE file in the root directory of this source tree) and the GPLv2 (found # in the COPYING file in the root directory of this source tree). # You may select, at your option, one of the above-listed licenses. # ########################################################################## import argparse import contextlib import os import re import shutil import sys from typing import Optional INCLUDED_SUBDIRS = ["common", "compress", "decompress"] SKIPPED_FILES = [ "common/mem.h", "common/zstd_deps.h", "common/pool.c", "common/pool.h", "common/threading.c", "common/threading.h", "common/zstd_trace.h", "compress/zstdmt_compress.h", "compress/zstdmt_compress.c", ] XXHASH_FILES = [ "common/xxhash.c", "common/xxhash.h", ] class FileLines(object): def __init__(self, filename): self.filename = filename with open(self.filename, "r") as f: self.lines = f.readlines() def write(self): with open(self.filename, "w") as f: f.write("".join(self.lines)) class PartialPreprocessor(object): """ Looks for simple ifdefs and ifndefs and replaces them. Handles && and ||. Has fancy logic to handle translating elifs to ifs. Only looks for macros in the first part of the expression with no parens. Does not handle multi-line macros (only looks in first line). """ def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]): MACRO_GROUP = r"(?P[a-zA-Z_][a-zA-Z_0-9]*)" ELIF_GROUP = r"(?Pel)?" OP_GROUP = r"(?P&&|\|\|)?" self._defs = {macro:value for macro, value in defs} self._replaces = {macro:value for macro, value in replaces} self._defs.update(self._replaces) self._undefs = set(undefs) self._define = re.compile(r"\s*#\s*define") self._if = re.compile(r"\s*#\s*if") self._elif = re.compile(r"\s*#\s*(?Pel)if") self._else = re.compile(r"\s*#\s*(?Pelse)") self._endif = re.compile(r"\s*#\s*endif") self._ifdef = re.compile(fr"\s*#\s*if(?Pn)?def {MACRO_GROUP}\s*") self._if_defined = re.compile( fr"\s*#\s*{ELIF_GROUP}if\s+(?P!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}" ) self._if_defined_value = re.compile( fr"\s*#\s*{ELIF_GROUP}if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*" fr"(?P&&)\s*" fr"(?P\()?\s*" fr"(?P[a-zA-Z_][a-zA-Z_0-9]*)\s*" fr"(?P[=>[0-9]*)\s*" fr"(?P\))?\s*" ) self._if_true = re.compile( fr"\s*#\s*{ELIF_GROUP}if\s+{MACRO_GROUP}\s*{OP_GROUP}" ) self._c_comment = re.compile(r"/\*.*?\*/") self._cpp_comment = re.compile(r"//") def _log(self, *args, **kwargs): print(*args, **kwargs) def _strip_comments(self, line): # First strip c-style comments (may include //) while True: m = self._c_comment.search(line) if m is None: break line = line[:m.start()] + line[m.end():] # Then strip cpp-style comments m = self._cpp_comment.search(line) if m is not None: line = line[:m.start()] return line def _fixup_indentation(self, macro, replace: [str]): if len(replace) == 0: return replace if len(replace) == 1 and self._define.match(replace[0]) is None: # If there is only one line, only replace defines return replace all_pound = True for line in replace: if not line.startswith('#'): all_pound = False if all_pound: replace = [line[1:] for line in replace] min_spaces = len(replace[0]) for line in replace: spaces = 0 for i, c in enumerate(line): if c != ' ': # Non-preprocessor line ==> skip the fixup if not all_pound and c != '#': return replace spaces = i break min_spaces = min(min_spaces, spaces) replace = [line[min_spaces:] for line in replace] if all_pound: replace = ["#" + line for line in replace] return replace def _handle_if_block(self, macro, idx, is_true, prepend): """ Remove the #if or #elif block starting on this line. """ REMOVE_ONE = 0 KEEP_ONE = 1 REMOVE_REST = 2 if is_true: state = KEEP_ONE else: state = REMOVE_ONE line = self._inlines[idx] is_if = self._if.match(line) is not None assert is_if or self._elif.match(line) is not None depth = 0 start_idx = idx idx += 1 replace = prepend finished = False while idx < len(self._inlines): line = self._inlines[idx] # Nested if statement if self._if.match(line): depth += 1 idx += 1 continue # We're inside a nested statement if depth > 0: if self._endif.match(line): depth -= 1 idx += 1 continue # We're at the original depth # Looking only for an endif. # We've found a true statement, but haven't # completely elided the if block, so we just # remove the remainder. if state == REMOVE_REST: if self._endif.match(line): if is_if: # Remove the endif because we took the first if idx += 1 finished = True break idx += 1 continue if state == KEEP_ONE: m = self._elif.match(line) if self._endif.match(line): replace += self._inlines[start_idx + 1:idx] idx += 1 finished = True break if self._elif.match(line) or self._else.match(line): replace += self._inlines[start_idx + 1:idx] state = REMOVE_REST idx += 1 continue if state == REMOVE_ONE: m = self._elif.match(line) if m is not None: if is_if: idx += 1 b = m.start('elif') e = m.end('elif') assert e - b == 2 replace.append(line[:b] + line[e:]) finished = True break m = self._else.match(line) if m is not None: if is_if: idx += 1 while self._endif.match(self._inlines[idx]) is None: replace.append(self._inlines[idx]) idx += 1 idx += 1 finished = True break if self._endif.match(line): if is_if: # Remove the endif because no other elifs idx += 1 finished = True break idx += 1 continue if not finished: raise RuntimeError("Unterminated if block!") replace = self._fixup_indentation(macro, replace) self._log(f"\tHardwiring {macro}") if start_idx > 0: self._log(f"\t\t {self._inlines[start_idx - 1][:-1]}") for x in range(start_idx, idx): self._log(f"\t\t- {self._inlines[x][:-1]}") for line in replace: self._log(f"\t\t+ {line[:-1]}") if idx < len(self._inlines): self._log(f"\t\t {self._inlines[idx][:-1]}") return idx, replace def _preprocess_once(self): outlines = [] idx = 0 changed = False while idx < len(self._inlines): line = self._inlines[idx] sline = self._strip_comments(line) m = self._ifdef.fullmatch(sline) if_true = False if m is None: m = self._if_defined_value.fullmatch(sline) if m is None: m = self._if_defined.match(sline) if m is None: m = self._if_true.match(sline) if_true = (m is not None) if m is None: outlines.append(line) idx += 1 continue groups = m.groupdict() macro = groups['macro'] op = groups.get('op') if not (macro in self._defs or macro in self._undefs): outlines.append(line) idx += 1 continue defined = macro in self._defs # Needed variables set: # resolved: Is the statement fully resolved? # is_true: If resolved, is the statement true? ifdef = False if if_true: if not defined: outlines.append(line) idx += 1 continue defined_value = self._defs[macro] is_int = True try: defined_value = int(defined_value) except TypeError: is_int = False except ValueError: is_int = False resolved = is_int is_true = (defined_value != 0) if resolved and op is not None: if op == '&&': resolved = not is_true else: assert op == '||' resolved = is_true else: ifdef = groups.get('not') is None elseif = groups.get('elif') is not None macro2 = groups.get('macro2') cmp = groups.get('cmp') value = groups.get('value') openp = groups.get('openp') closep = groups.get('closep') is_true = (ifdef == defined) resolved = True if op is not None: if op == '&&': resolved = not is_true else: assert op == '||' resolved = is_true if macro2 is not None and not resolved: assert ifdef and defined and op == '&&' and cmp is not None # If the statement is true, but we have a single value check, then # check the value. defined_value = self._defs[macro] are_ints = True try: defined_value = int(defined_value) value = int(value) except TypeError: are_ints = False except ValueError: are_ints = False if ( macro == macro2 and ((openp is None) == (closep is None)) and are_ints ): resolved = True if cmp == '<': is_true = defined_value < value elif cmp == '<=': is_true = defined_value <= value elif cmp == '==': is_true = defined_value == value elif cmp == '!=': is_true = defined_value != value elif cmp == '>=': is_true = defined_value >= value elif cmp == '>': is_true = defined_value > value else: resolved = False if op is not None and not resolved: # Remove the first op in the line + spaces if op == '&&': opre = op else: assert op == '||' opre = r'\|\|' needle = re.compile(fr"(?P\s*#\s*(el)?if\s+).*?(?P{opre}\s*)") match = needle.match(line) assert match is not None newline = line[:match.end('if')] + line[match.end('op'):] self._log(f"\tHardwiring partially resolved {macro}") self._log(f"\t\t- {line[:-1]}") self._log(f"\t\t+ {newline[:-1]}") outlines.append(newline) idx += 1 continue # Skip any statements we cannot fully compute if not resolved: outlines.append(line) idx += 1 continue prepend = [] if macro in self._replaces: assert not ifdef assert op is None value = self._replaces.pop(macro) prepend = [f"#define {macro} {value}\n"] idx, replace = self._handle_if_block(macro, idx, is_true, prepend) outlines += replace changed = True return changed, outlines def preprocess(self, filename): with open(filename, 'r') as f: self._inlines = f.readlines() changed = True iters = 0 while changed: iters += 1 changed, outlines = self._preprocess_once() self._inlines = outlines with open(filename, 'w') as f: f.write(''.join(self._inlines)) class Freestanding(object): def __init__( self, zstd_deps: str, mem: str, source_lib: str, output_lib: str, external_xxhash: bool, xxh64_state: Optional[str], xxh64_prefix: Optional[str], rewritten_includes: [(str, str)], defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str], excludes: [str], seds: [str], spdx: bool, ): self._zstd_deps = zstd_deps self._mem = mem self._src_lib = source_lib self._dst_lib = output_lib self._external_xxhash = external_xxhash self._xxh64_state = xxh64_state self._xxh64_prefix = xxh64_prefix self._rewritten_includes = rewritten_includes self._defs = defs self._replaces = replaces self._undefs = undefs self._excludes = excludes self._seds = seds self._spdx = spdx def _dst_lib_file_paths(self): """ Yields all the file paths in the dst_lib. """ for root, dirname, filenames in os.walk(self._dst_lib): for filename in filenames: filepath = os.path.join(root, filename) yield filepath def _log(self, *args, **kwargs): print(*args, **kwargs) def _copy_file(self, lib_path): suffixes = [".c", ".h", ".S"] if not any((lib_path.endswith(suffix) for suffix in suffixes)): return if lib_path in SKIPPED_FILES: self._log(f"\tSkipping file: {lib_path}") return if self._external_xxhash and lib_path in XXHASH_FILES: self._log(f"\tSkipping xxhash file: {lib_path}") return src_path = os.path.join(self._src_lib, lib_path) dst_path = os.path.join(self._dst_lib, lib_path) self._log(f"\tCopying: {src_path} -> {dst_path}") shutil.copyfile(src_path, dst_path) def _copy_source_lib(self): self._log("Copying source library into output library") assert os.path.exists(self._src_lib) os.makedirs(self._dst_lib, exist_ok=True) self._copy_file("zstd.h") self._copy_file("zstd_errors.h") for subdir in INCLUDED_SUBDIRS: src_dir = os.path.join(self._src_lib, subdir) dst_dir = os.path.join(self._dst_lib, subdir) assert os.path.exists(src_dir) os.makedirs(dst_dir, exist_ok=True) for filename in os.listdir(src_dir): lib_path = os.path.join(subdir, filename) self._copy_file(lib_path) def _copy_zstd_deps(self): dst_zstd_deps = os.path.join(self._dst_lib, "common", "zstd_deps.h") self._log(f"Copying zstd_deps: {self._zstd_deps} -> {dst_zstd_deps}") shutil.copyfile(self._zstd_deps, dst_zstd_deps) def _copy_mem(self): dst_mem = os.path.join(self._dst_lib, "common", "mem.h") self._log(f"Copying mem: {self._mem} -> {dst_mem}") shutil.copyfile(self._mem, dst_mem) def _hardwire_preprocessor(self, name: str, value: Optional[str] = None, undef=False): """ If value=None then hardwire that it is defined, but not what the value is. If undef=True then value must be None. If value='' then the macro is defined to '' exactly. """ assert not (undef and value is not None) for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) def _hardwire_defines(self): self._log("Hardwiring macros") partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs) for filepath in self._dst_lib_file_paths(): partial_preprocessor.preprocess(filepath) def _remove_excludes(self): self._log("Removing excluded sections") for exclude in self._excludes: self._log(f"\tRemoving excluded sections for: {exclude}") begin_re = re.compile(f"BEGIN {exclude}") end_re = re.compile(f"END {exclude}") for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) outlines = [] skipped = [] emit = True for line in file.lines: if emit and begin_re.search(line) is not None: assert end_re.search(line) is None emit = False if emit: outlines.append(line) else: skipped.append(line) if end_re.search(line) is not None: assert begin_re.search(line) is None self._log(f"\t\tRemoving excluded section: {exclude}") for s in skipped: self._log(f"\t\t\t- {s}") emit = True skipped = [] if not emit: raise RuntimeError("Excluded section unfinished!") file.lines = outlines file.write() def _rewrite_include(self, original, rewritten): self._log(f"\tRewriting include: {original} -> {rewritten}") regex = re.compile(f"\\s*#\\s*include\\s*(?P{original})") for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) for i, line in enumerate(file.lines): match = regex.match(line) if match is None: continue s = match.start('include') e = match.end('include') file.lines[i] = line[:s] + rewritten + line[e:] file.write() def _rewrite_includes(self): self._log("Rewriting includes") for original, rewritten in self._rewritten_includes: self._rewrite_include(original, rewritten) def _replace_xxh64_prefix(self): if self._xxh64_prefix is None: return self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}") replacements = [] if self._xxh64_state is not None: replacements.append( (re.compile(r"([^\w]|^)(?PXXH64_state_t)([^\w]|$)"), self._xxh64_state) ) if self._xxh64_prefix is not None: replacements.append( (re.compile(r"([^\w]|^)(?PXXH64)[\(_]"), self._xxh64_prefix) ) for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) for i, line in enumerate(file.lines): modified = False for regex, replacement in replacements: match = regex.search(line) while match is not None: modified = True b = match.start('orig') e = match.end('orig') line = line[:b] + replacement + line[e:] match = regex.search(line) if modified: self._log(f"\t- {file.lines[i][:-1]}") self._log(f"\t+ {line[:-1]}") file.lines[i] = line file.write() def _parse_sed(self, sed): assert sed[0] == 's' delim = sed[1] match = re.fullmatch(f's{delim}(.+){delim}(.*){delim}(.*)', sed) assert match is not None regex = re.compile(match.group(1)) format_str = match.group(2) is_global = match.group(3) == 'g' return regex, format_str, is_global def _process_sed(self, sed): self._log(f"Processing sed: {sed}") regex, format_str, is_global = self._parse_sed(sed) for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) for i, line in enumerate(file.lines): modified = False while True: match = regex.search(line) if match is None: break replacement = format_str.format(match.groups(''), match.groupdict('')) b = match.start() e = match.end() line = line[:b] + replacement + line[e:] modified = True if not is_global: break if modified: self._log(f"\t- {file.lines[i][:-1]}") self._log(f"\t+ {line[:-1]}") file.lines[i] = line file.write() def _process_seds(self): self._log("Processing seds") for sed in self._seds: self._process_sed(sed) def _process_spdx(self): if not self._spdx: return self._log("Processing spdx") SPDX_C = "// SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause\n" SPDX_H_S = "/* SPDX-License-Identifier: GPL-2.0+ OR BSD-3-Clause */\n" for filepath in self._dst_lib_file_paths(): file = FileLines(filepath) if file.lines[0] == SPDX_C or file.lines[0] == SPDX_H_S: continue for line in file.lines: if "SPDX-License-Identifier" in line: raise RuntimeError(f"Unexpected SPDX license identifier: {file.filename} {repr(line)}") if file.filename.endswith(".c"): file.lines.insert(0, SPDX_C) elif file.filename.endswith(".h") or file.filename.endswith(".S"): file.lines.insert(0, SPDX_H_S) else: raise RuntimeError(f"Unexpected file extension: {file.filename}") file.write() def go(self): self._copy_source_lib() self._copy_zstd_deps() self._copy_mem() self._hardwire_defines() self._remove_excludes() self._rewrite_includes() self._replace_xxh64_prefix() self._process_seds() self._process_spdx() def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]: output = [] for define in defines: parsed = define.split('=') if len(parsed) == 1: output.append((parsed[0], None)) elif len(parsed) == 2: output.append((parsed[0], parsed[1])) else: raise RuntimeError(f"Bad define: {define}") return output def parse_pair(rewritten_includes: [str]) -> [(str, str)]: output = [] for rewritten_include in rewritten_includes: parsed = rewritten_include.split('=') if len(parsed) == 2: output.append((parsed[0], parsed[1])) else: raise RuntimeError(f"Bad rewritten include: {rewritten_include}") return output def main(name, args): parser = argparse.ArgumentParser(prog=name) parser.add_argument("--zstd-deps", default="zstd_deps.h", help="Zstd dependencies file") parser.add_argument("--mem", default="mem.h", help="Memory module") parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library") parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library") parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash=''. If set xxhash is not included.") parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'") parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64") parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '=')") parser.add_argument("--sed", default=[], dest="seds", action="append", help="Apply a sed replacement. Format: `s/REGEX/FORMAT/[g]`. REGEX is a Python regex. FORMAT is a Python format string formatted by the regex dict.") parser.add_argument("--spdx", action="store_true", help="Add SPDX License Identifiers") parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)") parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed multiple times)") parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition") parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN ' and 'END '") args = parser.parse_args(args) # Always remove threading if "ZSTD_MULTITHREAD" not in args.undefs: args.undefs.append("ZSTD_MULTITHREAD") args.defs = parse_optional_pair(args.defs) for name, _ in args.defs: if name in args.undefs: raise RuntimeError(f"{name} is both defined and undefined!") # Always set tracing to 0 if "ZSTD_NO_TRACE" not in (arg[0] for arg in args.defs): args.defs.append(("ZSTD_NO_TRACE", None)) args.defs.append(("ZSTD_TRACE", "0")) args.replaces = parse_pair(args.replaces) for name, _ in args.replaces: if name in args.undefs or name in args.defs: raise RuntimeError(f"{name} is both replaced and (un)defined!") args.rewritten_includes = parse_pair(args.rewritten_includes) external_xxhash = False if args.xxhash is not None: external_xxhash = True args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash)) if args.xxh64_prefix is not None: if not external_xxhash: raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided") if args.xxh64_state is not None: if not external_xxhash: raise RuntimeError("--xxh64-state may only be used with --xxhash provided") Freestanding( args.zstd_deps, args.mem, args.source_lib, args.output_lib, external_xxhash, args.xxh64_state, args.xxh64_prefix, args.rewritten_includes, args.defs, args.replaces, args.undefs, args.excludes, args.seds, args.spdx, ).go() if __name__ == "__main__": main(sys.argv[0], sys.argv[1:])