arm_logsumexp_f16.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_logsumexp_f16.c 4 * Description: LogSumExp 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/statistics_functions_f16.h" 30 31 #if defined(ARM_FLOAT16_SUPPORTED) 32 33 #include <limits.h> 34 #include <math.h> 35 36 37 /** 38 * @addtogroup LogSumExp 39 * @{ 40 */ 41 42 43 /** 44 * @brief Computation of the LogSumExp 45 * 46 * In probabilistic computations, the dynamic of the probability values can be very 47 * wide because they come from gaussian functions. 48 * To avoid underflow and overflow issues, the values are represented by their log. 49 * In this representation, multiplying the original exp values is easy : their logs are added. 50 * But adding the original exp values is requiring some special handling and it is the 51 * goal of the LogSumExp function. 52 * 53 * If the values are x1...xn, the function is computing: 54 * 55 * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that 56 * rounding issues are minimised. 57 * 58 * The max xm of the values is extracted and the function is computing: 59 * xm + ln(exp(x1 - xm) + ... + exp(xn - xm)) 60 * 61 * @param[in] *in Pointer to an array of input values. 62 * @param[in] blockSize Number of samples in the input array. 63 * @return LogSumExp 64 * 65 */ 66 67 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) 68 69 #include "arm_helium_utils.h" 70 #include "arm_vec_math_f16.h" 71 72 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize) 73 { 74 float16_t maxVal; 75 const float16_t *pIn; 76 int32_t blkCnt; 77 _Float16 accum=0.0f16; 78 _Float16 tmp; 79 80 81 arm_max_no_idx_f16((float16_t *) in, blockSize, &maxVal); 82 83 84 blkCnt = blockSize; 85 pIn = in; 86 87 88 f16x8_t vSum = vdupq_n_f16(0.0f16); 89 blkCnt = blockSize >> 3; 90 while(blkCnt > 0) 91 { 92 f16x8_t vecIn = vld1q(pIn); 93 f16x8_t vecExp; 94 95 vecExp = vexpq_f16(vsubq_n_f16(vecIn, maxVal)); 96 97 vSum = vaddq_f16(vSum, vecExp); 98 99 /* 100 * Decrement the blockSize loop counter 101 * Advance vector source and destination pointers 102 */ 103 pIn += 8; 104 blkCnt --; 105 } 106 107 /* sum + log */ 108 accum = vecAddAcrossF16Mve(vSum); 109 110 blkCnt = blockSize & 0x7; 111 while(blkCnt > 0) 112 { 113 tmp = *pIn++; 114 accum += expf(tmp - maxVal); 115 blkCnt--; 116 117 } 118 119 accum = maxVal + logf(accum); 120 121 return (accum); 122 } 123 124 #else 125 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize) 126 { 127 _Float16 maxVal; 128 _Float16 tmp; 129 const float16_t *pIn; 130 uint32_t blkCnt; 131 _Float16 accum; 132 133 pIn = in; 134 blkCnt = blockSize; 135 136 maxVal = *pIn++; 137 blkCnt--; 138 139 while(blkCnt > 0) 140 { 141 tmp = *pIn++; 142 143 if (tmp > maxVal) 144 { 145 maxVal = tmp; 146 } 147 blkCnt--; 148 149 } 150 151 blkCnt = blockSize; 152 pIn = in; 153 accum = 0; 154 while(blkCnt > 0) 155 { 156 tmp = *pIn++; 157 accum += expf(tmp - maxVal); 158 blkCnt--; 159 160 } 161 accum = maxVal + logf(accum); 162 163 return(accum); 164 } 165 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 166 167 /** 168 * @} end of LogSumExp group 169 */ 170 171 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 172