/ src / modules / cmdpal / Core / Microsoft.CmdPal.Core.Common / Helpers / SupersedingAsyncGate.cs
SupersedingAsyncGate.cs
  1  // Copyright (c) Microsoft Corporation
  2  // The Microsoft Corporation licenses this file to you under the MIT license.
  3  // See the LICENSE file in the project root for more information.
  4  
  5  using System;
  6  using System.Threading;
  7  using System.Threading.Tasks;
  8  
  9  namespace Microsoft.CmdPal.Core.Common.Helpers;
 10  
 11  /// <summary>
 12  /// An async gate that ensures only one operation runs at a time.
 13  /// If ExecuteAsync is called while already executing, it cancels the current execution
 14  /// and starts the operation again (superseding behavior).
 15  /// </summary>
 16  public sealed partial class SupersedingAsyncGate : IDisposable
 17  {
 18      private readonly Func<CancellationToken, Task> _action;
 19      private readonly Lock _lock = new();
 20      private int _callId;
 21      private TaskCompletionSource<bool>? _currentTcs;
 22      private CancellationTokenSource? _currentCancellationSource;
 23      private Task? _executingTask;
 24  
 25      public SupersedingAsyncGate(Func<CancellationToken, Task> action)
 26      {
 27          ArgumentNullException.ThrowIfNull(action);
 28          _action = action;
 29      }
 30  
 31      /// <summary>
 32      /// Executes the configured action. If another execution is running, this call will
 33      /// cancel the current execution and restart the operation.
 34      /// </summary>
 35      /// <param name="cancellationToken">Optional external cancellation token</param>
 36      public async Task ExecuteAsync(CancellationToken cancellationToken = default)
 37      {
 38          TaskCompletionSource<bool> tcs;
 39  
 40          lock (_lock)
 41          {
 42              _currentCancellationSource?.Cancel();
 43              _currentTcs?.TrySetException(new OperationCanceledException("Superseded by newer call"));
 44  
 45              tcs = new();
 46              _currentTcs = tcs;
 47              _callId++;
 48  
 49              var shouldStartExecution = _executingTask is null;
 50              if (shouldStartExecution)
 51              {
 52                  _executingTask = Task.Run(ExecuteLoop, CancellationToken.None);
 53              }
 54          }
 55  
 56          await using var ctr = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken));
 57          await tcs.Task;
 58      }
 59  
 60      private async Task ExecuteLoop()
 61      {
 62          try
 63          {
 64              while (true)
 65              {
 66                  TaskCompletionSource<bool>? currentTcs;
 67                  CancellationTokenSource? currentCts;
 68                  int currentCallId;
 69  
 70                  lock (_lock)
 71                  {
 72                      currentTcs = _currentTcs;
 73                      currentCallId = _callId;
 74  
 75                      if (currentTcs is null)
 76                      {
 77                          break;
 78                      }
 79  
 80                      _currentCancellationSource?.Dispose();
 81                      _currentCancellationSource = new();
 82                      currentCts = _currentCancellationSource;
 83                  }
 84  
 85                  try
 86                  {
 87                      await _action(currentCts.Token);
 88                      CompleteIfCurrent(currentTcs, currentCallId, static t => t.TrySetResult(true));
 89                  }
 90                  catch (OperationCanceledException)
 91                  {
 92                      CompleteIfCurrent(currentTcs, currentCallId, tcs => tcs.TrySetCanceled(currentCts.Token));
 93                  }
 94                  catch (Exception ex)
 95                  {
 96                      CompleteIfCurrent(currentTcs, currentCallId, tcs => tcs.TrySetException(ex));
 97                  }
 98              }
 99          }
100          finally
101          {
102              lock (_lock)
103              {
104                  _currentTcs = null;
105                  _currentCancellationSource?.Dispose();
106                  _currentCancellationSource = null;
107                  _executingTask = null;
108              }
109          }
110      }
111  
112      private void CompleteIfCurrent(
113          TaskCompletionSource<bool> candidate,
114          int id,
115          Action<TaskCompletionSource<bool>> complete)
116      {
117          lock (_lock)
118          {
119              if (_currentTcs == candidate && _callId == id)
120              {
121                  complete(candidate);
122                  _currentTcs = null;
123              }
124          }
125      }
126  
127      public void Dispose()
128      {
129          lock (_lock)
130          {
131              _currentCancellationSource?.Cancel();
132              _currentCancellationSource?.Dispose();
133              _currentTcs?.TrySetException(new ObjectDisposedException(nameof(SupersedingAsyncGate)));
134              _currentTcs = null;
135          }
136  
137          GC.SuppressFinalize(this);
138      }
139  }