arm_mat_vec_mult_q7.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_vec_mult_q7.c 4 * Description: Q7 matrix and vector multiplication 5 * 6 * $Date: 23 April 2021 7 * 8 * $Revision: V1.9.0 9 * 10 * Target Processor: Cortex-M and Cortex-A cores 11 * -------------------------------------------------------------------- */ 12 /* 13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 14 * 15 * SPDX-License-Identifier: Apache-2.0 16 * 17 * Licensed under the Apache License, Version 2.0 (the License); you may 18 * not use this file except in compliance with the License. 19 * You may obtain a copy of the License at 20 * 21 * www.apache.org/licenses/LICENSE-2.0 22 * 23 * Unless required by applicable law or agreed to in writing, software 24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 * See the License for the specific language governing permissions and 27 * limitations under the License. 28 */ 29 30 #include "dsp/matrix_functions.h" 31 32 /** 33 * @ingroup groupMatrix 34 */ 35 36 37 38 /** 39 * @addtogroup MatrixVectMult 40 * @{ 41 */ 42 43 /** 44 * @brief Q7 matrix and vector multiplication. 45 * @param[in] *pSrcMat points to the input matrix structure 46 * @param[in] *pVec points to the input vector 47 * @param[out] *pDst points to the output vector 48 */ 49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE) 50 51 #include "arm_helium_utils.h" 52 53 void arm_mat_vec_mult_q7( 54 const arm_matrix_instance_q7 * pSrcMat, 55 const q7_t *pSrcVec, 56 q7_t *pDstVec) 57 { 58 const q7_t *pMatSrc = pSrcMat->pData; 59 const q7_t *pMat0, *pMat1; 60 uint32_t numRows = pSrcMat->numRows; 61 uint32_t numCols = pSrcMat->numCols; 62 q7_t *px; 63 int32_t row; 64 uint16_t blkCnt; /* loop counters */ 65 66 row = numRows; 67 px = pDstVec; 68 69 /* 70 * compute 4x64-bit accumulators per loop 71 */ 72 while (row >= 4) 73 { 74 q7_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pMat3Vec, *pVec; 75 const q7_t *pMat2, *pMat3; 76 q7_t const *pSrcVecPtr = pSrcVec; 77 q31_t acc0, acc1, acc2, acc3; 78 q7x16_t vecMatA0, vecMatA1, vecMatA2, vecMatA3, vecIn; 79 80 pVec = pSrcVec; 81 /* 82 * Initialize the pointer pIn1 to point to the starting address of the column being processed 83 */ 84 pMat0 = pMatSrc; 85 pMat1 = pMat0 + numCols; 86 pMat2 = pMat1 + numCols; 87 pMat3 = pMat2 + numCols; 88 89 acc0 = 0L; 90 acc1 = 0L; 91 acc2 = 0L; 92 acc3 = 0L; 93 94 pMat0Vec = pMat0; 95 pMat1Vec = pMat1; 96 pMat2Vec = pMat2; 97 pMat3Vec = pMat3; 98 pVec = pSrcVecPtr; 99 100 blkCnt = numCols >> 4; 101 while (blkCnt > 0U) 102 { 103 104 vecMatA0 = vld1q(pMat0Vec); 105 pMat0Vec += 16; 106 vecMatA1 = vld1q(pMat1Vec); 107 pMat1Vec += 16; 108 vecMatA2 = vld1q(pMat2Vec); 109 pMat2Vec += 16; 110 vecMatA3 = vld1q(pMat3Vec); 111 pMat3Vec += 16; 112 vecIn = vld1q(pVec); 113 pVec += 16; 114 115 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 116 acc1 = vmladavaq(acc1, vecIn, vecMatA1); 117 acc2 = vmladavaq(acc2, vecIn, vecMatA2); 118 acc3 = vmladavaq(acc3, vecIn, vecMatA3); 119 120 blkCnt--; 121 } 122 /* 123 * tail 124 * (will be merged thru tail predication) 125 */ 126 blkCnt = numCols & 0xF; 127 if (blkCnt > 0U) 128 { 129 mve_pred16_t p0 = vctp8q(blkCnt); 130 131 vecMatA0 = vld1q(pMat0Vec); 132 vecMatA1 = vld1q(pMat1Vec); 133 vecMatA2 = vld1q(pMat2Vec); 134 vecMatA3 = vld1q(pMat3Vec); 135 vecIn = vldrbq_z_s8(pVec, p0); 136 137 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 138 acc1 = vmladavaq(acc1, vecIn, vecMatA1); 139 acc2 = vmladavaq(acc2, vecIn, vecMatA2); 140 acc3 = vmladavaq(acc3, vecIn, vecMatA3); 141 } 142 143 *px++ = __SSAT(acc0 >> 7, 8); 144 *px++ = __SSAT(acc1 >> 7, 8); 145 *px++ = __SSAT(acc2 >> 7, 8); 146 *px++ = __SSAT(acc3 >> 7, 8); 147 148 pMatSrc += numCols * 4; 149 /* 150 * Decrement the row loop counter 151 */ 152 row -= 4; 153 } 154 155 /* 156 * process any remaining rows pair 157 */ 158 if (row >= 2) 159 { 160 q7_t const *pMat0Vec, *pMat1Vec, *pVec; 161 q7_t const *pSrcVecPtr = pSrcVec; 162 q31_t acc0, acc1; 163 q7x16_t vecMatA0, vecMatA1, vecIn; 164 165 /* 166 * For every row wise process, the pInVec pointer is set 167 * to the starting address of the vector 168 */ 169 pVec = pSrcVec; 170 171 /* 172 * Initialize the pointer pIn1 to point to the starting address of the column being processed 173 */ 174 pMat0 = pMatSrc; 175 pMat1 = pMat0 + numCols; 176 177 acc0 = 0; 178 acc1 = 0; 179 180 pMat0Vec = pMat0; 181 pMat1Vec = pMat1; 182 pVec = pSrcVecPtr; 183 184 blkCnt = numCols >> 4; 185 while (blkCnt > 0U) 186 { 187 vecMatA0 = vld1q(pMat0Vec); 188 pMat0Vec += 16; 189 vecMatA1 = vld1q(pMat1Vec); 190 pMat1Vec += 16; 191 vecIn = vld1q(pVec); 192 pVec += 16; 193 194 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 195 acc1 = vmladavaq(acc1, vecIn, vecMatA1); 196 197 blkCnt--; 198 } 199 200 /* 201 * tail 202 * (will be merged thru tail predication) 203 */ 204 blkCnt = numCols & 0xF; 205 if (blkCnt > 0U) 206 { 207 mve_pred16_t p0 = vctp8q(blkCnt); 208 209 vecMatA0 = vld1q(pMat0Vec); 210 vecMatA1 = vld1q(pMat1Vec); 211 vecIn = vldrbq_z_s8(pVec, p0); 212 213 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 214 acc1 = vmladavaq(acc1, vecIn, vecMatA1); 215 } 216 217 *px++ = __SSAT(acc0 >> 7, 8); 218 *px++ = __SSAT(acc1 >> 7, 8); 219 220 pMatSrc += numCols * 2; 221 /* 222 * Decrement the row loop counter 223 */ 224 row -= 2; 225 } 226 227 if (row >= 1) 228 { 229 q7_t const *pMat0Vec, *pVec; 230 q7_t const *pSrcVecPtr = pSrcVec; 231 q31_t acc0; 232 q7x16_t vecMatA0, vecIn; 233 234 /* 235 * For every row wise process, the pInVec pointer is set 236 * to the starting address of the vector 237 */ 238 pVec = pSrcVec; 239 240 /* 241 * Initialize the pointer pIn1 to point to the starting address of the column being processed 242 */ 243 pMat0 = pMatSrc; 244 245 acc0 = 0LL; 246 247 pMat0Vec = pMat0; 248 pVec = pSrcVecPtr; 249 250 blkCnt = numCols >> 4; 251 while (blkCnt > 0U) 252 { 253 vecMatA0 = vld1q(pMat0Vec); 254 pMat0Vec += 16; 255 vecIn = vld1q(pVec); 256 pVec += 16; 257 258 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 259 blkCnt--; 260 } 261 /* 262 * tail 263 * (will be merged thru tail predication) 264 */ 265 blkCnt = numCols & 0xF; 266 if (blkCnt > 0U) 267 { 268 mve_pred16_t p0 = vctp8q(blkCnt); 269 270 vecMatA0 = vld1q(pMat0Vec); 271 vecIn = vldrbq_z_s8(pVec, p0); 272 acc0 = vmladavaq(acc0, vecIn, vecMatA0); 273 } 274 *px++ = __SSAT(acc0 >> 7, 8); 275 } 276 } 277 278 #else 279 void arm_mat_vec_mult_q7(const arm_matrix_instance_q7 *pSrcMat, const q7_t *pVec, q7_t *pDst) 280 { 281 uint32_t numRows = pSrcMat->numRows; 282 uint32_t numCols = pSrcMat->numCols; 283 const q7_t *pSrcA = pSrcMat->pData; 284 const q7_t *pInA1; /* input data matrix pointer of Q7 type */ 285 const q7_t *pInA2; /* input data matrix pointer of Q7 type */ 286 const q7_t *pInA3; /* input data matrix pointer of Q7 type */ 287 const q7_t *pInA4; /* input data matrix pointer of Q7 type */ 288 const q7_t *pInVec; /* input data vector pointer of Q7 type */ 289 q7_t *px; /* output data pointer */ 290 uint32_t i, row, colCnt; /* loop counters */ 291 292 q31_t matData, matData2, vecData, vecData2; 293 294 295 /* Process 4 rows at a time */ 296 row = numRows >> 2; 297 i = 0u; 298 px = pDst; 299 300 301 302 /* The following loop performs the dot-product of each row in pSrcA with the vector */ 303 while (row > 0) { 304 /* Initialize accumulators */ 305 q31_t sum1 = 0; 306 q31_t sum2 = 0; 307 q31_t sum3 = 0; 308 q31_t sum4 = 0; 309 310 /* For every row wise process, the pInVec pointer is set 311 ** to the starting address of the vector */ 312 pInVec = pVec; 313 314 /* Loop unrolling: process 4 columns per iteration */ 315 colCnt = numCols >> 2; 316 317 /* Initialize row pointers so we can track 4 rows at once */ 318 pInA1 = pSrcA + i; 319 pInA2 = pInA1 + numCols; 320 pInA3 = pInA2 + numCols; 321 pInA4 = pInA3 + numCols; 322 323 324 // Inner loop: matrix-vector multiplication 325 326 while (colCnt > 0u) { 327 // Read 4 values from vector 328 vecData = read_q7x4_ia (&pInVec); 329 vecData2 = __SXTB16(__ROR(vecData, 8)); 330 vecData = __SXTB16(vecData); 331 // Read 16 values from the matrix - 4 values from each of 4 rows, and do multiply accumulate 332 matData = read_q7x4_ia (&pInA1); 333 matData2 = __SXTB16(__ROR(matData, 8)); 334 matData = __SXTB16(matData); 335 sum1 = __SMLAD(matData, vecData, sum1); 336 sum1 = __SMLAD(matData2, vecData2, sum1); 337 matData = read_q7x4_ia (&pInA2); 338 matData2 = __SXTB16(__ROR(matData, 8)); 339 matData = __SXTB16(matData); 340 sum2 = __SMLAD(matData, vecData, sum2); 341 sum2 = __SMLAD(matData2, vecData2, sum2); 342 matData = read_q7x4_ia (&pInA3); 343 matData2 = __SXTB16(__ROR(matData, 8)); 344 matData = __SXTB16(matData); 345 sum3 = __SMLAD(matData, vecData, sum3); 346 sum3 = __SMLAD(matData2, vecData2, sum3); 347 matData = read_q7x4_ia (&pInA4); 348 matData2 = __SXTB16(__ROR(matData, 8)); 349 matData = __SXTB16(matData); 350 sum4 = __SMLAD(matData, vecData, sum4); 351 sum4 = __SMLAD(matData2, vecData2, sum4); 352 353 // Decrement the loop counter 354 colCnt--; 355 } 356 357 /* process any remaining columns */ 358 359 colCnt = numCols & 3u; 360 361 while (colCnt > 0) { 362 vecData = *pInVec++; 363 sum1 += *pInA1++ * vecData; 364 sum2 += *pInA2++ * vecData; 365 sum3 += *pInA3++ * vecData; 366 sum4 += *pInA4++ * vecData; 367 colCnt--; 368 } 369 370 /* Saturate and store the result in the destination buffer */ 371 *px++ = (q7_t)(__SSAT((sum1 >> 7), 8)); 372 *px++ = (q7_t)(__SSAT((sum2 >> 7), 8)); 373 *px++ = (q7_t)(__SSAT((sum3 >> 7), 8)); 374 *px++ = (q7_t)(__SSAT((sum4 >> 7), 8)); 375 376 i = i + numCols * 4; 377 378 /* Decrement the row loop counter */ 379 row--; 380 } 381 382 /* process any remaining rows */ 383 row = numRows & 3u; 384 while (row > 0) { 385 386 q31_t sum = 0; 387 pInVec = pVec; 388 pInA1 = pSrcA + i; 389 390 // loop unrolling - process 4 elements at a time 391 colCnt = numCols >> 2; 392 393 while (colCnt > 0) { 394 vecData = read_q7x4_ia (&pInVec); 395 vecData2 = __SXTB16(__ROR(vecData, 8)); 396 vecData = __SXTB16(vecData); 397 matData = read_q7x4_ia (&pInA1); 398 matData2 = __SXTB16(__ROR(matData, 8)); 399 matData = __SXTB16(matData); 400 sum = __SMLAD(matData, vecData, sum); 401 sum = __SMLAD(matData2, vecData2, sum); 402 colCnt--; 403 } 404 405 // process remainder of row 406 colCnt = numCols & 3u; 407 while (colCnt > 0) { 408 sum += *pInA1++ * *pInVec++; 409 colCnt--; 410 } 411 *px++ = (q7_t)(__SSAT((sum >> 7), 8)); 412 i = i + numCols; 413 row--; 414 } 415 } 416 #endif /* defined(ARM_MATH_MVEI) */ 417 418 /** 419 * @} end of MatrixMult group 420 */