/*
 * Copyright (c) Yann Collet, Meta Platforms, Inc.
 * All rights reserved.
 *
 * This source code is licensed under both the BSD-style license (found in the
 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
 * in the COPYING file in the root directory of this source tree).
 * You may select, at your option, one of the above-listed licenses.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "zstd_errors.h"
#include "sequence_producer.h" // simpleSequenceProducer

#define CHECK(res)                                      \
do {                                                    \
    if (ZSTD_isError(res)) {                            \
        printf("ERROR: %s\n", ZSTD_getErrorName(res));  \
        return 1;                                       \
    }                                                   \
} while (0)                                             \

int main(int argc, char *argv[]) {
    if (argc != 2) {
        printf("Usage: externalSequenceProducer <file>\n");
        return 1;
    }

    ZSTD_CCtx* const zc = ZSTD_createCCtx();

    int simpleSequenceProducerState = 0xdeadbeef;

    // Here is the crucial bit of code!
    ZSTD_registerSequenceProducer(
        zc,
        &simpleSequenceProducerState,
        simpleSequenceProducer
    );

    {
        size_t const res = ZSTD_CCtx_setParameter(zc, ZSTD_c_enableSeqProducerFallback, 1);
        CHECK(res);
    }

    FILE *f = fopen(argv[1], "rb");
    assert(f);
    {
        int const ret = fseek(f, 0, SEEK_END);
        assert(ret == 0);
    }
    size_t const srcSize = ftell(f);
    {
        int const ret = fseek(f, 0, SEEK_SET);
        assert(ret == 0);
    }

    char* const src = malloc(srcSize + 1);
    assert(src);
    {
        size_t const ret = fread(src, srcSize, 1, f);
        assert(ret == 1);
        int const ret2 = fclose(f);
        assert(ret2 == 0);
    }

    size_t const dstSize = ZSTD_compressBound(srcSize);
    char* const dst = malloc(dstSize);
    assert(dst);

    size_t const cSize = ZSTD_compress2(zc, dst, dstSize, src, srcSize);
    CHECK(cSize);

    char* const val = malloc(srcSize);
    assert(val);

    {
        size_t const res = ZSTD_decompress(val, srcSize, dst, cSize);
        CHECK(res);
    }

    if (memcmp(src, val, srcSize) == 0) {
        printf("Compression and decompression were successful!\n");
        printf("Original size: %lu\n", srcSize);
        printf("Compressed size: %lu\n", cSize);
    } else {
        printf("ERROR: input and validation buffers don't match!\n");
        for (size_t i = 0; i < srcSize; i++) {
            if (src[i] != val[i]) {
                printf("First bad index: %zu\n", i);
                break;
            }
        }
        return 1;
    }

    ZSTD_freeCCtx(zc);
    free(src);
    free(dst);
    free(val);
    return 0;
}