arm_logsumexp_f32.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_logsumexp_f32.c 4 * Description: LogSumExp 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/statistics_functions.h" 30 #include <limits.h> 31 #include <math.h> 32 33 34 /** 35 * @addtogroup LogSumExp 36 * @{ 37 */ 38 39 40 /** 41 * @brief Computation of the LogSumExp 42 * 43 * In probabilistic computations, the dynamic of the probability values can be very 44 * wide because they come from gaussian functions. 45 * To avoid underflow and overflow issues, the values are represented by their log. 46 * In this representation, multiplying the original exp values is easy : their logs are added. 47 * But adding the original exp values is requiring some special handling and it is the 48 * goal of the LogSumExp function. 49 * 50 * If the values are x1...xn, the function is computing: 51 * 52 * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that 53 * rounding issues are minimised. 54 * 55 * The max xm of the values is extracted and the function is computing: 56 * xm + ln(exp(x1 - xm) + ... + exp(xn - xm)) 57 * 58 * @param[in] *in Pointer to an array of input values. 59 * @param[in] blockSize Number of samples in the input array. 60 * @return LogSumExp 61 * 62 */ 63 64 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) 65 66 #include "arm_helium_utils.h" 67 #include "arm_vec_math.h" 68 69 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize) 70 { 71 float32_t maxVal; 72 const float32_t *pIn; 73 int32_t blkCnt; 74 float32_t accum=0.0f; 75 float32_t tmp; 76 77 78 arm_max_no_idx_f32((float32_t *) in, blockSize, &maxVal); 79 80 81 blkCnt = blockSize; 82 pIn = in; 83 84 85 f32x4_t vSum = vdupq_n_f32(0.0f); 86 blkCnt = blockSize >> 2; 87 while(blkCnt > 0) 88 { 89 f32x4_t vecIn = vld1q(pIn); 90 f32x4_t vecExp; 91 92 vecExp = vexpq_f32(vsubq_n_f32(vecIn, maxVal)); 93 94 vSum = vaddq_f32(vSum, vecExp); 95 96 /* 97 * Decrement the blockSize loop counter 98 * Advance vector source and destination pointers 99 */ 100 pIn += 4; 101 blkCnt --; 102 } 103 104 /* sum + log */ 105 accum = vecAddAcrossF32Mve(vSum); 106 107 blkCnt = blockSize & 0x3; 108 while(blkCnt > 0) 109 { 110 tmp = *pIn++; 111 accum += expf(tmp - maxVal); 112 blkCnt--; 113 114 } 115 116 accum = maxVal + log(accum); 117 118 return (accum); 119 } 120 121 #else 122 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE) 123 124 #include "NEMath.h" 125 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize) 126 { 127 float32_t maxVal; 128 float32_t tmp; 129 float32x4_t tmpV, tmpVb; 130 float32x4_t maxValV; 131 uint32x4_t idxV; 132 float32x4_t accumV; 133 float32x2_t accumV2; 134 135 const float32_t *pIn; 136 uint32_t blkCnt; 137 float32_t accum; 138 139 pIn = in; 140 141 blkCnt = blockSize; 142 143 if (blockSize <= 3) 144 { 145 maxVal = *pIn++; 146 blkCnt--; 147 148 while(blkCnt > 0) 149 { 150 tmp = *pIn++; 151 152 if (tmp > maxVal) 153 { 154 maxVal = tmp; 155 } 156 blkCnt--; 157 } 158 } 159 else 160 { 161 maxValV = vld1q_f32(pIn); 162 pIn += 4; 163 blkCnt = (blockSize - 4) >> 2; 164 165 while(blkCnt > 0) 166 { 167 tmpVb = vld1q_f32(pIn); 168 pIn += 4; 169 170 idxV = vcgtq_f32(tmpVb, maxValV); 171 maxValV = vbslq_f32(idxV, tmpVb, maxValV ); 172 173 blkCnt--; 174 } 175 176 accumV2 = vpmax_f32(vget_low_f32(maxValV),vget_high_f32(maxValV)); 177 accumV2 = vpmax_f32(accumV2,accumV2); 178 maxVal = vget_lane_f32(accumV2, 0) ; 179 180 blkCnt = (blockSize - 4) & 3; 181 182 while(blkCnt > 0) 183 { 184 tmp = *pIn++; 185 186 if (tmp > maxVal) 187 { 188 maxVal = tmp; 189 } 190 blkCnt--; 191 } 192 193 } 194 195 196 197 maxValV = vdupq_n_f32(maxVal); 198 pIn = in; 199 accum = 0; 200 accumV = vdupq_n_f32(0.0f); 201 202 blkCnt = blockSize >> 2; 203 204 while(blkCnt > 0) 205 { 206 tmpV = vld1q_f32(pIn); 207 pIn += 4; 208 tmpV = vsubq_f32(tmpV, maxValV); 209 tmpV = vexpq_f32(tmpV); 210 accumV = vaddq_f32(accumV, tmpV); 211 212 blkCnt--; 213 214 } 215 accumV2 = vpadd_f32(vget_low_f32(accumV),vget_high_f32(accumV)); 216 accum = vget_lane_f32(accumV2, 0) + vget_lane_f32(accumV2, 1); 217 218 blkCnt = blockSize & 0x3; 219 while(blkCnt > 0) 220 { 221 tmp = *pIn++; 222 accum += expf(tmp - maxVal); 223 blkCnt--; 224 225 } 226 227 accum = maxVal + logf(accum); 228 229 return(accum); 230 } 231 #else 232 float32_t arm_logsumexp_f32(const float32_t *in, uint32_t blockSize) 233 { 234 float32_t maxVal; 235 float32_t tmp; 236 const float32_t *pIn; 237 uint32_t blkCnt; 238 float32_t accum; 239 240 pIn = in; 241 blkCnt = blockSize; 242 243 maxVal = *pIn++; 244 blkCnt--; 245 246 while(blkCnt > 0) 247 { 248 tmp = *pIn++; 249 250 if (tmp > maxVal) 251 { 252 maxVal = tmp; 253 } 254 blkCnt--; 255 256 } 257 258 blkCnt = blockSize; 259 pIn = in; 260 accum = 0; 261 while(blkCnt > 0) 262 { 263 tmp = *pIn++; 264 accum += expf(tmp - maxVal); 265 blkCnt--; 266 267 } 268 accum = maxVal + logf(accum); 269 270 return(accum); 271 } 272 #endif 273 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 274 275 /** 276 * @} end of LogSumExp group 277 */