sklearn_base.py
1 # -*- coding: utf-8 -*- 2 """Utility function copied over from sklearn/base.py 3 """ 4 # Author: Yue Zhao <zhaoy@cmu.edu> 5 # License: BSD 2 clause 6 7 8 import numpy as np 9 from joblib.parallel import cpu_count 10 11 12 def _get_n_jobs(n_jobs): 13 """Get number of jobs for the computation. 14 See sklearn/utils/__init__.py for more information. 15 16 This function reimplements the logic of joblib to determine the actual 17 number of jobs depending on the cpu count. If -1 all CPUs are used. 18 If 1 is given, no parallel computing code is used at all, which is useful 19 for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. 20 Thus for n_jobs = -2, all CPUs but one are used. 21 Parameters 22 ---------- 23 n_jobs : int 24 Number of jobs stated in joblib convention. 25 Returns 26 ------- 27 n_jobs : int 28 The actual number of jobs as positive integer. 29 """ 30 if n_jobs < 0: 31 return max(cpu_count() + 1 + n_jobs, 1) 32 elif n_jobs == 0: 33 raise ValueError('Parameter n_jobs == 0 has no meaning.') 34 else: 35 return n_jobs 36 37 38 def _partition_estimators(n_estimators, n_jobs): 39 """Private function used to partition estimators between jobs. 40 See sklearn/ensemble/base.py for more information. 41 """ 42 # Compute the number of jobs 43 n_jobs = min(_get_n_jobs(n_jobs), n_estimators) 44 45 # Partition estimators between jobs 46 n_estimators_per_job = (n_estimators // n_jobs) * np.ones(n_jobs, 47 dtype=int) 48 n_estimators_per_job[:n_estimators % n_jobs] += 1 49 starts = np.cumsum(n_estimators_per_job) 50 51 return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() 52 53 54 def _pprint(params, offset=0, printer=repr): 55 # noinspection PyPep8 56 """Pretty print the dictionary 'params' 57 58 See http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html 59 and sklearn/base.py for more information. 60 61 :param params: The dictionary to pretty print 62 :type params: dict 63 64 :param offset: The offset in characters to add at the begin of each line. 65 :type offset: int 66 67 :param printer: The function to convert entries to strings, typically 68 the builtin str or repr 69 :type printer: callable 70 71 :return: None 72 """ 73 74 # Do a multi-line justified repr: 75 options = np.get_printoptions() 76 np.set_printoptions(precision=5, threshold=64, edgeitems=2) 77 params_list = list() 78 this_line_length = offset 79 line_sep = ',\n' + (1 + offset // 2) * ' ' 80 for i, (k, v) in enumerate(sorted(params.items())): 81 if type(v) is float: 82 # use str for representing floating point numbers 83 # this way we get consistent representation across 84 # architectures and versions. 85 this_repr = '%s=%s' % (k, str(v)) 86 else: 87 # use repr of the rest 88 this_repr = '%s=%s' % (k, printer(v)) 89 if len(this_repr) > 500: 90 this_repr = this_repr[:300] + '...' + this_repr[-100:] 91 if i > 0: 92 if this_line_length + len(this_repr) >= 75 or '\n' in this_repr: 93 params_list.append(line_sep) 94 this_line_length = len(line_sep) 95 else: 96 params_list.append(', ') 97 this_line_length += 2 98 params_list.append(this_repr) 99 this_line_length += len(this_repr) 100 101 np.set_printoptions(**options) 102 lines = ''.join(params_list) 103 # Strip trailing space to avoid nightmare in doctests 104 lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n')) 105 return lines