/ rpcperms / middleware_handler_test.go
middleware_handler_test.go
 1  package rpcperms
 2  
 3  import (
 4  	"encoding/json"
 5  	"testing"
 6  
 7  	"github.com/lightningnetwork/lnd/lnrpc"
 8  	"github.com/stretchr/testify/require"
 9  )
10  
11  // TestReplaceProtoMsg makes sure the proto message replacement works as
12  // expected.
13  func TestReplaceProtoMsg(t *testing.T) {
14  	testCases := []struct {
15  		name        string
16  		original    interface{}
17  		replacement interface{}
18  		expectedErr string
19  	}{{
20  		name: "simple content replacement",
21  		original: &lnrpc.Invoice{
22  			Memo:  "This is a memo string",
23  			Value: 123456,
24  		},
25  		replacement: &lnrpc.Invoice{
26  			Memo:  "This is the replaced string",
27  			Value: 654321,
28  		},
29  	}, {
30  		name: "replace with empty message",
31  		original: &lnrpc.Invoice{
32  			Memo:  "This is a memo string",
33  			Value: 123456,
34  		},
35  		replacement: &lnrpc.Invoice{},
36  	}, {
37  		name: "replace with fewer fields",
38  		original: &lnrpc.Invoice{
39  			Memo:  "This is a memo string",
40  			Value: 123456,
41  		},
42  		replacement: &lnrpc.Invoice{
43  			Value: 654321,
44  		},
45  	}, {
46  		name: "wrong replacement type",
47  		original: &lnrpc.Invoice{
48  			Memo:  "This is a memo string",
49  			Value: 123456,
50  		},
51  		replacement: &lnrpc.AddInvoiceResponse{},
52  		expectedErr: "replacement message is of wrong type",
53  	}, {
54  		name:     "wrong original type",
55  		original: &interceptRequest{},
56  		replacement: &lnrpc.Invoice{
57  			Memo:  "This is the replaced string",
58  			Value: 654321,
59  		},
60  		expectedErr: "target is not a proto message",
61  	}}
62  
63  	for _, tc := range testCases {
64  		t.Run(tc.name, func(tt *testing.T) {
65  			err := replaceProtoMsg(tc.original, tc.replacement)
66  
67  			if tc.expectedErr != "" {
68  				require.Error(tt, err)
69  				require.Contains(
70  					tt, err.Error(), tc.expectedErr,
71  				)
72  
73  				return
74  			}
75  
76  			require.NoError(tt, err)
77  			jsonEqual(tt, tc.replacement, tc.original)
78  		})
79  	}
80  }
81  
82  func jsonEqual(t *testing.T, expected, actual interface{}) {
83  	expectedJSON, err := json.Marshal(expected)
84  	require.NoError(t, err)
85  
86  	actualJSON, err := json.Marshal(actual)
87  	require.NoError(t, err)
88  
89  	require.JSONEq(t, string(expectedJSON), string(actualJSON))
90  }