/ Drivers / CMSIS / DSP / Source / StatisticsFunctions / arm_logsumexp_f32.c
arm_logsumexp_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_logsumexp_f32.c
  4   * Description:  LogSumExp
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/statistics_functions.h"
 30  #include <limits.h>
 31  #include <math.h>
 32  
 33  
 34  /**
 35   * @addtogroup LogSumExp
 36   * @{
 37   */
 38  
 39  
 40  /**
 41   * @brief Computation of the LogSumExp
 42   *
 43   * In probabilistic computations, the dynamic of the probability values can be very
 44   * wide because they come from gaussian functions.
 45   * To avoid underflow and overflow issues, the values are represented by their log.
 46   * In this representation, multiplying the original exp values is easy : their logs are added.
 47   * But adding the original exp values is requiring some special handling and it is the
 48   * goal of the LogSumExp function.
 49   *
 50   * If the values are x1...xn, the function is computing:
 51   *
 52   * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
 53   * rounding issues are minimised.
 54   *
 55   * The max xm of the values is extracted and the function is computing:
 56   * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
 57   *
 58   * @param[in]  *in         Pointer to an array of input values.
 59   * @param[in]  blockSize   Number of samples in the input array.
 60   * @return LogSumExp
 61   *
 62   */
 63  
 64  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 65  
 66  #include "arm_helium_utils.h"
 67  #include "arm_vec_math.h"
 68  
 69  float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
 70  {
 71      float32_t       maxVal;
 72      const float32_t *pIn;
 73      int32_t         blkCnt;
 74      float32_t       accum=0.0f;
 75      float32_t       tmp;
 76  
 77  
 78      arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal);
 79  
 80  
 81      blkCnt = blockSize;
 82      pIn = in;
 83  
 84  
 85      f32x4_t         vSum = vdupq_n_f32(0.0f);
 86      blkCnt = blockSize >> 2;
 87      while(blkCnt > 0)
 88      {
 89          f32x4_t         vecIn = vld1q(pIn);
 90          f32x4_t         vecExp;
 91  
 92          vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal));
 93  
 94          vSum = vaddq_f32(vSum, vecExp);
 95  
 96          /*
 97           * Decrement the blockSize loop counter
 98           * Advance vector source and destination pointers
 99           */
100          pIn += 4;
101          blkCnt --;
102      }
103  
104      /* sum + log */
105      accum = vecAddAcrossF32Mve(vSum);
106  
107      blkCnt = blockSize & 0x3;
108      while(blkCnt > 0)
109      {
110         tmp = *pIn++;
111         accum += expf(tmp - maxVal);
112         blkCnt--;
113      
114      }
115  
116      accum = maxVal + log(accum);
117  
118      return (accum);
119  }
120  
121  #else
122  #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
123  
124  #include "NEMath.h"
125  float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
126  {
127      float32_t maxVal;
128      float32_t tmp;
129      float32x4_t tmpV, tmpVb;
130      float32x4_t maxValV;
131      uint32x4_t idxV;
132      float32x4_t accumV;
133      float32x2_t accumV2;
134  
135      const float32_t *pIn;
136      uint32_t blkCnt;
137      float32_t accum;
138   
139      pIn = in;
140  
141      blkCnt = blockSize;
142  
143      if (blockSize <= 3)
144      {
145        maxVal = *pIn++;
146        blkCnt--;
147  
148        while(blkCnt > 0)
149        {
150           tmp = *pIn++;
151    
152           if (tmp > maxVal)
153           {
154              maxVal = tmp;
155           }
156           blkCnt--;
157        }
158      }
159      else
160      {
161        maxValV = vld1q_f32(pIn);
162        pIn += 4;
163        blkCnt = (blockSize - 4) >> 2;
164  
165        while(blkCnt > 0)
166        {
167           tmpVb = vld1q_f32(pIn);
168           pIn += 4;
169    
170           idxV = vcgtq_f32(tmpVb, maxValV);
171           maxValV = vbslq_f32(idxV, tmpVb, maxValV );
172  
173           blkCnt--;
174        }
175  
176        accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV));
177        accumV2 = vpmax_f32(accumV2,accumV2);
178        maxVal = vget_lane_f32(accumV2, 0) ;
179  
180        blkCnt = (blockSize - 4) & 3;
181  
182        while(blkCnt > 0)
183        {
184           tmp = *pIn++;
185    
186           if (tmp > maxVal)
187           {
188              maxVal = tmp;
189           }
190           blkCnt--;
191        }
192  
193      }
194  
195      
196  
197      maxValV = vdupq_n_f32(maxVal);
198      pIn = in;
199      accum = 0;
200      accumV = vdupq_n_f32(0.0f);
201  
202      blkCnt = blockSize >> 2;
203  
204      while(blkCnt > 0)
205      {
206         tmpV = vld1q_f32(pIn);
207         pIn += 4;
208         tmpV = vsubq_f32(tmpV, maxValV);
209         tmpV = vexpq_f32(tmpV);
210         accumV = vaddq_f32(accumV, tmpV);
211  
212         blkCnt--;
213      
214      }
215      accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV));
216      accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1);
217  
218      blkCnt = blockSize & 0x3;
219      while(blkCnt > 0)
220      {
221         tmp = *pIn++;
222         accum += expf(tmp - maxVal);
223         blkCnt--;
224      
225      }
226  
227      accum = maxVal + logf(accum);
228  
229      return(accum);
230  }
231  #else
232  float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize)
233  {
234      float32_t maxVal;
235      float32_t tmp;
236      const float32_t *pIn;
237      uint32_t blkCnt;
238      float32_t accum;
239   
240      pIn = in;
241      blkCnt = blockSize;
242  
243      maxVal = *pIn++;
244      blkCnt--;
245  
246      while(blkCnt > 0)
247      {
248         tmp = *pIn++;
249  
250         if (tmp > maxVal)
251         {
252            maxVal = tmp;
253         }
254         blkCnt--;
255      
256      }
257  
258      blkCnt = blockSize;
259      pIn = in;
260      accum = 0;
261      while(blkCnt > 0)
262      {
263         tmp = *pIn++;
264         accum += expf(tmp - maxVal);
265         blkCnt--;
266      
267      }
268      accum = maxVal + logf(accum);
269  
270      return(accum);
271  }
272  #endif
273  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
274  
275  /**
276   * @} end of LogSumExp group
277   */