ConvertWavToText.cs

using System.Management;
using System.Management.Automation;
using Whisper.net;
using Whisper.net.Ggml;
 
namespace GenXdev.Helpers
{
    [Cmdlet(VerbsData.Convert, "WavToText")]
    public class ConvertWavToText : Cmdlet
    {
        #region Cmdlet Parameters
        [Parameter(Mandatory = true, HelpMessage = "Path to the model file")]
        public string ModelFilePath { get; set; }
 
        [Parameter(Mandatory = true, HelpMessage = "Path to the 16Khz mono .WAV file")]
        public string WaveFile { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Returns objects instead of strings")]
        public SwitchParameter Passthru { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to include token timestamps")]
        public SwitchParameter WithTokenTimestamps { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Sum threshold for token timestamps, defaults to 0.5")]
        public float TokenTimestampsSumThreshold { get; set; } = 0.5f;
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to split on word boundaries")]
        public SwitchParameter SplitOnWord { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of tokens per segment")]
        public int? MaxTokensPerSegment { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Language to detect, defaults to 'en'")]
        public string Language { get; set; } = "en";
 
        [Parameter(Mandatory = false, HelpMessage = "Number of CPU threads, defaults to 0 (auto)")]
        public int CpuThreads { get; set; } = 0;
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature for speech generation")]
        [ValidateRange(0, 1)]
        public float? Temperature { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature increment")]
        [ValidateRange(0, 1)]
        public float? TemperatureInc { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to translate the output")]
        public SwitchParameter WithTranslate { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Prompt to use for the model")]
        public string Prompt { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Regex to suppress tokens from the output")]
        public string SuppressRegex { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to show progress")]
        public SwitchParameter WithProgress { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Size of the audio context")]
        public int? AudioContextSize { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to NOT suppress blank lines")]
        public SwitchParameter DontSuppressBlank { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum duration of the audio")]
        public TimeSpan? MaxDuration { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Offset for the audio")]
        public TimeSpan? Offset { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of last text tokens")]
        public int? MaxLastTextTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to use single segment only")]
        public SwitchParameter SingleSegmentOnly { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to print special tokens")]
        public SwitchParameter PrintSpecialTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum segment length")]
        public int? MaxSegmentLength { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Start timestamps at this moment")]
        public TimeSpan? MaxInitialTimestamp { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Length penalty")]
        [ValidateRange(0, 1)]
        public float? LengthPenalty { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Entropy threshold")]
        [ValidateRange(0, 1)]
        public float? EntropyThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Log probability threshold")]
        [ValidateRange(0, 1)]
        public float? LogProbThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "No speech threshold")]
        [ValidateRange(0, 1)]
        public float? NoSpeechThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Don't use context")]
        public SwitchParameter NoContext { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Use beam search sampling strategy")]
        public SwitchParameter WithBeamSearchSamplingStrategy { get; set; }
        #endregion
 
        protected override void ProcessRecord()
        {
            base.ProcessRecord();
 
            var ggmlType = GgmlType.LargeV3Turbo;
            var modelFileName = Path.GetFullPath(Path.Combine(ModelFilePath, "ggml-largeV3Turbo.bin"));
 
            if (!File.Exists(modelFileName))
            {
                DownloadModel(modelFileName, ggmlType).GetAwaiter().GetResult();
            }
 
            using var whisperFactory = WhisperFactory.FromPath(modelFileName);
 
            var builder = whisperFactory.CreateBuilder()
                .WithLanguage(Language);
             
            int physicalCoreCount = 0;
            var searcher = new ManagementObjectSearcher("select NumberOfCores from Win32_Processor");
            foreach (var item in searcher.Get())
            {
                physicalCoreCount += Convert.ToInt32(item["NumberOfCores"]);
            }
 
            builder.WithThreads(CpuThreads > 0 ? CpuThreads : physicalCoreCount);
 
            if (Temperature.HasValue) builder.WithTemperature(Temperature.Value);
            if (TemperatureInc.HasValue) builder.WithTemperatureInc(TemperatureInc.Value);
            if (WithTokenTimestamps) builder.WithTokenTimestamps().WithTokenTimestampsSumThreshold(TokenTimestampsSumThreshold);
            if (WithTranslate) builder.WithTranslate();
            if (!string.IsNullOrWhiteSpace(Prompt)) builder.WithPrompt(Prompt);
            if (!string.IsNullOrWhiteSpace(SuppressRegex)) builder.WithSuppressRegex(SuppressRegex);
            if (WithProgress) builder.WithProgressHandler(progress => WriteProgress(new ProgressRecord(1, "Processing", $"Progress: {progress}%") { PercentComplete = progress }));
            if (SplitOnWord) builder.SplitOnWord();
            if (MaxTokensPerSegment.HasValue) builder.WithMaxTokensPerSegment(MaxTokensPerSegment.Value);
            if (AudioContextSize.HasValue) builder.WithAudioContextSize(AudioContextSize.Value);
            if (DontSuppressBlank) builder.WithoutSuppressBlank();
            if (MaxDuration.HasValue) builder.WithDuration(MaxDuration.Value);
            if (Offset.HasValue) builder.WithOffset(Offset.Value);
            if (MaxLastTextTokens.HasValue) builder.WithMaxLastTextTokens(MaxLastTextTokens.Value);
            if (SingleSegmentOnly) builder.WithSingleSegment();
            if (PrintSpecialTokens) builder.WithPrintSpecialTokens();
            if (MaxSegmentLength.HasValue) builder.WithMaxSegmentLength(MaxSegmentLength.Value);
            if (MaxInitialTimestamp.HasValue) builder.WithMaxInitialTs((int)MaxInitialTimestamp.Value.TotalSeconds);
            if (LengthPenalty.HasValue) builder.WithLengthPenalty(LengthPenalty.Value);
            if (EntropyThreshold.HasValue) builder.WithEntropyThreshold(EntropyThreshold.Value);
            if (LogProbThreshold.HasValue) builder.WithLogProbThreshold(LogProbThreshold.Value);
            if (NoSpeechThreshold.HasValue) builder.WithNoSpeechThreshold(NoSpeechThreshold.Value);
            if (NoContext) builder.WithNoContext();
            if (WithBeamSearchSamplingStrategy) builder.WithBeamSearchSamplingStrategy();
 
            using var processor = builder.Build();
            using var stream = File.OpenRead(WaveFile);
 
            var cts = new CancellationTokenSource();
            Console.WriteLine("Processing WAV file. Press Q to abort...");
 
            var processTask = Task.Run(async () =>
            {
                await foreach (var segment in processor.ProcessAsync(stream, cts.Token))
                {
                    WriteObject(Passthru ? segment : segment.Text);
                }
            });
 
            while (!processTask.IsCompleted)
            {
                if (Console.KeyAvailable && Console.ReadKey(true).Key == ConsoleKey.Q)
                {
                    cts.Cancel();
                    break;
                }
                Thread.Sleep(100);
            }
 
            processTask.Wait();
        }
 
        private static async Task DownloadModel(string fileName, GgmlType ggmlType)
        {
            Console.WriteLine($"Downloading Model {fileName}");
            using var modelStream = await WhisperGgmlDownloader.GetGgmlModelAsync(ggmlType);
            using var fileWriter = File.OpenWrite(fileName);
            await modelStream.CopyToAsync(fileWriter);
        }
    }
}