/ pyod / models / sklearn_base.py
sklearn_base.py
  1  # -*- coding: utf-8 -*-
  2  """Utility function copied over from sklearn/base.py
  3  """
  4  # Author: Yue Zhao <zhaoy@cmu.edu>
  5  # License: BSD 2 clause
  6  
  7  
  8  import numpy as np
  9  from joblib.parallel import cpu_count
 10  
 11  
 12  def _get_n_jobs(n_jobs):
 13      """Get number of jobs for the computation.
 14      See sklearn/utils/__init__.py for more information.
 15  
 16      This function reimplements the logic of joblib to determine the actual
 17      number of jobs depending on the cpu count. If -1 all CPUs are used.
 18      If 1 is given, no parallel computing code is used at all, which is useful
 19      for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used.
 20      Thus for n_jobs = -2, all CPUs but one are used.
 21      Parameters
 22      ----------
 23      n_jobs : int
 24          Number of jobs stated in joblib convention.
 25      Returns
 26      -------
 27      n_jobs : int
 28          The actual number of jobs as positive integer.
 29      """
 30      if n_jobs < 0:
 31          return max(cpu_count() + 1 + n_jobs, 1)
 32      elif n_jobs == 0:
 33          raise ValueError('Parameter n_jobs == 0 has no meaning.')
 34      else:
 35          return n_jobs
 36  
 37  
 38  def _partition_estimators(n_estimators, n_jobs):
 39      """Private function used to partition estimators between jobs.
 40      See sklearn/ensemble/base.py for more information.
 41      """
 42      # Compute the number of jobs
 43      n_jobs = min(_get_n_jobs(n_jobs), n_estimators)
 44  
 45      # Partition estimators between jobs
 46      n_estimators_per_job = (n_estimators // n_jobs) * np.ones(n_jobs,
 47                                                                dtype=int)
 48      n_estimators_per_job[:n_estimators % n_jobs] += 1
 49      starts = np.cumsum(n_estimators_per_job)
 50  
 51      return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
 52  
 53  
 54  def _pprint(params, offset=0, printer=repr):
 55      # noinspection PyPep8
 56      """Pretty print the dictionary 'params'
 57  
 58      See http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html
 59      and sklearn/base.py for more information.
 60  
 61      :param params: The dictionary to pretty print
 62      :type params: dict
 63  
 64      :param offset: The offset in characters to add at the begin of each line.
 65      :type offset: int
 66  
 67      :param printer: The function to convert entries to strings, typically
 68          the builtin str or repr
 69      :type printer: callable
 70  
 71      :return: None
 72      """
 73  
 74      # Do a multi-line justified repr:
 75      options = np.get_printoptions()
 76      np.set_printoptions(precision=5, threshold=64, edgeitems=2)
 77      params_list = list()
 78      this_line_length = offset
 79      line_sep = ',\n' + (1 + offset // 2) * ' '
 80      for i, (k, v) in enumerate(sorted(params.items())):
 81          if type(v) is float:
 82              # use str for representing floating point numbers
 83              # this way we get consistent representation across
 84              # architectures and versions.
 85              this_repr = '%s=%s' % (k, str(v))
 86          else:
 87              # use repr of the rest
 88              this_repr = '%s=%s' % (k, printer(v))
 89          if len(this_repr) > 500:
 90              this_repr = this_repr[:300] + '...' + this_repr[-100:]
 91          if i > 0:
 92              if this_line_length + len(this_repr) >= 75 or '\n' in this_repr:
 93                  params_list.append(line_sep)
 94                  this_line_length = len(line_sep)
 95              else:
 96                  params_list.append(', ')
 97                  this_line_length += 2
 98          params_list.append(this_repr)
 99          this_line_length += len(this_repr)
100  
101      np.set_printoptions(**options)
102      lines = ''.join(params_list)
103      # Strip trailing space to avoid nightmare in doctests
104      lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n'))
105      return lines