#include <array>      // std::size
#include <cassert>    // assert
#include <cstddef>    // std::size_t
#include <exception>  // std::exception
#include <iostream>   // std::cout/endl
#include <limits>     // std::numeric_limits
#include <new>        // std::bad_alloc
#include <stdexcept>  // std::invalid_argument/out_of_range
#include <string>     // std::string/getline
#include <ctype.h>    // toupper
#include <getopt.h>   // getopt_long
#include <limits.h>   // UINT_MAX/ULONG_MAX/ULLONG_MAX
#include <math.h>     // log/pow
#include <stdlib.h>   // strtoul/strtoull
#include <string.h>   // memset

class byte_suffix_converter {
public:
    static constexpr const char* suffixes[] = {
        "B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"};
    static constexpr auto suffix_count = std::size(suffixes);
    static constexpr auto scale = 1024U;

    template <typename Int>
    class output {
    public:
        output(Int value) : value_(value) {}
        friend std::ostream& operator<<(std::ostream& os, output out_bytes)
        {
            auto bytes = out_bytes.value_;
            if (bytes == 0) {
                os << "0 B";
            } else {
                auto exp = static_cast<Int>(log(bytes) / log(scale));
                if (exp >= suffix_count) {
                    throw std::out_of_range("Value is too big");
                }
                double converted_value =
                    static_cast<double>(bytes) / pow(scale, exp);
                os << converted_value << " " << suffixes[exp];
            }
            return os;
        }

    private:
        Int value_;
    };

    template <typename Int>
    static Int to(const char* input);

private:
    template <typename Int>
    static Int process_suffix(Int value, char suffix_ch)
    {
        if (suffix_ch == '\0') {
            return value;
        }
        Int multiplier = 1;
        for (const char* suffix : suffixes) {
            if (toupper(static_cast<unsigned char>(suffix_ch)) == *suffix) {
                if (value > std::numeric_limits<Int>::max() / multiplier) {
                    throw std::out_of_range("Value is too big");
                }
                return value * multiplier;
            }
            if (multiplier > std::numeric_limits<Int>::max() / scale) {
                break;
            }
            multiplier *= scale;
        }
        throw std::invalid_argument("Invalid suffix");
    }
};

template <>
unsigned int
byte_suffix_converter::to<unsigned int>(const char* input)
{
    char* ptr{};
    auto value = static_cast<unsigned int>(strtoul(input, &ptr, 10));
    if (ptr == input || value == UINT_MAX) {
        throw std::invalid_argument("Invalid value specified");
    }
    return process_suffix(value, *ptr);
}

template <>
unsigned long
byte_suffix_converter::to<unsigned long>(const char* input)
{
    char* ptr{};
    unsigned long value = strtoul(input, &ptr, 10);
    if (ptr == input || value == ULONG_MAX) {
        throw std::invalid_argument("Invalid value specified");
    }
    return process_suffix(value, *ptr);
}

template <>
unsigned long long
byte_suffix_converter::to<unsigned long long>(const char* input)
{
    char* ptr{};
    unsigned long long value = strtoull(input, &ptr, 10);
    if (ptr == input || value == ULLONG_MAX) {
        throw std::invalid_argument("Invalid value specified");
    }
    return process_suffix(value, *ptr);
}

constexpr std::size_t operator""_MB(unsigned long long value)
{
    constexpr auto mb = std::size_t(1024 * 1024);
    assert(value <= std::numeric_limits<std::size_t>::max() / mb);
    return value * mb;
}

void repeated_alloc(std::size_t chunk_size, bool zero_mem)
{
    std::size_t total_alloc = 0;
    for (;;) {
        char* ptr = new char[chunk_size];
        if (zero_mem) {
            memset(ptr, 7, chunk_size);
        }
        total_alloc += chunk_size;
        std::cout << "Allocated " << (zero_mem ? "and initialized " : "")
                  << byte_suffix_converter::output(total_alloc) << "\n";
    }
}

void usage(std::ostream& os, const char* progname)
{
    os << "Memory allocation test tool\n\n";
    os << "Usage: " << progname << " [options]...\n\n";
    os << "Options:\n";
    os << "  -h, --help     Show this usage help\n";
    os << "  -n, --no-zero  Skip the zeroing of allocated memory\n";
    os << "  -q, --quiet    Quit quietly when bad_alloc is caught\n";
    os << "  -s CHUNK_SIZE, --allocation-size=CHUNK_SIZE\n"
       << "                 Specify chunk size of each allocation (common suffixes are\n"
          "                 allowed; default value is 128M)\n";
    os << "\n";
}

int main(int argc, char* argv[])
{
    static option long_opts[] = {
        {"help", no_argument, nullptr, 'h'},
        {"no-zero", no_argument, nullptr, 'n'},
        {"quiet", no_argument, nullptr, 'q'},
        {"allocation-size", required_argument, nullptr, 's'},
    };
    bool zero_mem = true;
    bool quiet = false;
    auto chunk_size = 128_MB;
    int optch{};

    try {
        while ((optch = getopt_long(argc, argv, "hnqs:", long_opts,
                                    nullptr)) != -1) {
            switch (optch) {
            case 'h':
                usage(std::cout, argv[0]);
                exit(EXIT_SUCCESS);
                break;
            case 'n':
                zero_mem = false;
                break;
            case 'q':
                quiet = true;
                break;
            case 's':
                chunk_size = byte_suffix_converter::to<std::size_t>(optarg);
                break;
            default:
                usage(std::cerr, argv[0]);
                exit(EXIT_FAILURE);
            }
        }

        if (optind != argc) {
            throw std::invalid_argument("Extra arguments");
        }
        repeated_alloc(chunk_size, zero_mem);
    }
    catch (std::bad_alloc&) {
        std::cout << "Successfully caught bad_alloc exception\n";
        if (!quiet) {
            std::string line;
            std::cout << "Press ENTER to quit ";
            std::getline(std::cin, line);
        }
    }
    catch (std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        exit(EXIT_FAILURE);
    }
}