diff --git a/concurrent_map.go b/concurrent_map.go index baccab0..f789c9c 100644 --- a/concurrent_map.go +++ b/concurrent_map.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" ) var SHARD_COUNT = 32 @@ -18,6 +19,7 @@ type Stringer interface { type ConcurrentMap[K comparable, V any] struct { shards []*ConcurrentMapShared[K, V] sharding func(key K) uint32 + count *int32 } // A "thread" safe string to anything map. @@ -29,6 +31,7 @@ type ConcurrentMapShared[K comparable, V any] struct { func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] { m := ConcurrentMap[K, V]{ sharding: sharding, + count: new(int32), shards: make([]*ConcurrentMapShared[K, V], SHARD_COUNT), } for i := 0; i < SHARD_COUNT; i++ { @@ -59,10 +62,7 @@ func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] { func (m ConcurrentMap[K, V]) MSet(data map[K]V) { for key, value := range data { - shard := m.GetShard(key) - shard.Lock() - shard.items[key] = value - shard.Unlock() + m.Set(key, value) } } @@ -70,9 +70,16 @@ func (m ConcurrentMap[K, V]) MSet(data map[K]V) { func (m ConcurrentMap[K, V]) Set(key K, value V) { // Get map shard. shard := m.GetShard(key) + b := len(shard.items) shard.Lock() shard.items[key] = value + // if add an item added is true + added := len(shard.items) != b shard.Unlock() + if added { + atomic.AddInt32(m.count, 1) + } + } // Callback to return new element to be inserted into the map @@ -102,6 +109,9 @@ func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool { shard.items[key] = value } shard.Unlock() + if !ok { + atomic.AddInt32(m.count, 1) + } return !ok } @@ -118,14 +128,7 @@ func (m ConcurrentMap[K, V]) Get(key K) (V, bool) { // Count returns the number of elements within the map. func (m ConcurrentMap[K, V]) Count() int { - count := 0 - for i := 0; i < SHARD_COUNT; i++ { - shard := m.shards[i] - shard.RLock() - count += len(shard.items) - shard.RUnlock() - } - return count + return int(atomic.LoadInt32(m.count)) } // Looks up an item under specified key @@ -143,9 +146,13 @@ func (m ConcurrentMap[K, V]) Has(key K) bool { func (m ConcurrentMap[K, V]) Remove(key K) { // Try to get shard. shard := m.GetShard(key) + if _, ok := m.Get(key); !ok { + return + } shard.Lock() delete(shard.items, key) shard.Unlock() + atomic.AddInt32(m.count, -1) } // RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held diff --git a/concurrent_map_bench_test.go b/concurrent_map_bench_test.go index 50cd075..8869c7d 100644 --- a/concurrent_map_bench_test.go +++ b/concurrent_map_bench_test.go @@ -292,7 +292,6 @@ func BenchmarkMultiGetSetBlock_256_Shard(b *testing.B) { runWithShards(benchmarkMultiGetSetBlock, b, 256) } - func GetSet[K comparable, V any](m ConcurrentMap[K, V], finished chan struct{}) (set func(key K, value V), get func(key K, value V)) { return func(key K, value V) { for i := 0; i < 10; i++ { @@ -341,3 +340,39 @@ func BenchmarkKeys(b *testing.B) { m.Keys() } } + +func BenchmarkCount(b *testing.B) { + m := New[Animal]() + + // Insert 100 elements. + for i := 0; i < 10000; i++ { + m.Set(strconv.Itoa(i), Animal{strconv.Itoa(i)}) + } + for i := 0; i < b.N; i++ { + m.Count() + } +} + +func BenchmarkRemoveExists(b *testing.B) { + m := New[Animal]() + + // Insert 100 elements. + for i := 0; i < 10000; i++ { + m.Set(strconv.Itoa(i), Animal{strconv.Itoa(i)}) + } + for i := 0; i < b.N; i++ { + m.Remove(strconv.Itoa(i)) + } +} + +func BenchmarkRemoveNotExists(b *testing.B) { + m := New[Animal]() + + // Insert 100 elements. + for i := 0; i < 10000; i++ { + m.Set(strconv.Itoa(i), Animal{strconv.Itoa(i)}) + } + for i := 0; i < b.N; i++ { + m.Remove(strconv.Itoa(0)) + } +}