/ test / functional / interface_usdt_coinselection.py
interface_usdt_coinselection.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2022-present The Bitcoin Core developers
  3  # Distributed under the MIT software license, see the accompanying
  4  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  5  
  6  """  Tests the coin_selection:* tracepoint API interface.
  7       See https://github.com/bitcoin/bitcoin/blob/master/doc/tracing.md#context-coin_selection
  8  """
  9  
 10  # Test will be skipped if we don't have bcc installed
 11  try:
 12      from bcc import BPF, USDT # type: ignore[import]
 13  except ImportError:
 14      pass
 15  from test_framework.test_framework import BitcoinTestFramework
 16  from test_framework.util import (
 17      assert_equal,
 18      assert_greater_than,
 19      assert_raises_rpc_error,
 20      bpf_cflags,
 21  )
 22  
 23  coinselection_tracepoints_program = """
 24  #include <uapi/linux/ptrace.h>
 25  
 26  #define WALLET_NAME_LENGTH 16
 27  #define ALGO_NAME_LENGTH 16
 28  
 29  struct event_data
 30  {
 31      u8 type;
 32      char wallet_name[WALLET_NAME_LENGTH];
 33  
 34      // selected coins event
 35      char algo[ALGO_NAME_LENGTH];
 36      s64 target;
 37      s64 waste;
 38      s64 selected_value;
 39  
 40      // create tx event
 41      bool success;
 42      s64 fee;
 43      s32 change_pos;
 44  
 45      // aps create tx event
 46      bool use_aps;
 47  };
 48  
 49  BPF_QUEUE(coin_selection_events, struct event_data, 1024);
 50  
 51  int trace_selected_coins(struct pt_regs *ctx) {
 52      struct event_data data;
 53      void *pwallet_name = NULL, *palgo = NULL;
 54      __builtin_memset(&data, 0, sizeof(data));
 55      data.type = 1;
 56      bpf_usdt_readarg(1, ctx, &pwallet_name);
 57      bpf_probe_read_user_str(&data.wallet_name, WALLET_NAME_LENGTH, pwallet_name);
 58      bpf_usdt_readarg(2, ctx, &palgo);
 59      bpf_probe_read_user_str(&data.algo, ALGO_NAME_LENGTH, palgo);
 60      bpf_usdt_readarg(3, ctx, &data.target);
 61      bpf_usdt_readarg(4, ctx, &data.waste);
 62      bpf_usdt_readarg(5, ctx, &data.selected_value);
 63      coin_selection_events.push(&data, 0);
 64      return 0;
 65  }
 66  
 67  int trace_normal_create_tx(struct pt_regs *ctx) {
 68      struct event_data data;
 69      void *pwallet_name = NULL;
 70      __builtin_memset(&data, 0, sizeof(data));
 71      data.type = 2;
 72      bpf_usdt_readarg(1, ctx, &pwallet_name);
 73      bpf_probe_read_user_str(&data.wallet_name, WALLET_NAME_LENGTH, pwallet_name);
 74      bpf_usdt_readarg(2, ctx, &data.success);
 75      bpf_usdt_readarg(3, ctx, &data.fee);
 76      bpf_usdt_readarg(4, ctx, &data.change_pos);
 77      coin_selection_events.push(&data, 0);
 78      return 0;
 79  }
 80  
 81  int trace_attempt_aps(struct pt_regs *ctx) {
 82      struct event_data data;
 83      void *pwallet_name = NULL;
 84      __builtin_memset(&data, 0, sizeof(data));
 85      data.type = 3;
 86      bpf_usdt_readarg(1, ctx, &pwallet_name);
 87      bpf_probe_read_user_str(&data.wallet_name, WALLET_NAME_LENGTH, pwallet_name);
 88      coin_selection_events.push(&data, 0);
 89      return 0;
 90  }
 91  
 92  int trace_aps_create_tx(struct pt_regs *ctx) {
 93      struct event_data data;
 94      void *pwallet_name = NULL;
 95      __builtin_memset(&data, 0, sizeof(data));
 96      data.type = 4;
 97      bpf_usdt_readarg(1, ctx, &pwallet_name);
 98      bpf_probe_read_user_str(&data.wallet_name, WALLET_NAME_LENGTH, pwallet_name);
 99      bpf_usdt_readarg(2, ctx, &data.use_aps);
100      bpf_usdt_readarg(3, ctx, &data.success);
101      bpf_usdt_readarg(4, ctx, &data.fee);
102      bpf_usdt_readarg(5, ctx, &data.change_pos);
103      coin_selection_events.push(&data, 0);
104      return 0;
105  }
106  """
107  
108  
109  class CoinSelectionTracepointTest(BitcoinTestFramework):
110      def set_test_params(self):
111          self.num_nodes = 1
112          self.setup_clean_chain = True
113  
114      def skip_test_if_missing_module(self):
115          self.skip_if_platform_not_linux()
116          self.skip_if_no_bitcoind_tracepoints()
117          self.skip_if_no_python_bcc()
118          self.skip_if_no_bpf_permissions()
119          self.skip_if_no_wallet()
120  
121      def get_tracepoints(self, expected_types):
122          events = []
123          try:
124              for i in range(0, len(expected_types) + 1):
125                  event = self.bpf["coin_selection_events"].pop()
126                  assert_equal(event.wallet_name.decode(), self.default_wallet_name)
127                  assert_equal(event.type, expected_types[i])
128                  events.append(event)
129              else:
130                  # If the loop exits successfully instead of throwing a KeyError, then we have had
131                  # more events than expected. There should be no more than len(expected_types) events.
132                  assert False
133          except KeyError:
134              assert_equal(len(events), len(expected_types))
135              return events
136  
137  
138      def determine_selection_from_usdt(self, events):
139          success = None
140          use_aps = None
141          algo = None
142          waste = None
143          change_pos = None
144  
145          is_aps = False
146          sc_events = []
147          for event in events:
148              if event.type == 1:
149                  if not is_aps:
150                      algo = event.algo.decode()
151                      waste = event.waste
152                  sc_events.append(event)
153              elif event.type == 2:
154                  success = event.success
155                  if not is_aps:
156                      change_pos = event.change_pos
157              elif event.type == 3:
158                  is_aps = True
159              elif event.type == 4:
160                  assert is_aps
161                  if event.use_aps:
162                      use_aps = True
163                      assert_equal(len(sc_events), 2)
164                      algo = sc_events[1].algo.decode()
165                      waste = sc_events[1].waste
166                      change_pos = event.change_pos
167          return success, use_aps, algo, waste, change_pos
168  
169      def run_test(self):
170          self.log.info("hook into the coin_selection tracepoints")
171          ctx = USDT(pid=self.nodes[0].process.pid)
172          ctx.enable_probe(probe="coin_selection:selected_coins", fn_name="trace_selected_coins")
173          ctx.enable_probe(probe="coin_selection:normal_create_tx_internal", fn_name="trace_normal_create_tx")
174          ctx.enable_probe(probe="coin_selection:attempting_aps_create_tx", fn_name="trace_attempt_aps")
175          ctx.enable_probe(probe="coin_selection:aps_create_tx_internal", fn_name="trace_aps_create_tx")
176          self.bpf = BPF(text=coinselection_tracepoints_program, usdt_contexts=[ctx], debug=0, cflags=bpf_cflags())
177  
178          self.log.info("Prepare wallets")
179          self.generate(self.nodes[0], 101)
180          wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name)
181  
182          self.log.info("Sending a transaction should result in all tracepoints")
183          # We should have 5 tracepoints in the order:
184          # 1. selected_coins (type 1)
185          # 2. normal_create_tx_internal (type 2)
186          # 3. attempting_aps_create_tx (type 3)
187          # 4. selected_coins (type 1)
188          # 5. aps_create_tx_internal (type 4)
189          wallet.sendtoaddress(wallet.getnewaddress(), 10)
190          events = self.get_tracepoints([1, 2, 3, 1, 4])
191          success, use_aps, _algo, _waste, change_pos = self.determine_selection_from_usdt(events)
192          assert_equal(success, True)
193          assert_greater_than(change_pos, -1)
194  
195          self.log.info("Failing to fund results in 1 tracepoint")
196          # We should have 1 tracepoints in the order
197          # 1. normal_create_tx_internal (type 2)
198          assert_raises_rpc_error(-6, "Insufficient funds", wallet.sendtoaddress, wallet.getnewaddress(), 102 * 50)
199          events = self.get_tracepoints([2])
200          success, use_aps, _algo, _waste, change_pos = self.determine_selection_from_usdt(events)
201          assert_equal(success, False)
202  
203          self.log.info("Explicitly enabling APS results in 2 tracepoints")
204          # We should have 2 tracepoints in the order
205          # 1. selected_coins (type 1)
206          # 2. normal_create_tx_internal (type 2)
207          wallet.setwalletflag("avoid_reuse")
208          wallet.sendtoaddress(address=wallet.getnewaddress(), amount=10, avoid_reuse=True)
209          events = self.get_tracepoints([1, 2])
210          success, use_aps, _algo, _waste, change_pos = self.determine_selection_from_usdt(events)
211          assert_equal(success, True)
212          assert_equal(use_aps, None)
213  
214          self.log.info("Change position is -1 if no change is created with APS when APS was initially not used")
215          # We should have 2 tracepoints in the order:
216          # 1. selected_coins (type 1)
217          # 2. normal_create_tx_internal (type 2)
218          # 3. attempting_aps_create_tx (type 3)
219          # 4. selected_coins (type 1)
220          # 5. aps_create_tx_internal (type 4)
221          wallet.sendtoaddress(address=wallet.getnewaddress(), amount=wallet.getbalance(), subtractfeefromamount=True, avoid_reuse=False)
222          events = self.get_tracepoints([1, 2, 3, 1, 4])
223          success, use_aps, _algo, _waste, change_pos = self.determine_selection_from_usdt(events)
224          assert_equal(success, True)
225          assert_equal(change_pos, -1)
226  
227          self.log.info("Change position is -1 if no change is created normally and APS is not used")
228          # We should have 2 tracepoints in the order:
229          # 1. selected_coins (type 1)
230          # 2. normal_create_tx_internal (type 2)
231          wallet.sendtoaddress(address=wallet.getnewaddress(), amount=wallet.getbalance(), subtractfeefromamount=True)
232          events = self.get_tracepoints([1, 2])
233          success, use_aps, _algo, _waste, change_pos = self.determine_selection_from_usdt(events)
234          assert_equal(success, True)
235          assert_equal(change_pos, -1)
236  
237          self.bpf.cleanup()
238  
239  
240  if __name__ == '__main__':
241      CoinSelectionTracepointTest(__file__).main()