diff --git a/src/lfc/collections/map.c b/src/lfc/collections/map.c index 497b80c..501d5c0 100644 --- a/src/lfc/collections/map.c +++ b/src/lfc/collections/map.c @@ -6,6 +6,8 @@ #include "internals/collections/__mapbucket.h" #include "lfc/collections/array.h" +#include "lfc/collections/set.h" + #include "lfc/utils/pair.h" #include "lfc/utils/panic.h" @@ -31,21 +33,19 @@ void hashmap_free(hashmap_t* map, free_fn_t key_free, free_fn_t val_free) { array_free(&map->buckets, NULL); } -size_t hashmap_bucket(hashmap_t* map, void* key) { +size_t __hashmap_bucket(hashmap_t* map, void* key) { return (map->hash_fn)(key) % map->buckets.len; } void* hashmap_get(hashmap_t* map, void* key) { - size_t bucket_index = hashmap_bucket(map, key); + size_t bucket_index = __hashmap_bucket(map, key); struct __mapbucket* bucket = array_at(&map->buckets, bucket_index); return __mapbucket_find(bucket, key, map->key_eq); } -// TODO: rehash - uint8_t hashmap_set(hashmap_t* map, void* target_key, void* new_value, free_fn_t val_free) { - size_t bucket_index = hashmap_bucket(map, target_key); + size_t bucket_index = __hashmap_bucket(map, target_key); struct __mapbucket* bucket = array_at(&map->buckets, bucket_index); for (struct __mapbucket_node* node = bucket->head; node != NULL; node = node->next) { @@ -73,9 +73,30 @@ uint8_t hashmap_contains(hashmap_t* map, void* key) { return hashmap_get(map, key) != NULL; } +void __hashmap_rehash(hashmap_t* map) { + array_t old_buckets = map->buckets; + hashmap_init(map, old_buckets.len * 2, map->hash_fn, map->key_eq); + + for (size_t i = 0; i < old_buckets.len; i++) { + struct __mapbucket* bucket = array_at(&old_buckets, i); + + for (struct __mapbucket_node* node = bucket->head; node != NULL; node = node->next) { + hashmap_insert(map, node->data.first, node->data.second); + } + + __mapbucket_free(bucket, NULL, NULL); + } + + array_free(&old_buckets, NULL); +} + uint8_t hashmap_insert(hashmap_t* map, void* key, void* val) { if (!hashmap_contains(map, key)) { - size_t bucket_index = hashmap_bucket(map, key); + if (hashmap_load_factor(map) > MAX_LOAD_FACTOR) { + __hashmap_rehash(map); + } + + size_t bucket_index = __hashmap_bucket(map, key); struct __mapbucket* bucket = array_at(&map->buckets, bucket_index); __mapbucket_prepend(bucket, key, val); @@ -88,7 +109,7 @@ uint8_t hashmap_insert(hashmap_t* map, void* key, void* val) { } void hashmap_remove(hashmap_t* map, void* key, free_fn_t key_free, free_fn_t val_free) { - size_t bucket_index = hashmap_bucket(map, key); + size_t bucket_index = __hashmap_bucket(map, key); struct __mapbucket* bucket = array_at(&map->buckets, bucket_index); if (__mapbucket_remove(bucket, key, map->key_eq, key_free, val_free)) { diff --git a/src/lfc/collections/tests/map_tests.c b/src/lfc/collections/tests/map_tests.c index 1eedb80..6a43c81 100644 --- a/src/lfc/collections/tests/map_tests.c +++ b/src/lfc/collections/tests/map_tests.c @@ -139,7 +139,7 @@ void test_hashmap_multiple_distinct_values_inserted_correctly() { assert(hashmap_contains(&map, keys + i % str_len)); assert_eq(map.size, i + 1); - assert_eq(hashmap_load_factor(&map), (i + 1.0) / DEFAULT_BUCKETS); // TODO: update for rehash + assert_eq(hashmap_load_factor(&map), (i + 1.0) / map.buckets.len); assert_false(hashmap_is_empty(&map)); }