Libs/TokenizerLib/src/Gpt2Tokenizer.cs

// This code is forked from microsoft/Tokenizer
// The original code is licensed under the MIT License. It can be download from this link.
// https://github.com/microsoft/Tokenizer/blob/858c5155997237088f4f24d1b0f732ea84224215/Tokenizer_C%23/TokenizerLib/TikTokenizer.cs
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Reflection;
using PSOpenAI.TokenizerLib.Utils;
 
namespace PSOpenAI.TokenizerLib
{
    /// <summary>
    /// This is a C# implementation of OpenAI's tiktoken implementation of
    /// Byte pair encoding(BPE): https://en.wikipedia.org/wiki/Byte_pair_encoding,
    /// the goal is to support context tokenization for OpenAI large language models
    /// in .NET runtime.
    /// Reference: https://github.com/openai/tiktoken/blob/main/src/lib.rs
    /// </summary>
    public static class Gpt2Tokenizer
    {
        private static readonly string s_bpeFile = @"gpt2.tiktoken";
        private static readonly IReadOnlyDictionary<string, int> SpecialTokensEncoder = new Dictionary<string, int>{
            { "<|endoftext|>", 50256}
        };
        private static readonly Regex s_encodingRegex = new Regex(
            @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+",
            RegexOptions.Compiled);
        private static readonly IReadOnlyDictionary<byte[], int> Encoder = null!;
        private static readonly IReadOnlyDictionary<int, byte[]> Decoder = null!;
        private static readonly Regex SpecialTokensRegex = new Regex(string.Join("|", SpecialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
        private static readonly IReadOnlyDictionary<int, string> SpecialTokensDecoder = SpecialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
        private static readonly LruCache<string, int[]> Cache = new LruCache<string, int[]>(4096);
 
        // Init
        static Gpt2Tokenizer()
        {
            Encoder = ReadBpeFile();
            Decoder = Encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
        }
 
        // Load BPE rank dictionary from a file.
        private static Dictionary<byte[], int> ReadBpeFile()
        {
            var assemblyDirectory = Path.GetDirectoryName((Assembly.GetExecutingAssembly().Location));
            var bpePath = Path.Combine(assemblyDirectory, s_bpeFile);
            var bpeDict = new Dictionary<byte[], int>(new ByteArrayComparer());
            try
            {
                using (StreamReader reader = new StreamReader(bpePath))
                {
                    while (!reader.EndOfStream)
                    {
                        string line = reader.ReadLine();
                        if (string.IsNullOrWhiteSpace(line))
                        {
                            continue;
                        }
 
                        var tokens = line.Split(' ');
                        if (tokens.Length != 2)
                        {
                            throw new FormatException($"Invalid format in the BPE encoder file stream");
                        }
 
                        var tokenBytes = Convert.FromBase64String(tokens[0]);
                        int rank = 0;
                        if (int.TryParse(tokens[1], out rank))
                        {
                            bpeDict[tokenBytes] = rank;
                        }
                        else
                        {
                            throw new FormatException($"Can't parse {tokens[1]} to integer");
                        }
                    }
                }
            }
            catch (Exception ex)
            {
                throw new InvalidOperationException($"Failed to load from BPE encoder file stream: {ex.Message}", ex);
            }
 
            return bpeDict;
        }
 
        //Encode a string
        public static List<int> Encode(string text)
        {
            var allowedSpecial = new List<string>();
            return Encode(text, allowedSpecial);
        }
 
        // Encode a string with a set of allowed special tokens that are not broken apart.
        public static List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
        {
            var tokenIds = new List<int>();
            int start = 0;
            while (true)
            {
                Match nextSpecial;
                int end;
                FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end);
                if (end > start)
                {
                    Encode(text, tokenIds, start, end);
                }
 
                if (nextSpecial.Success)
                {
                    start = EncodeSpecialToken(tokenIds, nextSpecial);
                    if (start >= text.Length)
                    {
                        break;
                    }
                }
                else
                {
                    break;
                }
            }
 
            return tokenIds;
        }
 
        // Encode a special token matched through regex.
        private static int EncodeSpecialToken(List<int> tokenIds, Match nextSpecial)
        {
            var token = SpecialTokensEncoder[nextSpecial.Value];
            tokenIds.Add(token);
            return nextSpecial.Index + nextSpecial.Length;
        }
 
        // Search for special token in a string
        private static void FindNextSpecialToken(string text, IReadOnlyCollection<string> allowedSpecial, int start, out Match nextSpecial, out int end)
        {
            int startFind = start;
            while (true)
            {
                nextSpecial = SpecialTokensRegex.Match(text, startFind);
                if (!nextSpecial.Success) break;
                if (allowedSpecial.Contains(text.Substring(nextSpecial.Index, nextSpecial.Length))) break;
                startFind = nextSpecial.Index + 1;
            }
            end = nextSpecial.Success ? nextSpecial.Index : text.Length;
        }
 
        // Encode a string based between start and end index
        private static void Encode(string text, List<int> tokenIds, int start, int end)
        {
            foreach (Match match in s_encodingRegex.Matches(text[start..end]))
            {
                if (Cache.Lookup(match.Value, out int[] tokens))
                {
                    tokenIds.AddRange(tokens);
                }
                else
                {
                    //cache miss
                    var bytes = Encoding.UTF8.GetBytes(match.Value);
                    if (Encoder.TryGetValue(bytes, out int token))
                    {
                        tokenIds.Add(token);
                    }
                    else
                    {
                        var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
                        tokenIds.AddRange(encodedTokens);
                        Cache.Add(match.Value, encodedTokens.ToArray());
                    }
                }
            }
        }
 
        // Encode a string from start index to end index based on max token count,
        private static (int TokenCount, int EncodeLength) EncodeTrimSuffix(string text, List<int> tokenIds, int start, int end, int maxTokenCount, int tokenCount, int encodeLength)
        {
            foreach (Match match in s_encodingRegex.Matches(text[start..end]))
            {
                var piece = match.Value;
                if (Cache.Lookup(piece, out int[] tokens))
                {
                    tokenCount += tokens.Length;
                    if (tokenCount <= maxTokenCount)
                    {
                        encodeLength += piece.Length;
                        tokenIds.AddRange(tokens);
                    }
                    else
                    {
                        break;
                    }
                }
                else
                {
                    //cache miss
                    var bytes = Encoding.UTF8.GetBytes(piece);
                    if (Encoder.TryGetValue(bytes, out int token))
                    {
                        tokenCount++;
                        if (tokenCount <= maxTokenCount)
                        {
                            encodeLength += piece.Length;
                            tokenIds.Add(token);
                        }
                        else
                        {
                            break;
                        }
                    }
                    else
                    {
                        var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
                        Cache.Add(piece, encodedTokens.ToArray());
                        tokenCount += encodedTokens.Count;
                        if (tokenCount <= maxTokenCount)
                        {
                            encodeLength += piece.Length;
                            tokenIds.AddRange(encodedTokens);
                        }
                        else
                        {
                            break;
                        }
                    }
                }
                if (tokenCount >= maxTokenCount) break;
            }
            return (tokenCount, encodeLength);
        }
 
        // Encode a piece of text limited by max token count through trimming suffix
        public static (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
        {
            var tokenIds = new List<int>();
 
            int start = 0;
            int tokenCount = 0;
            var encodeLength = 0;
            while (true)
            {
                Match nextSpecial;
                int end;
                FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end);
 
                if (end > start)
                {
                    (tokenCount, encodeLength) = EncodeTrimSuffix(text, tokenIds, start, end, maxTokenCount, tokenCount, encodeLength);
 
                    if (tokenCount >= maxTokenCount)
                    {
                        break;
                    }
                }
 
                if (nextSpecial.Success)
                {
                    tokenCount++;
                    if (tokenCount <= maxTokenCount)
                    {
                        start = EncodeSpecialToken(tokenIds, nextSpecial);
                        encodeLength += nextSpecial.Value.Length;
                        if (start >= text.Length)
                        {
                            break;
                        }
                    }
                    if (tokenCount >= maxTokenCount)
                    {
                        break;
                    }
                }
                else
                {
                    break;
                }
            }
 
            var encodedText = encodeLength == text.Length ? text : text[..encodeLength];
 
            return (tokenIds, encodedText);
        }
 
        // Encode a piece of text limited by max token count through trimming prefix
        public static (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
        {
            var tokenIds = new List<int>();
 
            int start = 0;
            int tokenCount = 0;
            var encodeLength = 0;
            var tokenCountMap = new SortedDictionary<int, int>();
            tokenCountMap.Add(tokenCount, encodeLength);
            while (true)
            {
                Match nextSpecial;
                int end;
                FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end);
 
                if (end > start)
                {
                    foreach (Match match in s_encodingRegex.Matches(text[start..end]))
                    {
                        var piece = match.Value;
 
                        if (Cache.Lookup(match.Value, out int[] tokens))
                        {
                            tokenCount += tokens.Length;
                            encodeLength += piece.Length;
                            tokenIds.AddRange(tokens);
                            tokenCountMap[tokenCount] = encodeLength;
                        }
                        else
                        {
                            var bytes = Encoding.UTF8.GetBytes(piece);
                            if (Encoder.TryGetValue(bytes, out int token))
                            {
                                tokenCount++;
                                encodeLength += piece.Length;
                                tokenIds.Add(token);
                                tokenCountMap[tokenCount] = encodeLength;
 
                            }
                            else
                            {
                                var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
                                Cache.Add(piece, encodedTokens.ToArray());
                                tokenCount += encodedTokens.Count;
                                encodeLength += piece.Length;
                                tokenIds.AddRange(encodedTokens);
                                tokenCountMap[tokenCount] = encodeLength;
                            }
                        }
                    }
                }
 
                if (nextSpecial.Success)
                {
                    start = EncodeSpecialToken(tokenIds, nextSpecial);
                    tokenCount++;
                    encodeLength += nextSpecial.Value.Length;
                    tokenCountMap[tokenCount] = encodeLength;
                    if (start >= text.Length)
                    {
                        break;
                    }
                }
                else
                {
                    break;
                }
            }
 
            if (tokenCount <= maxTokenCount)
            {
                return (tokenIds, text);
            }
 
            var prefixTokenCount = tokenCount - maxTokenCount;
            var actualPrefixTokenCount = 0;
            var actualPrefixStrLength = 0;
            foreach (var pair in tokenCountMap)
            {
                if (pair.Key >= prefixTokenCount)
                {
                    actualPrefixTokenCount = pair.Key;
                    actualPrefixStrLength = pair.Value;
                    break;
                }
            }
 
            return (tokenIds.Skip(actualPrefixTokenCount).ToList(), text[actualPrefixStrLength..]);
        }
 
        // Decode an array of integer token ids
        public static string Decode(int[] tokens)
        {
            var decoded = new List<byte>(tokens.Length * 2);
            foreach (var token in tokens)
            {
                byte[] tokenBytes = { };
                if (Decoder.TryGetValue(token, out var value))
                {
                    tokenBytes = value;
                }
                else if (SpecialTokensDecoder.TryGetValue(token, out var specialTokenValue))
                {
                    tokenBytes = Encoding.UTF8.GetBytes(specialTokenValue);
                }
                decoded.AddRange(tokenBytes);
            }
 
            return Encoding.UTF8.GetString(decoded.ToArray());
        }
    }
}