/ tests / utils / test_arguments_utils.py
test_arguments_utils.py
 1  import functools
 2  
 3  import pytest
 4  
 5  from mlflow.utils.arguments_utils import _get_arg_names
 6  
 7  
 8  def no_args():
 9      pass
10  
11  
12  def positional(a, b):
13      return a, b
14  
15  
16  def keyword(a=0, b=0):
17      return a, b
18  
19  
20  def positional_and_keyword(a, b=0):
21      return a, b
22  
23  
24  def keyword_only(*, a, b=0):
25      return a, b
26  
27  
28  def var_positional(*args):
29      return args
30  
31  
32  def var_keyword(**kwargs):
33      return kwargs
34  
35  
36  def var_positional_and_keyword(*args, **kwargs):
37      return args, kwargs
38  
39  
40  @functools.wraps(positional)
41  def wrapper(*args, **kwargs):
42      return positional(*args, **kwargs)
43  
44  
45  @pytest.mark.parametrize(
46      ("func", "expected_args"),
47      [
48          (no_args, []),
49          (positional, ["a", "b"]),
50          (keyword, ["a", "b"]),
51          (positional_and_keyword, ["a", "b"]),
52          (keyword_only, ["a", "b"]),
53          (var_positional, ["args"]),
54          (var_keyword, ["kwargs"]),
55          (var_positional_and_keyword, ["args", "kwargs"]),
56          (wrapper, ["a", "b"]),
57      ],
58  )
59  def test_get_arg_names(func, expected_args):
60      assert _get_arg_names(func) == expected_args