#include "tree.h"

#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <errno.h>
#include <inttypes.h>
#include <netdb.h>
#include <stdbool.h>
#include <stdio.h>

#define LOCAL

#ifdef __GNUC__
#  define UNUSED(x) UNUSED_ ## x __attribute__((__unused__))
#else
#  define UNUSED(x) UNUSED_ ## x
#endif

/* This is also defined in MaxMind::DB::Common but we don't want to have to
 * fetch it every time we need it. */
#define DATA_SECTION_SEPARATOR_SIZE (16)

#define SHA1_KEY_LENGTH (27)

typedef struct freeze_args_s {
    int fd;
    char *filename;
    HV *data_hash;
} freeze_args_s;

typedef struct thawed_network_s {
    MMDBW_network_s *network;
    MMDBW_record_s *record;
} thawed_network_s;

typedef struct encode_args_s {
    PerlIO *output_io;
    SV *root_data_type;
    SV *serializer;
    HV *data_pointer_cache;
} encode_args_s;

/* *INDENT-OFF* */
/* --prototypes automatically generated by dev-bin/regen-prototypes.pl - don't remove this comment */
LOCAL void verify_ip(MMDBW_tree_s *tree, const char *ipstr);
LOCAL int128_t ip_string_to_integer(const char *ipstr, int family);
LOCAL int128_t ip_bytes_to_integer(uint8_t *bytes, int family);
LOCAL void integer_to_ip_bytes(int tree_ip_version, uint128_t ip,
                               uint8_t *bytes);
LOCAL void integer_to_ip_string(int tree_ip_version, uint128_t ip,
                                char *dst, int dst_length);
LOCAL int prefix_length_for_largest_subnet(uint128_t start_ip,
                                           uint128_t end_ip, int family,
                                           uint128_t *reverse_mask);
LOCAL const char *store_data_in_tree(MMDBW_tree_s *tree,
                                     const char *const key,
                                     SV *data_sv);
LOCAL const char *increment_data_reference_count(MMDBW_tree_s *tree,
                                                 const char *const key);
LOCAL void set_stored_data_in_tree(MMDBW_tree_s *tree,
                                   const char *const key,
                                   SV *data_sv);
LOCAL void decrement_data_reference_count(MMDBW_tree_s *tree,
                                          const char *const key);
LOCAL MMDBW_network_s resolve_network(MMDBW_tree_s *tree,
                                      const char *const ipstr,
                                      uint8_t prefix_length);
LOCAL void resolve_ip(int tree_ip_version, const char *const ipstr,
                      uint8_t *bytes);
LOCAL void free_network(MMDBW_network_s *network);
LOCAL void alias_ipv4_networks(MMDBW_tree_s *tree);
LOCAL MMDBW_status insert_record_for_network(MMDBW_tree_s *tree,
                                             MMDBW_network_s *network,
                                             MMDBW_record_s *new_record,
                                             bool merge_record_collisions,
                                             bool is_internal_insert);
LOCAL bool merge_records(MMDBW_tree_s *tree,
                         MMDBW_network_s *network,
                         MMDBW_record_s *new_record,
                         MMDBW_record_s *record_to_set);
LOCAL int network_bit_value(MMDBW_tree_s *tree, MMDBW_network_s *network,
                            uint8_t current_bit);
LOCAL int tree_depth0(MMDBW_tree_s *tree);
LOCAL SV * merge_hashes(MMDBW_tree_s *tree, SV *from, SV *into);
LOCAL void merge_new_from_hash_into_hash(MMDBW_tree_s *tree, HV *from, HV *to);
LOCAL SV * merge_values(MMDBW_tree_s *tree, SV *from, SV *into);
LOCAL SV * merge_arrays(MMDBW_tree_s *tree, SV *from, SV *into);
LOCAL MMDBW_status find_record_for_network(MMDBW_tree_s *tree,
                                           MMDBW_network_s *network,
                                           bool follow_aliases,
                                           MMDBW_node_s *(if_not_node)(
                                               MMDBW_tree_s *tree,
                                               MMDBW_record_s *record),
                                           MMDBW_record_s **record,
                                           MMDBW_record_s **sibling_record
                                           );
LOCAL MMDBW_node_s *return_null(
    MMDBW_tree_s *UNUSED(tree), MMDBW_record_s *UNUSED(record));
LOCAL MMDBW_node_s *new_node_from_record(MMDBW_tree_s *tree,
                                         MMDBW_record_s *record);
LOCAL void free_node_and_subnodes(MMDBW_tree_s *tree, MMDBW_node_s *node);
LOCAL void free_record_value(MMDBW_tree_s *tree, MMDBW_record_s *record);
LOCAL void assign_node_number(MMDBW_tree_s *tree, MMDBW_node_s *node,
                              uint128_t UNUSED(network),
                              uint8_t UNUSED(depth), void *UNUSED(args));
LOCAL void freeze_search_tree(MMDBW_tree_s *tree, freeze_args_s *args);
LOCAL void freeze_node(MMDBW_tree_s *tree, MMDBW_node_s *node,
                       uint128_t network, uint8_t depth, void *void_args);
LOCAL void freeze_data_record(MMDBW_tree_s *UNUSED(tree),
                              uint128_t network, uint8_t depth,
                              const char *key,
                              freeze_args_s *args);
LOCAL void freeze_to_fd(freeze_args_s *args, void *data, size_t size);
LOCAL void freeze_data_to_fd(int fd, MMDBW_tree_s *tree);
LOCAL SV *freeze_hash(HV *hash);
LOCAL uint8_t thaw_uint8(uint8_t **buffer);
LOCAL uint32_t thaw_uint32(uint8_t **buffer);
LOCAL thawed_network_s *thaw_network(MMDBW_tree_s *tree, uint8_t **buffer);
LOCAL uint8_t *thaw_bytes(uint8_t **buffer, size_t size);
LOCAL uint128_t thaw_uint128(uint8_t **buffer);
LOCAL STRLEN thaw_strlen(uint8_t **buffer);
LOCAL const char *thaw_data_key(uint8_t **buffer);
LOCAL HV *thaw_data_hash(SV *data_to_decode);
LOCAL void encode_node(MMDBW_tree_s *tree, MMDBW_node_s *node,
                       uint128_t UNUSED(network),
                       uint8_t UNUSED(depth), void *void_args);
LOCAL void check_record_sanity(MMDBW_node_s *node, MMDBW_record_s *record,
                               char *side);
LOCAL uint32_t record_value_as_number(MMDBW_tree_s *tree,
                                      MMDBW_record_s *record,
                                      encode_args_s * args);
LOCAL void iterate_tree(MMDBW_tree_s *tree,
                        MMDBW_record_s *record,
                        uint128_t network,
                        const uint8_t depth,
                        bool depth_first,
                        void *args,
                        MMDBW_iterator_callback callback);
LOCAL SV *key_for_data(SV * data);
LOCAL void dwarn(SV *thing);
LOCAL void *checked_malloc(size_t size);
LOCAL void checked_write(int fd, char *filename, void *buffer,
                         ssize_t count);
LOCAL void checked_perlio_read(PerlIO * io, void *buffer,
                               SSize_t size);
LOCAL void check_perlio_result(SSize_t result, SSize_t expected,
                               char *op);
LOCAL char *status_error_message(MMDBW_status status);
/* --prototypes end - don't remove this comment-- */
/* *INDENT-ON* */

MMDBW_tree_s *new_tree(const uint8_t ip_version, uint8_t record_size,
                       MMDBW_merge_strategy merge_strategy,
                       const bool alias_ipv6)
{
    MMDBW_tree_s *tree = checked_malloc(sizeof(MMDBW_tree_s));

    if (ip_version != 4 && ip_version != 6) {
        croak("Unexpected IP version of %u", ip_version);
    }
    tree->ip_version = ip_version;

    if (record_size != 24 && record_size != 28 && record_size != 32) {
        croak("Only record sizes of 24, 28, and 32 are supported. Received %u.",
              record_size);
    }
    tree->record_size = record_size;
    tree->merge_strategy = merge_strategy;
    tree->data_table = NULL;
    tree->is_aliased = false;
    tree->root_record = (MMDBW_record_s) {
        .type = MMDBW_RECORD_TYPE_EMPTY,
    };
    tree->node_count = 0;

    if (alias_ipv6) {
        alias_ipv4_networks(tree);
    }

    return tree;
}

void insert_network(MMDBW_tree_s *tree, const char *ipstr,
                    const uint8_t prefix_length, SV *key_sv, SV *data,
                    bool force_overwrite)
{
    verify_ip(tree, ipstr);

    MMDBW_network_s network = resolve_network(tree, ipstr, prefix_length);

    const char *const key =
        store_data_in_tree(tree, SvPVbyte_nolen(key_sv), data);
    MMDBW_record_s new_record = {
        .type    = MMDBW_RECORD_TYPE_DATA,
        .value   = {
            .key = key
        }
    };

    MMDBW_status status = insert_record_for_network(
        tree, &network, &new_record,
        tree->merge_strategy !=
        MMDBW_MERGE_STRATEGY_NONE
        && !force_overwrite,
        false);

    free_network(&network);

    if (MMDBW_SUCCESS != status) {
        croak("%s (%s)", status_error_message(status), ipstr);
    }
}

LOCAL void verify_ip(MMDBW_tree_s *tree, const char *ipstr)
{
    if (tree->ip_version == 4 && strchr(ipstr, ':')) {
        croak("You cannot insert an IPv6 address (%s) into an IPv4 tree.",
              ipstr);
    }
}

void insert_range(MMDBW_tree_s *tree, const char *start_ipstr,
                  const char *end_ipstr, SV *key_sv, SV *data_sv,
                  bool force_overwrite)
{
    verify_ip(tree, start_ipstr);
    verify_ip(tree, end_ipstr);

    uint128_t start_ip = ip_string_to_integer(start_ipstr, tree->ip_version);
    uint128_t end_ip = ip_string_to_integer(end_ipstr, tree->ip_version);

    if (end_ip < start_ip) {
        croak("First IP (%s) in range comes before last IP (%s)", start_ipstr,
              end_ipstr);
    }

    const char *const key =
        store_data_in_tree(tree, SvPVbyte_nolen(key_sv), data_sv);

    uint8_t bytes[tree->ip_version == 6 ? 16 : 4];

    MMDBW_status status = MMDBW_SUCCESS;

    // Eventually we could change the code to walk the tree and break up the
    // range at the same time, saving some unnecessary computation. However,
    // that would require more significant refactoring of the insertion and
    // merging code.
    while (start_ip <= end_ip) {
        uint128_t reverse_mask;
        int prefix_length =
            prefix_length_for_largest_subnet(start_ip, end_ip, tree->ip_version,
                                             &reverse_mask);

        integer_to_ip_bytes(tree->ip_version, start_ip, bytes);

        MMDBW_network_s network = {
            .bytes         = bytes,
            .prefix_length = prefix_length,
        };

        const char *const new_key = increment_data_reference_count(tree, key);

        MMDBW_record_s new_record = {
            .type    = MMDBW_RECORD_TYPE_DATA,
            .value   = {
                .key = new_key
            }
        };

        status = insert_record_for_network(
            tree, &network, &new_record,
            tree->merge_strategy !=
            MMDBW_MERGE_STRATEGY_NONE
            && !force_overwrite,
            false);
        if (MMDBW_SUCCESS != status) {
            break;
        }

        start_ip = (start_ip | reverse_mask) + 1;

        // The +1 caused an overflow and we are done.
        if (start_ip == 0) {
            break;
        }
    }
    // store_data_in_tree starts at a reference count of 1, so we need to
    // decrement in order to account for that.
    decrement_data_reference_count(tree, key);

    if (MMDBW_SUCCESS != status) {
        croak("%s (%s - %s)", status_error_message(status), start_ipstr,
              end_ipstr);
    }
}

LOCAL int128_t ip_string_to_integer(const char *ipstr, int family)
{
    uint8_t bytes[family == 6 ? 16 : 4];
    resolve_ip(family, ipstr, bytes);
    return ip_bytes_to_integer(bytes, family);
}

LOCAL int128_t ip_bytes_to_integer(uint8_t *bytes, int family)
{
    int length = family == 6 ? 16 : 4;

    int128_t ipint = 0;
    for (int i = 0; i < length; i++) {
        ipint = (ipint << 8) | bytes[i];
    }
    return ipint;
}

LOCAL void integer_to_ip_bytes(int tree_ip_version, uint128_t ip,
                               uint8_t *bytes)
{
    int bytes_length = tree_ip_version == 6 ? 16 : 4;
    for (int i = 1; i <= bytes_length; i++) {
        bytes[bytes_length - i] = 0xFF & ip;
        ip >>= 8;
    }
}

LOCAL void integer_to_ip_string(int tree_ip_version, uint128_t ip,
                                char *dst, int dst_length)
{
    uint8_t bytes[tree_ip_version == 6 ? 16 : 4];
    integer_to_ip_bytes(tree_ip_version, ip, bytes);

    if (NULL == inet_ntop(tree_ip_version == 6 ? AF_INET6 : AF_INET,
                          bytes, dst, dst_length) ) {
        croak("Error converting IP integer to string");
    }
}

LOCAL int prefix_length_for_largest_subnet(uint128_t start_ip,
                                           uint128_t end_ip, int family,
                                           uint128_t *reverse_mask)
{

    if (start_ip > end_ip) {
        croak("Start IP of the range must be less than or equal to end IP");
    }

    int prefix_length = family == 6 ? 128 : 32;
    *reverse_mask = 1;

    while (
        // First IP of the subnet must be the start IP
        (start_ip & ~*reverse_mask) == start_ip
        // the last IP of the subnet must be <= the end IP
        && (start_ip | *reverse_mask) <= end_ip
        // stop if we have all IPs (shouldn't be required, but safety measure)
        && prefix_length > 0
        ) {
        prefix_length--;
        *reverse_mask = (*reverse_mask << 1) | 1;
    }

    // We overshoot by one shift
    *reverse_mask >>= 1;

    return prefix_length;
}

void remove_network(MMDBW_tree_s *tree, const char *ipstr,
                    const uint8_t prefix_length)
{
    verify_ip(tree, ipstr);

    MMDBW_network_s network = resolve_network(tree, ipstr, prefix_length);

    MMDBW_record_s new_record = {
        .type = MMDBW_RECORD_TYPE_EMPTY
    };

    MMDBW_status status =
        insert_record_for_network(tree, &network, &new_record, false,
                                  false);

    free_network(&network);
    if (MMDBW_SUCCESS != status) {
        croak(status_error_message(status));
    }
}

LOCAL const char *store_data_in_tree(MMDBW_tree_s *tree,
                                     const char *const key,
                                     SV *data_sv)
{
    const char *const new_key = increment_data_reference_count(tree, key);
    set_stored_data_in_tree(tree, key, data_sv);

    return new_key;
}

LOCAL const char *increment_data_reference_count(MMDBW_tree_s *tree,
                                                 const char *const key)
{
    MMDBW_data_hash_s *data = NULL;
    HASH_FIND(hh, tree->data_table, key, SHA1_KEY_LENGTH, data);

    /* We allow this possibility as we need to create the record separately
       from updating the data when thawing */
    if (NULL == data) {
        data = checked_malloc(sizeof(MMDBW_data_hash_s));
        data->reference_count = 0;

        data->data_sv = NULL;

        data->key = checked_malloc(SHA1_KEY_LENGTH + 1);
        strcpy((char *)data->key, key);

        HASH_ADD_KEYPTR(hh, tree->data_table, data->key, SHA1_KEY_LENGTH, data);
    }
    data->reference_count++;

    return data->key;
}


LOCAL void set_stored_data_in_tree(MMDBW_tree_s *tree,
                                   const char *const key,
                                   SV *data_sv)
{
    MMDBW_data_hash_s *data = NULL;
    HASH_FIND(hh, tree->data_table, key, SHA1_KEY_LENGTH, data);

    if (NULL == data) {
        croak("Attempt to set unknown data record in tree");
    }

    if (NULL != data->data_sv) {
        return;
    }

    SvREFCNT_inc_simple_void_NN(data_sv);
    data->data_sv = data_sv;
}

LOCAL void decrement_data_reference_count(MMDBW_tree_s *tree,
                                          const char *const key)
{
    MMDBW_data_hash_s *data = NULL;
    HASH_FIND(hh, tree->data_table, key, SHA1_KEY_LENGTH, data);

    if (NULL == data) {
        croak("Attempt to remove data that does not exist from tree");
    }

    data->reference_count--;
    if (0 == data->reference_count) {
        HASH_DEL(tree->data_table, data);
        SvREFCNT_dec(data->data_sv);
        free((char *)data->key);
        free(data);
    }
}

LOCAL MMDBW_network_s resolve_network(MMDBW_tree_s *tree,
                                      const char *const ipstr,
                                      uint8_t prefix_length)
{
    uint8_t *bytes = checked_malloc(tree->ip_version == 6 ? 16 : 4);

    resolve_ip(tree->ip_version, ipstr, bytes);

    if (NULL == strchr(ipstr, ':')) {
        // IPv4
        if (prefix_length > 32) {
            free(bytes);
            croak("Prefix length greater than 32 on an IPv4 network (%s/%d)",
                  ipstr, prefix_length);
        }
        if (tree->ip_version == 6) {
            // Inserting IPv4 network into an IPv6 tree
            prefix_length += 96;
        }
    } else if (prefix_length > 128) {
        free(bytes);
        croak("Prefix length greater than 128 on an IPv6 network (%s/%d)",
              ipstr, prefix_length);
    }

    MMDBW_network_s network = {
        .bytes         = bytes,
        .prefix_length = prefix_length,
    };

    return network;
}

LOCAL void resolve_ip(int tree_ip_version, const char *const ipstr,
                      uint8_t *bytes)
{
    bool is_ipv4_address = NULL == strchr(ipstr, ':');
    int family = is_ipv4_address ? AF_INET : AF_INET6;
    if (tree_ip_version == 6 && is_ipv4_address) {
        // We are inserting/looking up an IPv4 address in an IPv6 tree.
        // The canonical location for this in our database is ::a.b.c.d.
        // To get this address, we zero out the first 12 bytes of bytes
        // and then put the IPv4 address in the remaining 4. The reason to
        // not use getaddrinfo with AI_V4MAPPED is that it gives us
        // ::FFFF:a.b.c.d and AI_V4MAPPED doesn't work on all platforms.
        // See GitHub #7 and #51.
        memset(bytes, 0, 12);
        bytes += 12;
    }
    if (!inet_pton(family, ipstr, bytes)) {
        croak("Invalid IP address: %s", ipstr);
    }
}

LOCAL void free_network(MMDBW_network_s *network)
{
    free((char *)network->bytes);
}

struct network {
    const char *const ipstr;
    const uint8_t prefix_length;
};

static struct network ipv4_aliases[] = {
    {
        .ipstr = "::ffff:0:0",
        .prefix_length = 96
    },
    {
        .ipstr = "2001::",
        .prefix_length = 32
    },
    {
        .ipstr = "2002::",
        .prefix_length = 16
    }
};

LOCAL void alias_ipv4_networks(MMDBW_tree_s *tree)
{
    if (tree->ip_version == 4) {
        return;
    }
    if (tree->is_aliased) {
        return;
    }

    MMDBW_network_s ipv4_root_network = resolve_network(tree, "::0.0.0.0", 96);

    // We create an empty record for the aliases to point to initially. We do
    // not simply use /96, as all of the alias code requires that the node be
    // aliased to a MMDBW_RECORD_TYPE_NODE node. This is gross and confusing
    // and should be fixed at some point.
    remove_network(tree, "::0.0.0.0", 97);

    MMDBW_record_s *ipv4_root_record;
    MMDBW_status status = find_record_for_network(tree, &ipv4_root_network,
                                                  false,
                                                  &return_null,
                                                  &ipv4_root_record,
                                                  NULL);
    free_network(&ipv4_root_network);
    if (status != MMDBW_SUCCESS) {
        croak("Unable to find IPv4 root node when setting up aliases");
    }

    if (MMDBW_RECORD_TYPE_NODE != ipv4_root_record->type) {
        croak("Unexpected type for IPv4 root record: %s",
              record_type_name(ipv4_root_record->type));
    }

    MMDBW_node_s *ipv4_root_node =
        ipv4_root_record->value.node;
    for (int i = 0; i <= 2; i++) {
        MMDBW_network_s alias_network =
            resolve_network(tree, ipv4_aliases[i].ipstr,
                            ipv4_aliases[i].prefix_length);

        MMDBW_record_s *record_for_alias;
        MMDBW_status status = find_record_for_network(
            tree, &alias_network, true,
            &new_node_from_record, &record_for_alias, NULL);

        free_network(&alias_network);

        if (MMDBW_SUCCESS != status) {
            croak("Unexpected NULL when searching for last node for alias");
        }
        record_for_alias->type = MMDBW_RECORD_TYPE_ALIAS;
        record_for_alias->value.node = ipv4_root_node;
    }
}

LOCAL MMDBW_status insert_record_for_network(MMDBW_tree_s *tree,
                                             MMDBW_network_s *network,
                                             MMDBW_record_s *new_record,
                                             bool merge_record_collisions,
                                             bool is_internal_insert)
{
    MMDBW_record_s *record_to_set, *other_record;
    MMDBW_status status =
        find_record_for_network(tree, network,
                                false,
                                &new_node_from_record,
                                &record_to_set, &other_record);
    if (MMDBW_SUCCESS != status) {
        free_record_value(tree, new_record);
        return MMDBW_FINDING_NODE_ERROR;
    }

    if (record_to_set->type == MMDBW_RECORD_TYPE_ALIAS) {
        MMDBW_record_type type = new_record->type;
        free_record_value(tree, new_record);
        if (type == MMDBW_RECORD_TYPE_DATA && is_internal_insert) {
            // Possibly change return value in future
            return MMDBW_SUCCESS;
        }
        return MMDBW_ALIAS_OVERWRITE_ATTEMPT_ERROR;
    }

    if (merge_record_collisions &&
        MMDBW_RECORD_TYPE_DATA == new_record->type) {

        if (merge_records(tree, network, new_record, record_to_set)) {
            return MMDBW_SUCCESS;
        }
    }

    /* If this record we're about to insert is a data record, and the other
     * record in the node also has the same data, then we instead want to
     * insert a single data record in this node's parent. We do this by
     * inserting the new record for the parent network, which we can calculate
     * quite easily by subtracting 1 from this network's prefix length. */
    if (MMDBW_RECORD_TYPE_DATA == new_record->type
        && NULL != other_record
        && MMDBW_RECORD_TYPE_DATA == other_record->type
        ) {

        const char *const new_key = new_record->value.key;
        const char *const other_key = other_record->value.key;

        if (strlen(new_key) == strlen(other_key)
            && 0 == strcmp(new_key, other_key)) {

            size_t bytes_length = tree->ip_version == 6 ? 16 : 4;
            uint8_t *bytes = checked_malloc(bytes_length);
            memcpy(bytes, network->bytes, bytes_length);

            uint8_t parent_prefix_length = network->prefix_length - 1;
            MMDBW_network_s parent_network = {
                .bytes         = bytes,
                .prefix_length = parent_prefix_length,
            };

            /* We don't need to merge record collisions in this insert as
             * we have already merged the new record with the existing
             * record
             */
            MMDBW_status status =
                insert_record_for_network(tree, &parent_network, new_record,
                                          false,
                                          true);
            free_network(&parent_network);
            return status;
        }
    }

    free_record_value(tree, record_to_set);

    record_to_set->type = new_record->type;
    if (MMDBW_RECORD_TYPE_DATA == new_record->type) {
        record_to_set->value.key = new_record->value.key;
    } else if (MMDBW_RECORD_TYPE_NODE == new_record->type ||
               MMDBW_RECORD_TYPE_ALIAS == new_record->type) {
        record_to_set->value.node = new_record->value.node;
    }

    return MMDBW_SUCCESS;
}

LOCAL bool merge_records(MMDBW_tree_s *tree,
                         MMDBW_network_s *network,
                         MMDBW_record_s *new_record,
                         MMDBW_record_s *record_to_set)
{
    int max_depth0 = tree_depth0(tree);

    if (MMDBW_RECORD_TYPE_NODE == record_to_set->type) {
        if (network->prefix_length > max_depth0) {
            croak("Something is very wrong. Prefix length is too long.");
        }

        /* We increment the count as we are turning one record into two */
        increment_data_reference_count(tree, new_record->value.key);

        uint8_t new_prefix_length = network->prefix_length + 1;

        MMDBW_network_s left = {
            .bytes         = network->bytes,
            .prefix_length = new_prefix_length,
        };

        MMDBW_record_s new_left_record = {
            .type    = new_record->type,
            .value   = {
                .key = new_record->value.key
            }
        };

        MMDBW_status status =
            insert_record_for_network(tree, &left, &new_left_record, true,
                                      true);
        if (MMDBW_SUCCESS != status) {
            return status;
        }

        size_t bytes_length = tree->ip_version == 6 ? 16 : 4;
        uint8_t right_bytes[bytes_length];

        memcpy(&right_bytes, network->bytes, bytes_length);

        right_bytes[ (new_prefix_length - 1) / 8]
            |= 1 << ((max_depth0 + 1 - new_prefix_length) % 8);

        MMDBW_network_s right = {
            .bytes         = (const uint8_t *const)&right_bytes,
            .prefix_length = new_prefix_length,
        };

        MMDBW_record_s new_right_record = {
            .type    = new_record->type,
            .value   = {
                .key = new_record->value.key
            }
        };

        status =
            insert_record_for_network(tree, &right, &new_right_record, true,
                                      true);
        if (MMDBW_SUCCESS != status) {
            return status;
        }

        /* There's no need continuing with the original record as the relevant
         * data has already been inserted further down the tree by the code
         * above. */
        return true;
    }
    /* This must come before the node pruning code in
       insert_record_for_network, as we only want to prune nodes where the
       merged record matches. */
    else if (MMDBW_RECORD_TYPE_DATA == record_to_set->type) {
        SV *merged = merge_hashes_for_keys(tree,
                                           new_record->value.key,
                                           record_to_set->value.key,
                                           network);

        SV *key_sv = key_for_data(merged);
        const char *const new_key =
            store_data_in_tree(tree, SvPVbyte_nolen(key_sv), merged);
        SvREFCNT_dec(key_sv);

        /* The ref count was incremented in store_data_in_tree */
        SvREFCNT_dec(merged);

        decrement_data_reference_count(tree, new_record->value.key);
        new_record->value.key = new_key;
    }

    return false;
}

LOCAL int network_bit_value(MMDBW_tree_s *tree, MMDBW_network_s *network,
                            uint8_t current_bit)
{
    int max_depth0 = tree_depth0(tree);
    return network->bytes[(max_depth0 - current_bit) >> 3]
           & (1U << (~(max_depth0 - current_bit) & 7));
}

LOCAL int tree_depth0(MMDBW_tree_s *tree)
{
    return tree->ip_version == 6 ? 127 : 31;
}

SV *merge_hashes_for_keys(MMDBW_tree_s *tree, const char *const key_from,
                          const char *const key_into, MMDBW_network_s *network)
{
    SV *data_from = data_for_key(tree, key_from);
    SV *data_into = data_for_key(tree, key_into);

    if (!(SvROK(data_from) && SvROK(data_into)
          && SvTYPE(SvRV(data_from)) == SVt_PVHV
          && SvTYPE(SvRV(data_into)) == SVt_PVHV)) {
        /* We added key_into earlier during insert_record_for_network, so we
           have to make sure here that it's removed again after we decide to
           not actually store this network. It might be nicer to not insert
           anything into the tree until we're sure we really want to. */
        decrement_data_reference_count(tree, key_from);

        bool is_ipv6 = tree->ip_version == 6;
        char address_string[ is_ipv6 ? INET6_ADDRSTRLEN : INET_ADDRSTRLEN];
        inet_ntop(is_ipv6 ? AF_INET6 : AF_INET,
                  network->bytes,
                  address_string,
                  sizeof(address_string));

        croak(
            "Cannot merge data records unless both records are hashes - inserting %s/%"
            PRIu8,
            address_string, network->prefix_length);
    }

    return merge_hashes(tree, data_from, data_into);
}

LOCAL SV * merge_hashes(MMDBW_tree_s *tree, SV *from, SV *into)
{
    HV *hash_from = (HV *)SvRV(from);
    HV *hash_into = (HV *)SvRV(into);
    HV *hash_new = newHV();

    merge_new_from_hash_into_hash(tree, hash_from, hash_new);
    merge_new_from_hash_into_hash(tree, hash_into, hash_new);

    return newRV_noinc((SV *)hash_new);
}

// Note: unlike the other merge functions, this does _not_ replace existing
// values.
LOCAL void merge_new_from_hash_into_hash(MMDBW_tree_s *tree, HV *from, HV *to)
{
    (void)hv_iterinit(from);
    HE *he;
    while (NULL != (he = hv_iternext(from))) {
        STRLEN key_length;
        const char *const key = HePV(he, key_length);
        U32 hash = 0;
        SV *value = HeVAL(he);
        if (hv_exists(to, key, key_length)) {
            if (tree->merge_strategy == MMDBW_MERGE_STRATEGY_RECURSE) {
                SV **existing_value = hv_fetch(to, key, key_length, 0);
                if (existing_value == NULL) {
                    // This should never happen as we just did an hv_exists
                    croak("Received an unexpected NULL from hv_fetch");
                }
                value = merge_values(tree, value, *existing_value);
            } else {
                continue;
            }
        } else {
            hash = HeHASH(he);
            SvREFCNT_inc_simple_void_NN(value);
        }

        (void)hv_store(to, key, key_length, value, hash);
    }

    return;
}

LOCAL SV * merge_values(MMDBW_tree_s *tree, SV *from, SV *into)
{
    if (SvROK(from) != SvROK(into)) {
        croak("Attempt to merge a reference value and non-refrence value");
    }

    if (!SvROK(from)) {
        // If the two values are scalars, we prefer the one in the hash being
        // inserted.
        SvREFCNT_inc_simple_void_NN(from);
        return from;
    }

    if (SvTYPE(SvRV(from)) == SVt_PVHV && SvTYPE(SvRV(into)) == SVt_PVHV) {
        return merge_hashes(tree, from, into);
    }

    if (SvTYPE(SvRV(from)) == SVt_PVAV && SvTYPE(SvRV(into)) == SVt_PVAV) {
        return merge_arrays(tree, from, into);
    }

    croak("Only arrayrefs, hashrefs, and scalars can be merged.");
}

LOCAL SV * merge_arrays(MMDBW_tree_s *tree, SV *from, SV *into)
{
    AV *from_array = (AV *)SvRV(from);
    AV *into_array = (AV *)SvRV(into);

    // Note that av_len() is really the index of the last element. In newer
    // Perl versions, it is also called av_top_index() or av_tindex()
    SSize_t from_top_index = av_len(from_array);
    SSize_t into_top_index = av_len(into_array);

    SSize_t new_top_index = from_top_index >
                            into_top_index ? from_top_index : into_top_index;

    AV *new_array = newAV();
    for (SSize_t i = 0; i <= new_top_index; i++) {
        SV * new_value = NULL;
        SV ** from_value = av_fetch(from_array, i, 0);
        SV ** into_value = av_fetch(into_array, i, 0);
        if (from_value != NULL && into_value != NULL) {
            new_value = merge_values(tree, *from_value, *into_value);
        } else if (from_value != NULL) {
            new_value = *from_value;
            SvREFCNT_inc_simple_void_NN(new_value);
        } else if (into_value != NULL) {
            new_value = *into_value;
            SvREFCNT_inc_simple_void_NN(new_value);
        } else {
            croak("Received unexpected NULLs when merging arrays");
        }

        av_push(new_array, new_value);
    }
    return newRV_noinc((SV *)new_array);
}

SV *lookup_ip_address(MMDBW_tree_s *tree, const char *const ipstr)
{
    bool is_ipv6_address = NULL != strchr(ipstr, ':');
    if (tree->ip_version == 4 && is_ipv6_address) {
        return &PL_sv_undef;
    }
    MMDBW_network_s network =
        resolve_network(tree, ipstr, is_ipv6_address ? 128 : 32);

    MMDBW_record_s *record_for_address;
    MMDBW_status status =
        find_record_for_network(tree, &network, true, &return_null,
                                &record_for_address,
                                NULL);

    free_network(&network);

    if (MMDBW_SUCCESS != status) {
        croak("Received an unexpected NULL when looking up %s: %s", ipstr,
              status_error_message(status));
    }

    if (MMDBW_RECORD_TYPE_NODE == record_for_address->type ||
        MMDBW_RECORD_TYPE_ALIAS == record_for_address->type) {
        croak(
            "WTF - found a node or alias record for an address lookup - %s"
            PRIu8,
            ipstr);
        return &PL_sv_undef;
    } else if (MMDBW_RECORD_TYPE_EMPTY == record_for_address->type) {
        return &PL_sv_undef;
    } else {
        return newSVsv(data_for_key(tree, record_for_address->value.key));
    }
}

LOCAL MMDBW_status find_record_for_network(MMDBW_tree_s *tree,
                                           MMDBW_network_s *network,
                                           bool follow_aliases,
                                           MMDBW_node_s *(if_not_node)(
                                               MMDBW_tree_s *tree,
                                               MMDBW_record_s *record),
                                           MMDBW_record_s **record,
                                           MMDBW_record_s **sibling_record
                                           )
{
    if (NULL != sibling_record) {
        *sibling_record = NULL;
    }

    *record = &(tree->root_record);

    int max_depth0 = tree_depth0(tree);
    uint8_t last_bit = max_depth0 - (network->prefix_length - 1);

    for (int current_bit = max_depth0; current_bit >= last_bit;
         current_bit--) {

        if (MMDBW_RECORD_TYPE_ALIAS == (*record)->type && !follow_aliases) {
            return MMDBW_FINDING_NODE_ERROR;
        }

        MMDBW_node_s *node;
        if (MMDBW_RECORD_TYPE_NODE == (*record)->type ||
            MMDBW_RECORD_TYPE_ALIAS == (*record)->type) {
            node = (*record)->value.node;
        } else {
            node = if_not_node(tree, *record);
            if (NULL == node) {
                break;
            }

            (*record)->type = MMDBW_RECORD_TYPE_NODE;
            (*record)->value.node = node;
        }


        if (network_bit_value(tree, network, current_bit)) {
            *record = &(node->right_record);
            if (NULL != sibling_record) {
                *sibling_record = &(node->left_record);
            }
        } else {
            *record = &(node->left_record);
            if (NULL != sibling_record) {
                *sibling_record = &(node->right_record);
            }
        }
    }

    return MMDBW_SUCCESS;
}

LOCAL MMDBW_node_s *return_null(
    MMDBW_tree_s *UNUSED(tree), MMDBW_record_s *UNUSED(record))
{
    return NULL;
}

LOCAL MMDBW_node_s *new_node_from_record(MMDBW_tree_s *tree,
                                         MMDBW_record_s *record)
{
    MMDBW_node_s *node = new_node();
    if (MMDBW_RECORD_TYPE_DATA == record->type) {
        /* We only need to increment the reference count once as we are
           replacing the parent record */
        increment_data_reference_count(tree, record->value.key);

        node->left_record.type = MMDBW_RECORD_TYPE_DATA;
        node->left_record.value.key = record->value.key;

        node->right_record.type = MMDBW_RECORD_TYPE_DATA;
        node->right_record.value.key = record->value.key;
    }

    return node;
}

MMDBW_node_s *new_node()
{
    MMDBW_node_s *node = checked_malloc(sizeof(MMDBW_node_s));

    node->number = 0;
    node->left_record.type = node->right_record.type = MMDBW_RECORD_TYPE_EMPTY;

    return node;
}

LOCAL void free_node_and_subnodes(MMDBW_tree_s *tree, MMDBW_node_s *node)
{
    free_record_value(tree, &(node->left_record));
    free_record_value(tree, &(node->right_record));

    free(node);
}

LOCAL void free_record_value(MMDBW_tree_s *tree, MMDBW_record_s *record)
{
    if (MMDBW_RECORD_TYPE_NODE == record->type) {
        free_node_and_subnodes(tree, record->value.node);
    }

    if (MMDBW_RECORD_TYPE_DATA == record->type) {
        decrement_data_reference_count(tree, record->value.key);
    }

    /* We do not follow MMDBW_RECORD_TYPE_ALIAS nodes */
}

void assign_node_numbers(MMDBW_tree_s *tree)
{
    tree->node_count = 0;
    start_iteration(tree, false, (void *)NULL, &assign_node_number);
}

LOCAL void assign_node_number(MMDBW_tree_s *tree, MMDBW_node_s *node,
                              uint128_t UNUSED(network),
                              uint8_t UNUSED(depth), void *UNUSED(args))
{
    node->number = tree->node_count++;
    return;
}

/* 16 bytes for an IP address, 1 byte for the prefix length */
#define FROZEN_RECORD_MAX_SIZE (16 + 1 + SHA1_KEY_LENGTH)
#define FROZEN_NODE_MAX_SIZE (FROZEN_RECORD_MAX_SIZE * 2)

/* 17 bytes of NULLs followed by something that cannot be an SHA1 key are a
   clear indicator that there are no more frozen networks in the buffer. */
#define SEVENTEEN_NULLS "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
#define FREEZE_SEPARATOR "not an SHA1 key"
/* We subtract 1 as we treat this as a sequence of bytes rather than a null terminated
   string. */
#define FREEZE_SEPARATOR_LENGTH (sizeof(FREEZE_SEPARATOR) - 1)

void freeze_tree(MMDBW_tree_s *tree, char *filename, char *frozen_params,
                 size_t frozen_params_size)
{
    int fd = open(filename, O_CREAT | O_TRUNC | O_RDWR, (mode_t)0644);
    if (-1 == fd) {
        croak("Could not open file %s: %s", filename, strerror(errno));
    }

    freeze_args_s args = {
        .fd       = fd,
        .filename = filename,
    };

    freeze_to_fd(&args, &frozen_params_size, 4);
    freeze_to_fd(&args, frozen_params, frozen_params_size);

    freeze_search_tree(tree, &args);

    freeze_to_fd(&args, SEVENTEEN_NULLS, 17);
    freeze_to_fd(&args, FREEZE_SEPARATOR, FREEZE_SEPARATOR_LENGTH);

    freeze_data_to_fd(fd, tree);

    if (-1 == close(fd)) {
        croak("Could not close file %s: %s", filename, strerror(errno));
    }

    /* When the hash is _freed_, Perl decrements the ref count for each value
     * so we don't need to mess with them. */
    SvREFCNT_dec((SV *)args.data_hash);
}

LOCAL void freeze_search_tree(MMDBW_tree_s *tree, freeze_args_s *args)
{
    if (MMDBW_RECORD_TYPE_DATA == tree->root_record.type) {
        croak("A tree that only contains a data record for /0 cannot be "
              "frozen");
    }

    if (MMDBW_RECORD_TYPE_NODE == tree->root_record.type) {
        start_iteration(tree, false, (void *)args, &freeze_node);
        return;
    }

    croak("Unexected root record type when freezing tree: %s",
          record_type_name(tree->root_record.type));
}


LOCAL void freeze_node(MMDBW_tree_s *tree, MMDBW_node_s *node,
                       uint128_t network, uint8_t depth, void *void_args)
{
    freeze_args_s *args = (freeze_args_s *)void_args;

    const uint8_t next_depth = depth + 1;

    if (MMDBW_RECORD_TYPE_DATA == node->left_record.type) {
        freeze_data_record(tree, network, next_depth,
                           node->left_record.value.key, args);
    }

    if (MMDBW_RECORD_TYPE_DATA == node->right_record.type) {
        uint128_t right_network =
            flip_network_bit(tree, network, depth);
        freeze_data_record(tree, right_network, next_depth,
                           node->right_record.value.key, args);
    }
}

LOCAL void freeze_data_record(MMDBW_tree_s *UNUSED(tree),
                              uint128_t network, uint8_t depth,
                              const char *key,
                              freeze_args_s *args)
{
    /* It'd save some space to shrink this to 4 bytes for IPv4-only trees, but
     * that would also complicated thawing quite a bit. */
    freeze_to_fd(args, &network, 16);
    freeze_to_fd(args, &(depth), 1);
    freeze_to_fd(args, (char *)key, SHA1_KEY_LENGTH);
}

LOCAL void freeze_to_fd(freeze_args_s *args, void *data, size_t size)
{
    checked_write(args->fd, args->filename, data, size);
}

LOCAL void freeze_data_to_fd(int fd, MMDBW_tree_s *tree)
{
    HV *data_hash = newHV();

    MMDBW_data_hash_s *item, *tmp;
    HASH_ITER(hh, tree->data_table, item, tmp) {
        SvREFCNT_inc_simple_void_NN(item->data_sv);
        (void)hv_store(data_hash, item->key, SHA1_KEY_LENGTH, item->data_sv, 0);
    }

    SV *frozen_data = freeze_hash(data_hash);
    STRLEN frozen_data_size;
    char *frozen_data_chars = SvPV(frozen_data, frozen_data_size);

    ssize_t written = write(fd, &frozen_data_size, sizeof(STRLEN));
    if (-1 == written) {
        croak("Could not write frozen data size to file: %s", strerror(errno));
    }
    if (written != sizeof(STRLEN)) {
        croak("Could not write frozen data size to file: %zd != %zu", written,
              sizeof(STRLEN));
    }

    written = write(fd, frozen_data_chars, frozen_data_size);
    if (-1 == written) {
        croak("Could not write frozen data size to file: %s", strerror(errno));
    }
    if (written != (ssize_t)frozen_data_size) {
        croak("Could not write frozen data to file: %zd != %zu", written,
              frozen_data_size);
    }

    SvREFCNT_dec(frozen_data);
    SvREFCNT_dec((SV *)data_hash);
}

LOCAL SV *freeze_hash(HV *hash)
{
    dSP;
    ENTER;
    SAVETMPS;

    SV *hashref = sv_2mortal(newRV_inc((SV *)hash));

    PUSHMARK(SP);
    EXTEND(SP, 1);
    PUSHs(hashref);
    PUTBACK;

    int count = call_pv("Sereal::Encoder::encode_sereal", G_SCALAR);

    SPAGAIN;

    if (count != 1) {
        croak("Expected 1 item back from Sereal::Encoder::encode_sereal call");
    }

    SV *frozen = POPs;
    if (!SvPOK(frozen)) {
        croak(
            "The Sereal::Encoder::encode_sereal sub returned an SV which is not SvPOK!");
    }

    /* The SV will be mortal so it's about to lose a ref with the FREETMPS
       call below. */
    SvREFCNT_inc_simple_void_NN(frozen);

    PUTBACK;
    FREETMPS;
    LEAVE;

    return frozen;
}

MMDBW_tree_s *thaw_tree(char *filename, uint32_t initial_offset,
                        uint8_t ip_version, uint8_t record_size,
                        MMDBW_merge_strategy merge_strategy,
                        const bool alias_ipv6)
{
    int fd = open(filename, O_RDONLY, 0);
    if (-1 == fd) {
        croak("Could not open file %s: %s", filename, strerror(errno));
    }

    struct stat fileinfo;
    if (-1 == fstat(fd, &fileinfo)) {
        close(fd);
        croak("Could not stat file: %s: %s", filename, strerror(errno));
    }

    uint8_t *buffer =
        (uint8_t *)mmap(NULL, fileinfo.st_size, PROT_READ, MAP_SHARED, fd,
                        0);
    close(fd);

    buffer += initial_offset;

    MMDBW_tree_s *tree = new_tree(ip_version, record_size, merge_strategy,
                                  alias_ipv6);

    thawed_network_s *thawed;
    while (NULL != (thawed = thaw_network(tree, &buffer))) {
        if (MMDBW_RECORD_TYPE_DATA == thawed->record->type) {
            const char *key = increment_data_reference_count(
                tree, thawed->record->value.key);

            /* insert_record_for_network reuses the key. We want it to use
               the same copy as used in the data hash. */
            free((char *)thawed->record->value.key);
            thawed->record->value.key = key;
        }
        // We should never need to merge when thawing a tree.
        MMDBW_status status = insert_record_for_network(
            tree, thawed->network, thawed->record,
            false, true);
        free_network(thawed->network);
        free(thawed->network);
        free(thawed->record);
        free(thawed);
        if (MMDBW_SUCCESS != status) {
            croak(status_error_message(status));
        }
    }

    STRLEN frozen_data_size = thaw_strlen(&buffer);

    /* per perlapi newSVpvn copies the string */
    SV *data_to_decode =
        sv_2mortal(newSVpvn((char *)buffer, frozen_data_size));
    HV *data_hash = thaw_data_hash(data_to_decode);

    hv_iterinit(data_hash);
    char *key;
    I32 keylen;
    SV *value;
    while (NULL != (value = hv_iternextsv(data_hash, &key, &keylen))) {
        set_stored_data_in_tree(tree, key, value);
    }

    SvREFCNT_dec((SV *)data_hash);

    return tree;
}

LOCAL uint8_t thaw_uint8(uint8_t **buffer)
{
    uint8_t value;
    memcpy(&value, *buffer, 1);
    *buffer += 1;
    return value;
}

LOCAL uint32_t thaw_uint32(uint8_t **buffer)
{
    uint32_t value;
    memcpy(&value, *buffer, 4);
    *buffer += 4;
    return value;
}

LOCAL thawed_network_s *thaw_network(MMDBW_tree_s *tree, uint8_t **buffer)
{
    uint128_t start_ip = thaw_uint128(buffer);
    uint8_t prefix_length = thaw_uint8(buffer);

    if (0 == start_ip && 0 == prefix_length) {
        uint8_t *maybe_separator = thaw_bytes(buffer, FREEZE_SEPARATOR_LENGTH);
        if (memcmp(maybe_separator, FREEZE_SEPARATOR,
                   FREEZE_SEPARATOR_LENGTH) == 0) {

            free(maybe_separator);
            return NULL;
        }

        croak("Found a ::0/0 network but that should never happen!");
    }

    uint8_t *start_ip_bytes = (uint8_t *)&start_ip;
    uint8_t temp;
    for (int i = 0; i < 8; i++) {
        temp = start_ip_bytes[i];
        start_ip_bytes[i] = start_ip_bytes[15 - i];
        start_ip_bytes[15 - i] = temp;
    }

    thawed_network_s *thawed = checked_malloc(sizeof(thawed_network_s));

    uint8_t *bytes;
    if (tree->ip_version == 4) {
        bytes = checked_malloc(4);
        memcpy(bytes, start_ip_bytes + 12, 4);
    } else {
        bytes = checked_malloc(16);
        memcpy(bytes, &start_ip, 16);
    }

    MMDBW_network_s network = {
        .bytes         = bytes,
        .prefix_length = prefix_length,
    };

    thawed->network = checked_malloc(sizeof(MMDBW_network_s));
    memcpy(thawed->network, &network, sizeof(MMDBW_network_s));

    MMDBW_record_s *record = checked_malloc(sizeof(MMDBW_record_s));
    record->type = MMDBW_RECORD_TYPE_DATA;

    record->value.key = thaw_data_key(buffer);

    thawed->record = record;

    return thawed;
}

LOCAL uint8_t *thaw_bytes(uint8_t **buffer, size_t size)
{
    uint8_t *value = checked_malloc(size);
    memcpy(value, *buffer, size);
    *buffer += size;
    return value;
}

LOCAL uint128_t thaw_uint128(uint8_t **buffer)
{
    uint128_t value;
    memcpy(&value, *buffer, 16);
    *buffer += 16;
    return value;
}

LOCAL STRLEN thaw_strlen(uint8_t **buffer)
{
    STRLEN value;
    memcpy(&value, *buffer, sizeof(STRLEN));
    *buffer += sizeof(STRLEN);
    return value;
}

LOCAL const char *thaw_data_key(uint8_t **buffer)
{
    char *value = checked_malloc(SHA1_KEY_LENGTH + 1);
    memcpy(value, *buffer, SHA1_KEY_LENGTH);
    *buffer += SHA1_KEY_LENGTH;
    value[SHA1_KEY_LENGTH] = '\0';
    return (const char *)value;
}

LOCAL HV *thaw_data_hash(SV *data_to_decode)
{
    dSP;
    ENTER;
    SAVETMPS;

    PUSHMARK(SP);
    EXTEND(SP, 1);
    PUSHs(data_to_decode);
    PUTBACK;

    int count = call_pv("Sereal::Decoder::decode_sereal", G_SCALAR);

    SPAGAIN;

    if (count != 1) {
        croak("Expected 1 item back from Sereal::Decoder::decode_sereal call");
    }

    SV *thawed = POPs;
    if (!SvROK(thawed)) {
        croak(
            "The Sereal::Decoder::decode_sereal sub returned an SV which is not SvROK!");
    }

    SV *data_hash = SvREFCNT_inc_simple_NN(SvRV(thawed));

    PUTBACK;
    FREETMPS;
    LEAVE;

    return (HV *)data_hash;
}

void write_search_tree(MMDBW_tree_s *tree, SV *output,
                       SV *root_data_type, SV *serializer)
{
    assign_node_numbers(tree);

    /* This is a gross way to get around the fact that with C function
     * pointers we can't easily pass different params to different
     * callbacks. */
    encode_args_s args = {
        .output_io          = IoOFP(sv_2io(output)),
        .root_data_type     = root_data_type,
        .serializer         = serializer,
        .data_pointer_cache = newHV()
    };

    start_iteration(tree, false, (void *)&args, &encode_node);

    /* When the hash is _freed_, Perl decrements the ref count for each value
     * so we don't need to mess with them. */
    SvREFCNT_dec((SV *)args.data_pointer_cache);

    return;
}

LOCAL void encode_node(MMDBW_tree_s *tree, MMDBW_node_s *node,
                       uint128_t UNUSED(network),
                       uint8_t UNUSED(depth), void *void_args)
{
    encode_args_s *args = (encode_args_s *)void_args;

    check_record_sanity(node, &(node->left_record), "left");
    check_record_sanity(node, &(node->right_record), "right");

    uint32_t left =
        htonl(record_value_as_number(tree, &(node->left_record), args));
    uint32_t right =
        htonl(record_value_as_number(tree, &(node->right_record), args));

    uint8_t *left_bytes = (uint8_t *)&left;
    uint8_t *right_bytes = (uint8_t *)&right;

    if (24 == tree->record_size) {
        check_perlio_result(
            PerlIO_printf(args->output_io, "%c%c%c%c%c%c",
                          left_bytes[1], left_bytes[2], left_bytes[3],
                          right_bytes[1], right_bytes[2],
                          right_bytes[3]),
            6, "PerlIO_printf");
    } else if (28 == tree->record_size) {
        check_perlio_result(
            PerlIO_printf(args->output_io, "%c%c%c%c%c%c%c",
                          left_bytes[1], left_bytes[2],
                          left_bytes[3],
                          (left_bytes[0] <<
                           4) | (right_bytes[0] & 15),
                          right_bytes[1], right_bytes[2],
                          right_bytes[3]),
            7, "PerlIO_printf");
    } else {
        check_perlio_result(
            PerlIO_printf(args->output_io, "%c%c%c%c%c%c%c%c",
                          left_bytes[0], left_bytes[1],
                          left_bytes[2], left_bytes[3],
                          right_bytes[0], right_bytes[1],
                          right_bytes[2], right_bytes[3]),
            8, "PerlIO_printf");
    }
}

/* Note that for data records, we will ensure that the key they contain does
 * match a data record in the record_value_as_number() subroutine. */
LOCAL void check_record_sanity(MMDBW_node_s *node, MMDBW_record_s *record,
                               char *side)
{
    if (MMDBW_RECORD_TYPE_NODE == record->type) {
        if (record->value.node->number == node->number) {
            croak("%s record of node %" PRIu32 " points to the same node",
                  side, node->number);
        }

        if (record->value.node->number < node->number) {
            croak(
                "%s record of node %" PRIu32 " points to a node  number (%"
                PRIu32
                ")",
                side, node->number, record->value.node->number);
        }
    }

    if (MMDBW_RECORD_TYPE_ALIAS == record->type) {
        if (0 == record->value.node->number) {
            croak("%s record of node %" PRIu32 " is an alias to node 0",
                  side, node->number);
        }
    }
}

LOCAL uint32_t record_value_as_number(MMDBW_tree_s *tree,
                                      MMDBW_record_s *record,
                                      encode_args_s * args)
{
    uint32_t record_value;

    if (MMDBW_RECORD_TYPE_EMPTY == record->type) {
        record_value = tree->node_count;
    } else if (MMDBW_RECORD_TYPE_NODE == record->type ||
               MMDBW_RECORD_TYPE_ALIAS == record->type) {
        record_value = record->value.node->number;
    } else {
        SV **cache_record =
            hv_fetch(args->data_pointer_cache, record->value.key,
                     SHA1_KEY_LENGTH, 0);
        if (cache_record) {
            /* It is ok to return this without the size check below as it
               would have already croaked when it was inserted if it was too
               big. */
            return SvIV(*cache_record);
        }

        SV *data = newSVsv(data_for_key(tree, record->value.key));
        if (!SvOK(data)) {
            croak("No data associated with key - %s", record->value.key);
        }

        dSP;
        ENTER;
        SAVETMPS;

        PUSHMARK(SP);
        EXTEND(SP, 5);
        PUSHs(args->serializer);
        PUSHs(args->root_data_type);
        mPUSHs(data);
        PUSHs(&PL_sv_undef);
        mPUSHp(record->value.key, strlen(record->value.key));
        PUTBACK;

        int count = call_method("store_data", G_SCALAR);

        SPAGAIN;

        if (count != 1) {
            croak("Expected 1 item back from ->store_data() call");
        }

        SV *rval = POPs;
        if (!(SvIOK(rval) || SvUOK(rval))) {
            croak(
                "The serializer's store_data() method returned an SV which is not SvIOK or SvUOK!");
        }
        uint32_t position = (uint32_t )SvUV(rval);

        PUTBACK;
        FREETMPS;
        LEAVE;

        record_value = position + tree->node_count +
                       DATA_SECTION_SEPARATOR_SIZE;

        SV *value = newSViv(record_value);
        (void)hv_store(args->data_pointer_cache, record->value.key,
                       SHA1_KEY_LENGTH, value, 0);
    }

    if (record_value > max_record_value(tree)) {
        croak(
            "Node value of %" PRIu32 " exceeds the record size of %" PRIu8
            " bits",
            record_value, tree->record_size);
    }

    return record_value;
}

uint32_t max_record_value(MMDBW_tree_s *tree)
{
    uint8_t record_size = tree->record_size;
    return record_size == 32 ? UINT32_MAX : (uint32_t)(1 << record_size) - 1;
}

void start_iteration(MMDBW_tree_s *tree,
                     bool depth_first,
                     void *args,
                     MMDBW_iterator_callback callback)
{
    uint128_t network = 0;
    uint8_t depth = 0;

    // We disallow this as the callback is based on nodes rather than records,
    // and changing that is a rabbit hole that I don't want to go down
    // currently. (I stuck my head in and regretted it.)
    if (MMDBW_RECORD_TYPE_NODE != tree->root_record.type) {
        croak("Iteration is not currently allowed in trees with no nodes.");
    }

    iterate_tree(tree, &tree->root_record, network, depth, depth_first, args,
                 callback);

    return;
}

LOCAL void iterate_tree(MMDBW_tree_s *tree,
                        MMDBW_record_s *record,
                        uint128_t network,
                        const uint8_t depth,
                        bool depth_first,
                        void *args,
                        MMDBW_iterator_callback callback)
{
    if (depth > tree_depth0(tree) + 1) {
        char ip[INET6_ADDRSTRLEN];
        integer_to_ip_string(tree->ip_version, network, ip, sizeof(ip));
        croak(
            "Depth during iteration is greater than 127 (depth: %u, "
            "start IP: %s)! The tree is wonky.\n", depth, ip);
    }

    if (MMDBW_RECORD_TYPE_NODE == record->type) {
        MMDBW_node_s *node = record->value.node;

        if (!depth_first) {
            callback(tree, node, network, depth, args);
        }

        iterate_tree(tree,
                     &node->left_record,
                     network,
                     depth + 1,
                     depth_first,
                     args,
                     callback);

        if (depth_first) {
            callback(tree, node, network, depth, args);
        }

        iterate_tree(tree,
                     &node->right_record,
                     flip_network_bit(tree, network, depth),
                     depth + 1,
                     depth_first,
                     args,
                     callback);
    }
}

uint128_t flip_network_bit(MMDBW_tree_s *tree, uint128_t network, uint8_t depth)
{
    return network | ((uint128_t)1 << (tree_depth0(tree) - depth));
}

LOCAL SV *key_for_data(SV * data)
{
    dSP;
    ENTER;
    SAVETMPS;

    PUSHMARK(SP);
    EXTEND(SP, 1);
    PUSHs(data);
    PUTBACK;

    const char *const sub = "MaxMind::DB::Writer::Util::key_for_data";
    int count = call_pv(sub, G_SCALAR);

    SPAGAIN;

    if (count != 1) {
        croak("Expected 1 item back from %s() call", sub);
    }

    SV *key = POPs;
    SvREFCNT_inc_simple_void_NN(key);

    PUTBACK;
    FREETMPS;
    LEAVE;

    return key;
}

SV *data_for_key(MMDBW_tree_s *tree, const char *const key)
{
    MMDBW_data_hash_s *data = NULL;
    HASH_FIND(hh, tree->data_table, key, strlen(key), data);

    if (NULL != data) {
        return data->data_sv;
    } else {
        return &PL_sv_undef;
    }
}

void free_tree(MMDBW_tree_s *tree)
{
    free_record_value(tree, &tree->root_record);

    int hash_count = HASH_COUNT(tree->data_table);
    if (0 != hash_count) {
        croak("%d elements left in data table after freeing all nodes!",
              hash_count);
    }

    free(tree);
}

const char *record_type_name(int record_type)
{
    return MMDBW_RECORD_TYPE_EMPTY == record_type
           ? "empty"
           : MMDBW_RECORD_TYPE_NODE == record_type
           ? "node"
           : MMDBW_RECORD_TYPE_ALIAS == record_type
           ? "alias"
           : "data";
}

static SV *module;
LOCAL void dwarn(SV *thing)
{
    if (NULL == module) {
        module = newSVpv("Devel::Dwarn", 0);
        load_module(PERL_LOADMOD_NOIMPORT, module, NULL);
    }

    dSP;
    ENTER;
    SAVETMPS;

    PUSHMARK(SP);
    EXTEND(SP, 1);
    PUSHs(thing);
    PUTBACK;

    (void)call_pv("Devel::Dwarn::Dwarn", G_VOID);

    SPAGAIN;

    PUTBACK;
    FREETMPS;
    LEAVE;
}

void warn_hex(uint8_t digest[16], char *where)
{
    char *hex = md5_as_hex(digest);
    fprintf(stderr, "MD5 = %s (%s)\n", hex, where);
    free(hex);
}

char *md5_as_hex(uint8_t digest[16])
{
    char *hex = checked_malloc(33);
    for (int i = 0; i < 16; ++i) {
        sprintf(&hex[i * 2], "%02x", digest[i]);
    }

    return hex;
}

LOCAL void *checked_malloc(size_t size)
{
    void *ptr = malloc(size);
    if (NULL == ptr) {
        abort();
    }

    return ptr;
}

LOCAL void checked_write(int fd, char *filename, void *buffer,
                         ssize_t count)
{
    ssize_t result = write(fd, buffer, count);
    if (-1 == result) {
        close(fd);
        croak("Could not write to the file %s: %s", filename,
              strerror(errno));
    }
    if (result != count) {
        close(fd);
        croak(
            "Write to %s did not write the expected amount of data (wrote %zd instead of %zu)",
            filename, result, count);
    }
}

LOCAL void checked_perlio_read(PerlIO * io, void *buffer,
                               SSize_t size)
{
    SSize_t read = PerlIO_read(io, buffer, size);
    check_perlio_result(read, size, "PerlIO_read");
}

LOCAL void check_perlio_result(SSize_t result, SSize_t expected,
                               char *op)
{
    if (result < 0) {
        croak("%s operation failed: %s\n", op, strerror(result));
    } else if (result != expected) {
        croak(
            "%s operation wrote %zd bytes when we expected to write %zd",
            op, result, expected);
    }
}


LOCAL char *status_error_message(MMDBW_status status)
{
    switch (status) {
    case MMDBW_SUCCESS:
        return "Success";
    case MMDBW_FINDING_NODE_ERROR:
        return
            "Error finding node. Did you try inserting into an aliased network?";
    case MMDBW_ALIAS_OVERWRITE_ATTEMPT_ERROR:
        return "Attempted to overwrite an alised network.";
    }
    // We should get a compile time warning if an enum is missing
    return "Unknown error";
}
