ska::flat_hash_map 源码分析

Posted by 111qqz on Saturday, August 21, 2021

TOC

背景

最近在调研各种hashmap.. 发现ska::flat hash map性能优秀。。于是来看看代码。。 发现最大的特点是,ska::flat_hash_map使用了带probe count上限的robin hood hashing

相关概念

Distance_from_desired

对于采用了open addressing的hash实现,当插入发生冲突时,会以一定方式(如线性探测、平方探测等)来探测下一个可以插入的slot. 因而实际插入的slot位置与理想的slot位置通常不相同,这段距离定义为distance_from_desired 在没有冲突的理想情况下,所有distance_from_desired的值应该都为0 distance_from_desired的一种更常见的说法叫做probe sequence lengths(PSL)

robin hood hashing

robin hood hashing的核心思想是"劫富济贫” distance_from_desired小的slot被认为更"富有”,distance_from_desired大的slot被认为更"贫穷” 具体来说,当去插入一个新的元素时,如果当前位置的元素的distance_from_desired要比待插入元素的distance_from_desired要小,那么就将待插入元素放入当前位置,将当前位置的元素取出,寻找一个新的位置。

这样做使得所有元素的distance_from_desired的分布更为平均,variance更小。 这样的分布对cache更友好(几乎全部元素distance_from_desired都小于一个cache line的长度,因此在find的时候只需要fetch一次cache line),从而拥有更好的性能。

一般的robin hashing 在find时,一般用一个全局的最大distance_from_desired作为没有找到该元素终止条件。 一种常见的改进是,不维护全局最大distance_from_desired,而是在看到当前位置元素的distance_from_desired比要插入的元素的distance_from_desired小时终止。


  iterator find(const FindKey& key) {
    size_t index =
        hash_policy.index_for_hash(hash_object(key), num_slots_minus_one);
    EntryPointer it = entries + ptrdiff_t(index);
    for (int8_t distance = 0; it->distance_from_desired >= distance;
         ++distance, ++it) {
      if (compares_equal(key, it->value)) return {it};
    }
    return end();
  }

带上限的robin hashing

一般的robin hashing在insert时,会不断进行寻找(包括了可能的swap过程),直到找到一个空的slot为止。该过程在hash table较满时可能接近线性的时间复杂度。 ska::flat_hash_map对这一点的改进是,限制了insert时尝试的上限次数,作者给出的经验值为log(N),其中N为slots的个数。 这样保证每个slot的最大distance_from_desired不会超过log(N)

关键实现

emplace

插入一个元素 分析见注释 其中emplace 函数主要是查找是否已经存在该元素+调整到合适的插入位置 emplace_new_key函数执行真正的emplace操作


  template <typename Key, typename... Args>
  std::pair<iterator, bool> emplace(Key&& key, Args&&... args) {
    size_t index =
        hash_policy.index_for_hash(hash_object(key), num_slots_minus_one);
    EntryPointer current_entry = entries + ptrdiff_t(index);
    int8_t distance_from_desired = 0;
    // 插入前先查找是否存在。。。
    // 只需要查找有限的距离

    // trick在于。。初始 current_entry->distance_from_desired为-1
    // 此时不会进入for loop,直接进行emplace_new_key。
    // 该for loop有两层意义: 1.在index位置不为空时找到合适的位置: 空的slot或者更富有的slot(也就是current_entry->distance_from_desired < distance_from_desired的slot)
    // 2.在该过程中找一下是否已经插入了该值
    for (; current_entry->distance_from_desired >= distance_from_desired;
         ++current_entry, ++distance_from_desired) {
      if (compares_equal(key, current_entry->value))
        return {{current_entry}, false};
    }
    return emplace_new_key(distance_from_desired, current_entry,
                           std::forward<Key>(key), std::forward<Args>(args)...);
  }


  template <typename Key, typename... Args>
  SKA_NOINLINE(std::pair<iterator, bool>)
  emplace_new_key(int8_t distance_from_desired, EntryPointer current_entry,
                  Key&& key, Args&&... args) {
    using std::swap;
    // num_slots_minus_one初始值为0,表示第一次进行插入。。需要先进行grow..很合理。。
    // 如果得到max_load_factor,或者查找次数达到max_lookups,就进行rehash
    if (num_slots_minus_one == 0 || distance_from_desired == max_lookups ||
        num_elements + 1 >
            (num_slots_minus_one + 1) * static_cast<double>(_max_load_factor)) {
      grow();
      return emplace(std::forward<Key>(key), std::forward<Args>(args)...);
    } else if (current_entry->is_empty()) {
      current_entry->emplace(distance_from_desired, std::forward<Key>(key),
                             std::forward<Args>(args)...);
      ++num_elements;
      return {{current_entry}, true};
    }

    // 执行到这里,说明有更富有的slot。于是进行swap,转而为被换出的pair<key,value>找一个新的slot
    // to_insert是当前要插入的,由于swap的发生,可能并不是最初要插入的那一对值
    value_type to_insert(std::forward<Key>(key), std::forward<Args>(args)...);
    swap(distance_from_desired, current_entry->distance_from_desired);
    swap(to_insert, current_entry->value);
    iterator result = {current_entry};
    for (++distance_from_desired, ++current_entry;; ++current_entry) {
      if (current_entry->is_empty()) {
        // 如果被换过的slot后面某个slot是空的。。就直接放置了
        current_entry->emplace(distance_from_desired, std::move(to_insert));
        ++num_elements;
        return {result, true};
      } else if (current_entry->distance_from_desired < distance_from_desired) {
        // 在找新slot的过程中,仍然进行劫富济贫的操作, 转而为被换出的pair<key,value>找一个新的slot
        swap(distance_from_desired, current_entry->distance_from_desired);
        swap(to_insert, current_entry->value);
        ++distance_from_desired;
      } else {
        // 如果没有空的slot,也没有更富有的slot,那就只能继续往前寻找了,直到达到上限
        ++distance_from_desired;
        if (distance_from_desired == max_lookups) {
          // 如果找了max_lookups个位置还没找到,就进行rehash
          swap(to_insert, result.current->value);
          grow();
          return emplace(std::move(to_insert));
        }
      }
    }
  }

rehash

  void rehash(size_t num_buckets) {
    num_buckets = std::max(
        num_buckets,
        static_cast<size_t>(
            std::ceil(num_elements / static_cast<double>(_max_load_factor))));
    if (num_buckets == 0) {
      reset_to_empty_state();
      return;
    }
    auto new_prime_index = hash_policy.next_size_over(num_buckets);
    if (num_buckets == bucket_count()) return;
    int8_t new_max_lookups = compute_max_lookups(num_buckets);
    // 额外分配了max_lookups个slots,避免了find时bound checking的开销
    EntryPointer new_buckets(
        AllocatorTraits::allocate(*this, num_buckets + new_max_lookups));
    EntryPointer special_end_item =
        new_buckets + static_cast<ptrdiff_t>(num_buckets + new_max_lookups - 1);
    for (EntryPointer it = new_buckets; it != special_end_item; ++it)
      it->distance_from_desired = -1;
    special_end_item->distance_from_desired = Entry::special_end_value;
    std::swap(entries, new_buckets);
    std::swap(num_slots_minus_one, num_buckets);
    --num_slots_minus_one;
    hash_policy.commit(new_prime_index);
    int8_t old_max_lookups = max_lookups;
    max_lookups = new_max_lookups;
    num_elements = 0;
    // new_buckets其实是旧的entries..
    // num_buckets其实也是旧的值。。因为已经被swap了
    for (EntryPointer
             it = new_buckets,
             end = it + static_cast<ptrdiff_t>(num_buckets + old_max_lookups);
         it != end; ++it) {
      if (it->has_value()) {
        emplace(std::move(it->value));
        it->destroy_value();
      }
    }
    deallocate_data(new_buckets, num_buckets, old_max_lookups);
  }

一些其他trick

通过多分配log(N)的slot消除bound checking的开销

代码见上面的rehash部分 由于每个元素的最大distance_from_desired不会超过log(N),因此可以保证查找时不需要做Bound checking. 使得find部分的实现非常简洁:

使用素数个slot而不是2的整数次幂个slot

2的整数次幂个slot是一种很常见的实现。 这种实现的主要好处是在将hash转换为index时,避免了代价高昂取模操作,而是用代价很小的按位与(&)替代

但是使用2的整数次幂个slot的缺点是,取模后得到的结果较少,比起使用素数个slot更容易发生冲突。 ska::flat_hash_map的作者借鉴了boost::multi_index中的做法,将变量展开为compile time const(对compile time const做取模运算要远远快于变量的取模运算),从而减小了这部分开销的影响。


struct prime_number_hash_policy {
  static size_t mod0(size_t) {
    return 0llu;
  }
  static size_t mod2(size_t hash) {
    return hash % 2llu;
  }
  static size_t mod3(size_t hash) {
    return hash % 3llu;
  }
  static size_t mod5(size_t hash) {
    return hash % 5llu;
  }
  static size_t mod7(size_t hash) {
    return hash % 7llu;
  }
  static size_t mod11(size_t hash) {
    return hash % 11llu;
  }
  static size_t mod13(size_t hash) {
    return hash % 13llu;
  }
  static size_t mod17(size_t hash) {
    return hash % 17llu;
  }
  static size_t mod23(size_t hash) {
    return hash % 23llu;
  }
  static size_t mod29(size_t hash) {
    return hash % 29llu;
  }
  static size_t mod37(size_t hash) {
    return hash % 37llu;
  }
  static size_t mod47(size_t hash) {
    return hash % 47llu;
  }
  ...