RecordRealTimeSpeechToText.cs

using System.Management.Automation;
using NAudio.Wave;
using Whisper.net;
using Whisper.net.Ggml;
using System.Management;
using System.Collections.Concurrent;
 
namespace GenXdev.Helpers
{
    [Cmdlet(VerbsCommunications.Receive, "RealTimeSpeechToText")]
    public class RecordRealTimeSpeechToText : Cmdlet
    {
        #region Cmdlet Parameters
        [Parameter(Mandatory = true, HelpMessage = "Path to the model file")]
        public string ModelFilePath { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to use desktop audio capture instead of microphone")]
        public SwitchParameter UseDesktopAudioCapture { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Returns objects instead of strings")]
        public SwitchParameter Passthru { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to include token timestamps")]
        public SwitchParameter WithTokenTimestamps { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Sum threshold for token timestamps, defaults to 0.5")]
        public float TokenTimestampsSumThreshold { get; set; } = 0.5f;
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to split on word boundaries")]
        public SwitchParameter SplitOnWord { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of tokens per segment")]
        public int? MaxTokensPerSegment { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to ignore silence")]
        public SwitchParameter IgnoreSilence { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum duration of silence before stopping")]
        public TimeSpan? MaxDurationOfSilence { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Silence detect threshold (0..32767 defaults to 30)")]
        [ValidateRange(0, 32767)]
        public int? SilenceThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Language to detect, defaults to 'en'")]
        public string Language { get; set; } = "en";
 
        [Parameter(Mandatory = false, HelpMessage = "Number of CPU threads, defaults to 0 (auto)")]
        public int CpuThreads { get; set; } = 0;
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature for speech generation")]
        [ValidateRange(0, 1)]
        public float? Temperature { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature increment")]
        [ValidateRange(0, 1)]
        public float? TemperatureInc { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to translate the output")]
        public SwitchParameter WithTranslate { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Prompt to use for the model")]
        public string Prompt { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Regex to suppress tokens from the output")]
        public string SuppressRegex { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to show progress")]
        public SwitchParameter WithProgress { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Size of the audio context")]
        public int? AudioContextSize { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to NOT suppress blank lines")]
        public SwitchParameter DontSuppressBlank { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum duration of the audio")]
        public TimeSpan? MaxDuration { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Offset for the audio")]
        public TimeSpan? Offset { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of last text tokens")]
        public int? MaxLastTextTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to use single segment only")]
        public SwitchParameter SingleSegmentOnly { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to print special tokens")]
        public SwitchParameter PrintSpecialTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum segment length")]
        public int? MaxSegmentLength { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Start timestamps at this moment")]
        public TimeSpan? MaxInitialTimestamp { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Length penalty")]
        [ValidateRange(0, 1)]
        public float? LengthPenalty { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Entropy threshold")]
        [ValidateRange(0, 1)]
        public float? EntropyThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Log probability threshold")]
        [ValidateRange(0, 1)]
        public float? LogProbThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "No speech threshold")]
        [ValidateRange(0, 1)]
        public float? NoSpeechThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Don't use context")]
        public SwitchParameter NoContext { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Use beam search sampling strategy")]
        public SwitchParameter WithBeamSearchSamplingStrategy { get; set; }
        #endregion
 
        private readonly ConcurrentQueue<SegmentData> _results = new();
        private CancellationTokenSource _cts;
        private WaveFileWriter _waveFileWriter;
        private MemoryStream _outputStream;
        private WhisperProcessor _processor;
 
        protected override void BeginProcessing()
        {
            base.BeginProcessing();
            _cts = new CancellationTokenSource();
            _outputStream = new MemoryStream();
            _waveFileWriter = new WaveFileWriter(_outputStream, new WaveFormat(16000, 1));
        }
 
        protected override void ProcessRecord()
        {
            base.ProcessRecord();
 
            // Initialize Whisper
            var ggmlType = GgmlType.LargeV3Turbo;
            var modelFileName = Path.GetFullPath(Path.Combine(ModelFilePath, "ggml-largeV3Turbo.bin"));
 
            if (!File.Exists(modelFileName))
            {
                DownloadModel(modelFileName, ggmlType).GetAwaiter().GetResult();
            }
 
            using var whisperFactory = WhisperFactory.FromPath(modelFileName);
            var builder = ConfigureWhisperBuilder(whisperFactory.CreateBuilder());
 
            _processor = builder.Build();
 
            // Start recording and processing
            using IWaveIn waveIn = UseDesktopAudioCapture ? new WasapiLoopbackCapture() : new WaveInEvent();
            waveIn.WaveFormat = new WaveFormat(16000, 1);
 
            var bufferQueue = new ConcurrentQueue<byte[]>();
            var processingTask = Task.Run(() => ProcessAudioBuffer(bufferQueue));
 
            double seconds = 0;
            double sum = 0;
            long count = 0;
            bool hadAudio = false;
            int threshold = SilenceThreshold ?? 30;
 
            waveIn.DataAvailable += (sender, args) =>
            {
                bufferQueue.Enqueue(args.Buffer[0..args.BytesRecorded]);
                _waveFileWriter.Write(args.Buffer, 0, args.BytesRecorded);
 
                if (MaxDurationOfSilence.HasValue || IgnoreSilence)
                {
                    seconds += args.BytesRecorded / 32000d;
                    count += args.BytesRecorded / 2;
 
                    unsafe
                    {
                        fixed (byte* buffer = args.Buffer)
                        {
                            var floatBuffer = (Int16*)buffer;
                            for (var i = 0; i < args.BytesRecorded / 2; i++)
                            {
                                sum += Math.Abs(floatBuffer[i]);
                            }
                        }
                    }
 
                    if (seconds > 0.85)
                    {
                        var current = (sum / count);
                        if (current > threshold)
                        {
                            hadAudio = true;
                        }
                        if (!hadAudio && MaxDurationOfSilence.HasValue && seconds > MaxDurationOfSilence.Value.TotalSeconds)
                        {
                            _cts.Cancel();
                        }
                        seconds = 0;
                        sum = 0;
                        count = 0;
                        hadAudio = false;
                    }
                }
            };
 
            waveIn.StartRecording();
            Console.WriteLine("Recording started. Press Q to stop...");
 
            var startTime = System.DateTime.UtcNow;
            while (!_cts.IsCancellationRequested)
            {
                if (Console.KeyAvailable && Console.ReadKey(true).Key == ConsoleKey.Q)
                {
                    _cts.Cancel();
                    break;
                }
 
                if (MaxDuration.HasValue && (System.DateTime.UtcNow - startTime) > MaxDuration.Value)
                {
                    _cts.Cancel();
                    break;
                }
 
                while (_results.TryDequeue(out var segment))
                {
                    WriteObject(Passthru ? segment : segment.Text);
                }
 
                Thread.Sleep(100);
            }
 
            waveIn.StopRecording();
            processingTask.Wait();
        }
 
        private WhisperProcessorBuilder ConfigureWhisperBuilder(WhisperProcessorBuilder builder)
        {
            int physicalCoreCount = 0;
            var searcher = new ManagementObjectSearcher("select NumberOfCores from Win32_Processor");
            foreach (var item in searcher.Get())
            {
                physicalCoreCount += Convert.ToInt32(item["NumberOfCores"]);
            }
 
            builder.WithLanguage(Language)
                   .WithThreads(CpuThreads > 0 ? CpuThreads : physicalCoreCount);
 
            if (Temperature.HasValue) builder.WithTemperature(Temperature.Value);
            if (TemperatureInc.HasValue) builder.WithTemperatureInc(TemperatureInc.Value);
            if (WithTokenTimestamps) builder.WithTokenTimestamps().WithTokenTimestampsSumThreshold(TokenTimestampsSumThreshold);
            if (WithTranslate) builder.WithTranslate();
            if (!string.IsNullOrWhiteSpace(Prompt)) builder.WithPrompt(Prompt);
            if (!string.IsNullOrWhiteSpace(SuppressRegex)) builder.WithSuppressRegex(SuppressRegex);
            if (WithProgress) builder.WithProgressHandler(progress => WriteProgress(new ProgressRecord(1, "Processing", $"Progress: {progress}%") { PercentComplete = progress }));
            if (SplitOnWord) builder.SplitOnWord();
            if (MaxTokensPerSegment.HasValue) builder.WithMaxTokensPerSegment(MaxTokensPerSegment.Value);
            if (IgnoreSilence) builder.WithNoSpeechThreshold(0.6f);
            if (AudioContextSize.HasValue) builder.WithAudioContextSize(AudioContextSize.Value);
            if (DontSuppressBlank) builder.WithoutSuppressBlank();
            if (MaxDuration.HasValue) builder.WithDuration(MaxDuration.Value);
            if (Offset.HasValue) builder.WithOffset(Offset.Value);
            if (MaxLastTextTokens.HasValue) builder.WithMaxLastTextTokens(MaxLastTextTokens.Value);
            if (SingleSegmentOnly) builder.WithSingleSegment();
            if (PrintSpecialTokens) builder.WithPrintSpecialTokens();
            if (MaxSegmentLength.HasValue) builder.WithMaxSegmentLength(MaxSegmentLength.Value);
            if (MaxInitialTimestamp.HasValue) builder.WithMaxInitialTs((int)MaxInitialTimestamp.Value.TotalSeconds);
            if (LengthPenalty.HasValue) builder.WithLengthPenalty(LengthPenalty.Value);
            if (EntropyThreshold.HasValue) builder.WithEntropyThreshold(EntropyThreshold.Value);
            if (LogProbThreshold.HasValue) builder.WithLogProbThreshold(LogProbThreshold.Value);
            if (NoSpeechThreshold.HasValue) builder.WithNoSpeechThreshold(NoSpeechThreshold.Value);
            if (NoContext) builder.WithNoContext();
            if (WithBeamSearchSamplingStrategy) builder.WithBeamSearchSamplingStrategy();
 
            return builder;
        }
 
        private async Task ProcessAudioBuffer(ConcurrentQueue<byte[]> bufferQueue)
        {
            using var processingStream = new MemoryStream();
            while (!_cts.IsCancellationRequested)
            {
                if (bufferQueue.TryDequeue(out var buffer))
                {
                    processingStream.Write(buffer, 0, buffer.Length);
                    processingStream.Position = 0;
 
                    await foreach (var segment in _processor.ProcessAsync(processingStream, _cts.Token))
                    {
                        _results.Enqueue(segment);
                    }
 
                    processingStream.SetLength(0);
                }
                else
                {
                    await Task.Delay(50);
                }
            }
        }
 
        protected override void EndProcessing()
        {
            _processor?.Dispose();
            _waveFileWriter?.Dispose();
            _outputStream?.Dispose();
            _cts?.Dispose();
            base.EndProcessing();
        }
 
        private static async Task DownloadModel(string fileName, GgmlType ggmlType)
        {
            Console.WriteLine($"Downloading Model {fileName}");
            using var modelStream = await WhisperGgmlDownloader.GetGgmlModelAsync(ggmlType);
            using var fileWriter = File.OpenWrite(fileName);
            await modelStream.CopyToAsync(fileWriter);
        }
    }
}