arm_dot_prod_f16.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_dot_prod_f16.c 4 * Description: Floating-point dot product 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/basic_math_functions_f16.h" 30 31 /** 32 @ingroup groupMath 33 */ 34 35 /** 36 @defgroup BasicDotProd Vector Dot Product 37 38 Computes the dot product of two vectors. 39 The vectors are multiplied element-by-element and then summed. 40 41 <pre> 42 sum = pSrcA[0]*pSrcB[0] + pSrcA[1]*pSrcB[1] + ... + pSrcA[blockSize-1]*pSrcB[blockSize-1] 43 </pre> 44 45 There are separate functions for floating-point, Q7, Q15, and Q31 data types. 46 */ 47 48 /** 49 @addtogroup BasicDotProd 50 @{ 51 */ 52 53 /** 54 @brief Dot product of floating-point vectors. 55 @param[in] pSrcA points to the first input vector. 56 @param[in] pSrcB points to the second input vector. 57 @param[in] blockSize number of samples in each vector. 58 @param[out] result output result returned here. 59 @return none 60 */ 61 62 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) 63 64 #include "arm_helium_utils.h" 65 66 67 void arm_dot_prod_f16( 68 const float16_t * pSrcA, 69 const float16_t * pSrcB, 70 uint32_t blockSize, 71 float16_t * result) 72 { 73 f16x8_t vecA, vecB; 74 f16x8_t vecSum; 75 uint32_t blkCnt; 76 float16_t sum = 0.0f; 77 vecSum = vdupq_n_f16(0.0f); 78 79 /* Compute 4 outputs at a time */ 80 blkCnt = blockSize >> 3U; 81 while (blkCnt > 0U) 82 { 83 /* 84 * C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] 85 * Calculate dot product and then store the result in a temporary buffer. 86 * and advance vector source and destination pointers 87 */ 88 vecA = vld1q(pSrcA); 89 pSrcA += 8; 90 91 vecB = vld1q(pSrcB); 92 pSrcB += 8; 93 94 vecSum = vfmaq(vecSum, vecA, vecB); 95 /* 96 * Decrement the blockSize loop counter 97 */ 98 blkCnt --; 99 } 100 101 102 blkCnt = blockSize & 7; 103 if (blkCnt > 0U) 104 { 105 /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */ 106 107 mve_pred16_t p0 = vctp16q(blkCnt); 108 vecA = vld1q(pSrcA); 109 vecB = vld1q(pSrcB); 110 vecSum = vfmaq_m(vecSum, vecA, vecB, p0); 111 } 112 113 sum = vecAddAcrossF16Mve(vecSum); 114 115 /* Store result in destination buffer */ 116 *result = sum; 117 118 } 119 120 #else 121 #if defined(ARM_FLOAT16_SUPPORTED) 122 void arm_dot_prod_f16( 123 const float16_t * pSrcA, 124 const float16_t * pSrcB, 125 uint32_t blockSize, 126 float16_t * result) 127 { 128 uint32_t blkCnt; /* Loop counter */ 129 _Float16 sum = 0.0f; /* Temporary return variable */ 130 131 132 #if defined (ARM_MATH_LOOPUNROLL) && !defined(ARM_MATH_AUTOVECTORIZE) 133 134 /* Loop unrolling: Compute 4 outputs at a time */ 135 blkCnt = blockSize >> 2U; 136 137 /* First part of the processing with loop unrolling. Compute 4 outputs at a time. 138 ** a second loop below computes the remaining 1 to 3 samples. */ 139 while (blkCnt > 0U) 140 { 141 /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */ 142 143 /* Calculate dot product and store result in a temporary buffer. */ 144 sum += (_Float16)(*pSrcA++) * (_Float16)(*pSrcB++); 145 146 sum += (_Float16)(*pSrcA++) * (_Float16)(*pSrcB++); 147 148 sum += (_Float16)(*pSrcA++) * (_Float16)(*pSrcB++); 149 150 sum += (_Float16)(*pSrcA++) * (_Float16)(*pSrcB++); 151 152 /* Decrement loop counter */ 153 blkCnt--; 154 } 155 156 /* Loop unrolling: Compute remaining outputs */ 157 blkCnt = blockSize % 0x4U; 158 159 #else 160 161 /* Initialize blkCnt with number of samples */ 162 blkCnt = blockSize; 163 164 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */ 165 166 while (blkCnt > 0U) 167 { 168 /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */ 169 170 /* Calculate dot product and store result in a temporary buffer. */ 171 sum += (_Float16)(*pSrcA++) * (_Float16)(*pSrcB++); 172 173 /* Decrement loop counter */ 174 blkCnt--; 175 } 176 177 /* Store result in destination buffer */ 178 *result = sum; 179 } 180 #endif 181 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 182 /** 183 @} end of BasicDotProd group 184 */