/ Drivers / CMSIS / DSP / Source / BasicMathFunctions / arm_dot_prod_f32.c
arm_dot_prod_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_dot_prod_f32.c
  4   * Description:  Floating-point dot product
  5   *
  6   * $Date:        05 October 2021
  7   * $Revision:    V1.9.1
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/basic_math_functions.h"
 30  
 31  /**
 32    @ingroup groupMath
 33   */
 34  
 35  /**
 36    @defgroup BasicDotProd Vector Dot Product
 37  
 38    Computes the dot product of two vectors.
 39    The vectors are multiplied element-by-element and then summed.
 40  
 41    <pre>
 42        sum = pSrcA[0]*pSrcB[0] + pSrcA[1]*pSrcB[1] + ... + pSrcA[blockSize-1]*pSrcB[blockSize-1]
 43    </pre>
 44  
 45    There are separate functions for floating-point, Q7, Q15, and Q31 data types.
 46   */
 47  
 48  /**
 49    @addtogroup BasicDotProd
 50    @{
 51   */
 52  
 53  /**
 54    @brief         Dot product of floating-point vectors.
 55    @param[in]     pSrcA      points to the first input vector.
 56    @param[in]     pSrcB      points to the second input vector.
 57    @param[in]     blockSize  number of samples in each vector.
 58    @param[out]    result     output result returned here.
 59    @return        none
 60   */
 61  
 62  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 63  
 64  #include "arm_helium_utils.h"
 65  
 66  
 67  void arm_dot_prod_f32(
 68      const float32_t * pSrcA,
 69      const float32_t * pSrcB,
 70      uint32_t    blockSize,
 71      float32_t * result)
 72  {
 73      f32x4_t vecA, vecB;
 74      f32x4_t vecSum;
 75      uint32_t blkCnt; 
 76      float32_t sum = 0.0f;  
 77      vecSum = vdupq_n_f32(0.0f);
 78  
 79      /* Compute 4 outputs at a time */
 80      blkCnt = blockSize >> 2U;
 81      while (blkCnt > 0U)
 82      {
 83          /*
 84           * C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1]
 85           * Calculate dot product and then store the result in a temporary buffer.
 86           * and advance vector source and destination pointers
 87           */
 88          vecA = vld1q(pSrcA);
 89          pSrcA += 4;
 90          
 91          vecB = vld1q(pSrcB);
 92          pSrcB += 4;
 93  
 94          vecSum = vfmaq(vecSum, vecA, vecB);
 95          /*
 96           * Decrement the blockSize loop counter
 97           */
 98          blkCnt --;
 99      }
100  
101  
102      blkCnt = blockSize & 3;
103      if (blkCnt > 0U)
104      {
105          /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */
106  
107          mve_pred16_t p0 = vctp32q(blkCnt);
108          vecA = vld1q(pSrcA);
109          vecB = vld1q(pSrcB);
110          vecSum = vfmaq_m(vecSum, vecA, vecB, p0);
111      }
112  
113      sum = vecAddAcrossF32Mve(vecSum);
114  
115      /* Store result in destination buffer */
116      *result = sum;
117  
118  }
119  
120  #else
121  
122  void arm_dot_prod_f32(
123    const float32_t * pSrcA,
124    const float32_t * pSrcB,
125          uint32_t blockSize,
126          float32_t * result)
127  {
128          uint32_t blkCnt;                               /* Loop counter */
129          float32_t sum = 0.0f;                          /* Temporary return variable */
130  
131  #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
132      f32x4_t vec1;
133      f32x4_t vec2;
134      f32x4_t accum = vdupq_n_f32(0);   
135  #if !defined(__aarch64__)
136      f32x2_t tmp = vdup_n_f32(0); 
137  #endif   
138  
139      /* Compute 4 outputs at a time */
140      blkCnt = blockSize >> 2U;
141  
142      vec1 = vld1q_f32(pSrcA);
143      vec2 = vld1q_f32(pSrcB);
144  
145      while (blkCnt > 0U)
146      {
147          /* C = A[0]*B[0] + A[1]*B[1] + A[2]*B[2] + ... + A[blockSize-1]*B[blockSize-1] */
148          /* Calculate dot product and then store the result in a temporary buffer. */
149          
150  	      accum = vmlaq_f32(accum, vec1, vec2);
151  	
152          /* Increment pointers */
153          pSrcA += 4;
154          pSrcB += 4; 
155  
156          vec1 = vld1q_f32(pSrcA);
157          vec2 = vld1q_f32(pSrcB);
158          
159          /* Decrement the loop counter */
160          blkCnt--;
161      }
162      
163  #if defined(__aarch64__)
164      sum = vpadds_f32(vpadd_f32(vget_low_f32(accum), vget_high_f32(accum)));
165  #else
166      tmp = vpadd_f32(vget_low_f32(accum), vget_high_f32(accum));
167      sum = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
168  
169  #endif    
170  
171      /* Tail */
172      blkCnt = blockSize & 0x3;
173  
174  #else
175  #if defined (ARM_MATH_LOOPUNROLL) && !defined(ARM_MATH_AUTOVECTORIZE)
176  
177    /* Loop unrolling: Compute 4 outputs at a time */
178    blkCnt = blockSize >> 2U;
179  
180    /* First part of the processing with loop unrolling. Compute 4 outputs at a time.
181     ** a second loop below computes the remaining 1 to 3 samples. */
182    while (blkCnt > 0U)
183    {
184      /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */
185  
186      /* Calculate dot product and store result in a temporary buffer. */
187      sum += (*pSrcA++) * (*pSrcB++);
188  
189      sum += (*pSrcA++) * (*pSrcB++);
190  
191      sum += (*pSrcA++) * (*pSrcB++);
192  
193      sum += (*pSrcA++) * (*pSrcB++);
194  
195      /* Decrement loop counter */
196      blkCnt--;
197    }
198  
199    /* Loop unrolling: Compute remaining outputs */
200    blkCnt = blockSize % 0x4U;
201  
202  #else
203  
204    /* Initialize blkCnt with number of samples */
205    blkCnt = blockSize;
206  
207  #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
208  #endif /* #if defined(ARM_MATH_NEON) */
209  
210    while (blkCnt > 0U)
211    {
212      /* C = A[0]* B[0] + A[1]* B[1] + A[2]* B[2] + .....+ A[blockSize-1]* B[blockSize-1] */
213  
214      /* Calculate dot product and store result in a temporary buffer. */
215      sum += (*pSrcA++) * (*pSrcB++);
216  
217      /* Decrement loop counter */
218      blkCnt--;
219    }
220  
221    /* Store result in destination buffer */
222    *result = sum;
223  }
224  
225  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
226  /**
227    @} end of BasicDotProd group
228   */