/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_cmplx_mult_f16.c
arm_mat_cmplx_mult_f16.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_cmplx_mult_f16.c
  4   * Description:  Floating-point matrix multiplication
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/matrix_functions_f16.h"
 30  
 31  #if defined(ARM_FLOAT16_SUPPORTED)
 32  
 33  
 34  /**
 35    @ingroup groupMatrix
 36   */
 37  
 38  
 39  /**
 40    @addtogroup CmplxMatrixMult
 41    @{
 42   */
 43  
 44  /**
 45    @brief         Floating-point Complex matrix multiplication.
 46    @param[in]     pSrcA      points to first input complex matrix structure
 47    @param[in]     pSrcB      points to second input complex matrix structure
 48    @param[out]    pDst       points to output complex matrix structure
 49    @return        execution status
 50                     - \ref ARM_MATH_SUCCESS       : Operation successful
 51                     - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
 52   */
 53  
 54  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H)
 55  #pragma GCC warning "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc."
 56  #endif 
 57  
 58  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) &&  !defined(__CMSIS_GCC_H)
 59  
 60  #include "arm_helium_utils.h"
 61  
 62  #define DONTCARE            0 /* inactive lane content */
 63  
 64  
 65  __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve(
 66      const arm_matrix_instance_f16 * pSrcA,
 67      const arm_matrix_instance_f16 * pSrcB,
 68      arm_matrix_instance_f16 * pDst)
 69  {
 70  #define MATRIX_DIM 2
 71      float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
 72      float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
 73      float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
 74      uint16x8_t     vecColBOffs0,vecColAOffs0,vecColAOffs1;
 75      float16_t       *pInA0 = pInA;
 76      f16x8_t        acc0, acc1;
 77      f16x8_t        vecB, vecA0, vecA1;
 78      f16x8_t        vecTmp;
 79      uint16_t         tmp;
 80      static const uint16_t offsetB0[8] = { 0, 1,
 81          MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
 82          2, 3,
 83          MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3,
 84      };
 85  
 86  
 87      vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
 88  
 89      tmp = 0;
 90      vecColAOffs0 = viwdupq_u16(tmp, 4, 1);
 91  
 92      tmp = (CMPLX_DIM * MATRIX_DIM);
 93      vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM);
 94  
 95  
 96      pInB = (float16_t const *)pSrcB->pData;
 97  
 98      vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0);
 99      vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1);
100  
101  
102      vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
103  
104      acc0 = vcmulq(vecA0, vecB);
105      acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
106  
107      acc1 = vcmulq(vecA1, vecB);
108      acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
109  
110  
111      /*
112       * Compute
113       *  re0+re1 | im0+im1 | re0+re1 | im0+im1
114       *  re2+re3 | im2+im3 | re2+re3 | im2+im3
115       */
116  
117      vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0);
118      vecTmp = vaddq(vecTmp, acc0);
119  
120  
121      *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
122      *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
123  
124      vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1);
125      vecTmp = vaddq(vecTmp, acc1);
126  
127      *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
128      *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
129  
130      /*
131       * Return to application
132       */
133      return (ARM_MATH_SUCCESS);
134  #undef MATRIX_DIM
135  }
136  
137  
138  
139  __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve(
140      const arm_matrix_instance_f16 * pSrcA,
141      const arm_matrix_instance_f16 * pSrcB,
142      arm_matrix_instance_f16 * pDst)
143  {
144  #define MATRIX_DIM 3
145      float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
146      float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
147      float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
148      uint16x8_t     vecColBOffs0;
149      float16_t       *pInA0 = pInA;
150      float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
151      float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
152      f16x8_t        acc0, acc1, acc2;
153      f16x8_t        vecB, vecA0, vecA1, vecA2;
154      static const uint16_t offsetB0[8] = { 0, 1,
155          MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
156          2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
157          DONTCARE, DONTCARE
158      };
159  
160      
161      /* enable predication to disable upper half complex vector element */
162      mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM);
163  
164      vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
165  
166      pInB = (float16_t const *)pSrcB->pData;
167  
168      vecA0 = vldrhq_f16(pInA0);
169      vecA1 = vldrhq_f16(pInA1);
170      vecA2 = vldrhq_f16(pInA2);
171  
172      vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
173  
174      acc0 = vcmulq(vecA0, vecB);
175      acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
176  
177      acc1 = vcmulq(vecA1, vecB);
178      acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
179  
180      acc2 = vcmulq(vecA2, vecB);
181      acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
182  
183      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
184      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
185      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
186      pOut += CMPLX_DIM;
187      /*
188       * move to next B column
189       */
190      pInB = pInB + CMPLX_DIM;
191  
192      vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
193  
194      acc0 = vcmulq(vecA0, vecB);
195      acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
196  
197      acc1 = vcmulq(vecA1, vecB);
198      acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
199  
200      acc2 = vcmulq(vecA2, vecB);
201      acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
202  
203      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
204      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
205      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
206      pOut += CMPLX_DIM;
207      /*
208       * move to next B column
209       */
210      pInB = pInB + CMPLX_DIM;
211  
212      vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
213  
214      acc0 = vcmulq(vecA0, vecB);
215      acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
216  
217      acc1 = vcmulq(vecA1, vecB);
218      acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
219  
220      acc2 = vcmulq(vecA2, vecB);
221      acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
222  
223      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
224      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
225      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
226      /*
227       * Return to application
228       */
229      return (ARM_MATH_SUCCESS);
230  #undef MATRIX_DIM
231  }
232  
233  
234  
235  
236  __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve(
237      const arm_matrix_instance_f16 * pSrcA,
238      const arm_matrix_instance_f16 * pSrcB,
239      arm_matrix_instance_f16 * pDst)
240  {
241  #define MATRIX_DIM 4
242      float16_t const *pInB = pSrcB->pData;  /* input data matrix pointer B */
243      float16_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
244      float16_t       *pOut = pDst->pData;   /* output data matrix pointer */
245      uint16x8_t     vecColBOffs0;
246      float16_t       *pInA0 = pInA;
247      float16_t       *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
248      float16_t       *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
249      float16_t       *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM;
250      f16x8_t        acc0, acc1, acc2, acc3;
251      f16x8_t        vecB, vecA;
252      static const uint16_t offsetB0[8] = { 0, 1,
253          MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
254          2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
255          3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1
256      };
257  
258      vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
259  
260      pInB = (float16_t const *)pSrcB->pData;
261  
262      vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
263  
264      vecA = vldrhq_f16(pInA0);
265      acc0 = vcmulq(vecA, vecB);
266      acc0 = vcmlaq_rot90(acc0, vecA, vecB);
267  
268      vecA = vldrhq_f16(pInA1);
269      acc1 = vcmulq(vecA, vecB);
270      acc1 = vcmlaq_rot90(acc1, vecA, vecB);
271  
272      vecA = vldrhq_f16(pInA2);
273      acc2 = vcmulq(vecA, vecB);
274      acc2 = vcmlaq_rot90(acc2, vecA, vecB);
275  
276      vecA = vldrhq_f16(pInA3);
277      acc3 = vcmulq(vecA, vecB);
278      acc3 = vcmlaq_rot90(acc3, vecA, vecB);
279  
280  
281      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
282      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
283      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
284      mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
285      pOut += CMPLX_DIM;
286      /*
287       * move to next B column
288       */
289      pInB = pInB + CMPLX_DIM;
290  
291      vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
292  
293      vecA = vldrhq_f16(pInA0);
294      acc0 = vcmulq(vecA, vecB);
295      acc0 = vcmlaq_rot90(acc0, vecA, vecB);
296  
297      vecA = vldrhq_f16(pInA1);
298      acc1 = vcmulq(vecA, vecB);
299      acc1 = vcmlaq_rot90(acc1, vecA, vecB);
300  
301      vecA = vldrhq_f16(pInA2);
302      acc2 = vcmulq(vecA, vecB);
303      acc2 = vcmlaq_rot90(acc2, vecA, vecB);
304  
305      vecA = vldrhq_f16(pInA3);
306      acc3 = vcmulq(vecA, vecB);
307      acc3 = vcmlaq_rot90(acc3, vecA, vecB);
308  
309  
310      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
311      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
312      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
313      mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
314      pOut += CMPLX_DIM;
315      /*
316       * move to next B column
317       */
318      pInB = pInB + CMPLX_DIM;
319  
320      vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
321  
322      vecA = vldrhq_f16(pInA0);
323      acc0 = vcmulq(vecA, vecB);
324      acc0 = vcmlaq_rot90(acc0, vecA, vecB);
325  
326      vecA = vldrhq_f16(pInA1);
327      acc1 = vcmulq(vecA, vecB);
328      acc1 = vcmlaq_rot90(acc1, vecA, vecB);
329  
330      vecA = vldrhq_f16(pInA2);
331      acc2 = vcmulq(vecA, vecB);
332      acc2 = vcmlaq_rot90(acc2, vecA, vecB);
333  
334      vecA = vldrhq_f16(pInA3);
335      acc3 = vcmulq(vecA, vecB);
336      acc3 = vcmlaq_rot90(acc3, vecA, vecB);
337  
338  
339      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
340      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
341      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
342      mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
343      pOut += CMPLX_DIM;
344      /*
345       * move to next B column
346       */
347      pInB = pInB + CMPLX_DIM;
348  
349      vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
350  
351      vecA = vldrhq_f16(pInA0);
352      acc0 = vcmulq(vecA, vecB);
353      acc0 = vcmlaq_rot90(acc0, vecA, vecB);
354  
355      vecA = vldrhq_f16(pInA1);
356      acc1 = vcmulq(vecA, vecB);
357      acc1 = vcmlaq_rot90(acc1, vecA, vecB);
358  
359      vecA = vldrhq_f16(pInA2);
360      acc2 = vcmulq(vecA, vecB);
361      acc2 = vcmlaq_rot90(acc2, vecA, vecB);
362  
363      vecA = vldrhq_f16(pInA3);
364      acc3 = vcmulq(vecA, vecB);
365      acc3 = vcmlaq_rot90(acc3, vecA, vecB);
366  
367  
368      mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
369      mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
370      mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
371      mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
372      /*
373       * Return to application
374       */
375      return (ARM_MATH_SUCCESS);
376  #undef MATRIX_DIM
377  }
378  
379  
380  
381  arm_status arm_mat_cmplx_mult_f16(
382    const arm_matrix_instance_f16 * pSrcA,
383    const arm_matrix_instance_f16 * pSrcB,
384    arm_matrix_instance_f16 * pDst)
385  {
386      float16_t const *pInB = (float16_t const *) pSrcB->pData;   /* input data matrix pointer B */
387      float16_t const *pInA = (float16_t const *) pSrcA->pData;   /* input data matrix pointer A */
388      float16_t *pOut = pDst->pData;  /* output data matrix pointer */
389      float16_t *px;              /* Temporary output data matrix pointer */
390      uint16_t  numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
391      uint16_t  numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
392      uint16_t  numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
393      uint16_t  col, i = 0U, row = numRowsA;  /* loop counters */
394      arm_status status;          /* status of matrix multiplication */
395      uint16x8_t vecOffs, vecColBOffs;
396      uint32_t  blkCnt,rowCnt;           /* loop counters */
397  
398      #ifdef ARM_MATH_MATRIX_CHECK
399  
400    /* Check for matrix mismatch condition */
401  if ((pSrcA->numCols != pSrcB->numRows) ||
402        (pSrcA->numRows != pDst->numRows)  ||
403        (pSrcB->numCols != pDst->numCols)    )
404    {
405      /* Set status as ARM_MATH_SIZE_MISMATCH */
406      status = ARM_MATH_SIZE_MISMATCH;
407    }
408    else
409  
410  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
411  
412    {
413  
414      /*
415       * small squared matrix specialized routines
416       */
417      if (numRowsA == numColsB && numColsB == numColsA)
418      {
419          if (numRowsA == 1)
420          {
421              pOut[0] = (_Float16)pInA[0] * (_Float16)pInB[0] - (_Float16)pInA[1] * (_Float16)pInB[1];
422              pOut[1] = (_Float16)pInA[0] * (_Float16)pInB[1] + (_Float16)pInA[1] * (_Float16)pInB[0];
423              return (ARM_MATH_SUCCESS);
424          }
425          else if  (numRowsA == 2)
426              return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
427          else if (numRowsA == 3)
428              return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
429          else if (numRowsA == 4)
430              return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
431      }
432  
433      vecColBOffs[0] = 0;
434      vecColBOffs[1] = 1;
435      vecColBOffs[2] = numColsB * CMPLX_DIM;
436      vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
437      vecColBOffs[4] = 2*numColsB * CMPLX_DIM;
438      vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1;
439      vecColBOffs[6] = 3*numColsB * CMPLX_DIM;
440      vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1;
441  
442      /*
443       * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
444       */
445  
446      /*
447       * row loop
448       */
449      rowCnt = row >> 2;
450      while (rowCnt > 0u)
451      {
452          /*
453           * Output pointer is set to starting address of the row being processed
454           */
455          px = pOut + i * CMPLX_DIM;
456          i = i + 4 * numColsB;
457          /*
458           * For every row wise process, the column loop counter is to be initiated
459           */
460          col = numColsB;
461          /*
462           * For every row wise process, the pInB pointer is set
463           * to the starting address of the pSrcB data
464           */
465          pInB = (float16_t const *) pSrcB->pData;
466          /*
467           * column loop
468           */
469          while (col > 0u)
470          {
471              /*
472               * generate 4 columns elements
473               */
474              /*
475               * Matrix A columns number of MAC operations are to be performed
476               */
477  
478              float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
479              float16_t const *pInA0 = pInA;
480              float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
481              float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
482              float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
483              f16x8_t acc0, acc1, acc2, acc3;
484  
485              acc0 = vdupq_n_f16(0.0f16);
486              acc1 = vdupq_n_f16(0.0f16);
487              acc2 = vdupq_n_f16(0.0f16);
488              acc3 = vdupq_n_f16(0.0f16);
489  
490              pSrcA0Vec = (float16_t const *) pInA0;
491              pSrcA1Vec = (float16_t const *) pInA1;
492              pSrcA2Vec = (float16_t const *) pInA2;
493              pSrcA3Vec = (float16_t const *) pInA3;
494  
495              vecOffs = vecColBOffs;
496  
497              /*
498               * process 1 x 4 block output
499               */
500              blkCnt = (numColsA * CMPLX_DIM) >> 3;
501              while (blkCnt > 0U)
502              {
503                  f16x8_t vecB, vecA;
504  
505                  vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs);
506                  /*
507                   * move Matrix B read offsets, 4 rows down
508                   */
509                  vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM));
510  
511                  vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 8;
512                  acc0 = vcmlaq(acc0, vecA, vecB);
513                  acc0 = vcmlaq_rot90(acc0, vecA, vecB);
514  
515                  vecA = vld1q(pSrcA1Vec);  pSrcA1Vec += 8;
516                  acc1 = vcmlaq(acc1, vecA, vecB);
517                  acc1 = vcmlaq_rot90(acc1, vecA, vecB);
518  
519                  vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 8;
520                  acc2 = vcmlaq(acc2, vecA, vecB);
521                  acc2 = vcmlaq_rot90(acc2, vecA, vecB);
522  
523                  vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 8;
524                  acc3 = vcmlaq(acc3, vecA, vecB);
525                  acc3 = vcmlaq_rot90(acc3, vecA, vecB);
526  
527                  blkCnt--;
528              }
529              /*
530               * Unsupported addressing mode compiler crash
531               */
532              /*
533               * tail
534               * (will be merged thru tail predication)
535               */
536              blkCnt = (numColsA * CMPLX_DIM) & 7;
537              if (blkCnt > 0U)
538              {
539                  mve_pred16_t p0 = vctp16q(blkCnt);
540                  f16x8_t vecB, vecA;
541  
542                  vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0);
543                  /*
544                   * move Matrix B read offsets, 4 rows down
545                   */
546                  vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
547  
548                  vecA = vld1q(pSrcA0Vec);
549                  acc0 = vcmlaq(acc0, vecA, vecB);
550                  acc0 = vcmlaq_rot90(acc0, vecA, vecB);
551  
552                  vecA = vld1q(pSrcA1Vec);
553                  acc1 = vcmlaq(acc1, vecA, vecB);
554                  acc1 = vcmlaq_rot90(acc1, vecA, vecB);
555  
556                  vecA = vld1q(pSrcA2Vec);
557                  acc2 = vcmlaq(acc2, vecA, vecB);
558                  acc2 = vcmlaq_rot90(acc2, vecA, vecB);
559  
560                  vecA = vld1q(pSrcA3Vec);
561                  acc3 = vcmlaq(acc3, vecA, vecB);
562                  acc3 = vcmlaq_rot90(acc3, vecA, vecB);
563  
564              }
565  
566  
567              mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]);
568              mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]);
569              mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]);
570              mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]);
571             
572              px += CMPLX_DIM;
573              /*
574               * Decrement the column loop counter
575               */
576              col--;
577              /*
578               * Update the pointer pInB to point to the  starting address of the next column
579               */
580              pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
581          }
582  
583          /*
584           * Update the pointer pInA to point to the  starting address of the next row
585           */
586          pInA += (numColsA * 4) * CMPLX_DIM;
587          /*
588           * Decrement the row loop counter
589           */
590          rowCnt --;
591  
592      }
593  
594      rowCnt = row & 3;
595      while (rowCnt > 0u)
596      {
597             /*
598           * Output pointer is set to starting address of the row being processed
599           */
600          px = pOut + i * CMPLX_DIM;
601          i = i + numColsB;
602          /*
603           * For every row wise process, the column loop counter is to be initiated
604           */
605          col = numColsB;
606          /*
607           * For every row wise process, the pInB pointer is set
608           * to the starting address of the pSrcB data
609           */
610          pInB = (float16_t const *) pSrcB->pData;
611          /*
612           * column loop
613           */
614          while (col > 0u)
615          {
616              /*
617               * generate 4 columns elements
618               */
619              /*
620               * Matrix A columns number of MAC operations are to be performed
621               */
622  
623              float16_t const *pSrcA0Vec;
624              float16_t const *pInA0 = pInA;
625              f16x8_t acc0;
626  
627              acc0 = vdupq_n_f16(0.0f16);
628  
629              pSrcA0Vec = (float16_t const *) pInA0;
630             
631              vecOffs = vecColBOffs;
632  
633              /*
634               * process 1 x 4 block output
635               */
636              blkCnt = (numColsA * CMPLX_DIM) >> 3;
637              while (blkCnt > 0U)
638              {
639                  f16x8_t vecB, vecA;
640  
641                  vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
642                  /*
643                   * move Matrix B read offsets, 4 rows down
644                   */
645                  vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM));
646  
647                  vecA = vld1q(pSrcA0Vec);  
648                  pSrcA0Vec += 8;
649                  acc0 = vcmlaq(acc0, vecA, vecB);
650                  acc0 = vcmlaq_rot90(acc0, vecA, vecB);
651                  
652  
653                  blkCnt--;
654              }
655  
656  
657              /*
658               * tail
659               */
660              blkCnt = (numColsA * CMPLX_DIM) & 7;
661              if (blkCnt > 0U)
662              {
663                  mve_pred16_t p0 = vctp16q(blkCnt);
664                  f16x8_t vecB, vecA;
665  
666                  vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0);
667                 
668                  vecA = vld1q(pSrcA0Vec);
669                  acc0 = vcmlaq(acc0, vecA, vecB);
670                  acc0 = vcmlaq_rot90(acc0, vecA, vecB);
671  
672              }
673  
674              mve_cmplx_sum_intra_vec_f16(acc0, &px[0]);
675  
676             
677              px += CMPLX_DIM;
678              /*
679               * Decrement the column loop counter
680               */
681              col--;
682              /*
683               * Update the pointer pInB to point to the  starting address of the next column
684               */
685              pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
686          }
687  
688          /*
689           * Update the pointer pInA to point to the  starting address of the next row
690           */
691          pInA += numColsA  * CMPLX_DIM;
692          rowCnt--;
693      }
694  
695      /*
696       * set status as ARM_MATH_SUCCESS
697       */
698      status = ARM_MATH_SUCCESS;
699   }
700      /*
701       * Return to application
702       */
703      return (status);
704  }
705  #else
706  
707  arm_status arm_mat_cmplx_mult_f16(
708    const arm_matrix_instance_f16 * pSrcA,
709    const arm_matrix_instance_f16 * pSrcB,
710          arm_matrix_instance_f16 * pDst)
711  {
712    float16_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
713    float16_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
714    float16_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
715    float16_t *pOut = pDst->pData;                 /* Output data matrix pointer */
716    float16_t *px;                                 /* Temporary output data matrix pointer */
717    uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
718    uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
719    uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
720    _Float16 sumReal, sumImag;                    /* Accumulator */
721    _Float16 a1, b1, c1, d1;
722    uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
723    arm_status status;                             /* status of matrix multiplication */
724  
725  #if defined (ARM_MATH_LOOPUNROLL)
726    _Float16 a0, b0, c0, d0;
727  #endif
728  
729  #ifdef ARM_MATH_MATRIX_CHECK
730  
731    /* Check for matrix mismatch condition */
732    if ((pSrcA->numCols != pSrcB->numRows) ||
733        (pSrcA->numRows != pDst->numRows)  ||
734        (pSrcB->numCols != pDst->numCols)    )
735    {
736      /* Set status as ARM_MATH_SIZE_MISMATCH */
737      status = ARM_MATH_SIZE_MISMATCH;
738    }
739    else
740  
741  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
742  
743    {
744      /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
745      /* row loop */
746      do
747      {
748        /* Output pointer is set to starting address of the row being processed */
749        px = pOut + 2 * i;
750  
751        /* For every row wise process, the column loop counter is to be initiated */
752        col = numColsB;
753  
754        /* For every row wise process, the pIn2 pointer is set
755         ** to the starting address of the pSrcB data */
756        pIn2 = pSrcB->pData;
757  
758        j = 0U;
759  
760        /* column loop */
761        do
762        {
763          /* Set the variable sum, that acts as accumulator, to zero */
764          sumReal = 0.0f16;
765          sumImag = 0.0f16;
766  
767          /* Initiate pointer pIn1 to point to starting address of column being processed */
768          pIn1 = pInA;
769  
770  #if defined (ARM_MATH_LOOPUNROLL)
771  
772          /* Apply loop unrolling and compute 4 MACs simultaneously. */
773          colCnt = numColsA >> 2U;
774  
775          /* matrix multiplication */
776          while (colCnt > 0U)
777          {
778  
779            /* Reading real part of complex matrix A */
780            a0 = *pIn1;
781  
782            /* Reading real part of complex matrix B */
783            c0 = *pIn2;
784  
785            /* Reading imaginary part of complex matrix A */
786            b0 = *(pIn1 + 1U);
787  
788            /* Reading imaginary part of complex matrix B */
789            d0 = *(pIn2 + 1U);
790  
791            /* Multiply and Accumlates */
792            sumReal += a0 * c0;
793            sumImag += b0 * c0;
794  
795            /* update pointers */
796            pIn1 += 2U;
797            pIn2 += 2 * numColsB;
798  
799            /* Multiply and Accumlates */
800            sumReal -= b0 * d0;
801            sumImag += a0 * d0;
802  
803            /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
804  
805            /* read real and imag values from pSrcA and pSrcB buffer */
806            a1 = *(pIn1     );
807            c1 = *(pIn2     );
808            b1 = *(pIn1 + 1U);
809            d1 = *(pIn2 + 1U);
810  
811            /* Multiply and Accumlates */
812            sumReal += a1 * c1;
813            sumImag += b1 * c1;
814  
815            /* update pointers */
816            pIn1 += 2U;
817            pIn2 += 2 * numColsB;
818  
819            /* Multiply and Accumlates */
820            sumReal -= b1 * d1;
821            sumImag += a1 * d1;
822  
823            a0 = *(pIn1     );
824            c0 = *(pIn2     );
825            b0 = *(pIn1 + 1U);
826            d0 = *(pIn2 + 1U);
827  
828            /* Multiply and Accumlates */
829            sumReal += a0 * c0;
830            sumImag += b0 * c0;
831  
832            /* update pointers */
833            pIn1 += 2U;
834            pIn2 += 2 * numColsB;
835  
836            /* Multiply and Accumlates */
837            sumReal -= b0 * d0;
838            sumImag += a0 * d0;
839  
840            /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
841  
842            a1 = *(pIn1     );
843            c1 = *(pIn2     );
844            b1 = *(pIn1 + 1U);
845            d1 = *(pIn2 + 1U);
846  
847            /* Multiply and Accumlates */
848            sumReal += a1 * c1;
849            sumImag += b1 * c1;
850  
851            /* update pointers */
852            pIn1 += 2U;
853            pIn2 += 2 * numColsB;
854  
855            /* Multiply and Accumlates */
856            sumReal -= b1 * d1;
857            sumImag += a1 * d1;
858  
859            /* Decrement loop count */
860            colCnt--;
861          }
862  
863          /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
864           ** No loop unrolling is used. */
865          colCnt = numColsA % 0x4U;
866  
867  #else
868  
869          /* Initialize blkCnt with number of samples */
870          colCnt = numColsA;
871  
872  #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
873  
874          while (colCnt > 0U)
875          {
876            /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
877            a1 = *(pIn1     );
878            c1 = *(pIn2     );
879            b1 = *(pIn1 + 1U);
880            d1 = *(pIn2 + 1U);
881  
882            /* Multiply and Accumlates */
883            sumReal += a1 * c1;
884            sumImag += b1 * c1;
885  
886            /* update pointers */
887            pIn1 += 2U;
888            pIn2 += 2 * numColsB;
889  
890            /* Multiply and Accumlates */
891            sumReal -= b1 * d1;
892            sumImag += a1 * d1;
893  
894            /* Decrement loop counter */
895            colCnt--;
896          }
897  
898          /* Store result in destination buffer */
899          *px++ = sumReal;
900          *px++ = sumImag;
901  
902          /* Update pointer pIn2 to point to starting address of next column */
903          j++;
904          pIn2 = pSrcB->pData + 2U * j;
905  
906          /* Decrement column loop counter */
907          col--;
908  
909        } while (col > 0U);
910  
911        /* Update pointer pInA to point to starting address of next row */
912        i = i + numColsB;
913        pInA = pInA + 2 * numColsA;
914  
915        /* Decrement row loop counter */
916        row--;
917  
918      } while (row > 0U);
919  
920      /* Set status as ARM_MATH_SUCCESS */
921      status = ARM_MATH_SUCCESS;
922    }
923  
924    /* Return to application */
925    return (status);
926  }
927  
928  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
929  
930  /**
931    @} end of MatrixMult group
932   */
933  
934  #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 
935