arm_mat_vec_mult_q31.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_vec_mult_q31.c 4 * Description: Q31 matrix and vector multiplication 5 * 6 * $Date: 23 April 2021 7 * 8 * $Revision: V1.9.0 9 * 10 * Target Processor: Cortex-M and Cortex-A cores 11 * -------------------------------------------------------------------- */ 12 /* 13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 14 * 15 * SPDX-License-Identifier: Apache-2.0 16 * 17 * Licensed under the Apache License, Version 2.0 (the License); you may 18 * not use this file except in compliance with the License. 19 * You may obtain a copy of the License at 20 * 21 * www.apache.org/licenses/LICENSE-2.0 22 * 23 * Unless required by applicable law or agreed to in writing, software 24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 * See the License for the specific language governing permissions and 27 * limitations under the License. 28 */ 29 30 #include "dsp/matrix_functions.h" 31 32 /** 33 * @ingroup groupMatrix 34 */ 35 36 37 38 /** 39 * @addtogroup MatrixVectMult 40 * @{ 41 */ 42 43 /** 44 * @brief Q31 matrix and vector multiplication. 45 * @param[in] *pSrcMat points to the input matrix structure 46 * @param[in] *pVec points to the input vector 47 * @param[out] *pDst points to the output vector 48 */ 49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE) 50 void arm_mat_vec_mult_q31( 51 const arm_matrix_instance_q31 * pSrcMat, 52 const q31_t *pSrcVec, 53 q31_t *pDstVec) 54 { 55 const q31_t *pMatSrc = pSrcMat->pData; 56 const q31_t *pMat0, *pMat1; 57 uint32_t numRows = pSrcMat->numRows; 58 uint32_t numCols = pSrcMat->numCols; 59 q31_t *px; 60 int32_t row; 61 uint16_t blkCnt; /* loop counters */ 62 63 row = numRows; 64 px = pDstVec; 65 66 /* 67 * compute 3x64-bit accumulators per loop 68 */ 69 while (row >= 3) 70 { 71 q31_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec; 72 const q31_t *pMat2; 73 q31_t const *pSrcVecPtr = pSrcVec; 74 q63_t acc0, acc1, acc2; 75 q31x4_t vecMatA0, vecMatA1, vecMatA2, vecIn; 76 77 78 pVec = pSrcVec; 79 /* 80 * Initialize the pointer pIn1 to point to the starting address of the column being processed 81 */ 82 pMat0 = pMatSrc; 83 pMat1 = pMat0 + numCols; 84 pMat2 = pMat1 + numCols; 85 86 acc0 = 0LL; 87 acc1 = 0LL; 88 acc2 = 0LL; 89 90 pMat0Vec = pMat0; 91 pMat1Vec = pMat1; 92 pMat2Vec = pMat2; 93 pVec = pSrcVecPtr; 94 95 blkCnt = numCols >> 2; 96 while (blkCnt > 0U) 97 { 98 vecMatA0 = vld1q(pMat0Vec); 99 pMat0Vec += 4; 100 vecMatA1 = vld1q(pMat1Vec); 101 pMat1Vec += 4; 102 vecMatA2 = vld1q(pMat2Vec); 103 pMat2Vec += 4; 104 vecIn = vld1q(pVec); 105 pVec += 4; 106 107 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 108 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1); 109 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2); 110 111 blkCnt--; 112 } 113 /* 114 * tail 115 * (will be merged thru tail predication) 116 */ 117 blkCnt = numCols & 3; 118 if (blkCnt > 0U) 119 { 120 mve_pred16_t p0 = vctp32q(blkCnt); 121 122 vecMatA0 = vld1q(pMat0Vec); 123 vecMatA1 = vld1q(pMat1Vec); 124 vecMatA2 = vld1q(pMat2Vec); 125 vecIn = vldrwq_z_s32(pVec, p0); 126 127 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 128 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1); 129 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2); 130 } 131 132 *px++ = asrl(acc0, 31); 133 *px++ = asrl(acc1, 31); 134 *px++ = asrl(acc2, 31); 135 136 pMatSrc += numCols * 3; 137 /* 138 * Decrement the row loop counter 139 */ 140 row -= 3; 141 } 142 143 /* 144 * process any remaining rows pair 145 */ 146 if (row >= 2) 147 { 148 q31_t const *pMat0Vec, *pMat1Vec, *pVec; 149 q31_t const *pSrcVecPtr = pSrcVec; 150 q63_t acc0, acc1; 151 q31x4_t vecMatA0, vecMatA1, vecIn; 152 153 /* 154 * For every row wise process, the pInVec pointer is set 155 * to the starting address of the vector 156 */ 157 pVec = pSrcVec; 158 159 /* 160 * Initialize the pointer pIn1 to point to the starting address of the column being processed 161 */ 162 pMat0 = pMatSrc; 163 pMat1 = pMat0 + numCols; 164 165 acc0 = 0LL; 166 acc1 = 0LL; 167 168 pMat0Vec = pMat0; 169 pMat1Vec = pMat1; 170 pVec = pSrcVecPtr; 171 172 blkCnt = numCols >> 2; 173 while (blkCnt > 0U) 174 { 175 vecMatA0 = vld1q(pMat0Vec); 176 pMat0Vec += 4; 177 vecMatA1 = vld1q(pMat1Vec); 178 pMat1Vec += 4; 179 vecIn = vld1q(pVec); 180 pVec += 4; 181 182 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 183 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1); 184 185 blkCnt--; 186 } 187 188 /* 189 * tail 190 * (will be merged thru tail predication) 191 */ 192 blkCnt = numCols & 3; 193 if (blkCnt > 0U) 194 { 195 mve_pred16_t p0 = vctp32q(blkCnt); 196 197 vecMatA0 = vld1q(pMat0Vec); 198 vecMatA1 = vld1q(pMat1Vec); 199 vecIn = vldrwq_z_s32(pVec, p0); 200 201 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 202 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1); 203 } 204 205 *px++ = asrl(acc0, 31); 206 *px++ = asrl(acc1, 31); 207 208 pMatSrc += numCols * 2; 209 /* 210 * Decrement the row loop counter 211 */ 212 row -= 2; 213 } 214 215 if (row >= 1) 216 { 217 q31_t const *pMat0Vec, *pVec; 218 q31_t const *pSrcVecPtr = pSrcVec; 219 q63_t acc0; 220 q31x4_t vecMatA0, vecIn; 221 222 /* 223 * For every row wise process, the pInVec pointer is set 224 * to the starting address of the vector 225 */ 226 pVec = pSrcVec; 227 228 /* 229 * Initialize the pointer pIn1 to point to the starting address of the column being processed 230 */ 231 pMat0 = pMatSrc; 232 233 acc0 = 0LL; 234 235 pMat0Vec = pMat0; 236 pVec = pSrcVecPtr; 237 238 blkCnt = numCols >> 2; 239 while (blkCnt > 0U) 240 { 241 vecMatA0 = vld1q(pMat0Vec); 242 pMat0Vec += 4; 243 vecIn = vld1q(pVec); 244 pVec += 4; 245 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 246 blkCnt--; 247 } 248 /* 249 * tail 250 * (will be merged thru tail predication) 251 */ 252 blkCnt = numCols & 3; 253 if (blkCnt > 0U) 254 { 255 mve_pred16_t p0 = vctp32q(blkCnt); 256 257 vecMatA0 = vld1q(pMat0Vec); 258 vecIn = vldrwq_z_s32(pVec, p0); 259 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0); 260 } 261 262 *px++ = asrl(acc0, 31); 263 } 264 } 265 #else 266 void arm_mat_vec_mult_q31(const arm_matrix_instance_q31 *pSrcMat, const q31_t *pVec, q31_t *pDst) 267 { 268 uint32_t numRows = pSrcMat->numRows; 269 uint32_t numCols = pSrcMat->numCols; 270 const q31_t *pSrcA = pSrcMat->pData; 271 const q31_t *pInA1; /* input data matrix pointer A of Q31 type */ 272 const q31_t *pInA2; /* input data matrix pointer A of Q31 type */ 273 const q31_t *pInA3; /* input data matrix pointer A of Q31 type */ 274 const q31_t *pInA4; /* input data matrix pointer A of Q31 type */ 275 const q31_t *pInVec; /* input data matrix pointer B of Q31 type */ 276 q31_t *px; /* Temporary output data matrix pointer */ 277 uint16_t i, row, colCnt; /* loop counters */ 278 q31_t matData, matData2, vecData, vecData2; 279 280 281 /* Process 4 rows at a time */ 282 row = numRows >> 2; 283 i = 0u; 284 px = pDst; 285 286 /* The following loop performs the dot-product of each row in pSrcA with the vector */ 287 /* row loop */ 288 while (row > 0) { 289 /* Initialize accumulators */ 290 q63_t sum1 = 0; 291 q63_t sum2 = 0; 292 q63_t sum3 = 0; 293 q63_t sum4 = 0; 294 295 /* For every row wise process, the pInVec pointer is set 296 ** to the starting address of the vector */ 297 pInVec = pVec; 298 299 /* Loop unrolling: process 2 columns per iteration */ 300 colCnt = numCols; 301 302 /* Initialize pointers to the starting address of the column being processed */ 303 pInA1 = pSrcA + i; 304 pInA2 = pInA1 + numCols; 305 pInA3 = pInA2 + numCols; 306 pInA4 = pInA3 + numCols; 307 308 309 // Main loop: matrix-vector multiplication 310 while (colCnt > 0u) { 311 // Read 2 values from vector 312 vecData = *(pInVec)++; 313 314 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate 315 matData = *(pInA1)++; 316 sum1 += (q63_t)matData * vecData; 317 matData = *(pInA2)++; 318 sum2 += (q63_t)matData * vecData; 319 matData = *(pInA3)++; 320 sum3 += (q63_t)matData * vecData; 321 matData = *(pInA4)++; 322 sum4 += (q63_t)matData * vecData; 323 324 // Decrement the loop counter 325 colCnt--; 326 } 327 328 /* Saturate and store the result in the destination buffer */ 329 *px++ = (q31_t)(sum1 >> 31); 330 *px++ = (q31_t)(sum2 >> 31); 331 *px++ = (q31_t)(sum3 >> 31); 332 *px++ = (q31_t)(sum4 >> 31); 333 334 i = i + numCols * 4; 335 336 /* Decrement the row loop counter */ 337 row--; 338 } 339 340 /* process any remaining rows */ 341 row = numRows & 3u; 342 while (row > 0) { 343 344 q63_t sum = 0; 345 pInVec = pVec; 346 pInA1 = pSrcA + i; 347 348 colCnt = numCols >> 1; 349 350 while (colCnt > 0) { 351 vecData = *(pInVec)++; 352 vecData2 = *(pInVec)++; 353 matData = *(pInA1)++; 354 matData2 = *(pInA1)++; 355 sum += (q63_t)matData * vecData; 356 sum += (q63_t)matData2 * vecData2; 357 colCnt--; 358 } 359 360 // process remainder of row 361 colCnt = numCols & 1u; 362 while (colCnt > 0) { 363 sum += (q63_t)*pInA1++ * *pInVec++; 364 colCnt--; 365 } 366 367 *px++ = (q31_t)(sum >> 31); 368 i = i + numCols; 369 row--; 370 } 371 } 372 #endif /* defined(ARM_MATH_MVEI) */ 373 374 /** 375 * @} end of MatrixMult group 376 */