/ Drivers / CMSIS / DSP / Source / StatisticsFunctions / arm_logsumexp_f16.c
arm_logsumexp_f16.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_logsumexp_f16.c
  4   * Description:  LogSumExp
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/statistics_functions_f16.h"
 30  
 31  #if defined(ARM_FLOAT16_SUPPORTED)
 32  
 33  #include <limits.h>
 34  #include <math.h>
 35  
 36  
 37  /**
 38   * @addtogroup LogSumExp
 39   * @{
 40   */
 41  
 42  
 43  /**
 44   * @brief Computation of the LogSumExp
 45   *
 46   * In probabilistic computations, the dynamic of the probability values can be very
 47   * wide because they come from gaussian functions.
 48   * To avoid underflow and overflow issues, the values are represented by their log.
 49   * In this representation, multiplying the original exp values is easy : their logs are added.
 50   * But adding the original exp values is requiring some special handling and it is the
 51   * goal of the LogSumExp function.
 52   *
 53   * If the values are x1...xn, the function is computing:
 54   *
 55   * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
 56   * rounding issues are minimised.
 57   *
 58   * The max xm of the values is extracted and the function is computing:
 59   * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
 60   *
 61   * @param[in]  *in         Pointer to an array of input values.
 62   * @param[in]  blockSize   Number of samples in the input array.
 63   * @return LogSumExp
 64   *
 65   */
 66  
 67  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
 68  
 69  #include "arm_helium_utils.h"
 70  #include "arm_vec_math_f16.h"
 71  
 72  float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
 73  {
 74      float16_t       maxVal;
 75      const float16_t *pIn;
 76      int32_t         blkCnt;
 77      _Float16       accum=0.0f16;
 78      _Float16       tmp;
 79  
 80  
 81      arm_max_no_idx_f16((float16_t *) in, blockSize, &maxVal);
 82  
 83  
 84      blkCnt = blockSize;
 85      pIn = in;
 86  
 87  
 88      f16x8_t         vSum = vdupq_n_f16(0.0f16);
 89      blkCnt = blockSize >> 3;
 90      while(blkCnt > 0)
 91      {
 92          f16x8_t         vecIn = vld1q(pIn);
 93          f16x8_t         vecExp;
 94  
 95          vecExp = vexpq_f16(vsubq_n_f16(vecIn, maxVal));
 96  
 97          vSum = vaddq_f16(vSum, vecExp);
 98  
 99          /*
100           * Decrement the blockSize loop counter
101           * Advance vector source and destination pointers
102           */
103          pIn += 8;
104          blkCnt --;
105      }
106  
107      /* sum + log */
108      accum = vecAddAcrossF16Mve(vSum);
109  
110      blkCnt = blockSize & 0x7;
111      while(blkCnt > 0)
112      {
113         tmp = *pIn++;
114         accum += expf(tmp - maxVal);
115         blkCnt--;
116      
117      }
118  
119      accum = maxVal + logf(accum);
120  
121      return (accum);
122  }
123  
124  #else
125  float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
126  {
127      _Float16 maxVal;
128      _Float16 tmp;
129      const float16_t *pIn;
130      uint32_t blkCnt;
131      _Float16 accum;
132   
133      pIn = in;
134      blkCnt = blockSize;
135  
136      maxVal = *pIn++;
137      blkCnt--;
138  
139      while(blkCnt > 0)
140      {
141         tmp = *pIn++;
142  
143         if (tmp > maxVal)
144         {
145            maxVal = tmp;
146         }
147         blkCnt--;
148      
149      }
150  
151      blkCnt = blockSize;
152      pIn = in;
153      accum = 0;
154      while(blkCnt > 0)
155      {
156         tmp = *pIn++;
157         accum += expf(tmp - maxVal);
158         blkCnt--;
159      
160      }
161      accum = maxVal + logf(accum);
162  
163      return(accum);
164  }
165  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
166  
167  /**
168   * @} end of LogSumExp group
169   */
170  
171  #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 
172