/ actor / future_test.go
future_test.go
  1  package actor
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"sync"
  7  	"sync/atomic"
  8  	"testing"
  9  	"time"
 10  
 11  	"github.com/lightningnetwork/lnd/fn/v2"
 12  	"github.com/stretchr/testify/require"
 13  	"pgregory.net/rapid"
 14  )
 15  
 16  // TestFutureAwaitContextCancellation tests that Await respects context
 17  // cancellation if the context is cancelled before the future resolves.
 18  func TestFutureAwaitContextCancellation(t *testing.T) {
 19  	t.Parallel()
 20  
 21  	rapid.Check(t, func(t *rapid.T) {
 22  		// Test cancellation when the Await context is cancelled via
 23  		// context.Cancel. The underlying future will not be completed, allowing
 24  		// us to test the cancellation path of Await.
 25  		prom1 := NewPromise[int]()
 26  		fut1 := prom1.Future()
 27  		ctx1, cancel1 := context.WithCancel(context.Background())
 28  
 29  		// We'll cancel the future immediately after creating it.
 30  		cancel1()
 31  
 32  		result1 := fut1.Await(ctx1)
 33  
 34  		require.True(t, result1.IsErr())
 35  		require.ErrorIs(
 36  			t, result1.Err(), context.Canceled,
 37  			"await with immediate cancel",
 38  		)
 39  
 40  		// Test cancellation when the Await context times out. The
 41  		// underlying future will also not be completed.
 42  		prom2 := NewPromise[int]()
 43  		fut2 := prom2.Future()
 44  
 45  		// Use a very short timeout that will trigger.
 46  		ctx2, cancel2 := context.WithTimeout(
 47  			context.Background(), 1*time.Nanosecond,
 48  		)
 49  		defer cancel2()
 50  
 51  		// Await the future; it should fall through to the timeout
 52  		// because the future itself is not completed.
 53  		result2 := fut2.Await(ctx2)
 54  
 55  		require.True(t, result2.IsErr())
 56  		require.ErrorIs(
 57  			t, result2.Err(), context.DeadlineExceeded,
 58  			"await with timeout",
 59  		)
 60  	})
 61  }
 62  
 63  // TestFutureAwaitFutureCompletes tests that Await returns the future's
 64  // result if the context is not cancelled before the future resolves.
 65  func TestFutureAwaitFutureCompletes(t *testing.T) {
 66  	t.Parallel()
 67  
 68  	rapid.Check(t, func(t *rapid.T) {
 69  		valToSet := rapid.Int().Draw(t, "valToSet")
 70  
 71  		// With a 50% chance, configure the test to complete the future
 72  		// with an error instead of a successful value.
 73  		var errToSet error
 74  		if rapid.Bool().Draw(t, "have_error") {
 75  			errToSet = fmt.Errorf("err")
 76  		}
 77  
 78  		promise := NewPromise[int]()
 79  		fut := promise.Future()
 80  
 81  		// Use a background context for Await, as we expect the future
 82  		// to complete normally.
 83  		ctx := context.Background()
 84  
 85  		// Complete the future in a separate goroutine to simulate an
 86  		// asynchronous operation.
 87  		go func() {
 88  			if errToSet != nil {
 89  				promise.Complete(fn.Err[int](errToSet))
 90  			} else {
 91  				promise.Complete(fn.Ok(valToSet))
 92  			}
 93  		}()
 94  
 95  		// Now we'll wait for the future to complete, then verify below
 96  		// that the result (value or error) is as expected.
 97  		result := fut.Await(ctx)
 98  
 99  		if errToSet != nil {
100  			// If an error was set, verify that Await returns that
101  			// specific error.
102  			require.True(t, result.IsErr())
103  			require.ErrorIs(
104  				t, result.Err(), errToSet,
105  				"await with error",
106  			)
107  		} else {
108  			// If no error was set, verify that Await returns the
109  			// correct value.
110  			require.False(t, result.IsErr(), "await with value")
111  
112  			result.WhenOk(func(val int) {
113  				require.Equal(
114  					t, valToSet, val, "await with value",
115  				)
116  			})
117  		}
118  	})
119  }
120  
121  // TestFutureThenApplyContextCancellation tests that ThenApply respects its
122  // context, yielding a context error if cancelled before the original future
123  // completes.
124  func TestFutureThenApplyContextCancellation(t *testing.T) {
125  	t.Parallel()
126  
127  	rapid.Check(t, func(t *rapid.T) {
128  		// The original future will not be completed in this test case,
129  		// allowing us to specifically test the cancellation behavior of
130  		// the context passed to ThenApply.
131  		originalPromise := NewPromise[int]()
132  		originalFut := originalPromise.Future()
133  
134  		// Create a context for ThenApply and cancel it immediately.
135  		ctxApply, cancelApply := context.WithCancel(
136  			context.Background(),
137  		)
138  		cancelApply()
139  
140  		var transformCalled atomic.Bool
141  		transform := func(i int) int {
142  			transformCalled.Store(true)
143  			return i * 2
144  		}
145  
146  		// Register the transformation. The ThenApply operation itself
147  		// will start a goroutine to await the originalFut.
148  		newFut := originalFut.ThenApply(ctxApply, transform)
149  
150  		// Await the new (transformed) future. Use a background context
151  		// for this Await to isolate the test to the cancellation of
152  		// ctxApply.
153  		result := newFut.Await(context.Background())
154  
155  		require.True(t, result.IsErr())
156  		require.ErrorIs(
157  			t, result.Err(), context.Canceled,
158  			"ThenApply with cancelled context",
159  		)
160  		require.False(
161  			t, transformCalled.Load(),
162  			"ThenApply transform function called despite "+
163  				"context cancellation",
164  		)
165  	})
166  }
167  
168  // TestFutureThenApplyOriginalFutureCompletes tests ThenApply's behavior when
169  // the original future completes (with a value or error) before ThenApply's
170  // context is cancelled.
171  func TestFutureThenApplyOriginalFutureCompletes(t *testing.T) {
172  	t.Parallel()
173  
174  	rapid.Check(t, func(t *rapid.T) {
175  		initialVal := rapid.Int().Draw(t, "initialVal")
176  
177  		// Configure whether the original future completes with an error
178  		// or a successful value.
179  		var originalErr error
180  		if rapid.Bool().Draw(t, "have_error") {
181  			originalErr = fmt.Errorf("original error")
182  		}
183  
184  		originalPromise := NewPromise[int]()
185  		originalFut := originalPromise.Future()
186  
187  		// Create a context for ThenApply that should not cancel before
188  		// the original future completes.
189  		ctxApply, cancelApply := context.WithTimeout(
190  			context.Background(), 50*time.Millisecond,
191  		)
192  		defer cancelApply()
193  
194  		var transformCalled atomic.Bool
195  		transform := func(i int) int {
196  			transformCalled.Store(true)
197  			return i * 2
198  		}
199  
200  		newFut := originalFut.ThenApply(ctxApply, transform)
201  
202  		// Complete the original future in a separate goroutine to
203  		// simulate asynchrony.
204  		go func() {
205  			if originalErr != nil {
206  				originalPromise.Complete(
207  					fn.Err[int](originalErr),
208  				)
209  			} else {
210  				originalPromise.Complete(fn.Ok(initialVal))
211  			}
212  		}()
213  
214  		// Await our new future which transforms the original future's
215  		// result. Use a background context for this Await.
216  		result := newFut.Await(context.Background())
217  
218  		if originalErr != nil {
219  			// If the original future had an error, the transformed
220  			// future should also yield that same error.
221  			require.True(t, result.IsErr())
222  			require.ErrorIs(
223  				t, result.Err(), originalErr,
224  				"ThenApply with original error",
225  			)
226  			require.False(
227  				t, transformCalled.Load(),
228  				"ThenApply transform function called despite "+
229  					"original future having an error",
230  			)
231  		} else {
232  			// If the original future completed successfully, the
233  			// transformed future should contain the transformed value.
234  			require.False(
235  				t, result.IsErr(),
236  				"ThenApply with original value",
237  			)
238  			require.True(
239  				t, transformCalled.Load(),
240  				"ThenApply transform function not called for "+
241  					"successful original future",
242  			)
243  
244  			result.WhenOk(func(val int) {
245  				expectedTransformedVal := initialVal * 2
246  				require.Equal(
247  					t, expectedTransformedVal, val,
248  					"ThenApply with original value",
249  				)
250  			})
251  		}
252  	})
253  }
254  
255  // TestFutureOnCompleteContextCancellation tests that OnComplete's callback
256  // receives a context error if its context is cancelled before the future
257  // completes.
258  func TestFutureOnCompleteContextCancellation(t *testing.T) {
259  	t.Parallel()
260  
261  	rapid.Check(t, func(t *rapid.T) {
262  		// The original future will not complete in this test, allowing
263  		// us to focus on the cancellation of OnComplete's context.
264  		originalPromise := NewPromise[int]()
265  		originalFut := originalPromise.Future()
266  
267  		// Create a context for OnComplete and cancel it immediately to
268  		// simulate a premature cancellation.
269  		ctxComplete, cancelComplete := context.WithCancel(
270  			context.Background(),
271  		)
272  		cancelComplete()
273  
274  		var wg sync.WaitGroup
275  		wg.Add(1)
276  		var (
277  			callbackInvoked     atomic.Bool
278  			callbackResultValue fn.Result[int]
279  
280  			// mu is a mutex to protect callbackResultValue as it's
281  			// written by the callback goroutine and read by the
282  			// test goroutine.
283  			mu sync.Mutex
284  		)
285  
286  		// Register an OnComplete callback. The callback itself runs in
287  		// a new goroutine started by OnComplete.
288  		originalFut.OnComplete(ctxComplete, func(res fn.Result[int]) {
289  			mu.Lock()
290  			callbackResultValue = res
291  			mu.Unlock()
292  
293  			callbackInvoked.Store(true)
294  			wg.Done()
295  		})
296  
297  		// Use a wait group and a channel to wait for the callback to
298  		// be invoked.
299  		waitChan := make(chan struct{})
300  		go func() {
301  			wg.Wait()
302  			close(waitChan)
303  		}()
304  
305  		select {
306  		// The callback should be invoked, even if with a context error.
307  		case <-waitChan:
308  		case <-time.After(50 * time.Millisecond):
309  			require.Fail(
310  				t, "OnComplete callback timed out waiting "+
311  					"for execution after context cancel",
312  			)
313  		}
314  
315  		require.True(
316  			t, callbackInvoked.Load(),
317  			"OnComplete callback not invoked",
318  		)
319  
320  		mu.Lock()
321  		defer mu.Unlock()
322  
323  		// Verify that the callback received a context.Canceled error
324  		// because its context (ctxComplete) was cancelled.
325  		require.True(t, callbackResultValue.IsErr())
326  		require.ErrorIs(
327  			t, callbackResultValue.Err(), context.Canceled,
328  			"OnComplete with cancelled context",
329  		)
330  	})
331  }
332  
333  // TestFutureOnCompleteFutureCompletes tests OnComplete's behavior when the
334  // future completes (with value or error) before its context is cancelled.
335  func TestFutureOnCompleteFutureCompletes(t *testing.T) {
336  	t.Parallel()
337  
338  	rapid.Check(t, func(t *rapid.T) {
339  		valToSet := rapid.Int().Draw(t, "valToSet")
340  
341  		// Configure whether the original future completes with an error
342  		// or a successful value.
343  		var originalErr error
344  		if rapid.Bool().Draw(t, "have_error") {
345  			originalErr = fmt.Errorf("original error")
346  		}
347  
348  		originalPromise := NewPromise[int]()
349  		originalFut := originalPromise.Future()
350  
351  		// Use a background context for OnComplete, as we expect the
352  		// future to complete normally.
353  		ctxComplete := context.Background()
354  
355  		var wg sync.WaitGroup
356  		wg.Add(1)
357  
358  		var (
359  			callbackInvoked     atomic.Bool
360  			callbackResultValue fn.Result[int]
361  			mu                  sync.Mutex
362  		)
363  
364  		// Register an OnComplete callback. This callback will execute
365  		// once the originalFut completes.
366  		originalFut.OnComplete(ctxComplete, func(res fn.Result[int]) {
367  			mu.Lock()
368  			callbackResultValue = res
369  			mu.Unlock()
370  
371  			callbackInvoked.Store(true)
372  
373  			wg.Done()
374  		})
375  
376  		// Complete the original future in a separate goroutine to
377  		// simulate an asynchronous operation.
378  		go func() {
379  			if originalErr != nil {
380  				originalPromise.Complete(
381  					fn.Err[int](originalErr),
382  				)
383  			} else {
384  				originalPromise.Complete(fn.Ok(valToSet))
385  			}
386  		}()
387  
388  		// Use a wait group and a channel to wait for the callback's
389  		// execution.
390  		waitChan := make(chan struct{})
391  		go func() {
392  			wg.Wait()
393  			close(waitChan)
394  		}()
395  
396  		select {
397  		// The callback should be invoked as the future completes.
398  		case <-waitChan:
399  		case <-time.After(50 * time.Millisecond):
400  			require.Fail(
401  				t, "OnComplete callback timed out waiting "+
402  					"for execution",
403  			)
404  		}
405  
406  		require.True(t, callbackInvoked.Load())
407  
408  		mu.Lock()
409  		defer mu.Unlock()
410  
411  		// Verify that the callback received the correct result (either
412  		// the error or the value from the completed future).
413  		if originalErr != nil {
414  			require.True(t, callbackResultValue.IsErr())
415  			require.ErrorIs(
416  				t, callbackResultValue.Err(), originalErr,
417  				"OnComplete with error",
418  			)
419  		} else {
420  			require.False(
421  				t, callbackResultValue.IsErr(),
422  				"OnComplete with value",
423  			)
424  			callbackResultValue.WhenOk(func(val int) {
425  				require.Equal(
426  					t, valToSet, val,
427  					"OnComplete with value",
428  				)
429  			})
430  		}
431  	})
432  }
433  
434  // TestPromiseCompleteIdempotency verifies that calling Complete on a Promise
435  // multiple times is safe and only the first completion takes effect. Subsequent
436  // calls should return false and not alter the future's result.
437  func TestPromiseCompleteIdempotency(t *testing.T) {
438  	t.Parallel()
439  
440  	promise := NewPromise[string]()
441  	future := promise.Future()
442  
443  	// First completion should succeed.
444  	firstResult := fn.Ok("first-value")
445  	ok := promise.Complete(firstResult)
446  	require.True(t, ok, "first Complete should return true")
447  
448  	// Second completion with a different value should be ignored.
449  	secondResult := fn.Ok("second-value")
450  	ok = promise.Complete(secondResult)
451  	require.False(t, ok, "second Complete should return false")
452  
453  	// Third completion with an error should also be ignored.
454  	thirdResult := fn.Err[string](fmt.Errorf("should be ignored"))
455  	ok = promise.Complete(thirdResult)
456  	require.False(t, ok, "third Complete should return false")
457  
458  	// The future should contain the first value.
459  	result := future.Await(context.Background())
460  	require.False(t, result.IsErr(), "future should not be an error")
461  	result.WhenOk(func(val string) {
462  		require.Equal(
463  			t, "first-value", val,
464  			"future should contain the first completion value",
465  		)
466  	})
467  }