Libs/TokenizerLib/src/Utils/LRUCache.cs
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. using System.Collections.Generic; namespace PSOpenAI.TokenizerLib.Utils { internal class LruCache<TKey, TValue> { /// <summary> /// The default LRU cache size. /// </summary> public const int DefaultCacheSize = 4096; private readonly object _lockObject = new object(); private class CacheItem { public readonly TKey Key; public TValue Value; public CacheItem(TKey key, TValue value) { Key = key; Value = value; } } private readonly Dictionary<TKey, LinkedListNode<CacheItem>> _cache; private readonly LinkedList<CacheItem> _lruList; private readonly int _cacheSize; /// <summary> /// Constructs an <see cref="LruCache{TKey,TValue}" /> object. /// </summary> /// <param name="cacheSize"> /// The maximum number of <typeparamref name="TKey" /> to <typeparamref name="TValue" /> mappings /// that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to /// <value>4096</value> /// . /// </param> public LruCache(int cacheSize = DefaultCacheSize) { _cache = new Dictionary<TKey, LinkedListNode<CacheItem>>(); _lruList = new LinkedList<CacheItem>(); _cacheSize = cacheSize; } /// <summary> /// Retrieves the value associated with the specified <paramref name="key" /> object. /// </summary> /// <param name="key">The object to be used as a key.</param> /// <param name="value"> /// An <code>out</code> parameter that is set to the value of the <see cref="key" /> if /// <paramref name="key" /> contains a mapping in the cache. /// </param> /// <returns> /// <code>true</code> if the cache contains a mapping for <paramref name="key" />, <code>false</code> otherwise. /// </returns> public bool Lookup(TKey key, out TValue value) { lock (_lockObject) { LinkedListNode<CacheItem> cached; if (_cache.TryGetValue(key, out cached)) { _lruList.Remove(cached); _lruList.AddFirst(cached); value = cached.Value.Value; return true; } value = default!; return false; } } protected virtual void OnEviction(TValue evictedValue) { } private void EvictIfNeeded() { while (_cache.Count >= _cacheSize) { LinkedListNode<CacheItem> nodeToEvict = _lruList.Last; _lruList.RemoveLast(); _cache.Remove(nodeToEvict.Value.Key); OnEviction(nodeToEvict.Value.Value); } } /// <summary> /// Adds or replaces a mapping in the cache. /// </summary> /// <param name="key">The key whose mapped <paramref name="value" /> is to be created or replaced.</param> /// <param name="value">The new value to be mapped to the <paramref name="key" />.</param> public void Add(TKey key, TValue value) => Replace(key, value, out _); public bool Replace(TKey key, TValue value, out TValue oldValue) { lock (_lockObject) { return ReplaceInternal(key, value, out oldValue); } } private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue) { if (_cache.TryGetValue(key, out LinkedListNode<CacheItem> cached)) { oldValue = cached.Value.Value; cached.Value.Value = value; _lruList.Remove(cached); _lruList.AddFirst(cached); return true; } EvictIfNeeded(); var node = new LinkedListNode<CacheItem>(new CacheItem(key, value)); _cache[key] = node; _lruList.AddFirst(node); oldValue = default!; return false; } /// <summary> /// The number of entries currently present in the cache. /// </summary> public int Count => _cache.Count; /// <summary> /// Clears the contents of this cache. /// </summary> public void Clear() { _cache.Clear(); _lruList.Clear(); } } } |