Libs/TokenizerLib/src/Utils/BytePairEncoder.cs

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Collections.Generic;
 
namespace PSOpenAI.TokenizerLib.Utils
{
    /// <summary>
    /// This class implements the byte pair encoding algorithm.
    /// </summary>
    internal class BytePairEncoder
    {
 
        public static List<int> BytePairEncode(byte[] mergingBytes, IReadOnlyDictionary<byte[], int> ranks)
        {
            if (mergingBytes.Length == 1)
            {
                return new List<int> { ranks[mergingBytes] };
            }
 
            var byteIndicesAndRanks = new List<(int Index, int Rank)>();
            for (int i = 0; i < mergingBytes.Length + 1; i++)
            {
                byteIndicesAndRanks.Add((i, int.MaxValue));
            }
            int GetRank(int startIndex, int skip = 0)
            {
                if (startIndex + skip + 2 < byteIndicesAndRanks.Count)
                {
                    var slice = mergingBytes[byteIndicesAndRanks[startIndex].Index..byteIndicesAndRanks[startIndex + skip + 2].Index];
                    if (ranks.TryGetValue(slice, out var rank))
                    {
                        return rank;
                    }
                }
                return int.MaxValue;
            }
            for (int i = 0; i < byteIndicesAndRanks.Count - 2; i++)
            {
                var rank = GetRank(i);
                if (rank != int.MaxValue)
                {
                    byteIndicesAndRanks[i] = (byteIndicesAndRanks[i].Index, rank);
                }
            }
            while (byteIndicesAndRanks.Count > 1)
            {
                var minRank = (Index: 0, Rank: int.MaxValue);
                for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++)
                {
                    if (byteIndicesAndRanks[i].Rank < minRank.Rank)
                    {
                        minRank = (i, byteIndicesAndRanks[i].Rank);
                    }
                }
                if (minRank.Rank != int.MaxValue)
                {
                    int j = minRank.Index;
                    byteIndicesAndRanks[j] = (byteIndicesAndRanks[j].Index, GetRank(j, 1));
                    if (j > 0)
                    {
                        byteIndicesAndRanks[j - 1] = (byteIndicesAndRanks[j - 1].Index, GetRank(j - 1, 1));
                    }
                    byteIndicesAndRanks.RemoveAt(j + 1);
                }
                else
                {
                    break;
                }
            }
            var outList = new List<int>(byteIndicesAndRanks.Count - 1);
            for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++)
            {
                outList.Add(ranks[mergingBytes[byteIndicesAndRanks[i].Index..byteIndicesAndRanks[i + 1].Index]]);
            }
            return outList;
        }
    }
}