/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_vec_mult_f16.c
arm_mat_vec_mult_f16.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_vec_mult_f16.c
  4   * Description:  Floating-point matrix and vector multiplication
  5   *
  6   * $Date:        23 April 2021
  7   *
  8   * $Revision:    V1.9.0
  9   *
 10   * Target Processor: Cortex-M and Cortex-A cores
 11   * -------------------------------------------------------------------- */
 12  /*
 13   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 14   *
 15   * SPDX-License-Identifier: Apache-2.0
 16   *
 17   * Licensed under the Apache License, Version 2.0 (the License); you may
 18   * not use this file except in compliance with the License.
 19   * You may obtain a copy of the License at
 20   *
 21   * www.apache.org/licenses/LICENSE-2.0
 22   *
 23   * Unless required by applicable law or agreed to in writing, software
 24   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 25   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 26   * See the License for the specific language governing permissions and
 27   * limitations under the License.
 28   */
 29  
 30  #include "dsp/matrix_functions_f16.h"
 31  
 32  #if defined(ARM_FLOAT16_SUPPORTED)
 33  
 34  
 35  /**
 36   * @ingroup groupMatrix
 37   */
 38  
 39  
 40  /**
 41   * @addtogroup MatrixVectMult
 42   * @{
 43   */
 44  
 45  /**
 46   * @brief Floating-point matrix and vector multiplication.
 47   * @param[in]       *pSrcMat points to the input matrix structure
 48   * @param[in]       *pVec points to input vector
 49   * @param[out]      *pDst points to output vector
 50   */
 51  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
 52  
 53  #include "arm_helium_utils.h"
 54  
 55  void arm_mat_vec_mult_f16(
 56      const arm_matrix_instance_f16   *pSrcMat,
 57      const float16_t                 *pSrcVec,
 58      float16_t                       *pDstVec)
 59  {
 60      uint32_t         numRows = pSrcMat->numRows;
 61      uint32_t         numCols = pSrcMat->numCols;
 62      const float16_t *pSrcA = pSrcMat->pData;
 63      const float16_t *pInA0;
 64      const float16_t *pInA1;
 65      float16_t       *px;
 66      int32_t          row;
 67      uint32_t         blkCnt;           /* loop counters */
 68  
 69      row = numRows;
 70      px = pDstVec;
 71  
 72      /*
 73       * compute 4 rows in parallel
 74       */
 75      while (row >= 4)
 76      {
 77          const float16_t     *pInA2, *pInA3;
 78          float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
 79          f16x8_t            vecIn, acc0, acc1, acc2, acc3;
 80          float16_t const     *pSrcVecPtr = pSrcVec;
 81  
 82          /*
 83           * Initialize the pointers to 4 consecutive MatrixA rows
 84           */
 85          pInA0 = pSrcA;
 86          pInA1 = pInA0 + numCols;
 87          pInA2 = pInA1 + numCols;
 88          pInA3 = pInA2 + numCols;
 89          /*
 90           * Initialize the vector pointer
 91           */
 92          pInVec =  pSrcVecPtr;
 93          /*
 94           * reset accumulators
 95           */
 96          acc0 = vdupq_n_f16(0.0f);
 97          acc1 = vdupq_n_f16(0.0f);
 98          acc2 = vdupq_n_f16(0.0f);
 99          acc3 = vdupq_n_f16(0.0f);
100  
101          pSrcA0Vec = pInA0;
102          pSrcA1Vec = pInA1;
103          pSrcA2Vec = pInA2;
104          pSrcA3Vec = pInA3;
105  
106          blkCnt = numCols >> 3;
107          while (blkCnt > 0U)
108          {
109              f16x8_t vecA;
110  
111              vecIn = vld1q(pInVec);      
112              pInVec += 8;
113              vecA = vld1q(pSrcA0Vec);    
114              pSrcA0Vec += 8;
115              acc0 = vfmaq(acc0, vecIn, vecA);
116              vecA = vld1q(pSrcA1Vec);  
117              pSrcA1Vec += 8;
118              acc1 = vfmaq(acc1, vecIn, vecA);
119              vecA = vld1q(pSrcA2Vec);  
120              pSrcA2Vec += 8;
121              acc2 = vfmaq(acc2, vecIn, vecA);
122              vecA = vld1q(pSrcA3Vec);  
123              pSrcA3Vec += 8;
124              acc3 = vfmaq(acc3, vecIn, vecA);
125  
126              blkCnt--;
127          }
128          /*
129           * tail
130           * (will be merged thru tail predication)
131           */
132          blkCnt = numCols & 7;
133          if (blkCnt > 0U)
134          {
135              mve_pred16_t p0 = vctp16q(blkCnt);
136              f16x8_t vecA;
137  
138              vecIn = vldrhq_z_f16(pInVec, p0);
139              vecA = vld1q(pSrcA0Vec);
140              acc0 = vfmaq(acc0, vecIn, vecA);
141              vecA = vld1q(pSrcA1Vec);
142              acc1 = vfmaq(acc1, vecIn, vecA);
143              vecA = vld1q(pSrcA2Vec);
144              acc2 = vfmaq(acc2, vecIn, vecA);
145              vecA = vld1q(pSrcA3Vec);
146              acc3 = vfmaq(acc3, vecIn, vecA);
147          }
148          /*
149           * Sum the partial parts
150           */
151          *px++ = vecAddAcrossF16Mve(acc0);
152          *px++ = vecAddAcrossF16Mve(acc1);
153          *px++ = vecAddAcrossF16Mve(acc2);
154          *px++ = vecAddAcrossF16Mve(acc3);
155  
156          pSrcA += numCols * 4;
157          /*
158           * Decrement the row loop counter
159           */
160          row -= 4;
161      }
162  
163      /*
164       * compute 2 rows in parrallel
165       */
166      if (row >= 2)
167      {
168          float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pInVec;
169          f16x8_t            vecIn, acc0, acc1;
170          float16_t const     *pSrcVecPtr = pSrcVec;
171  
172          /*
173           * Initialize the pointers to 2 consecutive MatrixA rows
174           */
175          pInA0 = pSrcA;
176          pInA1 = pInA0 + numCols;
177          /*
178           * Initialize the vector pointer
179           */
180          pInVec = pSrcVecPtr;
181          /*
182           * reset accumulators
183           */
184          acc0 = vdupq_n_f16(0.0f);
185          acc1 = vdupq_n_f16(0.0f);
186          pSrcA0Vec = pInA0;
187          pSrcA1Vec = pInA1;
188  
189          blkCnt = numCols >> 3;
190          while (blkCnt > 0U)
191          {
192              f16x8_t vecA;
193  
194              vecIn = vld1q(pInVec);      
195              pInVec += 8;
196              vecA = vld1q(pSrcA0Vec);    
197              pSrcA0Vec += 8;
198              acc0 = vfmaq(acc0, vecIn, vecA);
199              vecA = vld1q(pSrcA1Vec);    
200              pSrcA1Vec += 8;
201              acc1 = vfmaq(acc1, vecIn, vecA);
202  
203              blkCnt--;
204          }
205          /*
206           * tail
207           * (will be merged thru tail predication)
208           */
209          blkCnt = numCols & 7;
210          if (blkCnt > 0U)
211          {
212              mve_pred16_t p0 = vctp16q(blkCnt);
213              f16x8_t vecA;
214  
215              vecIn = vldrhq_z_f16(pInVec, p0);
216              vecA = vld1q(pSrcA0Vec);
217              acc0 = vfmaq(acc0, vecIn, vecA);
218              vecA = vld1q(pSrcA1Vec);
219              acc1 = vfmaq(acc1, vecIn, vecA);
220          }
221          /*
222           * Sum the partial parts
223           */
224          *px++ = vecAddAcrossF16Mve(acc0);
225          *px++ = vecAddAcrossF16Mve(acc1);
226  
227          pSrcA += numCols * 2;
228          row -= 2;
229      }
230  
231      if (row >= 1)
232      {
233          f16x8_t             vecIn, acc0;
234          float16_t const     *pSrcA0Vec, *pInVec;
235          float16_t const      *pSrcVecPtr = pSrcVec;
236          /*
237           * Initialize the pointers to last MatrixA row
238           */
239          pInA0 = pSrcA;
240          /*
241           * Initialize the vector pointer
242           */
243          pInVec = pSrcVecPtr;
244          /*
245           * reset accumulators
246           */
247          acc0 = vdupq_n_f16(0.0f);
248  
249          pSrcA0Vec = pInA0;
250  
251          blkCnt = numCols >> 3;
252          while (blkCnt > 0U)
253          {
254              f16x8_t vecA;
255  
256              vecIn = vld1q(pInVec);      
257              pInVec += 8;
258              vecA = vld1q(pSrcA0Vec);    
259              pSrcA0Vec += 8;
260              acc0 = vfmaq(acc0, vecIn, vecA);
261  
262              blkCnt--;
263          }
264          /*
265           * tail
266           * (will be merged thru tail predication)
267           */
268          blkCnt = numCols & 7;
269          if (blkCnt > 0U)
270          {
271              mve_pred16_t p0 = vctp16q(blkCnt);
272              f16x8_t vecA;
273  
274              vecIn = vldrhq_z_f16(pInVec, p0);
275              vecA = vld1q(pSrcA0Vec);
276              acc0 = vfmaq(acc0, vecIn, vecA);
277          }
278          /*
279           * Sum the partial parts
280           */
281          *px++ = vecAddAcrossF16Mve(acc0);
282      }
283  }
284  #else
285  void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
286  {
287      uint32_t numRows = pSrcMat->numRows;
288      uint32_t numCols = pSrcMat->numCols;
289      const float16_t *pSrcA = pSrcMat->pData;
290      const float16_t *pInA1;      /* input data matrix pointer A of Q31 type */
291      const float16_t *pInA2;      /* input data matrix pointer A of Q31 type */
292      const float16_t *pInA3;      /* input data matrix pointer A of Q31 type */
293      const float16_t *pInA4;      /* input data matrix pointer A of Q31 type */
294      const float16_t *pInVec;     /* input data matrix pointer B of Q31 type */
295      float16_t *px;               /* Temporary output data matrix pointer */
296      uint16_t i, row, colCnt; /* loop counters */
297      float16_t matData, matData2, vecData, vecData2;
298  
299  
300      /* Process 4 rows at a time */
301      row = numRows >> 2;
302      i = 0u;
303      px = pDst;
304  
305      /* The following loop performs the dot-product of each row in pSrcA with the vector */
306      /* row loop */
307      while (row > 0) {
308          /* For every row wise process, the pInVec pointer is set
309           ** to the starting address of the vector */
310          pInVec = pVec;
311  
312          /* Initialize accumulators */
313          float16_t sum1 = 0.0f16;
314          float16_t sum2 = 0.0f16;
315          float16_t sum3 = 0.0f16;
316          float16_t sum4 = 0.0f16;
317  
318          /* Loop unrolling: process 2 columns per iteration */
319          colCnt = numCols;
320  
321          /* Initialize pointers to the starting address of the column being processed */
322          pInA1 = pSrcA + i;
323          pInA2 = pInA1 + numCols;
324          pInA3 = pInA2 + numCols;
325          pInA4 = pInA3 + numCols;
326  
327  
328          // Main loop: matrix-vector multiplication
329          while (colCnt > 0u) {
330              // Read 2 values from vector
331              vecData = *(pInVec)++;
332              // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
333              matData = *(pInA1)++;
334              sum1 += (_Float16)matData * (_Float16)vecData;
335              matData = *(pInA2)++;
336              sum2 += (_Float16)matData * (_Float16)vecData;
337              matData = *(pInA3)++;
338              sum3 += (_Float16)matData * (_Float16)vecData;
339              matData = *(pInA4)++;
340              sum4 += (_Float16)matData * (_Float16)vecData;
341  
342              // Decrement the loop counter
343              colCnt--;
344          }
345  
346          /* Saturate and store the result in the destination buffer */
347          *px++ = sum1;
348          *px++ = sum2;
349          *px++ = sum3;
350          *px++ = sum4;
351  
352          i = i + numCols * 4;
353  
354          /* Decrement the row loop counter */
355          row--;
356      }
357  
358      /* process any remaining rows */
359      row = numRows & 3u;
360      while (row > 0) {
361  
362          float16_t sum = 0.0f16;
363          pInVec = pVec;
364          pInA1 = pSrcA + i;
365  
366          colCnt = numCols >> 1;
367  
368          while (colCnt > 0) {
369              vecData = *(pInVec)++;
370              vecData2 = *(pInVec)++;
371              matData = *(pInA1)++;
372              matData2 = *(pInA1)++;
373              sum += (_Float16)matData * (_Float16)vecData;
374              sum += (_Float16)matData2 * (_Float16)vecData2;
375              colCnt--;
376          }
377          // process remainder of row
378          colCnt = numCols & 1u;
379          while (colCnt > 0) {
380              sum += (_Float16)*pInA1++ * (_Float16)*pInVec++;
381              colCnt--;
382          }
383  
384          *px++ = sum;
385          i = i + numCols;
386          row--;
387      }
388  }
389  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
390  
391  /**
392   * @} end of MatrixMult group
393   */
394  
395  #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 
396