InvokeProcessFast.cs

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Management.Automation;
using System.Management.Automation.Language;
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Nito.AsyncEx;

namespace PowerProcess
{

    /// <summary>
    /// This class implements the Start-process command.
    /// </summary>
    /// <remarks>
    /// Monad2021:
    /// Added support for buffering streams.
    /// Its not possible to call the shell by mistake.
    /// Support for the new VT terminals [WIP].
    /// Correct treatment of argument lists (just use Sys.Diag.Proc).
    /// Possibility of merging Out and Error at the source.
    /// </remarks>
    [Cmdlet(VerbsLifecycle.Invoke, "ProcessFast", SupportsShouldProcess = true, HelpUri = "https://go.microsoft.com/fwlink/?LinkID=2097141")]
    [OutputType(typeof(Process))]
    public sealed class InvokeProcessFastCommand : PSCmdlet, IDisposable
    {
        private Process? _process = null;
        private ManualResetEvent? _waitHandle = null;
        private CancellationTokenSource? _cancellationSource = null;

        #region Parameters
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.

        private const string DefaultParameterSet = "ScriptBlock";
        private const string WinEnvParameterSet = "WinEnv";

        /// <summary>
        /// Path/FileName of the process to start.
        /// </summary>
        [Parameter(ParameterSetName = DefaultParameterSet, Mandatory = true, Position = 0)]
        [ValidateNotNullOrEmpty]
        [Alias("PSPath", "Path")]
        public string FilePath { get; set; }

        /// <summary>
        /// Arguments for the process.
        /// </summary>
        [Parameter(ParameterSetName = DefaultParameterSet, Position = 1)]
        [Alias("Args")]
        public string[]? ArgumentList { get; set; }

        /// <summary>
        /// Credentials for the process.
        /// </summary>
        [Parameter(ParameterSetName = WinEnvParameterSet)]
        [Alias("RunAs")]
        [ValidateNotNullOrEmpty]
        [Credential]
        public PSCredential? Credential { get; set; }

        /// <summary>
        /// Working directory of the process.
        /// </summary>
        [Parameter(ParameterSetName = DefaultParameterSet)]
        [ValidateNotNullOrEmpty]
        public string? WorkingDirectory { get; set; }

        /// <summary>
        /// Load user profile from registry.
        /// </summary>
        [Parameter(ParameterSetName = WinEnvParameterSet)]
        [Alias("Lup")]
        public SwitchParameter LoadUserProfile { get; set; }

        /// <summary>
        /// PassThru parameter.
        /// </summary>
        [Parameter]
        public SwitchParameter PassThru { get; set; }

        /// <summary>
        /// Redirect outputs.
        /// </summary>
        [Parameter]
        [Alias("NoRedir")]
        public SwitchParameter DontRedirectOutputs { get; set; }

        /// <summary>
        /// Merge Error to Output.
        /// </summary>
        [Parameter]
        [Alias("Merge")]
        public SwitchParameter MergeStandardErrorToOutput { get; set; }

        /// <summary>
        /// Wrap output stream.
        /// </summary>
        [Parameter]
        [Alias("Obj")]
        public SwitchParameter WrapOutputStream { get; set; }

        /// <summary>
        /// Wait for the process to terminate.
        /// </summary>
        [Parameter]
        public SwitchParameter Wait { get; set; }

        /// <summary>
        /// Default Environment.
        /// </summary>
        [Parameter(ParameterSetName = WinEnvParameterSet)]
        public SwitchParameter UseNewEnvironment { get; set; }

#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
        #endregion

        #region Pipeline

        [Parameter(ValueFromPipeline = true)]
        public string? InputObject { get; set; }

        /// <summary>
        /// Buffer the output stream.
        /// </summary>
        [Parameter]
        public int? OutputBuffer { get; set; }

        #endregion

        #region Overrides

        /// <summary>
        /// BeginProcessing.
        /// </summary>
        protected override void BeginProcessing()
        {

            ProcessStartInfo startInfo = new();

            // Path = Mandatory parameter -> Will not be empty.
            try
            {
                var cmdinfo = base.InvokeCommand.GetCommand(
                    FilePath, CommandTypes.Application | CommandTypes.ExternalScript);
                startInfo.FileName = cmdinfo.Definition;
            }
            catch (CommandNotFoundException)
            {
                startInfo.FileName = FilePath;
            }

            if (ArgumentList != null)
            {
                foreach (var arg in ArgumentList)
                {
                    startInfo.ArgumentList.Add(arg);
                }
            }

            if (WorkingDirectory != null)
            {
                // WorkingDirectory -> Not Exist -> Throw Error
                WorkingDirectory = ResolveFilePath(WorkingDirectory);
                if (!Directory.Exists(WorkingDirectory))
                {
                    var message = StringUtil.Format(ProcessResources.InvalidInput, nameof(WorkingDirectory));
                    var er = new ErrorRecord(new DirectoryNotFoundException(message), nameof(DirectoryNotFoundException), ErrorCategory.InvalidOperation, null);
                    WriteError(er);
                    return;
                }

                startInfo.WorkingDirectory = WorkingDirectory;
            }
            else
            {
                // Working Directory not specified -> Assign Current Path.
                startInfo.WorkingDirectory = base.SessionState.Path.CurrentFileSystemLocation.Path;
            }

            if (this.ParameterSetName.Equals(WinEnvParameterSet))
            {
                startInfo.UseShellExecute = false;

                if (UseNewEnvironment)
                {
                    startInfo.EnvironmentVariables.Clear();
                    LoadEnvironmentVariable(startInfo, Environment.GetEnvironmentVariables(EnvironmentVariableTarget.Machine));
                    LoadEnvironmentVariable(startInfo, Environment.GetEnvironmentVariables(EnvironmentVariableTarget.User));
                }

                startInfo.CreateNoWindow = true;
#if !UNIX
#pragma warning disable CA1416 // Validate platform compatibility
                startInfo.LoadUserProfile = LoadUserProfile;
                if (Credential != null)
                {
                    NetworkCredential nwcredential = Credential.GetNetworkCredential();
                    startInfo.UserName = nwcredential.UserName;
                    if (string.IsNullOrEmpty(nwcredential.Domain))
                    {
                        startInfo.Domain = ".";
                    }
                    else
                    {
                        startInfo.Domain = nwcredential.Domain;
                    }

                    startInfo.Password = Credential.Password;
                }
#pragma warning restore CA1416 // Validate platform compatibility
#endif
            }

            string targetMessage = StringUtil.Format(ProcessResources.StartProcessTarget, startInfo.FileName, startInfo.Arguments.Trim());
            if (!ShouldProcess(targetMessage)) { return; }

            _process = Start(startInfo);

            if (PassThru.IsPresent)
            {
                if (_process != null)
                {
                    WriteObject(_process);
                }
                else
                {
                    var message = StringUtil.Format(ProcessResources.CannotStartTheProcess);
                    var er = new ErrorRecord(new InvalidOperationException(message), nameof(InvalidOperationException), ErrorCategory.InvalidOperation, null);
                    ThrowTerminatingError(er);
                    return;
                }
            }

            _cancellationSource = new CancellationTokenSource();

            if (!Wait.IsPresent)
            {
                return;
            }

            if (_process != null)
            {
                if (_process.HasExited)
                {
                    ConsumeAvailableNativeProcessOutput(blocking: true, _process, _cancellationSource.Token);
                    SetLastExitCode(_process);
                    _process = null;
                }
                else
                {
                    _process.Exited += myProcess_Exited;
                    _process.EnableRaisingEvents = true;
                    _waitHandle = new ManualResetEvent(false);
                }
            }
            else
            {
                var message = StringUtil.Format(ProcessResources.CannotStartTheProcess);
                var er = new ErrorRecord(new InvalidOperationException(message), nameof(InvalidOperationException), ErrorCategory.InvalidOperation, null);
                ThrowTerminatingError(er);
            }
        }

        /// <summary>
        /// Pass parameter from pipeline to the process.
        /// </summary>
        protected override void ProcessRecord()
        {
            if (!base.MyInvocation.ExpectingInput) return;
            if (_process != null && _cancellationSource != null)
            {
                ProduceNativeProcessInput(_process);
            }
            else
            {
                var message = StringUtil.Format(ProcessResources.ProcessIsNotStarted);
                var er = new ErrorRecord(new InvalidOperationException(message), nameof(InvalidOperationException), ErrorCategory.InvalidOperation, null);
                ThrowTerminatingError(er);
            }
        }

        /// <summary>
        /// Wait for the process to terminate.
        /// </summary>
        protected override void EndProcessing()
        {
            if (_process == null || _cancellationSource == null) return;

            if (Wait.IsPresent && _waitHandle != null)
            {
                ConsumeAvailableNativeProcessOutput(blocking: true, _process, _cancellationSource.Token);

                _waitHandle.WaitOne();

                if (_process.HasExited)
                {
                    SetLastExitCode(_process);
                }
                else
                {
                    var message = StringUtil.Format(ProcessResources.ProcessIsNotTerminated);
                    var er = new ErrorRecord(new InvalidOperationException(message), nameof(InvalidOperationException), ErrorCategory.InvalidOperation, null);
                    ThrowTerminatingError(er);
                }
            }
            else
            {
                var p = _process;
                _process = null; //suppress finalize
                ConsumeAvailableNativeProcessOutput(blocking: false, p, _cancellationSource.Token);
                SetLastExitCode(0);
            }
        }

        /// <summary>
        /// Implements ^c, after creating a process.
        /// </summary>
        protected override void StopProcessing()
        {
            if (_cancellationSource != null)
            {
                _cancellationSource.Cancel();
                _cancellationSource = null;
            }
            if (_waitHandle != null)
            {
                _waitHandle.Set();
            }
            if (_process != null)
            {
                if (!_process.HasExited)
                {
                    _process.Kill(true);
                }
                _process = null;
            }
        }

        #endregion

        #region IDisposable Overrides

        /// <summary>
        /// Dispose WaitHandle used to honor -Wait parameter.
        /// </summary>
        public void Dispose()
        {
            Dispose(true);
            System.GC.SuppressFinalize(this);
        }

        private void Dispose(bool isDisposing)
        {
            if (_waitHandle != null)
            {
                _waitHandle.Dispose();
                _waitHandle = null;
            }
            try
            {
                // Dispose the process if it's already created
                if (_process != null)
                {
                    _process.Dispose();
                }
            }
            catch (Exception)
            {
            }
        }

        #endregion

        #region Private Methods

        /// <summary>
        /// When Process exits the wait handle is set.
        /// </summary>
        private void myProcess_Exited(object? sender, System.EventArgs e)
        {
            if (_waitHandle != null)
            {
                _waitHandle.Set();
            }
        }

        private string ResolveFilePath(string path)
        {
            return base.GetResolvedProviderPathFromPSPath(path, out _)[0];
        }

        private static void LoadEnvironmentVariable(ProcessStartInfo startinfo, IDictionary EnvironmentVariables)
        {
            var processEnvironment = startinfo.EnvironmentVariables;
            foreach (DictionaryEntry entry in EnvironmentVariables)
            {
                var key = entry.Key.ToString();
                if (key == null) continue;
                if (processEnvironment.ContainsKey(key))
                {
                    processEnvironment.Remove(key);
                }

                if (key.Equals("PATH"))
                {
                    processEnvironment.Add(key,
                        Environment.GetEnvironmentVariable(key, EnvironmentVariableTarget.Machine) + ";" +
                        Environment.GetEnvironmentVariable(key, EnvironmentVariableTarget.User));
                }
                else
                {
                    processEnvironment.Add(key, entry.Value?.ToString());
                }
            }
        }

        private Process Start(ProcessStartInfo startInfo)
        {
            var process = new Process() { StartInfo = startInfo };
            SetupInputOutputRedirection(process);
            process.Start();
            return process;
        }

        private void SetupInputOutputRedirection(Process p)
        {
            p.StartInfo.RedirectStandardInput = base.MyInvocation.ExpectingInput;
            p.StartInfo.RedirectStandardOutput = !DontRedirectOutputs;
            p.StartInfo.RedirectStandardError = !DontRedirectOutputs;
        }

        /// <summary>
        /// Read the input from the pipeline and send it down the native process.
        /// </summary>
        private void ProduceNativeProcessInput(Process p)
        {
            p.StandardInput.WriteLine(InputObject);
        }

        /// <summary>
        /// Read the output from the native process and send it down the line.
        /// </summary>
        private void ConsumeAvailableNativeProcessOutput(bool blocking, Process p, CancellationToken ct)
        {
            if (DontRedirectOutputs) return;
            var _buffer = OutputBuffer ?? 256;
            var _merge = MergeStandardErrorToOutput.ToBool();
            var _wrap = WrapOutputStream.ToBool();
            Func<TaskJob?, PSCmdlet, Func<Task<bool>>> _task = ((_job, _cmdlet) => () =>
                   ConsumeAvailableNativeProcessOutputAsync(
                       _process: p,
                       _cmdlet: this,
                       _job: _job,
                       _merge: _merge,
                       _wrap: _wrap,
                       _buffer: _buffer));
            if (blocking)
            {
                AsyncContext.Run(_task(null, this), ct);
            }
            else
            {
                var cts = new CancellationTokenSource();
                var job = new TaskJob(
                    this,
                    p.ProcessName,
                    job => AsyncContext.Run(_task(job, this), cts.Token),
                    cts);
                cts.Token.Register(() =>
                {
                    if (!p.HasExited)
                    {
                        p.Kill(true);
                    }
                    p.Dispose();
                });
                job.StartJobAsync();
            }
        }

        private static async Task<bool> ConsumeAvailableNativeProcessOutputAsync(
            Process _process,
            PSCmdlet _cmdlet,
            TaskJob? _job,
            bool _merge,
            bool _wrap,
            int _buffer)
        {
            var _out = _process.StandardOutput;
            var _err = _process.StandardError;

            var streams = new[] { _out, _err };
            var ids = new List<int>(2) { 0, 1 };
            var tasks = new List<Task<string?>?>(2);

            // calculate redirect
            CalculareRedirect(
                _job != null,
                _merge,
                _wrap,
                _buffer,
                out var src,
                out var buf_tgt,
                out var buf_stream,
                out var tgt);

            // stream from source to target
            foreach (var strm in streams)
            {
                tasks.Add(strm.ReadLineAsync());
            }
            do
            {
                var count = 0;
                do
                {
                    // wait
                    var t = await Task.WhenAny(tasks!);

                    // find which stream
                    var i = tasks.IndexOf(t);
                    if (!t.IsCompleted) continue;

                    // check for end of stream
                    var r = t.Result!;
                    if (r == null)
                    {
                        tasks.RemoveAt(i);
                        ids.RemoveAt(i);
                        continue;
                    }

                    // get and clear for next read if completed
                    var id = ids[i];
                    tasks[i] = streams[id].ReadLineAsync();

                    // now redirect
                    var o = WrapMessage(r, src[id]);
                    RedirectMessage(o, buf_tgt[id], buf_stream[id], _job, _cmdlet);
                    count++;

                    // until buffer full or end of any streams
                } while (count < _buffer && tasks.Count > 0);

                // send collected results for each stream respectively
                for (var i = 0; i < streams.Length; i++)
                {
                    RedirectList(buf_stream[i], tgt[i], _job, _cmdlet);
                }

                // until end of all streams
            } while (tasks.Count > 0);

            return true;
        }

        private static void CalculareRedirect(
            bool task, bool merge, bool wrap, int buffer,
            out WrapSource[] wrapSrc,
            out RedirTarget[] redirTgt, out object?[] redirStream,
            out FinalTarget[] finalTgt)
        {
            if (!wrap)
            {
                if (buffer != 1)
                {
                    wrapSrc = new[] { WrapSource.Str, WrapSource.Str };
                    redirTgt = new[] { RedirTarget.StrLst, RedirTarget.StrLst };
                    redirStream = new[] { NewList<string>(buffer), NewList<string>(buffer) };
                    finalTgt = new[] { FinalTarget.OutStrLst, FinalTarget.ErrStrLst };
                }
                else
                {
                    wrapSrc = new[] { WrapSource.Pso, WrapSource.Rcd };
                    redirTgt = new[] { RedirTarget.Out, RedirTarget.Err };
                    redirStream = new[] { (object?)null, (object?)null };
                    finalTgt = new[] { FinalTarget.Nop, FinalTarget.Nop };
                }
            }
            else
            {
                if (buffer != 1)
                {
                    wrapSrc = new[] { WrapSource.Out, WrapSource.Err };
                    redirTgt = new[] { RedirTarget.ObjLst, RedirTarget.ObjLst };
                    redirStream = new[] { NewList<object>(buffer), NewList<object>(buffer) };
                    finalTgt = new[] { FinalTarget.OutObjLst, FinalTarget.ErrObjLst };
                }
                else
                {
                    wrapSrc = new[] { WrapSource.Out, WrapSource.Wse };
                    redirTgt = new[] { RedirTarget.Out, RedirTarget.Err };
                    redirStream = new[] { (object?)null, (object?)null };
                    finalTgt = new[] { FinalTarget.Nop, FinalTarget.Nop };
                }
            }
            if (merge)
            {
                if (wrapSrc[1] == WrapSource.Wse) wrapSrc[1] = WrapSource.Err;
                redirTgt[1] = redirTgt[0];
                redirStream[1] = redirStream[0];
                finalTgt[1] = finalTgt[0];
            }
            if (task)
            {
                wrapSrc[0] = WrapTaskStream(wrapSrc[0]);
                wrapSrc[1] = WrapTaskStream(wrapSrc[1]);
            }
        }

        private static WrapSource WrapTaskStream(
            WrapSource source)
        {
            return source switch
            {
                WrapSource.Str => WrapSource.Pso,
                WrapSource.Out => WrapSource.Wso,
                WrapSource.Err => WrapSource.Wse,
                _ => source
            };
        }

        private static object WrapMessage(
            string message,
            WrapSource source)
        {
            var m = message;
            return source switch
            {
                WrapSource.Str => m,
                WrapSource.Out => WrapObject.Output(m),
                WrapSource.Err => WrapObject.Error(m),
                WrapSource.Pso => PSObject.AsPSObject(m),
                WrapSource.Rcd => MakeError(m),
                WrapSource.Wso => MakeError(WrapObject.Output(m)),
                WrapSource.Wse => MakeError(WrapObject.Error(m)),
                _ => throw new InvalidOperationException(),
            };
        }

        private static void RedirectMessage(
            object message,
            RedirTarget target,
            object? stream,
            TaskJob? job,
            PSCmdlet cmdlet)
        {
            switch (target)
            {
                case RedirTarget.Out:
                    if (job != null) job.Output.Add((PSObject)message);
                    else cmdlet.WriteObject(message);
                    break;

                case RedirTarget.Err:
                    if (job != null) job.Error.Add((ErrorRecord)message);
                    else cmdlet.WriteError((ErrorRecord)message);
                    break;

                case RedirTarget.StrLst:
                    ((List<string>)stream!).Add((string)message);
                    break;

                case RedirTarget.ObjLst:
                    ((List<object>)stream!).Add(message);
                    break;
            }
        }

        private static void RedirectList(
            object? stream,
            FinalTarget target,
            TaskJob? job,
            PSCmdlet cmdlet)
        {
            var l = stream as IList;
            var lst = target switch
            {
                FinalTarget.OutObjLst => ((List<string>)stream!).ToArray(),
                FinalTarget.ErrObjLst => ((List<string>)stream!).ToArray(),
                FinalTarget.OutStrLst => ((List<object>)stream!).ToArray(),
                FinalTarget.ErrStrLst => ((List<object>)stream!).ToArray(),
                _ => null,
            };
            if (lst == null || lst.Length == 0) return;
            switch (target)
            {
                case FinalTarget.OutObjLst:
                case FinalTarget.OutStrLst:
                    if (job != null) job.Output.Add(PSObject.AsPSObject(lst));
                    else cmdlet.WriteObject(PSObject.AsPSObject(lst));
                    break;

                case FinalTarget.ErrStrLst:
                case FinalTarget.ErrObjLst:
                    if (job != null) job.Error.Add(MakeError(lst));
                    else cmdlet.WriteError(MakeError(lst));
                    break;
            }
            l!.Clear();
        }

        private enum WrapSource
        {
            Str,
            Out,
            Err,
            Pso,
            Rcd,
            Wso,
            Wse,
        }

        private enum RedirTarget
        {
            Out,
            Err,
            StrLst,
            ObjLst,
        }

        private enum FinalTarget
        {
            Nop,
            OutStrLst,
            OutObjLst,
            ErrStrLst,
            ErrObjLst,
        }

        #endregion

        #region Helpers

        private static ErrorRecord MakeError(string message)
        {
            return new ErrorRecord(new StdErr(message), null, ErrorCategory.FromStdErr, null);
        }

        private static ErrorRecord MakeError(WrapObject wrapped)
        {
            return new ErrorRecord(new StdErr(wrapped), null, ErrorCategory.FromStdErr, null);
        }

        private static ErrorRecord MakeError(object messages)
        {
            return new ErrorRecord(new StdErr(messages), null, ErrorCategory.FromStdErr, null);
        }

        private class StdErr : Exception
        {
            private static readonly string[] empty = Array.Empty<string>();
            private readonly string[] err;

            public StdErr(object lst) : this(ConvertList(lst))
            {
            }

            public StdErr(string[] err) : base(ToString(err))
            {
                this.err = err;
            }

            public StdErr(WrapObject wrapped) : base(wrapped.Message)
            {
                this.err = empty;
            }

            public StdErr(string message) : base(message)
            {
                this.err = empty;
            }

            public IList<string> ErrorList => err;

            public static string[] ConvertList(object lst)
            {
                if (lst is string[] s)
                    return s;
                if (lst is object[] o)
                    return Array.ConvertAll(o, i => i.ToString()!);
                return empty;
            }

            public static string ToString(string[] lst)
            {
                var sb = new StringBuilder();
                foreach (var str in lst) sb.Append(str);
                return sb.ToString();
            }

        }

        public struct WrapObject
        {

            public WrapObject(RedirectionStream stream, string message)
            {
                Stream = stream;
                Message = message;
            }

            public RedirectionStream Stream { get; }
            public string Message { get; }


            internal static WrapObject Error(string message)
            {
                return new WrapObject(RedirectionStream.Error, message);
            }

            internal static WrapObject Output(string message)
            {
                return new WrapObject(RedirectionStream.Output, message);
            }

            public override string ToString()
            {
                return $"{Stream}: {Message}";
            }
        }

        private static List<T> NewList<T>(int capacity)
        {
            return capacity == int.MaxValue ? new List<T>(32) : new List<T>(capacity);
        }

        private void SetLastExitCode(Process process)
        {
            SetLastExitCode(process.ExitCode);
        }

        private void SetLastExitCode(int exitCode)
        {
            base.SessionState.PSVariable.Set("LASTEXITCODE", exitCode);
        }

        #endregion

    }

}