/ Drivers / CMSIS / DSP / Source / ComplexMathFunctions / arm_cmplx_dot_prod_f32.c
arm_cmplx_dot_prod_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_cmplx_dot_prod_f32.c
  4   * Description:  Floating-point complex dot product
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/complex_math_functions.h"
 30  
 31  /**
 32    @ingroup groupCmplxMath
 33   */
 34  
 35  /**
 36    @defgroup cmplx_dot_prod Complex Dot Product
 37  
 38    Computes the dot product of two complex vectors.
 39    The vectors are multiplied element-by-element and then summed.
 40  
 41    The <code>pSrcA</code> points to the first complex input vector and
 42    <code>pSrcB</code> points to the second complex input vector.
 43    <code>numSamples</code> specifies the number of complex samples
 44    and the data in each array is stored in an interleaved fashion
 45    (real, imag, real, imag, ...).
 46    Each array has a total of <code>2*numSamples</code> values.
 47  
 48    The underlying algorithm is used:
 49  
 50    <pre>
 51    realResult = 0;
 52    imagResult = 0;
 53    for (n = 0; n < numSamples; n++) {
 54        realResult += pSrcA[(2*n)+0] * pSrcB[(2*n)+0] - pSrcA[(2*n)+1] * pSrcB[(2*n)+1];
 55        imagResult += pSrcA[(2*n)+0] * pSrcB[(2*n)+1] + pSrcA[(2*n)+1] * pSrcB[(2*n)+0];
 56    }
 57    </pre>
 58  
 59    There are separate functions for floating-point, Q15, and Q31 data types.
 60   */
 61  
 62  /**
 63    @addtogroup cmplx_dot_prod
 64    @{
 65   */
 66  
 67  /**
 68    @brief         Floating-point complex dot product.
 69    @param[in]     pSrcA       points to the first input vector
 70    @param[in]     pSrcB       points to the second input vector
 71    @param[in]     numSamples  number of samples in each vector
 72    @param[out]    realResult  real part of the result returned here
 73    @param[out]    imagResult  imaginary part of the result returned here
 74    @return        none
 75   */
 76  
 77  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 78  
 79  void arm_cmplx_dot_prod_f32(
 80      const float32_t * pSrcA,
 81      const float32_t * pSrcB,
 82      uint32_t numSamples,
 83      float32_t * realResult,
 84      float32_t * imagResult)
 85  {
 86      int32_t         blkCnt;
 87      float32_t       real_sum, imag_sum;
 88      f32x4_t         vecSrcA, vecSrcB;
 89      f32x4_t         vec_acc = vdupq_n_f32(0.0f);
 90      f32x4_t         vecSrcC, vecSrcD;
 91  
 92      blkCnt = numSamples >> 2;
 93      blkCnt -= 1;
 94      if (blkCnt > 0) {
 95          /* should give more freedom to generate stall free code */
 96          vecSrcA = vld1q(pSrcA);
 97          vecSrcB = vld1q(pSrcB);
 98          pSrcA += 4;
 99          pSrcB += 4;
100  
101          while (blkCnt > 0) {
102              vec_acc = vcmlaq(vec_acc, vecSrcA, vecSrcB);
103              vecSrcC = vld1q(pSrcA);
104              pSrcA += 4;
105  
106              vec_acc = vcmlaq_rot90(vec_acc, vecSrcA, vecSrcB);
107              vecSrcD = vld1q(pSrcB);
108              pSrcB += 4;
109  
110              vec_acc = vcmlaq(vec_acc, vecSrcC, vecSrcD);
111              vecSrcA = vld1q(pSrcA);
112              pSrcA += 4;
113  
114              vec_acc = vcmlaq_rot90(vec_acc, vecSrcC, vecSrcD);
115              vecSrcB = vld1q(pSrcB);
116              pSrcB += 4;
117              /*
118               * Decrement the blockSize loop counter
119               */
120              blkCnt--;
121          }
122  
123           /* process last elements out of the loop avoid the armclang breaking the SW pipeline */
124          vec_acc = vcmlaq(vec_acc, vecSrcA, vecSrcB);
125          vecSrcC = vld1q(pSrcA);
126  
127          vec_acc = vcmlaq_rot90(vec_acc, vecSrcA, vecSrcB);
128          vecSrcD = vld1q(pSrcB);
129  
130          vec_acc = vcmlaq(vec_acc, vecSrcC, vecSrcD);
131          vec_acc = vcmlaq_rot90(vec_acc, vecSrcC, vecSrcD);
132  
133          /*
134           * tail
135           */
136          blkCnt = CMPLX_DIM * (numSamples & 3);
137          while (blkCnt > 0) {
138              mve_pred16_t    p = vctp32q(blkCnt);
139              pSrcA += 4;
140              pSrcB += 4;
141              vecSrcA = vldrwq_z_f32(pSrcA, p);
142              vecSrcB = vldrwq_z_f32(pSrcB, p);
143              vec_acc = vcmlaq_m(vec_acc, vecSrcA, vecSrcB, p);
144              vec_acc = vcmlaq_rot90_m(vec_acc, vecSrcA, vecSrcB, p);
145              blkCnt -= 4;
146          }
147      } else {
148          /* small vector */
149          blkCnt = numSamples * CMPLX_DIM;
150          vec_acc = vdupq_n_f32(0.0f);
151  
152          do {
153              mve_pred16_t    p = vctp32q(blkCnt);
154  
155              vecSrcA = vldrwq_z_f32(pSrcA, p);
156              vecSrcB = vldrwq_z_f32(pSrcB, p);
157  
158              vec_acc = vcmlaq_m(vec_acc, vecSrcA, vecSrcB, p);
159              vec_acc = vcmlaq_rot90_m(vec_acc, vecSrcA, vecSrcB, p);
160  
161              /*
162               * Decrement the blkCnt loop counter
163               * Advance vector source and destination pointers
164               */
165              pSrcA += 4;
166              pSrcB += 4;
167              blkCnt -= 4;
168          }
169          while (blkCnt > 0);
170      }
171  
172      real_sum = vgetq_lane(vec_acc, 0) + vgetq_lane(vec_acc, 2);
173      imag_sum = vgetq_lane(vec_acc, 1) + vgetq_lane(vec_acc, 3);
174  
175      /*
176       * Store the real and imaginary results in the destination buffers
177       */
178      *realResult = real_sum;
179      *imagResult = imag_sum;
180  }
181  
182  #else
183  void arm_cmplx_dot_prod_f32(
184    const float32_t * pSrcA,
185    const float32_t * pSrcB,
186          uint32_t numSamples,
187          float32_t * realResult,
188          float32_t * imagResult)
189  {
190          uint32_t blkCnt;                               /* Loop counter */
191          float32_t real_sum = 0.0f, imag_sum = 0.0f;    /* Temporary result variables */
192          float32_t a0,b0,c0,d0;
193  
194  #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
195      float32x4x2_t vec1,vec2,vec3,vec4;
196      float32x4_t accR,accI;
197      float32x2_t accum = vdup_n_f32(0);
198  
199      accR = vdupq_n_f32(0.0f);
200      accI = vdupq_n_f32(0.0f);
201  
202      /* Loop unrolling: Compute 8 outputs at a time */
203      blkCnt = numSamples >> 3U;
204  
205      while (blkCnt > 0U)
206      {
207  	/* C = (A[0]+jA[1])*(B[0]+jB[1]) + ...  */
208          /* Calculate dot product and then store the result in a temporary buffer. */
209  
210  	      vec1 = vld2q_f32(pSrcA);
211          vec2 = vld2q_f32(pSrcB);
212  
213  	/* Increment pointers */
214          pSrcA += 8;
215          pSrcB += 8;
216  
217  	/* Re{C} = Re{A}*Re{B} - Im{A}*Im{B} */
218          accR = vmlaq_f32(accR,vec1.val[0],vec2.val[0]);
219          accR = vmlsq_f32(accR,vec1.val[1],vec2.val[1]);
220  
221  	/* Im{C} = Re{A}*Im{B} + Im{A}*Re{B} */
222          accI = vmlaq_f32(accI,vec1.val[1],vec2.val[0]);
223          accI = vmlaq_f32(accI,vec1.val[0],vec2.val[1]);
224  
225          vec3 = vld2q_f32(pSrcA);
226          vec4 = vld2q_f32(pSrcB);
227  	
228  	/* Increment pointers */
229          pSrcA += 8;
230          pSrcB += 8;
231  
232  	/* Re{C} = Re{A}*Re{B} - Im{A}*Im{B} */
233          accR = vmlaq_f32(accR,vec3.val[0],vec4.val[0]);
234          accR = vmlsq_f32(accR,vec3.val[1],vec4.val[1]);
235  
236  	/* Im{C} = Re{A}*Im{B} + Im{A}*Re{B} */
237          accI = vmlaq_f32(accI,vec3.val[1],vec4.val[0]);
238          accI = vmlaq_f32(accI,vec3.val[0],vec4.val[1]);
239  
240          /* Decrement the loop counter */
241          blkCnt--;
242      }
243  
244      accum = vpadd_f32(vget_low_f32(accR), vget_high_f32(accR));
245      real_sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
246  
247      accum = vpadd_f32(vget_low_f32(accI), vget_high_f32(accI));
248      imag_sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
249  
250      /* Tail */
251      blkCnt = numSamples & 0x7;
252  
253  #else
254  #if defined (ARM_MATH_LOOPUNROLL) && !defined(ARM_MATH_AUTOVECTORIZE)
255  
256    /* Loop unrolling: Compute 4 outputs at a time */
257    blkCnt = numSamples >> 2U;
258  
259    while (blkCnt > 0U)
260    {
261      a0 = *pSrcA++;
262      b0 = *pSrcA++;
263      c0 = *pSrcB++;
264      d0 = *pSrcB++;
265  
266      real_sum += a0 * c0;
267      imag_sum += a0 * d0;
268      real_sum -= b0 * d0;
269      imag_sum += b0 * c0;
270  
271      a0 = *pSrcA++;
272      b0 = *pSrcA++;
273      c0 = *pSrcB++;
274      d0 = *pSrcB++;
275  
276      real_sum += a0 * c0;
277      imag_sum += a0 * d0;
278      real_sum -= b0 * d0;
279      imag_sum += b0 * c0;
280  
281      a0 = *pSrcA++;
282      b0 = *pSrcA++;
283      c0 = *pSrcB++;
284      d0 = *pSrcB++;
285  
286      real_sum += a0 * c0;
287      imag_sum += a0 * d0;
288      real_sum -= b0 * d0;
289      imag_sum += b0 * c0;
290  
291      a0 = *pSrcA++;
292      b0 = *pSrcA++;
293      c0 = *pSrcB++;
294      d0 = *pSrcB++;
295  
296      real_sum += a0 * c0;
297      imag_sum += a0 * d0;
298      real_sum -= b0 * d0;
299      imag_sum += b0 * c0;
300  
301      /* Decrement loop counter */
302      blkCnt--;
303    }
304  
305    /* Loop unrolling: Compute remaining outputs */
306    blkCnt = numSamples % 0x4U;
307  
308  #else
309  
310    /* Initialize blkCnt with number of samples */
311    blkCnt = numSamples;
312  
313  #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
314  #endif /* #if defined(ARM_MATH_NEON) */
315  
316    while (blkCnt > 0U)
317    {
318      a0 = *pSrcA++;
319      b0 = *pSrcA++;
320      c0 = *pSrcB++;
321      d0 = *pSrcB++;
322  
323      real_sum += a0 * c0;
324      imag_sum += a0 * d0;
325      real_sum -= b0 * d0;
326      imag_sum += b0 * c0;
327  
328      /* Decrement loop counter */
329      blkCnt--;
330    }
331  
332    /* Store real and imaginary result in destination buffer. */
333    *realResult = real_sum;
334    *imagResult = imag_sum;
335  }
336  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
337  
338  /**
339    @} end of cmplx_dot_prod group
340   */