/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_vec_mult_f32.c
arm_mat_vec_mult_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_vec_mult_f32.c
  4   * Description:  Floating-point matrix and vector multiplication
  5   *
  6   * $Date:        23 April 2021
  7   *
  8   * $Revision:    V1.9.0
  9   *
 10   * Target Processor: Cortex-M and Cortex-A cores
 11   * -------------------------------------------------------------------- */
 12  /*
 13   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 14   *
 15   * SPDX-License-Identifier: Apache-2.0
 16   *
 17   * Licensed under the Apache License, Version 2.0 (the License); you may
 18   * not use this file except in compliance with the License.
 19   * You may obtain a copy of the License at
 20   *
 21   * www.apache.org/licenses/LICENSE-2.0
 22   *
 23   * Unless required by applicable law or agreed to in writing, software
 24   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 25   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 26   * See the License for the specific language governing permissions and
 27   * limitations under the License.
 28   */
 29  
 30  #include "dsp/matrix_functions.h"
 31  
 32  
 33  /**
 34   * @ingroup groupMatrix
 35   */
 36  
 37  /**
 38   * @defgroup MatrixVectMult Matrix Vector Multiplication
 39   *
 40   * Multiplies a matrix and a vector.
 41   *
 42   */
 43  
 44  /**
 45   * @addtogroup MatrixVectMult
 46   * @{
 47   */
 48  
 49  /**
 50   * @brief Floating-point matrix and vector multiplication.
 51   * @param[in]       *pSrcMat points to the input matrix structure
 52   * @param[in]       *pVec points to input vector
 53   * @param[out]      *pDst points to output vector
 54   */
 55  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 56  
 57  #include "arm_helium_utils.h"
 58  
 59  void arm_mat_vec_mult_f32(
 60      const arm_matrix_instance_f32   *pSrcMat,
 61      const float32_t                 *pSrcVec,
 62      float32_t                       *pDstVec)
 63  {
 64      uint32_t         numRows = pSrcMat->numRows;
 65      uint32_t         numCols = pSrcMat->numCols;
 66      const float32_t *pSrcA = pSrcMat->pData;
 67      const float32_t *pInA0;
 68      const float32_t *pInA1;
 69      float32_t       *px;
 70      int32_t          row;
 71      uint32_t         blkCnt;           /* loop counters */
 72  
 73      row = numRows;
 74      px = pDstVec;
 75  
 76      /*
 77       * compute 4 rows in parallel
 78       */
 79      while (row >= 4)
 80      {
 81          const float32_t     *pInA2, *pInA3;
 82          float32_t const    *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
 83          f32x4_t            vecIn, acc0, acc1, acc2, acc3;
 84          float32_t const     *pSrcVecPtr = pSrcVec;
 85  
 86          /*
 87           * Initialize the pointers to 4 consecutive MatrixA rows
 88           */
 89          pInA0 = pSrcA;
 90          pInA1 = pInA0 + numCols;
 91          pInA2 = pInA1 + numCols;
 92          pInA3 = pInA2 + numCols;
 93          /*
 94           * Initialize the vector pointer
 95           */
 96          pInVec =  pSrcVecPtr;
 97          /*
 98           * reset accumulators
 99           */
100          acc0 = vdupq_n_f32(0.0f);
101          acc1 = vdupq_n_f32(0.0f);
102          acc2 = vdupq_n_f32(0.0f);
103          acc3 = vdupq_n_f32(0.0f);
104  
105          pSrcA0Vec = pInA0;
106          pSrcA1Vec = pInA1;
107          pSrcA2Vec = pInA2;
108          pSrcA3Vec = pInA3;
109  
110          blkCnt = numCols >> 2;
111          while (blkCnt > 0U)
112          {
113              f32x4_t vecA;
114  
115              vecIn = vld1q(pInVec);      
116              pInVec += 4;
117              vecA = vld1q(pSrcA0Vec);    
118              pSrcA0Vec += 4;
119              acc0 = vfmaq(acc0, vecIn, vecA);
120              vecA = vld1q(pSrcA1Vec);  
121              pSrcA1Vec += 4;
122              acc1 = vfmaq(acc1, vecIn, vecA);
123              vecA = vld1q(pSrcA2Vec);  
124              pSrcA2Vec += 4;
125              acc2 = vfmaq(acc2, vecIn, vecA);
126              vecA = vld1q(pSrcA3Vec);  
127              pSrcA3Vec += 4;
128              acc3 = vfmaq(acc3, vecIn, vecA);
129  
130              blkCnt--;
131          }
132          /*
133           * tail
134           * (will be merged thru tail predication)
135           */
136          blkCnt = numCols & 3;
137          if (blkCnt > 0U)
138          {
139              mve_pred16_t p0 = vctp32q(blkCnt);
140              f32x4_t vecA;
141  
142              vecIn = vldrwq_z_f32(pInVec, p0);
143              vecA = vld1q(pSrcA0Vec);
144              acc0 = vfmaq(acc0, vecIn, vecA);
145              vecA = vld1q(pSrcA1Vec);
146              acc1 = vfmaq(acc1, vecIn, vecA);
147              vecA = vld1q(pSrcA2Vec);
148              acc2 = vfmaq(acc2, vecIn, vecA);
149              vecA = vld1q(pSrcA3Vec);
150              acc3 = vfmaq(acc3, vecIn, vecA);
151          }
152          /*
153           * Sum the partial parts
154           */
155          *px++ = vecAddAcrossF32Mve(acc0);
156          *px++ = vecAddAcrossF32Mve(acc1);
157          *px++ = vecAddAcrossF32Mve(acc2);
158          *px++ = vecAddAcrossF32Mve(acc3);
159  
160          pSrcA += numCols * 4;
161          /*
162           * Decrement the row loop counter
163           */
164          row -= 4;
165      }
166  
167      /*
168       * compute 2 rows in parrallel
169       */
170      if (row >= 2)
171      {
172          float32_t const    *pSrcA0Vec, *pSrcA1Vec, *pInVec;
173          f32x4_t            vecIn, acc0, acc1;
174          float32_t const     *pSrcVecPtr = pSrcVec;
175  
176          /*
177           * Initialize the pointers to 2 consecutive MatrixA rows
178           */
179          pInA0 = pSrcA;
180          pInA1 = pInA0 + numCols;
181          /*
182           * Initialize the vector pointer
183           */
184          pInVec = pSrcVecPtr;
185          /*
186           * reset accumulators
187           */
188          acc0 = vdupq_n_f32(0.0f);
189          acc1 = vdupq_n_f32(0.0f);
190          pSrcA0Vec = pInA0;
191          pSrcA1Vec = pInA1;
192  
193          blkCnt = numCols >> 2;
194          while (blkCnt > 0U)
195          {
196              f32x4_t vecA;
197  
198              vecIn = vld1q(pInVec);      
199              pInVec += 4;
200              vecA = vld1q(pSrcA0Vec);    
201              pSrcA0Vec += 4;
202              acc0 = vfmaq(acc0, vecIn, vecA);
203              vecA = vld1q(pSrcA1Vec);    
204              pSrcA1Vec += 4;
205              acc1 = vfmaq(acc1, vecIn, vecA);
206  
207              blkCnt--;
208          }
209          /*
210           * tail
211           * (will be merged thru tail predication)
212           */
213          blkCnt = numCols & 3;
214          if (blkCnt > 0U)
215          {
216              mve_pred16_t p0 = vctp32q(blkCnt);
217              f32x4_t vecA;
218  
219              vecIn = vldrwq_z_f32(pInVec, p0);
220              vecA = vld1q(pSrcA0Vec);
221              acc0 = vfmaq(acc0, vecIn, vecA);
222              vecA = vld1q(pSrcA1Vec);
223              acc1 = vfmaq(acc1, vecIn, vecA);
224          }
225          /*
226           * Sum the partial parts
227           */
228          *px++ = vecAddAcrossF32Mve(acc0);
229          *px++ = vecAddAcrossF32Mve(acc1);
230  
231          pSrcA += numCols * 2;
232          row -= 2;
233      }
234  
235      if (row >= 1)
236      {
237          f32x4_t             vecIn, acc0;
238          float32_t const     *pSrcA0Vec, *pInVec;
239          float32_t const      *pSrcVecPtr = pSrcVec;
240          /*
241           * Initialize the pointers to last MatrixA row
242           */
243          pInA0 = pSrcA;
244          /*
245           * Initialize the vector pointer
246           */
247          pInVec = pSrcVecPtr;
248          /*
249           * reset accumulators
250           */
251          acc0 = vdupq_n_f32(0.0f);
252  
253          pSrcA0Vec = pInA0;
254  
255          blkCnt = numCols >> 2;
256          while (blkCnt > 0U)
257          {
258              f32x4_t vecA;
259  
260              vecIn = vld1q(pInVec);      
261              pInVec += 4;
262              vecA = vld1q(pSrcA0Vec);    
263              pSrcA0Vec += 4;
264              acc0 = vfmaq(acc0, vecIn, vecA);
265  
266              blkCnt--;
267          }
268          /*
269           * tail
270           * (will be merged thru tail predication)
271           */
272          blkCnt = numCols & 3;
273          if (blkCnt > 0U)
274          {
275              mve_pred16_t p0 = vctp32q(blkCnt);
276              f32x4_t vecA;
277  
278              vecIn = vldrwq_z_f32(pInVec, p0);
279              vecA = vld1q(pSrcA0Vec);
280              acc0 = vfmaq(acc0, vecIn, vecA);
281          }
282          /*
283           * Sum the partial parts
284           */
285          *px++ = vecAddAcrossF32Mve(acc0);
286      }
287  }
288  #else
289  
290  void arm_mat_vec_mult_f32(const arm_matrix_instance_f32 *pSrcMat, const float32_t *pVec, float32_t *pDst)
291  {
292      uint32_t numRows = pSrcMat->numRows;
293      uint32_t numCols = pSrcMat->numCols;
294      const float32_t *pSrcA = pSrcMat->pData;
295      const float32_t *pInA1;      /* input data matrix pointer A of Q31 type */
296      const float32_t *pInA2;      /* input data matrix pointer A of Q31 type */
297      const float32_t *pInA3;      /* input data matrix pointer A of Q31 type */
298      const float32_t *pInA4;      /* input data matrix pointer A of Q31 type */
299      const float32_t *pInVec;     /* input data matrix pointer B of Q31 type */
300      float32_t *px;               /* Temporary output data matrix pointer */
301      uint16_t i, row, colCnt; /* loop counters */
302      float32_t matData, matData2, vecData, vecData2;
303  
304  
305      /* Process 4 rows at a time */
306      row = numRows >> 2;
307      i = 0u;
308      px = pDst;
309  
310      /* The following loop performs the dot-product of each row in pSrcA with the vector */
311      /* row loop */
312      while (row > 0) {
313          /* Initialize accumulators */
314          float32_t sum1 = 0.0f;
315          float32_t sum2 = 0.0f;
316          float32_t sum3 = 0.0f;
317          float32_t sum4 = 0.0f;
318  
319          /* For every row wise process, the pInVec pointer is set
320           ** to the starting address of the vector */
321          pInVec = pVec;
322  
323          /* Loop unrolling: process 2 columns per iteration */
324          colCnt = numCols;
325  
326          /* Initialize pointers to the starting address of the column being processed */
327          pInA1 = pSrcA + i;
328          pInA2 = pInA1 + numCols;
329          pInA3 = pInA2 + numCols;
330          pInA4 = pInA3 + numCols;
331  
332  
333          // Main loop: matrix-vector multiplication
334          while (colCnt > 0u) {
335              // Read 2 values from vector
336              vecData = *(pInVec)++;
337              // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
338              matData = *(pInA1)++;
339              sum1 += matData * vecData;
340              matData = *(pInA2)++;
341              sum2 += matData * vecData;
342              matData = *(pInA3)++;
343              sum3 += matData * vecData;
344              matData = *(pInA4)++;
345              sum4 += matData * vecData;
346  
347              // Decrement the loop counter
348              colCnt--;
349          }
350  
351          /* Saturate and store the result in the destination buffer */
352          *px++ = sum1;
353          *px++ = sum2;
354          *px++ = sum3;
355          *px++ = sum4;
356  
357          i = i + numCols * 4;
358  
359          /* Decrement the row loop counter */
360          row--;
361      }
362  
363      /* process any remaining rows */
364      row = numRows & 3u;
365      while (row > 0) {
366  
367          float32_t sum = 0.0f;
368          pInVec = pVec;
369          pInA1 = pSrcA + i;
370  
371          colCnt = numCols >> 1;
372          while (colCnt > 0) {
373              vecData = *(pInVec)++;
374              vecData2 = *(pInVec)++;
375              matData = *(pInA1)++;
376              matData2 = *(pInA1)++;
377              sum += matData * vecData;
378              sum += matData2 * vecData2;
379              colCnt--;
380          }
381          // process remainder of row
382          colCnt = numCols & 1u;
383  
384  
385          while (colCnt > 0) {
386              sum += *pInA1++ * *pInVec++;
387              colCnt--;
388          }
389  
390          *px++ = sum;
391          i = i + numCols;
392          row--;
393      }
394  }
395  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
396  
397  /**
398   * @} end of MatrixMult group
399   */