/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_mult_f16.c
arm_mat_mult_f16.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_mult_f16.c
  4   * Description:  Floating-point matrix multiplication
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/matrix_functions_f16.h"
 30  
 31  #if defined(ARM_FLOAT16_SUPPORTED)
 32  
 33  
 34  /**
 35   * @ingroup groupMatrix
 36   */
 37  
 38  
 39  /**
 40   * @addtogroup MatrixMult
 41   * @{
 42   */
 43  
 44  /**
 45   * @brief Floating-point matrix multiplication.
 46   * @param[in]       *pSrcA points to the first input matrix structure
 47   * @param[in]       *pSrcB points to the second input matrix structure
 48   * @param[out]      *pDst points to output matrix structure
 49   * @return     		The function returns either
 50   * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
 51   */
 52  
 53  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
 54  
 55  __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_2x2_mve(
 56      const arm_matrix_instance_f16 *pSrcA,
 57      const arm_matrix_instance_f16 *pSrcB,
 58      arm_matrix_instance_f16 *pDst)
 59  {
 60      static const uint16_t offsetA[8] = { 0, 0, 2, 2, 0, 0, 2, 2 };
 61      /* offsetB allows to read and duplicate 1 row of B */
 62      static const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 };
 63      uint16x8_t    vecOffsA, vecOffsB;
 64      f16x8_t       vecInA, vecInB, vecDst;
 65      float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
 66  
 67      /*
 68       * load initial offsets
 69       */
 70      vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
 71      vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
 72      /*
 73       * load {a00 a00 a10 a10 x x x x }
 74       */
 75      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
 76      /*
 77       * load {b00 b01 b00 b01 x x x x }
 78       */
 79      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
 80      /*
 81       *  { a00 b00       a00 b01
 82       *    a10 b00       a10 b01
 83       *       x             x
 84       *       x             x   }
 85       */
 86      vecDst = vmulq(vecInA, vecInB);
 87      /*
 88       * move to 2nd column of matrix A
 89       */
 90      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
 91      /*
 92       * load {a01 a01 a11 a11 x x x x}
 93       */
 94      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
 95      /*
 96       * move to next B row
 97       */
 98      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 2);
 99      /*
100       * load {b10, b11, b10, b11, x x x x }
101       */
102      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
103      /*
104       *  { a00 b00 + a01 b10   a00 b01 + a01 b11
105       *    a10 b00 + a11 b10     a10 b01 + a11 b11
106       *             x                    x
107       *             x                    x       }
108       */
109      vecDst = vfmaq(vecDst, vecInA, vecInB);
110  
111      mve_pred16_t p0 = vctp16q(2*2);
112      /*
113       * Store the result in the destination buffer
114       * (lower half of the vector)
115       */
116      vstrhq_p(pOut, vecDst, p0);
117  
118      return (ARM_MATH_SUCCESS);
119  }
120  
121  
122  
123  
124  __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_3x3_mve(
125      const arm_matrix_instance_f16 *pSrcA,
126      const arm_matrix_instance_f16 *pSrcB,
127      arm_matrix_instance_f16 *pDst)
128  {
129      static const uint16_t offsetA[8] = { 0, 0, 0, 3, 3, 3, 6, 6 };
130      /* offsetB allows to read and duplicate 1 row of B */
131      static const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 };
132      uint16x8_t    vecOffsA, vecOffsB;
133      f16x8_t       vecInA, vecInB, vecDst;
134      float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
135  
136      /*
137       * load initial offsets
138       */
139      vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
140      vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
141  
142      /*
143       * load {a00 a00 a00 a10 a10 a10 a20 a20}
144       */
145      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
146      /*
147       * load {b00 b01 b02 b00 b01 b02 b00 b01}
148       */
149      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
150      /*
151       *  { a00 b00       a00 b01     a00 b02
152       *    a10 b00       a10 b01     a10 b02
153       *    a20 b00       a20 b01}
154       */
155      vecDst = vmulq(vecInA, vecInB);
156  
157      /*
158       * move to 2nd column of matrix A
159       */
160      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
161      /*
162       * load {a01 a01 a01 a11 a11 a11 a21 a21}
163       */
164      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
165      /*
166       * move to next B row
167       */
168      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
169      /*
170       * load {b10, b11, b12, b10, b11, b12, b10, b11}
171       */
172      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
173      /*
174       *  { a00 b00 + a01 b10   a00 b01 + a01 b11     a00 b02 + a01 b12
175       *    a10 b00 + a11 b10     a10 b01 + a11 b11     a10 b02 + a11 b12
176       *    a20 b00 + a21 b10     a20 b01 + a21 b11   }
177       */
178      vecDst = vfmaq(vecDst, vecInA, vecInB);
179      /*
180       * move to 3rd column of matrix A
181       */
182      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
183      /*
184       * load {a02 a02 a02 a12 a12 a12 a22 a22}
185       */
186      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
187      /*
188       * move to next B row
189       */
190      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
191      /*
192       * load {b20, b21, b22, b20, b21, b22, b20, b21}
193       */
194      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
195      /*
196       *  {a00 b00 + a01 b10 + a02 b20  a00 b01 + a01 b11 + a02 b21     a00 b02 + a01 b12 + a02 b22},
197       *   a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21     a10 b02 + a11 b12 + a12 b22},
198       *   a20 b00 + a21 b10 + a22 b20    a20 b01 + a21 b11 + a22 b21   }
199       */
200      vecDst = vfmaq(vecDst, vecInA, vecInB);
201  
202      /*
203       * Store the result in the destination buffer
204       */
205      vst1q(pOut, vecDst); pOut += 8;
206  
207      /* last element computed in scalar mode
208       * a20 b02 + a21 b12 + a22 b22
209       */
210      _Float16 * pA = (_Float16 *)pSrcA->pData;
211      _Float16 * pB = (_Float16 *)pSrcB->pData;
212      *pOut = pA[2*3] * pB[2] + pA[2*3+1] * pB[3+2] + pA[2*3+2] * pB[2*3+2];
213  
214      return (ARM_MATH_SUCCESS);
215  }
216  
217  
218  
219  
220  
221  __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_4x4_mve(
222      const arm_matrix_instance_f16 *pSrcA,
223      const arm_matrix_instance_f16 *pSrcB,
224      arm_matrix_instance_f16 *pDst)
225  {
226      /* offsetA allows to read and duplicate 2 successive column elements of A */
227      static const uint16_t offsetA[8] = { 0, 0, 0, 0, 4, 4, 4, 4 };
228      /* offsetB allows to read and duplicate 1 row of B */
229      static const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 };
230      uint16x8_t    vecOffsA, vecOffsB;
231      f16x8_t       vecInA, vecInB, vecDst0, vecDst1;
232      float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
233  
234      /*
235       * load initial offsets
236       */
237      vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
238      vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
239  
240      /*
241       * load {a00 a00 a00 a00 a10 a10 a10 a10}
242       */
243      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
244      /*
245       * load {b00 b01 b02 b03 b00 b01 b02 b03}
246       */
247      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
248      /*
249       *  { a00 b00       a00 b01     a00 b02     a00 b03
250       *    a10 b00       a10 b01     a10 b02     a10 b03 }
251       */
252      vecDst0 = vmulq(vecInA, vecInB);
253      /*
254       * jump 2 x A rows (2nd half of matrix)
255       */
256      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
257      /*
258       * load {a20 a20 a20 a20 a30 a30 a30 a30}
259       */
260      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
261      /*
262       *  { a20 b00       a20 b01     a20 b02     a20 b03
263       *    a30 b00       a30 b01     a30 b02 +   a31 b12 }
264       */
265      vecDst1 = vmulq(vecInA, vecInB);
266      /*
267       * rewind back to top half of the A matrix (2nd column)
268       */
269      vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
270      /*
271       * load {a01 a01 a01 a01 a11 a11 a11 a11}
272       */
273      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
274      /*
275       * move to next B row
276       */
277      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
278      /*
279       * load {b10, b11, b12, b13, b10, b11, b12, b13}
280       */
281      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
282      /*
283       *  { a00 b00 + a01 b10         a00 b01 + a01 b11       a00 b02 + a01 b12       a00 b03 + a01 b13
284       *    a10 b00 + a11 b10         a10 b01 + a11 b11       a10 b02 + a11 b12       a10 b03 + a11 b13 }
285       */
286      vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
287      /*
288       * jump 2 x A rows (2nd half of matrix)
289       */
290      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
291      /*
292       * load {a21 a21 a21 a21 a31 a31 a31 a31}
293       */
294      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
295      /*
296       *  {a20 b00 + a21 b10      a20 b01 + a21 b11       a20 b02 + a21 b12       a20 b03 + a21 b13
297       *   a30 b00 + a31 b10      a30 b01 + a31 b11       a30 b02 + a31 b12       a30 b03 + a31 b13 }
298       */
299      vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
300  
301      /*
302       * rewind back to top half of the A matrix (3rd column)
303       */
304      vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
305      /*
306       * load {a02 a02 a02 a02 a12 a12 a12 a12}
307       */
308      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
309      /*
310       * move to next B row
311       */
312      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
313      /*
314       * load {b20, b21, b22, b23, b20, b21, b22, b23}
315       */
316      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
317      /*
318       *  { a00 b00 + a01 b10 + a02 b20    a00 b01 + a01 b11 + a02 b21    a00 b02 + a01 b12 + a02 b22   a00 b03 + a01 b13 + a02 b23
319       *    a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21    a10 b02 + a11 b12 + a12 b22   a10 b03 + a11 b13 + a12 b23 }
320       */
321      vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
322      /*
323       * jump 2 x A rows
324       */
325      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
326  
327      /*
328       * load {a22 a22 a22 a22 a32 a32 a32 a32}
329       */
330      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
331      /*
332       *  {a20 b00 + a21 b10 + a22 b20   a20 b01 + a21 b11 + a22 b21  a20 b02 + a21 b12 + a22 b22    a20 b03 + a21 b13 + a22 b23
333       *   a30 b00 + a31 b10 + a32 b20   a30 b01 + a31 b11 + a32 b21  a30 b02 + a31 b12 + a32 b22    a30 b03 + a31 b13 + a32 b23 }
334       */
335      vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
336  
337      /*
338       * rewind back to top half of the A matrix (4th column)
339       */
340      vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
341      /*
342       * load {a03 a03 a03 a03 a13 a13 a13 a13}
343       */
344      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
345      /*
346       * move to next B row
347       */
348      vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
349      /*
350       * load {b30, b31, b32, b33, b30, b31, b32, b33}
351       */
352      vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
353      /*
354       * { a00 b00 +...+ a03 b30,    a00 b01 +...+ a03 b31,   a00 b02 +...+ a03 b32,   a00 b03 +...+ a03 b33
355       *   a10 b00 +...+ a13 b30,    a10 b01 +...+ a13 b31,   a10 b02 +...+ a13 b32,   a10 b03 +...+ a13 b33 }
356       */
357      vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
358      /*
359       * jump 2 x A rows
360       */
361      vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
362      /*
363       * load {a23 a23 a23 a23 a33 a33 a33 a33}
364       */
365      vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
366      /*
367       *  {a20 b00 +...+ a23 b30,   a20 b01 +...+ a23 b31,   a20 b02 +...+ a23 b32,   a20 b03 +...+ a23 b33
368       *   a30 b00 +...+ a33 b30,   a30 b01 +...+ a33 b31,   a30 b02 +...+ a33 b32,   a30 b03 +...+ a33 b33 }
369       */
370      vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
371  
372      /*
373       * Store the result in the destination buffer
374       */
375      vst1q(pOut, vecDst0); pOut += 8;
376      vst1q(pOut, vecDst1);
377  
378      return (ARM_MATH_SUCCESS);
379  }
380  
381  
382  arm_status arm_mat_mult_f16(
383    const arm_matrix_instance_f16 * pSrcA,
384    const arm_matrix_instance_f16 * pSrcB,
385    arm_matrix_instance_f16 * pDst)
386  {
387         float16_t  *pInB = pSrcB->pData;        /* input data matrix pointer B */
388      float16_t  *pInA = pSrcA->pData;        /* input data matrix pointer A  */
389      float16_t  *pOut = pDst->pData;         /* output data matrix pointer */
390      int         numRowsA = pSrcA->numRows;  /* number of rows of input matrix A */
391      int         numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
392      int         numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
393      uint32_t    blkCnt;                     /* loop counters */
394      int         i;
395  
396  
397  #ifdef ARM_MATH_MATRIX_CHECK
398  
399    /* Check for matrix mismatch condition */
400    if ((pSrcA->numCols != pSrcB->numRows) ||
401        (pSrcA->numRows != pDst->numRows)  ||
402        (pSrcB->numCols != pDst->numCols)    )
403    {
404      /* Set status as ARM_MATH_SIZE_MISMATCH */
405      return(ARM_MATH_SIZE_MISMATCH);
406    }
407    else
408  
409  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
410  {
411      /* small squared matrix specialized routines */
412      if(numRowsA == numColsB && numColsB == numColsA) {
413          if(numRowsA == 2)
414              return arm_mat_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
415          else if(numRowsA == 3)
416              return arm_mat_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
417          else if(numRowsA == 4)
418              return arm_mat_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
419      }
420  
421      /* main loop process 4 rows */
422      i = numRowsA / 4;
423      while(i > 0)
424      {
425          float16_t   *pInA0, *pInA1, *pInA2, *pInA3;
426          float16_t   *pInB0;
427          float16_t   *pOut0, *pOut1, *pOut2, *pOut3;
428          f16x8_t    vecMac0, vecMac1, vecMac2, vecMac3;
429          f16x8_t    vecInB;
430  
431          /* pointers to 4 consecutive output rows */
432          pOut0 = pOut;
433          pOut1 = pOut0 + numColsB;
434          pOut2 = pOut1 + numColsB;
435          pOut3 = pOut2 + numColsB;
436          pInB0 = pInB;
437  
438          int       k = numColsB >> 3;
439          while(k > 0)
440          {
441              /* pointers to 4 consecutive Matrix A rows */
442              pInA0 = pInA;
443              pInA1 = pInA0 + numColsA;
444              pInA2 = pInA1 + numColsA;
445              pInA3 = pInA2 + numColsA;
446  
447              vecMac0 = vdupq_n_f16(0.0f16);
448              vecMac1 = vdupq_n_f16(0.0f16);
449              vecMac2 = vdupq_n_f16(0.0f16);
450              vecMac3 = vdupq_n_f16(0.0f16);
451  
452              blkCnt = numColsA;
453  
454              while (blkCnt > 0U)
455              {
456                  /*
457                   * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3..., bi,4n+7}
458                   */
459                  vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
460  
461                  vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
462                  vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
463                  vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
464                  vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
465  
466                  pInB0 = pInB0 + numColsB;
467                  /*
468                   * Decrement the blockSize loop counter
469                   */
470                  blkCnt--;
471              }
472  
473              /* Store the results (4 x 8 block) in the destination buffer */
474              vst1q(pOut0, vecMac0);  pOut0 += 8;
475              vst1q(pOut1, vecMac1);  pOut1 += 8;
476              vst1q(pOut2, vecMac2);  pOut2 += 8;
477              vst1q(pOut3, vecMac3);  pOut3 += 8;
478              /*
479               * rewind
480               */
481              pInB0 -= (numColsB * numColsA) - 8;
482              k--;
483          }
484  
485          int       colBLeft = numColsB & 7;
486          if (colBLeft)
487          {
488              pInA0 = pInA;
489              pInA1 = pInA0 + numColsA;
490              pInA2 = pInA1 + numColsA;
491              pInA3 = pInA2 + numColsA;
492              mve_pred16_t p0 = vctp16q(colBLeft);
493  
494              vecMac0 = vdupq_n_f16(0.0f16);
495              vecMac1 = vdupq_n_f16(0.0f16);
496              vecMac2 = vdupq_n_f16(0.0f16);
497              vecMac3 = vdupq_n_f16(0.0f16);
498  
499              blkCnt = numColsA;
500  
501              while (blkCnt > 0U)
502              {
503                  /*
504                   * load {bi,4n+0, bi,4n+1, bi,4n+2, ..bi,4n+colBLeft-1, 0, ..}
505                   */
506                  vecInB = vldrhq_z_f16(pInB0, p0);
507  
508                  vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
509                  vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
510                  vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
511                  vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
512  
513                  pInB0 = pInB0 + numColsB;
514                  /*
515                   * Decrement the blockSize loop counter
516                   */
517                  blkCnt--;
518              }
519  
520              /* Store the results (4 x colBLeft block) in the destination buffer */
521              vstrhq_p_f16(pOut0, vecMac0, p0);
522              vstrhq_p_f16(pOut1, vecMac1, p0);
523              vstrhq_p_f16(pOut2, vecMac2, p0);
524              vstrhq_p_f16(pOut3, vecMac3, p0);
525          }
526  
527          pInA += 4 * numColsA;
528          pOut += 4 * numColsB;
529          i--;
530      }
531  
532      /*
533       * non multiple of 4 rows for Matrix A
534       * process single row
535       */
536      if (numRowsA & 3)
537      {
538          i = numRowsA & 3;
539          do
540          {
541              float16_t   *pInA0;
542              float16_t   *pInB0;
543              float16_t   *pOut0;
544              f16x8_t    vecInB;
545              f16x8_t    vecMac0;
546  
547              pOut0 = pOut;
548              pInB0 = pInB;
549  
550              int       k = numColsB >> 3;
551              while(k > 0)
552              {
553                  pInA0 = pInA;
554  
555                  vecMac0 = vdupq_n_f16(0.0f16);
556                  blkCnt = numColsA;
557  
558                  while (blkCnt > 0U)
559                  {
560                      /*
561                       * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3, ...bi,4n+7}
562                       */
563                      vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
564  
565                      vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
566  
567                      pInB0 = pInB0 + numColsB;
568                      /*
569                       * Decrement the blockSize loop counter
570                       */
571                      blkCnt--;
572                  }
573                  /* Store the results (1 x 8 block) in the destination buffer */
574                  vst1q(pOut0, vecMac0);   pOut0 += 8;
575                  /*
576                   * rewind
577                   */
578                  pInB0 -= (numColsB * numColsA) - 8;
579                  k--;
580              }
581  
582              int  colBLeft = numColsB & 7;
583              if (colBLeft)
584              {
585                  pInA0 = pInA;
586                  mve_pred16_t p0 = vctp16q(colBLeft);
587  
588                  vecMac0 = vdupq_n_f16(0.0f16);
589                  blkCnt = numColsA;
590  
591                  while (blkCnt > 0U)
592                  {
593                      /*
594                       * load {bi,4n+0, bi,4n+1, bi,4n+2, ..., bi,4n+colBLeft, 0, ...}
595                       */
596                      vecInB = vldrhq_z_f16(pInB0, p0);
597  
598                      vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
599  
600                      pInB0 = pInB0 + numColsB;
601                      /*
602                       * Decrement the blockSize loop counter
603                       */
604                      blkCnt--;
605                  }
606                  /* Store the results (1 x colBLeft block) in the destination buffer */
607                  vstrhq_p_f16(pOut0, vecMac0, p0);
608              }
609  
610              pInA += 1 * numColsA;
611              pOut += 1 * numColsB;
612          }
613          while (--i);
614      }
615      /*
616       * Return to application
617       */
618      return (ARM_MATH_SUCCESS);
619    }
620  }
621  #else
622  
623  
624  arm_status arm_mat_mult_f16(
625    const arm_matrix_instance_f16 * pSrcA,
626    const arm_matrix_instance_f16 * pSrcB,
627          arm_matrix_instance_f16 * pDst)
628  {
629    float16_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
630    float16_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
631    float16_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
632    float16_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
633    float16_t *pOut = pDst->pData;                 /* Output data matrix pointer */
634    float16_t *px;                                 /* Temporary output data matrix pointer */
635    _Float16 sum;                                 /* Accumulator */
636    uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
637    uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
638    uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
639    uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
640    arm_status status;                             /* Status of matrix multiplication */
641  
642  #ifdef ARM_MATH_MATRIX_CHECK
643  
644    /* Check for matrix mismatch condition */
645    if ((pSrcA->numCols != pSrcB->numRows) ||
646        (pSrcA->numRows != pDst->numRows)  ||
647        (pSrcB->numCols != pDst->numCols)    )
648    {
649      /* Set status as ARM_MATH_SIZE_MISMATCH */
650      status = ARM_MATH_SIZE_MISMATCH;
651    }
652    else
653  
654  #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
655  
656    {
657      /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
658      /* row loop */
659      do
660      {
661        /* Output pointer is set to starting address of row being processed */
662        px = pOut + i;
663  
664        /* For every row wise process, column loop counter is to be initiated */
665        col = numColsB;
666  
667        /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
668        pIn2 = pSrcB->pData;
669  
670        /* column loop */
671        do
672        {
673          /* Set the variable sum, that acts as accumulator, to zero */
674          sum = 0.0f16;
675  
676          /* Initialize pointer pIn1 to point to starting address of column being processed */
677          pIn1 = pInA;
678  
679  #if defined (ARM_MATH_LOOPUNROLL)
680  
681          /* Loop unrolling: Compute 4 MACs at a time. */
682          colCnt = numColsA >> 2U;
683  
684          /* matrix multiplication */
685          while (colCnt > 0U)
686          {
687            /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
688  
689            /* Perform the multiply-accumulates */
690            sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
691            pIn2 += numColsB;
692  
693            sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
694            pIn2 += numColsB;
695  
696            sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
697            pIn2 += numColsB;
698  
699            sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
700            pIn2 += numColsB;
701  
702            /* Decrement loop counter */
703            colCnt--;
704          }
705  
706          /* Loop unrolling: Compute remaining MACs */
707          colCnt = numColsA % 0x4U;
708  
709  #else
710  
711          /* Initialize cntCnt with number of columns */
712          colCnt = numColsA;
713  
714  #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
715  
716          while (colCnt > 0U)
717          {
718            /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
719  
720            /* Perform the multiply-accumulates */
721            sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
722            pIn2 += numColsB;
723  
724            /* Decrement loop counter */
725            colCnt--;
726          }
727  
728          /* Store result in destination buffer */
729          *px++ = sum;
730  
731          /* Decrement column loop counter */
732          col--;
733  
734          /* Update pointer pIn2 to point to starting address of next column */
735          pIn2 = pInB + (numColsB - col);
736  
737        } while (col > 0U);
738  
739        /* Update pointer pInA to point to starting address of next row */
740        i = i + numColsB;
741        pInA = pInA + numColsA;
742  
743        /* Decrement row loop counter */
744        row--;
745  
746      } while (row > 0U);
747  
748      /* Set status as ARM_MATH_SUCCESS */
749      status = ARM_MATH_SUCCESS;
750    }
751  
752    /* Return to application */
753    return (status);
754  }
755  
756  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
757  
758  /**
759   * @} end of MatrixMult group
760   */
761  
762  #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 
763