/ Drivers / CMSIS / DSP / Source / MatrixFunctions / arm_mat_vec_mult_q7.c
arm_mat_vec_mult_q7.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_mat_vec_mult_q7.c
  4   * Description:  Q7 matrix and vector multiplication
  5   *
  6   * $Date:        23 April 2021
  7   *
  8   * $Revision:    V1.9.0
  9   *
 10   * Target Processor: Cortex-M and Cortex-A cores
 11   * -------------------------------------------------------------------- */
 12  /*
 13   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 14   *
 15   * SPDX-License-Identifier: Apache-2.0
 16   *
 17   * Licensed under the Apache License, Version 2.0 (the License); you may
 18   * not use this file except in compliance with the License.
 19   * You may obtain a copy of the License at
 20   *
 21   * www.apache.org/licenses/LICENSE-2.0
 22   *
 23   * Unless required by applicable law or agreed to in writing, software
 24   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 25   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 26   * See the License for the specific language governing permissions and
 27   * limitations under the License.
 28   */
 29  
 30  #include "dsp/matrix_functions.h"
 31  
 32  /**
 33   * @ingroup groupMatrix
 34   */
 35  
 36  
 37  
 38  /**
 39   * @addtogroup MatrixVectMult
 40   * @{
 41   */
 42  
 43  /**
 44   * @brief Q7 matrix and vector multiplication.
 45   * @param[in]       *pSrcMat points to the input matrix structure
 46   * @param[in]       *pVec points to the input vector
 47   * @param[out]      *pDst points to the output vector
 48   */
 49  #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
 50  
 51  #include "arm_helium_utils.h"
 52  
 53  void arm_mat_vec_mult_q7(
 54      const arm_matrix_instance_q7 * pSrcMat,
 55      const q7_t     *pSrcVec,
 56      q7_t           *pDstVec)
 57  {
 58      const q7_t *pMatSrc = pSrcMat->pData;
 59      const q7_t *pMat0, *pMat1;
 60      uint32_t     numRows = pSrcMat->numRows;
 61      uint32_t     numCols = pSrcMat->numCols;
 62      q7_t       *px;
 63      int32_t      row;
 64      uint16_t     blkCnt;           /* loop counters */
 65  
 66      row = numRows;
 67      px = pDstVec;
 68  
 69      /*
 70       * compute 4x64-bit accumulators per loop
 71       */
 72      while (row >= 4)
 73      {
 74          q7_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pMat3Vec, *pVec;
 75          const q7_t  *pMat2, *pMat3;
 76          q7_t const  *pSrcVecPtr = pSrcVec;
 77          q31_t        acc0, acc1, acc2, acc3;
 78          q7x16_t      vecMatA0, vecMatA1, vecMatA2, vecMatA3, vecIn;
 79  
 80          pVec = pSrcVec;
 81          /*
 82           * Initialize the pointer pIn1 to point to the starting address of the column being processed
 83           */
 84          pMat0 = pMatSrc;
 85          pMat1 = pMat0 + numCols;
 86          pMat2 = pMat1 + numCols;
 87          pMat3 = pMat2 + numCols;
 88  
 89          acc0 = 0L;
 90          acc1 = 0L;
 91          acc2 = 0L;
 92          acc3 = 0L;
 93  
 94          pMat0Vec = pMat0;
 95          pMat1Vec = pMat1;
 96          pMat2Vec = pMat2;
 97          pMat3Vec = pMat3;
 98          pVec = pSrcVecPtr;
 99  
100          blkCnt = numCols >> 4;
101          while (blkCnt > 0U)
102          {
103  
104              vecMatA0 = vld1q(pMat0Vec); 
105              pMat0Vec += 16;
106              vecMatA1 = vld1q(pMat1Vec); 
107              pMat1Vec += 16;
108              vecMatA2 = vld1q(pMat2Vec); 
109              pMat2Vec += 16;
110              vecMatA3 = vld1q(pMat3Vec); 
111              pMat3Vec += 16;
112              vecIn = vld1q(pVec);        
113              pVec += 16;
114  
115              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
116              acc1 = vmladavaq(acc1, vecIn, vecMatA1);
117              acc2 = vmladavaq(acc2, vecIn, vecMatA2);
118              acc3 = vmladavaq(acc3, vecIn, vecMatA3);
119  
120              blkCnt--;
121          }
122          /*
123           * tail
124           * (will be merged thru tail predication)
125           */
126          blkCnt = numCols & 0xF;
127          if (blkCnt > 0U)
128          {
129              mve_pred16_t p0 = vctp8q(blkCnt);
130  
131              vecMatA0 = vld1q(pMat0Vec);
132              vecMatA1 = vld1q(pMat1Vec);
133              vecMatA2 = vld1q(pMat2Vec);
134              vecMatA3 = vld1q(pMat3Vec);
135              vecIn = vldrbq_z_s8(pVec, p0);
136  
137              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
138              acc1 = vmladavaq(acc1, vecIn, vecMatA1);
139              acc2 = vmladavaq(acc2, vecIn, vecMatA2);
140              acc3 = vmladavaq(acc3, vecIn, vecMatA3);
141          }
142  
143          *px++ = __SSAT(acc0 >> 7, 8);
144          *px++ = __SSAT(acc1 >> 7, 8);
145          *px++ = __SSAT(acc2 >> 7, 8);
146          *px++ = __SSAT(acc3 >> 7, 8);
147  
148          pMatSrc += numCols * 4;
149          /*
150           * Decrement the row loop counter
151           */
152          row -= 4;
153      }
154  
155      /*
156       * process any remaining rows pair
157       */
158      if (row >= 2)
159      {
160          q7_t const  *pMat0Vec, *pMat1Vec, *pVec;
161          q7_t const  *pSrcVecPtr = pSrcVec;
162          q31_t         acc0, acc1;
163          q7x16_t     vecMatA0, vecMatA1, vecIn;
164  
165          /*
166           * For every row wise process, the pInVec pointer is set
167           * to the starting address of the vector
168           */
169          pVec = pSrcVec;
170  
171          /*
172           * Initialize the pointer pIn1 to point to the starting address of the column being processed
173           */
174          pMat0 = pMatSrc;
175          pMat1 = pMat0 + numCols;
176  
177          acc0 = 0;
178          acc1 = 0;
179  
180          pMat0Vec = pMat0;
181          pMat1Vec = pMat1;
182          pVec = pSrcVecPtr;
183  
184          blkCnt = numCols >> 4;
185          while (blkCnt > 0U)
186          {
187              vecMatA0 = vld1q(pMat0Vec); 
188              pMat0Vec += 16;
189              vecMatA1 = vld1q(pMat1Vec); 
190              pMat1Vec += 16;
191              vecIn = vld1q(pVec);        
192              pVec += 16;
193  
194              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
195              acc1 = vmladavaq(acc1, vecIn, vecMatA1);
196  
197              blkCnt--;
198          }
199  
200          /*
201           * tail
202           * (will be merged thru tail predication)
203           */
204          blkCnt = numCols & 0xF;
205          if (blkCnt > 0U)
206          {
207              mve_pred16_t p0 = vctp8q(blkCnt);
208  
209              vecMatA0 = vld1q(pMat0Vec);
210              vecMatA1 = vld1q(pMat1Vec);
211              vecIn = vldrbq_z_s8(pVec, p0);
212  
213              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
214              acc1 = vmladavaq(acc1, vecIn, vecMatA1);
215          }
216  
217          *px++ = __SSAT(acc0 >> 7, 8);
218          *px++ = __SSAT(acc1 >> 7, 8);
219  
220          pMatSrc += numCols * 2;
221          /*
222           * Decrement the row loop counter
223           */
224          row -= 2;
225      }
226  
227      if (row >= 1)
228      {
229          q7_t const  *pMat0Vec, *pVec;
230          q7_t const  *pSrcVecPtr = pSrcVec;
231          q31_t         acc0;
232          q7x16_t     vecMatA0, vecIn;
233  
234          /*
235           * For every row wise process, the pInVec pointer is set
236           * to the starting address of the vector
237           */
238          pVec = pSrcVec;
239  
240          /*
241           * Initialize the pointer pIn1 to point to the starting address of the column being processed
242           */
243          pMat0 = pMatSrc;
244  
245          acc0 = 0LL;
246  
247          pMat0Vec = pMat0;
248          pVec = pSrcVecPtr;
249  
250          blkCnt = numCols >> 4;
251          while (blkCnt > 0U)
252          {
253              vecMatA0 = vld1q(pMat0Vec); 
254              pMat0Vec += 16;
255              vecIn = vld1q(pVec);        
256              pVec += 16;
257  
258              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
259              blkCnt--;
260          }
261          /*
262           * tail
263           * (will be merged thru tail predication)
264           */
265          blkCnt = numCols & 0xF;
266          if (blkCnt > 0U)
267          {
268              mve_pred16_t p0 = vctp8q(blkCnt);
269  
270              vecMatA0 = vld1q(pMat0Vec);
271              vecIn = vldrbq_z_s8(pVec, p0);
272              acc0 = vmladavaq(acc0, vecIn, vecMatA0);
273          }
274          *px++ = __SSAT(acc0 >> 7, 8);
275      }
276  }
277  
278  #else
279  void arm_mat_vec_mult_q7(const arm_matrix_instance_q7 *pSrcMat, const q7_t *pVec, q7_t *pDst)
280  {
281      uint32_t numRows = pSrcMat->numRows;
282      uint32_t numCols = pSrcMat->numCols;
283      const q7_t *pSrcA = pSrcMat->pData;
284      const q7_t *pInA1;       /* input data matrix pointer of Q7 type */
285      const q7_t *pInA2;       /* input data matrix pointer of Q7 type */
286      const q7_t *pInA3;       /* input data matrix pointer of Q7 type */
287      const q7_t *pInA4;       /* input data matrix pointer of Q7 type */
288      const q7_t *pInVec;      /* input data vector pointer of Q7 type */
289      q7_t *px;                /* output data pointer */
290      uint32_t i, row, colCnt; /* loop counters */
291  
292      q31_t matData, matData2, vecData, vecData2;
293  
294  
295      /* Process 4 rows at a time */
296      row = numRows >> 2;
297      i = 0u;
298      px = pDst;
299  
300  
301  
302      /* The following loop performs the dot-product of each row in pSrcA with the vector */
303      while (row > 0) {
304          /* Initialize accumulators */
305          q31_t sum1 = 0;
306          q31_t sum2 = 0;
307          q31_t sum3 = 0;
308          q31_t sum4 = 0;
309  
310          /* For every row wise process, the pInVec pointer is set
311           ** to the starting address of the vector */
312          pInVec = pVec;
313  
314          /* Loop unrolling: process 4 columns per iteration */
315          colCnt = numCols >> 2;
316  
317          /* Initialize row pointers so we can track 4 rows at once */
318          pInA1 = pSrcA + i;
319          pInA2 = pInA1 + numCols;
320          pInA3 = pInA2 + numCols;
321          pInA4 = pInA3 + numCols;
322  
323  
324          // Inner loop: matrix-vector multiplication
325  
326          while (colCnt > 0u) {
327              // Read 4 values from vector
328              vecData = read_q7x4_ia (&pInVec);
329              vecData2 = __SXTB16(__ROR(vecData, 8));
330              vecData = __SXTB16(vecData);
331              // Read 16 values from the matrix - 4 values from each of 4 rows, and do multiply accumulate
332              matData = read_q7x4_ia (&pInA1);
333              matData2 = __SXTB16(__ROR(matData, 8));
334              matData = __SXTB16(matData);
335              sum1 = __SMLAD(matData, vecData, sum1);
336              sum1 = __SMLAD(matData2, vecData2, sum1);
337              matData = read_q7x4_ia (&pInA2);
338              matData2 = __SXTB16(__ROR(matData, 8));
339              matData = __SXTB16(matData);
340              sum2 = __SMLAD(matData, vecData, sum2);
341              sum2 = __SMLAD(matData2, vecData2, sum2);
342              matData = read_q7x4_ia (&pInA3);
343              matData2 = __SXTB16(__ROR(matData, 8));
344              matData = __SXTB16(matData);
345              sum3 = __SMLAD(matData, vecData, sum3);
346              sum3 = __SMLAD(matData2, vecData2, sum3);
347              matData = read_q7x4_ia (&pInA4);
348              matData2 = __SXTB16(__ROR(matData, 8));
349              matData = __SXTB16(matData);
350              sum4 = __SMLAD(matData, vecData, sum4);
351              sum4 = __SMLAD(matData2, vecData2, sum4);
352  
353              // Decrement the loop counter
354              colCnt--;
355          }
356  
357          /* process any remaining columns */
358  
359          colCnt = numCols & 3u;
360  
361          while (colCnt > 0) {
362              vecData = *pInVec++;
363              sum1 += *pInA1++ * vecData;
364              sum2 += *pInA2++ * vecData;
365              sum3 += *pInA3++ * vecData;
366              sum4 += *pInA4++ * vecData;
367              colCnt--;
368          }
369  
370          /* Saturate and store the result in the destination buffer */
371          *px++ = (q7_t)(__SSAT((sum1 >> 7), 8));
372          *px++ = (q7_t)(__SSAT((sum2 >> 7), 8));
373          *px++ = (q7_t)(__SSAT((sum3 >> 7), 8));
374          *px++ = (q7_t)(__SSAT((sum4 >> 7), 8));
375  
376          i = i + numCols * 4;
377  
378          /* Decrement the row loop counter */
379          row--;
380      }
381  
382      /* process any remaining rows */
383      row = numRows & 3u;
384      while (row > 0) {
385  
386          q31_t sum = 0;
387          pInVec = pVec;
388          pInA1 = pSrcA + i;
389  
390          // loop unrolling - process 4 elements at a time
391          colCnt = numCols >> 2;
392  
393          while (colCnt > 0) {
394              vecData = read_q7x4_ia (&pInVec);
395              vecData2 = __SXTB16(__ROR(vecData, 8));
396              vecData = __SXTB16(vecData);
397              matData = read_q7x4_ia (&pInA1);
398              matData2 = __SXTB16(__ROR(matData, 8));
399              matData = __SXTB16(matData);
400              sum = __SMLAD(matData, vecData, sum);
401              sum = __SMLAD(matData2, vecData2, sum);
402              colCnt--;
403          }
404  
405          // process remainder of row
406          colCnt = numCols & 3u;
407          while (colCnt > 0) {
408              sum += *pInA1++ * *pInVec++;
409              colCnt--;
410          }
411          *px++ = (q7_t)(__SSAT((sum >> 7), 8));
412          i = i + numCols;
413          row--;
414      }
415  }
416  #endif /* defined(ARM_MATH_MVEI) */
417  
418  /**
419   * @} end of MatrixMult group
420   */