diff --git a/include/crucible/cache.h b/include/crucible/cache.h index 60d2a7f..767b790 100644 --- a/include/crucible/cache.h +++ b/include/crucible/cache.h @@ -18,7 +18,7 @@ namespace crucible { public: using Key = tuple; using Func = function; - using Time = unsigned; + using Time = size_t; using Value = pair; private: Func m_fn; @@ -28,7 +28,7 @@ namespace crucible { size_t m_max_size; mutex m_mutex; - void check_overflow(); + bool check_overflow(); public: LRUCache(Func f = Func(), size_t max_size = 100); @@ -52,21 +52,24 @@ namespace crucible { } template - void + bool LRUCache::check_overflow() { - if (m_map.size() <= m_max_size) return; - vector> map_contents; - map_contents.reserve(m_map.size()); - for (auto i : m_map) { - map_contents.push_back(make_pair(i.first, i.second.first)); + if (m_map.size() <= m_max_size) { + return false; } - sort(map_contents.begin(), map_contents.end(), [](const pair &a, const pair &b) { + vector> key_times; + key_times.reserve(m_map.size()); + for (auto i : m_map) { + key_times.push_back(make_pair(i.first, i.second.first)); + } + sort(key_times.begin(), key_times.end(), [](const pair &a, const pair &b) { return a.second < b.second; }); - for (size_t i = 0; i < map_contents.size() / 2; ++i) { - m_map.erase(map_contents[i].first); + for (size_t i = 0; i < key_times.size() / 2; ++i) { + m_map.erase(key_times[i].first); } + return true; } template @@ -141,9 +144,14 @@ namespace crucible { // We hold a lock on this key so we are the ones to insert it THROW_CHECK0(runtime_error, inserted); - // Release key lock and clean out overflow + // Release key lock, keep the cache lock key_lock.unlock(); - check_overflow(); + + // Check to see if we have too many items and reduce if so. + if (check_overflow()) { + // Reset iterator + found = m_map.find(k); + } } } @@ -207,7 +215,12 @@ namespace crucible { // Release key lock and clean out overflow key_lock.unlock(); - check_overflow(); + + // Check to see if we have too many items and reduce if so. + if (check_overflow()) { + // Reset iterator + found = m_map.find(k); + } } }