/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_vec_mult_q31.c
arm_mat_vec_mult_q31.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_vec_mult_q31.c
  4   * Description:  Q31 matrix and vector multiplication
  5   *
  6   * $Date:        23 April 2021
  7   *
  8   * $Revision:    V1.9.0
  9   *
 10   * Target Processor: Cortex-M and Cortex-A cores
 11   * -------------------------------------------------------------------- */
 12  /*
 13   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 14   *
 15   * SPDX-License-Identifier: Apache-2.0
 16   *
 17   * Licensed under the Apache License, Version 2.0 (the License); you may
 18   * not use this file except in compliance with the License.
 19   * You may obtain a copy of the License at
 20   *
 21   * www.apache.org/licenses/LICENSE-2.0
 22   *
 23   * Unless required by applicable law or agreed to in writing, software
 24   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 25   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 26   * See the License for the specific language governing permissions and
 27   * limitations under the License.
 28   */
 29  
 30  #include "dsp/matrix_functions.h"
 31  
 32  /**
 33   * @ingroup groupMatrix
 34   */
 35  
 36  
 37  
 38  /**
 39   * @addtogroup MatrixVectMult
 40   * @{
 41   */
 42  
 43  /**
 44   * @brief Q31 matrix and vector multiplication.
 45   * @param[in]       *pSrcMat points to the input matrix structure
 46   * @param[in]       *pVec points to the input vector
 47   * @param[out]      *pDst points to the output vector
 48   */
 49  #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
 50  void arm_mat_vec_mult_q31(
 51      const arm_matrix_instance_q31 * pSrcMat,
 52      const q31_t     *pSrcVec,
 53      q31_t           *pDstVec)
 54  {
 55      const q31_t *pMatSrc = pSrcMat->pData;
 56      const q31_t *pMat0, *pMat1;
 57      uint32_t     numRows = pSrcMat->numRows;
 58      uint32_t     numCols = pSrcMat->numCols;
 59      q31_t       *px;
 60      int32_t      row;
 61      uint16_t     blkCnt;           /* loop counters */
 62  
 63      row = numRows;
 64      px = pDstVec;
 65  
 66      /*
 67       * compute 3x64-bit accumulators per loop
 68       */
 69      while (row >= 3)
 70      {
 71          q31_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
 72          const q31_t  *pMat2;
 73          q31_t const  *pSrcVecPtr = pSrcVec;
 74          q63_t         acc0, acc1, acc2;
 75          q31x4_t     vecMatA0, vecMatA1, vecMatA2, vecIn;
 76  
 77  
 78          pVec = pSrcVec;
 79          /*
 80           * Initialize the pointer pIn1 to point to the starting address of the column being processed
 81           */
 82          pMat0 = pMatSrc;
 83          pMat1 = pMat0 + numCols;
 84          pMat2 = pMat1 + numCols;
 85  
 86          acc0 = 0LL;
 87          acc1 = 0LL;
 88          acc2 = 0LL;
 89  
 90          pMat0Vec = pMat0;
 91          pMat1Vec = pMat1;
 92          pMat2Vec = pMat2;
 93          pVec = pSrcVecPtr;
 94  
 95          blkCnt = numCols >> 2;
 96          while (blkCnt > 0U)
 97          {
 98              vecMatA0 = vld1q(pMat0Vec); 
 99              pMat0Vec += 4;
100              vecMatA1 = vld1q(pMat1Vec); 
101              pMat1Vec += 4;
102              vecMatA2 = vld1q(pMat2Vec); 
103              pMat2Vec += 4;
104              vecIn = vld1q(pVec);        
105              pVec += 4;
106  
107              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
108              acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
109              acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
110  
111              blkCnt--;
112          }
113          /*
114           * tail
115           * (will be merged thru tail predication)
116           */
117          blkCnt = numCols & 3;
118          if (blkCnt > 0U)
119          {
120              mve_pred16_t p0 = vctp32q(blkCnt);
121  
122              vecMatA0 = vld1q(pMat0Vec);
123              vecMatA1 = vld1q(pMat1Vec);
124              vecMatA2 = vld1q(pMat2Vec);
125              vecIn = vldrwq_z_s32(pVec, p0);
126  
127              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
128              acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
129              acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
130          }
131  
132          *px++ = asrl(acc0, 31);
133          *px++ = asrl(acc1, 31);
134          *px++ = asrl(acc2, 31);
135  
136          pMatSrc += numCols * 3;
137          /*
138           * Decrement the row loop counter
139           */
140          row -= 3;
141      }
142  
143      /*
144       * process any remaining rows pair
145       */
146      if (row >= 2)
147      {
148          q31_t const *pMat0Vec, *pMat1Vec, *pVec;
149          q31_t const  *pSrcVecPtr = pSrcVec;
150          q63_t         acc0, acc1;
151          q31x4_t     vecMatA0, vecMatA1, vecIn;
152  
153          /*
154           * For every row wise process, the pInVec pointer is set
155           * to the starting address of the vector
156           */
157          pVec = pSrcVec;
158  
159          /*
160           * Initialize the pointer pIn1 to point to the starting address of the column being processed
161           */
162          pMat0 = pMatSrc;
163          pMat1 = pMat0 + numCols;
164  
165          acc0 = 0LL;
166          acc1 = 0LL;
167  
168          pMat0Vec = pMat0;
169          pMat1Vec = pMat1;
170          pVec = pSrcVecPtr;
171  
172          blkCnt = numCols >> 2;
173          while (blkCnt > 0U)
174          {
175              vecMatA0 = vld1q(pMat0Vec); 
176              pMat0Vec += 4;
177              vecMatA1 = vld1q(pMat1Vec); 
178              pMat1Vec += 4;
179              vecIn = vld1q(pVec);        
180              pVec += 4;
181  
182              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
183              acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
184  
185              blkCnt--;
186          }
187  
188          /*
189           * tail
190           * (will be merged thru tail predication)
191           */
192          blkCnt = numCols & 3;
193          if (blkCnt > 0U)
194          {
195              mve_pred16_t p0 = vctp32q(blkCnt);
196  
197              vecMatA0 = vld1q(pMat0Vec);
198              vecMatA1 = vld1q(pMat1Vec);
199              vecIn = vldrwq_z_s32(pVec, p0);
200  
201              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
202              acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
203          }
204  
205          *px++ = asrl(acc0, 31);
206          *px++ = asrl(acc1, 31);
207  
208          pMatSrc += numCols * 2;
209          /*
210           * Decrement the row loop counter
211           */
212          row -= 2;
213      }
214  
215      if (row >= 1)
216      {
217          q31_t const *pMat0Vec, *pVec;
218          q31_t const  *pSrcVecPtr = pSrcVec;
219          q63_t         acc0;
220          q31x4_t     vecMatA0, vecIn;
221  
222          /*
223           * For every row wise process, the pInVec pointer is set
224           * to the starting address of the vector
225           */
226          pVec = pSrcVec;
227  
228          /*
229           * Initialize the pointer pIn1 to point to the starting address of the column being processed
230           */
231          pMat0 = pMatSrc;
232  
233          acc0 = 0LL;
234  
235          pMat0Vec = pMat0;
236          pVec = pSrcVecPtr;
237  
238          blkCnt = numCols >> 2;
239          while (blkCnt > 0U)
240          {
241              vecMatA0 = vld1q(pMat0Vec); 
242              pMat0Vec += 4;
243              vecIn = vld1q(pVec);        
244              pVec += 4;
245              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
246              blkCnt--;
247          }
248          /*
249           * tail
250           * (will be merged thru tail predication)
251           */
252          blkCnt = numCols & 3;
253          if (blkCnt > 0U)
254          {
255              mve_pred16_t p0 = vctp32q(blkCnt);
256  
257              vecMatA0 = vld1q(pMat0Vec);
258              vecIn = vldrwq_z_s32(pVec, p0);
259              acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
260          }
261  
262          *px++ = asrl(acc0, 31);
263      }
264  }
265  #else
266  void arm_mat_vec_mult_q31(const arm_matrix_instance_q31 *pSrcMat, const q31_t *pVec, q31_t *pDst)
267  {
268      uint32_t numRows = pSrcMat->numRows;
269      uint32_t numCols = pSrcMat->numCols;
270      const q31_t *pSrcA = pSrcMat->pData;
271      const q31_t *pInA1;      /* input data matrix pointer A of Q31 type */
272      const q31_t *pInA2;      /* input data matrix pointer A of Q31 type */
273      const q31_t *pInA3;      /* input data matrix pointer A of Q31 type */
274      const q31_t *pInA4;      /* input data matrix pointer A of Q31 type */
275      const q31_t *pInVec;     /* input data matrix pointer B of Q31 type */
276      q31_t *px;               /* Temporary output data matrix pointer */
277      uint16_t i, row, colCnt; /* loop counters */
278      q31_t matData, matData2, vecData, vecData2;
279  
280  
281      /* Process 4 rows at a time */
282      row = numRows >> 2;
283      i = 0u;
284      px = pDst;
285  
286      /* The following loop performs the dot-product of each row in pSrcA with the vector */
287      /* row loop */
288      while (row > 0) {
289          /* Initialize accumulators */
290          q63_t sum1 = 0;
291          q63_t sum2 = 0;
292          q63_t sum3 = 0;
293          q63_t sum4 = 0;
294  
295          /* For every row wise process, the pInVec pointer is set
296           ** to the starting address of the vector */
297          pInVec = pVec;
298  
299          /* Loop unrolling: process 2 columns per iteration */
300          colCnt = numCols;
301  
302          /* Initialize pointers to the starting address of the column being processed */
303          pInA1 = pSrcA + i;
304          pInA2 = pInA1 + numCols;
305          pInA3 = pInA2 + numCols;
306          pInA4 = pInA3 + numCols;
307  
308  
309          // Main loop: matrix-vector multiplication
310          while (colCnt > 0u) {
311              // Read 2 values from vector
312              vecData = *(pInVec)++;
313  
314              // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
315              matData = *(pInA1)++;
316              sum1 += (q63_t)matData * vecData;
317              matData = *(pInA2)++;
318              sum2 += (q63_t)matData * vecData;
319              matData = *(pInA3)++;
320              sum3 += (q63_t)matData * vecData;
321              matData = *(pInA4)++;
322              sum4 += (q63_t)matData * vecData;
323  
324              // Decrement the loop counter
325              colCnt--;
326          }
327  
328          /* Saturate and store the result in the destination buffer */
329          *px++ = (q31_t)(sum1 >> 31);
330          *px++ = (q31_t)(sum2 >> 31);
331          *px++ = (q31_t)(sum3 >> 31);
332          *px++ = (q31_t)(sum4 >> 31);
333  
334          i = i + numCols * 4;
335  
336          /* Decrement the row loop counter */
337          row--;
338      }
339  
340      /* process any remaining rows */
341      row = numRows & 3u;
342      while (row > 0) {
343  
344          q63_t sum = 0;
345          pInVec = pVec;
346          pInA1 = pSrcA + i;
347  
348          colCnt = numCols >> 1;
349  
350          while (colCnt > 0) {
351              vecData = *(pInVec)++;
352              vecData2 = *(pInVec)++;
353              matData = *(pInA1)++;
354              matData2 = *(pInA1)++;
355              sum += (q63_t)matData * vecData;
356              sum += (q63_t)matData2 * vecData2;
357              colCnt--;
358          }
359  
360          // process remainder of row
361          colCnt = numCols & 1u;
362          while (colCnt > 0) {
363              sum += (q63_t)*pInA1++ * *pInVec++;
364              colCnt--;
365          }
366  
367          *px++ = (q31_t)(sum >> 31);
368          i = i + numCols;
369          row--;
370      }
371  }
372  #endif /* defined(ARM_MATH_MVEI) */
373  
374  /**
375   * @} end of MatrixMult group
376   */