Cmdlets/src/XpandPosh.Cmdlets/AsyncCmdlet.cs
using System;
using System.Collections.Concurrent; using System.Collections.Generic; using System.Management.Automation; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace XpandPosh.CmdLets{ /// <summary> /// Base class for Cmdlets that run asynchronously. /// </summary> /// <remarks> /// Inherit from this class if your Cmdlet needs to use <c>async</c> / <c>await</c> functionality. /// </remarks> public abstract class AsyncCmdlet : PSCmdlet, IDisposable{ /// <summary> /// The source for cancellation tokens that can be used to cancel the operation. /// </summary> private readonly CancellationTokenSource _cancellationSource = new CancellationTokenSource(); /// <summary> /// Dispose of resources being used by the Cmdlet. /// </summary> public void Dispose(){ Dispose(true); GC.SuppressFinalize(this); } /// <summary> /// Finaliser for <see cref="AsyncCmdlet" />. /// </summary> ~AsyncCmdlet(){ Dispose(false); } /// <summary> /// Dispose of resources being used by the Cmdlet. /// </summary> /// <param name="disposing"> /// Explicit disposal? /// </param> protected virtual void Dispose(bool disposing){ if (disposing) _cancellationSource.Dispose(); } /// <summary> /// Asynchronously perform Cmdlet pre-processing. /// </summary> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task BeginProcessingAsync(){ return BeginProcessingAsync(_cancellationSource.Token); } /// <summary> /// Asynchronously perform Cmdlet pre-processing. /// </summary> /// <param name="cancellationToken"> /// A <see cref="CancellationToken" /> that can be used to cancel the asynchronous operation. /// </param> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task BeginProcessingAsync(CancellationToken cancellationToken){ return Task.CompletedTask; } /// <summary> /// Asynchronously perform Cmdlet processing. /// </summary> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task ProcessRecordAsync(){ return ProcessRecordAsync(_cancellationSource.Token); } /// <summary> /// Asynchronously perform Cmdlet processing. /// </summary> /// <param name="cancellationToken"> /// A <see cref="CancellationToken" /> that can be used to cancel the asynchronous operation. /// </param> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task ProcessRecordAsync(CancellationToken cancellationToken){ return Task.CompletedTask; } /// <summary> /// Asynchronously perform Cmdlet post-processing. /// </summary> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task EndProcessingAsync(){ return EndProcessingAsync(_cancellationSource.Token); } /// <summary> /// Asynchronously perform Cmdlet post-processing. /// </summary> /// <param name="cancellationToken"> /// A <see cref="CancellationToken" /> that can be used to cancel the asynchronous operation. /// </param> /// <returns> /// A <see cref="Task" /> representing the asynchronous operation. /// </returns> protected virtual Task EndProcessingAsync(CancellationToken cancellationToken){ return Task.CompletedTask; } /// <summary> /// Perform Cmdlet pre-processing. /// </summary> protected sealed override void BeginProcessing(){ ThreadAffinitiveSynchronizationContext.RunSynchronized(BeginProcessingAsync); } /// <summary> /// Perform Cmdlet processing. /// </summary> protected sealed override void ProcessRecord(){ ThreadAffinitiveSynchronizationContext.RunSynchronized(ProcessRecordAsync); } /// <summary> /// Perform Cmdlet post-processing. /// </summary> protected sealed override void EndProcessing(){ ThreadAffinitiveSynchronizationContext.RunSynchronized(EndProcessingAsync); } /// <summary> /// Interrupt Cmdlet processing (if possible). /// </summary> protected sealed override void StopProcessing(){ _cancellationSource.Cancel(); base.StopProcessing(); } /// <summary> /// Write a progress record to the output stream, and as a verbose message. /// </summary> /// <param name="progressRecord"> /// The progress record to write. /// </param> protected void WriteVerboseProgress(ProgressRecord progressRecord){ if (progressRecord == null) throw new ArgumentNullException(nameof(progressRecord)); WriteProgress(progressRecord); WriteVerbose(progressRecord.StatusDescription); } /// <summary> /// Write a progress record to the output stream, and as a verbose message. /// </summary> /// <param name="progressRecord"> /// The progress record to write. /// </param> /// <param name="messageOrFormat"> /// The message or message-format specifier. /// </param> /// <param name="formatArguments"> /// Optional format arguments. /// </param> protected void WriteVerboseProgress(ProgressRecord progressRecord, string messageOrFormat, params object[] formatArguments){ if (progressRecord == null) throw new ArgumentNullException(nameof(progressRecord)); if (string.IsNullOrWhiteSpace(messageOrFormat)) throw new ArgumentException( "Argument cannot be null, empty, or composed entirely of whitespace: 'messageOrFormat'.", nameof(messageOrFormat)); if (formatArguments == null) throw new ArgumentNullException(nameof(formatArguments)); progressRecord.StatusDescription = string.Format(messageOrFormat, formatArguments); WriteVerboseProgress(progressRecord); } /// <summary> /// Write a completed progress record to the output stream. /// </summary> /// <param name="progressRecord"> /// The progress record to complete. /// </param> /// <param name="completionMessageOrFormat"> /// The completion message or message-format specifier. /// </param> /// <param name="formatArguments"> /// Optional format arguments. /// </param> protected void WriteProgressCompletion(ProgressRecord progressRecord, string completionMessageOrFormat, params object[] formatArguments){ if (progressRecord == null) throw new ArgumentNullException(nameof(progressRecord)); if (string.IsNullOrWhiteSpace(completionMessageOrFormat)) throw new ArgumentException( "Argument cannot be null, empty, or composed entirely of whitespace: 'completionMessageOrFormat'.", nameof(completionMessageOrFormat)); if (formatArguments == null) throw new ArgumentNullException(nameof(formatArguments)); progressRecord.StatusDescription = string.Format(completionMessageOrFormat, formatArguments); progressRecord.PercentComplete = 100; progressRecord.RecordType = ProgressRecordType.Completed; WriteProgress(progressRecord); WriteVerbose(progressRecord.StatusDescription); } } /// <summary> /// A synchronisation context that runs all calls scheduled on it (via <see cref="SynchronizationContext.Post" />) on a /// single thread. /// </summary> /// <remarks> /// With thanks to Stephen Toub. /// </remarks> public sealed class ThreadAffinitiveSynchronizationContext : SynchronizationContext, IDisposable{ /// <summary> /// A blocking collection (effectively a queue) of work items to execute, consisting of callback delegates and their /// callback state (if any). /// </summary> private BlockingCollection<KeyValuePair<SendOrPostCallback, object>> _workItemQueue = new BlockingCollection<KeyValuePair<SendOrPostCallback, object>>(); /// <summary> /// Create a new thread-affinitive synchronisation context. /// </summary> private ThreadAffinitiveSynchronizationContext(){ } /// <summary> /// Dispose of resources being used by the synchronisation context. /// </summary> void IDisposable.Dispose(){ if (_workItemQueue != null){ _workItemQueue.Dispose(); _workItemQueue = null; } } /// <summary> /// Check if the synchronisation context has been disposed. /// </summary> private void CheckDisposed(){ if (_workItemQueue == null) throw new ObjectDisposedException(GetType().Name); } /// <summary> /// Run the message pump for the callback queue on the current thread. /// </summary> private void RunMessagePump(){ CheckDisposed(); while (_workItemQueue.TryTake(out var workItem, Timeout.InfiniteTimeSpan)){ workItem.Key(workItem.Value); // Has the synchronisation context been disposed? if (_workItemQueue == null) break; } } /// <summary> /// Terminate the message pump once all callbacks have completed. /// </summary> private void TerminateMessagePump(){ CheckDisposed(); _workItemQueue.CompleteAdding(); } /// <summary> /// Dispatch an asynchronous message to the synchronization context. /// </summary> /// <param name="callback"> /// The <see cref="SendOrPostCallback" /> delegate to call in the synchronisation context. /// </param> /// <param name="callbackState"> /// Optional state data passed to the callback. /// </param> /// <exception cref="InvalidOperationException"> /// The message pump has already been started, and then terminated by calling <see cref="TerminateMessagePump" />. /// </exception> public override void Post(SendOrPostCallback callback, object callbackState){ if (callback == null) throw new ArgumentNullException(nameof(callback)); CheckDisposed(); try{ _workItemQueue.Add( new KeyValuePair<SendOrPostCallback, object>( callback, callbackState ) ); } catch (InvalidOperationException eMessagePumpAlreadyTerminated){ throw new InvalidOperationException( "Cannot enqueue the specified callback because the synchronisation context's message pump has already been terminated.", eMessagePumpAlreadyTerminated ); } } /// <summary> /// Run an asynchronous operation using the current thread as its synchronisation context. /// </summary> /// <param name="asyncOperation"> /// A <see cref="Func{TResult}" /> delegate representing the asynchronous operation to run. /// </param> public static void RunSynchronized(Func<Task> asyncOperation){ if (asyncOperation == null) throw new ArgumentNullException(nameof(asyncOperation)); var savedContext = Current; try{ using (var synchronizationContext = new ThreadAffinitiveSynchronizationContext()){ SetSynchronizationContext(synchronizationContext); var rootOperationTask = asyncOperation(); if (rootOperationTask == null) throw new InvalidOperationException("The asynchronous operation delegate cannot return null."); rootOperationTask.ContinueWith( operationTask => // ReSharper disable once AccessToDisposedClosure synchronizationContext.TerminateMessagePump(), TaskScheduler.Default ); synchronizationContext.RunMessagePump(); try{ rootOperationTask .GetAwaiter() .GetResult(); } catch (AggregateException eWaitForTask ) // The TPL will almost always wrap an AggregateException around any exception thrown by the async operation. { // Is this just a wrapped exception? var flattenedAggregate = eWaitForTask.Flatten(); if (flattenedAggregate.InnerExceptions.Count != 1) throw; // Nope, genuine aggregate. // Yep, so rethrow (preserving original stack-trace). ExceptionDispatchInfo.Capture(flattenedAggregate.InnerExceptions[0]).Throw(); } } } finally{ SetSynchronizationContext(savedContext); } } /// <summary> /// Run an asynchronous operation using the current thread as its synchronisation context. /// </summary> /// <typeparam name="TResult"> /// The operation result type. /// </typeparam> /// <param name="asyncOperation"> /// A <see cref="Func{TResult}" /> delegate representing the asynchronous operation to run. /// </param> /// <returns> /// The operation result. /// </returns> public static TResult RunSynchronized<TResult>(Func<Task<TResult>> asyncOperation){ if (asyncOperation == null) throw new ArgumentNullException(nameof(asyncOperation)); var savedContext = Current; try{ using (var synchronizationContext = new ThreadAffinitiveSynchronizationContext()){ SetSynchronizationContext(synchronizationContext); var rootOperationTask = asyncOperation(); if (rootOperationTask == null) throw new InvalidOperationException("The asynchronous operation delegate cannot return null."); rootOperationTask.ContinueWith( operationTask => // ReSharper disable once AccessToDisposedClosure synchronizationContext.TerminateMessagePump(), TaskScheduler.Default ); synchronizationContext.RunMessagePump(); try{ return rootOperationTask .GetAwaiter() .GetResult(); } catch (AggregateException eWaitForTask ) // The TPL will almost always wrap an AggregateException around any exception thrown by the async operation. { // Is this just a wrapped exception? var flattenedAggregate = eWaitForTask.Flatten(); if (flattenedAggregate.InnerExceptions.Count != 1) throw; // Nope, genuine aggregate. // Yep, so rethrow (preserving original stack-trace). ExceptionDispatchInfo .Capture( flattenedAggregate .InnerExceptions[0] ) .Throw(); throw; // Never reached. } } } finally{ SetSynchronizationContext(savedContext); } } } } |