SupersedingAsyncGate.cs
1 // Copyright (c) Microsoft Corporation 2 // The Microsoft Corporation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System; 6 using System.Threading; 7 using System.Threading.Tasks; 8 9 namespace Microsoft.CmdPal.Core.Common.Helpers; 10 11 /// <summary> 12 /// An async gate that ensures only one operation runs at a time. 13 /// If ExecuteAsync is called while already executing, it cancels the current execution 14 /// and starts the operation again (superseding behavior). 15 /// </summary> 16 public sealed partial class SupersedingAsyncGate : IDisposable 17 { 18 private readonly Func<CancellationToken, Task> _action; 19 private readonly Lock _lock = new(); 20 private int _callId; 21 private TaskCompletionSource<bool>? _currentTcs; 22 private CancellationTokenSource? _currentCancellationSource; 23 private Task? _executingTask; 24 25 public SupersedingAsyncGate(Func<CancellationToken, Task> action) 26 { 27 ArgumentNullException.ThrowIfNull(action); 28 _action = action; 29 } 30 31 /// <summary> 32 /// Executes the configured action. If another execution is running, this call will 33 /// cancel the current execution and restart the operation. 34 /// </summary> 35 /// <param name="cancellationToken">Optional external cancellation token</param> 36 public async Task ExecuteAsync(CancellationToken cancellationToken = default) 37 { 38 TaskCompletionSource<bool> tcs; 39 40 lock (_lock) 41 { 42 _currentCancellationSource?.Cancel(); 43 _currentTcs?.TrySetException(new OperationCanceledException("Superseded by newer call")); 44 45 tcs = new(); 46 _currentTcs = tcs; 47 _callId++; 48 49 var shouldStartExecution = _executingTask is null; 50 if (shouldStartExecution) 51 { 52 _executingTask = Task.Run(ExecuteLoop, CancellationToken.None); 53 } 54 } 55 56 await using var ctr = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); 57 await tcs.Task; 58 } 59 60 private async Task ExecuteLoop() 61 { 62 try 63 { 64 while (true) 65 { 66 TaskCompletionSource<bool>? currentTcs; 67 CancellationTokenSource? currentCts; 68 int currentCallId; 69 70 lock (_lock) 71 { 72 currentTcs = _currentTcs; 73 currentCallId = _callId; 74 75 if (currentTcs is null) 76 { 77 break; 78 } 79 80 _currentCancellationSource?.Dispose(); 81 _currentCancellationSource = new(); 82 currentCts = _currentCancellationSource; 83 } 84 85 try 86 { 87 await _action(currentCts.Token); 88 CompleteIfCurrent(currentTcs, currentCallId, static t => t.TrySetResult(true)); 89 } 90 catch (OperationCanceledException) 91 { 92 CompleteIfCurrent(currentTcs, currentCallId, tcs => tcs.TrySetCanceled(currentCts.Token)); 93 } 94 catch (Exception ex) 95 { 96 CompleteIfCurrent(currentTcs, currentCallId, tcs => tcs.TrySetException(ex)); 97 } 98 } 99 } 100 finally 101 { 102 lock (_lock) 103 { 104 _currentTcs = null; 105 _currentCancellationSource?.Dispose(); 106 _currentCancellationSource = null; 107 _executingTask = null; 108 } 109 } 110 } 111 112 private void CompleteIfCurrent( 113 TaskCompletionSource<bool> candidate, 114 int id, 115 Action<TaskCompletionSource<bool>> complete) 116 { 117 lock (_lock) 118 { 119 if (_currentTcs == candidate && _callId == id) 120 { 121 complete(candidate); 122 _currentTcs = null; 123 } 124 } 125 } 126 127 public void Dispose() 128 { 129 lock (_lock) 130 { 131 _currentCancellationSource?.Cancel(); 132 _currentCancellationSource?.Dispose(); 133 _currentTcs?.TrySetException(new ObjectDisposedException(nameof(SupersedingAsyncGate))); 134 _currentTcs = null; 135 } 136 137 GC.SuppressFinalize(this); 138 } 139 }