/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the * LICENSE file in the root directory of this source tree) and the GPLv2 (found * in the COPYING file in the root directory of this source tree). * You may select, at your option, one of the above-listed licenses. */ #include "method.h" #include #include #define ZSTD_STATIC_LINKING_ONLY #include #define MIN(x, y) ((x) < (y) ? (x) : (y)) static char const* g_zstdcli = NULL; void method_set_zstdcli(char const* zstdcli) { g_zstdcli = zstdcli; } /** * Macro to get a pointer of type, given ptr, which is a member variable with * the given name, member. * * method_state_t* base = ...; * buffer_state_t* state = container_of(base, buffer_state_t, base); */ #define container_of(ptr, type, member) \ ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member))) /** State to reuse the same buffers between compression calls. */ typedef struct { method_state_t base; data_buffers_t inputs; /**< The input buffer for each file. */ data_buffer_t dictionary; /**< The dictionary. */ data_buffer_t compressed; /**< The compressed data buffer. */ data_buffer_t decompressed; /**< The decompressed data buffer. */ } buffer_state_t; static size_t buffers_max_size(data_buffers_t buffers) { size_t max = 0; for (size_t i = 0; i < buffers.size; ++i) { if (buffers.buffers[i].size > max) max = buffers.buffers[i].size; } return max; } static method_state_t* buffer_state_create(data_t const* data) { buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t)); if (state == NULL) return NULL; state->base.data = data; state->inputs = data_buffers_get(data); state->dictionary = data_buffer_get_dict(data); size_t const max_size = buffers_max_size(state->inputs); state->compressed = data_buffer_create(ZSTD_compressBound(max_size)); state->decompressed = data_buffer_create(max_size); return &state->base; } static void buffer_state_destroy(method_state_t* base) { if (base == NULL) return; buffer_state_t* state = container_of(base, buffer_state_t, base); free(state); } static int buffer_state_bad( buffer_state_t const* state, config_t const* config) { if (state == NULL) { fprintf(stderr, "buffer_state_t is NULL\n"); return 1; } if (state->inputs.size == 0 || state->compressed.data == NULL || state->decompressed.data == NULL) { fprintf(stderr, "buffer state allocation failure\n"); return 1; } if (config->use_dictionary && state->dictionary.data == NULL) { fprintf(stderr, "dictionary loading failed\n"); return 1; } return 0; } static result_t simple_compress(method_state_t* base, config_t const* config) { buffer_state_t* state = container_of(base, buffer_state_t, base); if (buffer_state_bad(state, config)) return result_error(result_error_system_error); /* Keep the tests short by skipping directories, since behavior shouldn't * change. */ if (base->data->type != data_type_file) return result_error(result_error_skip); if (config->advanced_api_only) return result_error(result_error_skip); if (config->use_dictionary || config->no_pledged_src_size) return result_error(result_error_skip); /* If the config doesn't specify a level, skip. */ int const level = config_get_level(config); if (level == CONFIG_NO_LEVEL) return result_error(result_error_skip); data_buffer_t const input = state->inputs.buffers[0]; /* Compress, decompress, and check the result. */ state->compressed.size = ZSTD_compress( state->compressed.data, state->compressed.capacity, input.data, input.size, level); if (ZSTD_isError(state->compressed.size)) return result_error(result_error_compression_error); state->decompressed.size = ZSTD_decompress( state->decompressed.data, state->decompressed.capacity, state->compressed.data, state->compressed.size); if (ZSTD_isError(state->decompressed.size)) return result_error(result_error_decompression_error); if (data_buffer_compare(input, state->decompressed)) return result_error(result_error_round_trip_error); result_data_t data; data.total_size = state->compressed.size; return result_data(data); } static result_t compress_cctx_compress( method_state_t* base, config_t const* config) { buffer_state_t* state = container_of(base, buffer_state_t, base); if (buffer_state_bad(state, config)) return result_error(result_error_system_error); if (config->no_pledged_src_size) return result_error(result_error_skip); if (base->data->type != data_type_dir) return result_error(result_error_skip); if (config->advanced_api_only) return result_error(result_error_skip); int const level = config_get_level(config); ZSTD_CCtx* cctx = ZSTD_createCCtx(); ZSTD_DCtx* dctx = ZSTD_createDCtx(); if (cctx == NULL || dctx == NULL) { fprintf(stderr, "context creation failed\n"); return result_error(result_error_system_error); } result_t result; result_data_t data = {.total_size = 0}; for (size_t i = 0; i < state->inputs.size; ++i) { data_buffer_t const input = state->inputs.buffers[i]; ZSTD_parameters const params = config_get_zstd_params(config, input.size, state->dictionary.size); if (level == CONFIG_NO_LEVEL) state->compressed.size = ZSTD_compress_advanced( cctx, state->compressed.data, state->compressed.capacity, input.data, input.size, config->use_dictionary ? state->dictionary.data : NULL, config->use_dictionary ? state->dictionary.size : 0, params); else if (config->use_dictionary) state->compressed.size = ZSTD_compress_usingDict( cctx, state->compressed.data, state->compressed.capacity, input.data, input.size, state->dictionary.data, state->dictionary.size, level); else state->compressed.size = ZSTD_compressCCtx( cctx, state->compressed.data, state->compressed.capacity, input.data, input.size, level); if (ZSTD_isError(state->compressed.size)) { result = result_error(result_error_compression_error); goto out; } if (config->use_dictionary) state->decompressed.size = ZSTD_decompress_usingDict( dctx, state->decompressed.data, state->decompressed.capacity, state->compressed.data, state->compressed.size, state->dictionary.data, state->dictionary.size); else state->decompressed.size = ZSTD_decompressDCtx( dctx, state->decompressed.data, state->decompressed.capacity, state->compressed.data, state->compressed.size); if (ZSTD_isError(state->decompressed.size)) { result = result_error(result_error_decompression_error); goto out; } if (data_buffer_compare(input, state->decompressed)) { result = result_error(result_error_round_trip_error); goto out; } data.total_size += state->compressed.size; } result = result_data(data); out: ZSTD_freeCCtx(cctx); ZSTD_freeDCtx(dctx); return result; } /** Generic state creation function. */ static method_state_t* method_state_create(data_t const* data) { method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t)); if (state == NULL) return NULL; state->data = data; return state; } static void method_state_destroy(method_state_t* state) { free(state); } static result_t cli_compress(method_state_t* state, config_t const* config) { if (config->cli_args == NULL) return result_error(result_error_skip); if (config->advanced_api_only) return result_error(result_error_skip); /* We don't support no pledged source size with directories. Too slow. */ if (state->data->type == data_type_dir && config->no_pledged_src_size) return result_error(result_error_skip); if (g_zstdcli == NULL) return result_error(result_error_system_error); /* '' -cqr [-D ''] '' */ char cmd[1024]; size_t const cmd_size = snprintf( cmd, sizeof(cmd), "'%s' -cqr %s %s%s%s %s '%s'", g_zstdcli, config->cli_args, config->use_dictionary ? "-D '" : "", config->use_dictionary ? state->data->dict.path : "", config->use_dictionary ? "'" : "", config->no_pledged_src_size ? "<" : "", state->data->data.path); if (cmd_size >= sizeof(cmd)) { fprintf(stderr, "command too large: %s\n", cmd); return result_error(result_error_system_error); } FILE* zstd = popen(cmd, "r"); if (zstd == NULL) { fprintf(stderr, "failed to popen command: %s\n", cmd); return result_error(result_error_system_error); } char out[4096]; size_t total_size = 0; while (1) { size_t const size = fread(out, 1, sizeof(out), zstd); total_size += size; if (size != sizeof(out)) break; } if (ferror(zstd) || pclose(zstd) != 0) { fprintf(stderr, "zstd failed with command: %s\n", cmd); return result_error(result_error_compression_error); } result_data_t const data = {.total_size = total_size}; return result_data(data); } static int advanced_config( ZSTD_CCtx* cctx, buffer_state_t* state, config_t const* config) { ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters); for (size_t p = 0; p < config->param_values.size; ++p) { param_value_t const pv = config->param_values.data[p]; if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) { return 1; } } if (config->use_dictionary) { if (ZSTD_isError(ZSTD_CCtx_loadDictionary( cctx, state->dictionary.data, state->dictionary.size))) { return 1; } } return 0; } static result_t advanced_one_pass_compress_output_adjustment( method_state_t* base, config_t const* config, size_t const subtract) { buffer_state_t* state = container_of(base, buffer_state_t, base); if (buffer_state_bad(state, config)) return result_error(result_error_system_error); ZSTD_CCtx* cctx = ZSTD_createCCtx(); result_t result; if (!cctx || advanced_config(cctx, state, config)) { result = result_error(result_error_compression_error); goto out; } result_data_t data = {.total_size = 0}; for (size_t i = 0; i < state->inputs.size; ++i) { data_buffer_t const input = state->inputs.buffers[i]; if (!config->no_pledged_src_size) { if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) { result = result_error(result_error_compression_error); goto out; } } size_t const size = ZSTD_compress2( cctx, state->compressed.data, ZSTD_compressBound(input.size) - subtract, input.data, input.size); if (ZSTD_isError(size)) { result = result_error(result_error_compression_error); goto out; } data.total_size += size; } result = result_data(data); out: ZSTD_freeCCtx(cctx); return result; } static result_t advanced_one_pass_compress( method_state_t* base, config_t const* config) { return advanced_one_pass_compress_output_adjustment(base, config, 0); } static result_t advanced_one_pass_compress_small_output( method_state_t* base, config_t const* config) { return advanced_one_pass_compress_output_adjustment(base, config, 1); } static result_t advanced_streaming_compress( method_state_t* base, config_t const* config) { buffer_state_t* state = container_of(base, buffer_state_t, base); if (buffer_state_bad(state, config)) return result_error(result_error_system_error); ZSTD_CCtx* cctx = ZSTD_createCCtx(); result_t result; if (!cctx || advanced_config(cctx, state, config)) { result = result_error(result_error_compression_error); goto out; } result_data_t data = {.total_size = 0}; for (size_t i = 0; i < state->inputs.size; ++i) { data_buffer_t input = state->inputs.buffers[i]; if (!config->no_pledged_src_size) { if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) { result = result_error(result_error_compression_error); goto out; } } while (input.size > 0) { ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)}; input.data += in.size; input.size -= in.size; ZSTD_EndDirective const op = input.size > 0 ? ZSTD_e_continue : ZSTD_e_end; size_t ret = 0; while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) { ZSTD_outBuffer out = {state->compressed.data, MIN(state->compressed.capacity, 1024)}; ret = ZSTD_compressStream2(cctx, &out, &in, op); if (ZSTD_isError(ret)) { result = result_error(result_error_compression_error); goto out; } data.total_size += out.pos; } } } result = result_data(data); out: ZSTD_freeCCtx(cctx); return result; } static int init_cstream( buffer_state_t* state, ZSTD_CStream* zcs, config_t const* config, int const advanced, ZSTD_CDict** cdict) { size_t zret; if (advanced) { ZSTD_parameters const params = config_get_zstd_params(config, 0, 0); ZSTD_CDict* dict = NULL; if (cdict) { if (!config->use_dictionary) return 1; *cdict = ZSTD_createCDict_advanced( state->dictionary.data, state->dictionary.size, ZSTD_dlm_byRef, ZSTD_dct_auto, params.cParams, ZSTD_defaultCMem); if (!*cdict) { return 1; } zret = ZSTD_initCStream_usingCDict_advanced( zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN); } else { zret = ZSTD_initCStream_advanced( zcs, config->use_dictionary ? state->dictionary.data : NULL, config->use_dictionary ? state->dictionary.size : 0, params, ZSTD_CONTENTSIZE_UNKNOWN); } } else { int const level = config_get_level(config); if (level == CONFIG_NO_LEVEL) return 1; if (cdict) { if (!config->use_dictionary) return 1; *cdict = ZSTD_createCDict( state->dictionary.data, state->dictionary.size, level); if (!*cdict) { return 1; } zret = ZSTD_initCStream_usingCDict(zcs, *cdict); } else if (config->use_dictionary) { zret = ZSTD_initCStream_usingDict( zcs, state->dictionary.data, state->dictionary.size, level); } else { zret = ZSTD_initCStream(zcs, level); } } if (ZSTD_isError(zret)) { return 1; } return 0; } static result_t old_streaming_compress_internal( method_state_t* base, config_t const* config, int const advanced, int const cdict) { buffer_state_t* state = container_of(base, buffer_state_t, base); if (buffer_state_bad(state, config)) return result_error(result_error_system_error); ZSTD_CStream* zcs = ZSTD_createCStream(); ZSTD_CDict* cd = NULL; result_t result; if (zcs == NULL) { result = result_error(result_error_compression_error); goto out; } if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) { result = result_error(result_error_skip); goto out; } if (cdict && !config->use_dictionary) { result = result_error(result_error_skip); goto out; } if (config->advanced_api_only) { result = result_error(result_error_skip); goto out; } if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) { result = result_error(result_error_compression_error); goto out; } result_data_t data = {.total_size = 0}; for (size_t i = 0; i < state->inputs.size; ++i) { data_buffer_t input = state->inputs.buffers[i]; size_t zret = ZSTD_resetCStream( zcs, config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size); if (ZSTD_isError(zret)) { result = result_error(result_error_compression_error); goto out; } while (input.size > 0) { ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)}; input.data += in.size; input.size -= in.size; ZSTD_EndDirective const op = input.size > 0 ? ZSTD_e_continue : ZSTD_e_end; zret = 0; while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) { ZSTD_outBuffer out = {state->compressed.data, MIN(state->compressed.capacity, 1024)}; if (op == ZSTD_e_continue || in.pos < in.size) zret = ZSTD_compressStream(zcs, &out, &in); else zret = ZSTD_endStream(zcs, &out); if (ZSTD_isError(zret)) { result = result_error(result_error_compression_error); goto out; } data.total_size += out.pos; } } } result = result_data(data); out: ZSTD_freeCStream(zcs); ZSTD_freeCDict(cd); return result; } static result_t old_streaming_compress( method_state_t* base, config_t const* config) { return old_streaming_compress_internal( base, config, /* advanced */ 0, /* cdict */ 0); } static result_t old_streaming_compress_advanced( method_state_t* base, config_t const* config) { return old_streaming_compress_internal( base, config, /* advanced */ 1, /* cdict */ 0); } static result_t old_streaming_compress_cdict( method_state_t* base, config_t const* config) { return old_streaming_compress_internal( base, config, /* advanced */ 0, /* cdict */ 1); } static result_t old_streaming_compress_cdict_advanced( method_state_t* base, config_t const* config) { return old_streaming_compress_internal( base, config, /* advanced */ 1, /* cdict */ 1); } method_t const simple = { .name = "compress simple", .create = buffer_state_create, .compress = simple_compress, .destroy = buffer_state_destroy, }; method_t const compress_cctx = { .name = "compress cctx", .create = buffer_state_create, .compress = compress_cctx_compress, .destroy = buffer_state_destroy, }; method_t const advanced_one_pass = { .name = "advanced one pass", .create = buffer_state_create, .compress = advanced_one_pass_compress, .destroy = buffer_state_destroy, }; method_t const advanced_one_pass_small_out = { .name = "advanced one pass small out", .create = buffer_state_create, .compress = advanced_one_pass_compress, .destroy = buffer_state_destroy, }; method_t const advanced_streaming = { .name = "advanced streaming", .create = buffer_state_create, .compress = advanced_streaming_compress, .destroy = buffer_state_destroy, }; method_t const old_streaming = { .name = "old streaming", .create = buffer_state_create, .compress = old_streaming_compress, .destroy = buffer_state_destroy, }; method_t const old_streaming_advanced = { .name = "old streaming advanced", .create = buffer_state_create, .compress = old_streaming_compress_advanced, .destroy = buffer_state_destroy, }; method_t const old_streaming_cdict = { .name = "old streaming cdict", .create = buffer_state_create, .compress = old_streaming_compress_cdict, .destroy = buffer_state_destroy, }; method_t const old_streaming_advanced_cdict = { .name = "old streaming advanced cdict", .create = buffer_state_create, .compress = old_streaming_compress_cdict_advanced, .destroy = buffer_state_destroy, }; method_t const cli = { .name = "zstdcli", .create = method_state_create, .compress = cli_compress, .destroy = method_state_destroy, }; static method_t const* g_methods[] = { &simple, &compress_cctx, &cli, &advanced_one_pass, &advanced_one_pass_small_out, &advanced_streaming, &old_streaming, &old_streaming_advanced, &old_streaming_cdict, &old_streaming_advanced_cdict, NULL, }; method_t const* const* methods = g_methods;