arm_mat_vec_mult_f32.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_vec_mult_f32.c 4 * Description: Floating-point matrix and vector multiplication 5 * 6 * $Date: 23 April 2021 7 * 8 * $Revision: V1.9.0 9 * 10 * Target Processor: Cortex-M and Cortex-A cores 11 * -------------------------------------------------------------------- */ 12 /* 13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 14 * 15 * SPDX-License-Identifier: Apache-2.0 16 * 17 * Licensed under the Apache License, Version 2.0 (the License); you may 18 * not use this file except in compliance with the License. 19 * You may obtain a copy of the License at 20 * 21 * www.apache.org/licenses/LICENSE-2.0 22 * 23 * Unless required by applicable law or agreed to in writing, software 24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 * See the License for the specific language governing permissions and 27 * limitations under the License. 28 */ 29 30 #include "dsp/matrix_functions.h" 31 32 33 /** 34 * @ingroup groupMatrix 35 */ 36 37 /** 38 * @defgroup MatrixVectMult Matrix Vector Multiplication 39 * 40 * Multiplies a matrix and a vector. 41 * 42 */ 43 44 /** 45 * @addtogroup MatrixVectMult 46 * @{ 47 */ 48 49 /** 50 * @brief Floating-point matrix and vector multiplication. 51 * @param[in] *pSrcMat points to the input matrix structure 52 * @param[in] *pVec points to input vector 53 * @param[out] *pDst points to output vector 54 */ 55 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) 56 57 #include "arm_helium_utils.h" 58 59 void arm_mat_vec_mult_f32( 60 const arm_matrix_instance_f32 *pSrcMat, 61 const float32_t *pSrcVec, 62 float32_t *pDstVec) 63 { 64 uint32_t numRows = pSrcMat->numRows; 65 uint32_t numCols = pSrcMat->numCols; 66 const float32_t *pSrcA = pSrcMat->pData; 67 const float32_t *pInA0; 68 const float32_t *pInA1; 69 float32_t *px; 70 int32_t row; 71 uint32_t blkCnt; /* loop counters */ 72 73 row = numRows; 74 px = pDstVec; 75 76 /* 77 * compute 4 rows in parallel 78 */ 79 while (row >= 4) 80 { 81 const float32_t *pInA2, *pInA3; 82 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec; 83 f32x4_t vecIn, acc0, acc1, acc2, acc3; 84 float32_t const *pSrcVecPtr = pSrcVec; 85 86 /* 87 * Initialize the pointers to 4 consecutive MatrixA rows 88 */ 89 pInA0 = pSrcA; 90 pInA1 = pInA0 + numCols; 91 pInA2 = pInA1 + numCols; 92 pInA3 = pInA2 + numCols; 93 /* 94 * Initialize the vector pointer 95 */ 96 pInVec = pSrcVecPtr; 97 /* 98 * reset accumulators 99 */ 100 acc0 = vdupq_n_f32(0.0f); 101 acc1 = vdupq_n_f32(0.0f); 102 acc2 = vdupq_n_f32(0.0f); 103 acc3 = vdupq_n_f32(0.0f); 104 105 pSrcA0Vec = pInA0; 106 pSrcA1Vec = pInA1; 107 pSrcA2Vec = pInA2; 108 pSrcA3Vec = pInA3; 109 110 blkCnt = numCols >> 2; 111 while (blkCnt > 0U) 112 { 113 f32x4_t vecA; 114 115 vecIn = vld1q(pInVec); 116 pInVec += 4; 117 vecA = vld1q(pSrcA0Vec); 118 pSrcA0Vec += 4; 119 acc0 = vfmaq(acc0, vecIn, vecA); 120 vecA = vld1q(pSrcA1Vec); 121 pSrcA1Vec += 4; 122 acc1 = vfmaq(acc1, vecIn, vecA); 123 vecA = vld1q(pSrcA2Vec); 124 pSrcA2Vec += 4; 125 acc2 = vfmaq(acc2, vecIn, vecA); 126 vecA = vld1q(pSrcA3Vec); 127 pSrcA3Vec += 4; 128 acc3 = vfmaq(acc3, vecIn, vecA); 129 130 blkCnt--; 131 } 132 /* 133 * tail 134 * (will be merged thru tail predication) 135 */ 136 blkCnt = numCols & 3; 137 if (blkCnt > 0U) 138 { 139 mve_pred16_t p0 = vctp32q(blkCnt); 140 f32x4_t vecA; 141 142 vecIn = vldrwq_z_f32(pInVec, p0); 143 vecA = vld1q(pSrcA0Vec); 144 acc0 = vfmaq(acc0, vecIn, vecA); 145 vecA = vld1q(pSrcA1Vec); 146 acc1 = vfmaq(acc1, vecIn, vecA); 147 vecA = vld1q(pSrcA2Vec); 148 acc2 = vfmaq(acc2, vecIn, vecA); 149 vecA = vld1q(pSrcA3Vec); 150 acc3 = vfmaq(acc3, vecIn, vecA); 151 } 152 /* 153 * Sum the partial parts 154 */ 155 *px++ = vecAddAcrossF32Mve(acc0); 156 *px++ = vecAddAcrossF32Mve(acc1); 157 *px++ = vecAddAcrossF32Mve(acc2); 158 *px++ = vecAddAcrossF32Mve(acc3); 159 160 pSrcA += numCols * 4; 161 /* 162 * Decrement the row loop counter 163 */ 164 row -= 4; 165 } 166 167 /* 168 * compute 2 rows in parrallel 169 */ 170 if (row >= 2) 171 { 172 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec; 173 f32x4_t vecIn, acc0, acc1; 174 float32_t const *pSrcVecPtr = pSrcVec; 175 176 /* 177 * Initialize the pointers to 2 consecutive MatrixA rows 178 */ 179 pInA0 = pSrcA; 180 pInA1 = pInA0 + numCols; 181 /* 182 * Initialize the vector pointer 183 */ 184 pInVec = pSrcVecPtr; 185 /* 186 * reset accumulators 187 */ 188 acc0 = vdupq_n_f32(0.0f); 189 acc1 = vdupq_n_f32(0.0f); 190 pSrcA0Vec = pInA0; 191 pSrcA1Vec = pInA1; 192 193 blkCnt = numCols >> 2; 194 while (blkCnt > 0U) 195 { 196 f32x4_t vecA; 197 198 vecIn = vld1q(pInVec); 199 pInVec += 4; 200 vecA = vld1q(pSrcA0Vec); 201 pSrcA0Vec += 4; 202 acc0 = vfmaq(acc0, vecIn, vecA); 203 vecA = vld1q(pSrcA1Vec); 204 pSrcA1Vec += 4; 205 acc1 = vfmaq(acc1, vecIn, vecA); 206 207 blkCnt--; 208 } 209 /* 210 * tail 211 * (will be merged thru tail predication) 212 */ 213 blkCnt = numCols & 3; 214 if (blkCnt > 0U) 215 { 216 mve_pred16_t p0 = vctp32q(blkCnt); 217 f32x4_t vecA; 218 219 vecIn = vldrwq_z_f32(pInVec, p0); 220 vecA = vld1q(pSrcA0Vec); 221 acc0 = vfmaq(acc0, vecIn, vecA); 222 vecA = vld1q(pSrcA1Vec); 223 acc1 = vfmaq(acc1, vecIn, vecA); 224 } 225 /* 226 * Sum the partial parts 227 */ 228 *px++ = vecAddAcrossF32Mve(acc0); 229 *px++ = vecAddAcrossF32Mve(acc1); 230 231 pSrcA += numCols * 2; 232 row -= 2; 233 } 234 235 if (row >= 1) 236 { 237 f32x4_t vecIn, acc0; 238 float32_t const *pSrcA0Vec, *pInVec; 239 float32_t const *pSrcVecPtr = pSrcVec; 240 /* 241 * Initialize the pointers to last MatrixA row 242 */ 243 pInA0 = pSrcA; 244 /* 245 * Initialize the vector pointer 246 */ 247 pInVec = pSrcVecPtr; 248 /* 249 * reset accumulators 250 */ 251 acc0 = vdupq_n_f32(0.0f); 252 253 pSrcA0Vec = pInA0; 254 255 blkCnt = numCols >> 2; 256 while (blkCnt > 0U) 257 { 258 f32x4_t vecA; 259 260 vecIn = vld1q(pInVec); 261 pInVec += 4; 262 vecA = vld1q(pSrcA0Vec); 263 pSrcA0Vec += 4; 264 acc0 = vfmaq(acc0, vecIn, vecA); 265 266 blkCnt--; 267 } 268 /* 269 * tail 270 * (will be merged thru tail predication) 271 */ 272 blkCnt = numCols & 3; 273 if (blkCnt > 0U) 274 { 275 mve_pred16_t p0 = vctp32q(blkCnt); 276 f32x4_t vecA; 277 278 vecIn = vldrwq_z_f32(pInVec, p0); 279 vecA = vld1q(pSrcA0Vec); 280 acc0 = vfmaq(acc0, vecIn, vecA); 281 } 282 /* 283 * Sum the partial parts 284 */ 285 *px++ = vecAddAcrossF32Mve(acc0); 286 } 287 } 288 #else 289 290 void arm_mat_vec_mult_f32(const arm_matrix_instance_f32 *pSrcMat, const float32_t *pVec, float32_t *pDst) 291 { 292 uint32_t numRows = pSrcMat->numRows; 293 uint32_t numCols = pSrcMat->numCols; 294 const float32_t *pSrcA = pSrcMat->pData; 295 const float32_t *pInA1; /* input data matrix pointer A of Q31 type */ 296 const float32_t *pInA2; /* input data matrix pointer A of Q31 type */ 297 const float32_t *pInA3; /* input data matrix pointer A of Q31 type */ 298 const float32_t *pInA4; /* input data matrix pointer A of Q31 type */ 299 const float32_t *pInVec; /* input data matrix pointer B of Q31 type */ 300 float32_t *px; /* Temporary output data matrix pointer */ 301 uint16_t i, row, colCnt; /* loop counters */ 302 float32_t matData, matData2, vecData, vecData2; 303 304 305 /* Process 4 rows at a time */ 306 row = numRows >> 2; 307 i = 0u; 308 px = pDst; 309 310 /* The following loop performs the dot-product of each row in pSrcA with the vector */ 311 /* row loop */ 312 while (row > 0) { 313 /* Initialize accumulators */ 314 float32_t sum1 = 0.0f; 315 float32_t sum2 = 0.0f; 316 float32_t sum3 = 0.0f; 317 float32_t sum4 = 0.0f; 318 319 /* For every row wise process, the pInVec pointer is set 320 ** to the starting address of the vector */ 321 pInVec = pVec; 322 323 /* Loop unrolling: process 2 columns per iteration */ 324 colCnt = numCols; 325 326 /* Initialize pointers to the starting address of the column being processed */ 327 pInA1 = pSrcA + i; 328 pInA2 = pInA1 + numCols; 329 pInA3 = pInA2 + numCols; 330 pInA4 = pInA3 + numCols; 331 332 333 // Main loop: matrix-vector multiplication 334 while (colCnt > 0u) { 335 // Read 2 values from vector 336 vecData = *(pInVec)++; 337 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate 338 matData = *(pInA1)++; 339 sum1 += matData * vecData; 340 matData = *(pInA2)++; 341 sum2 += matData * vecData; 342 matData = *(pInA3)++; 343 sum3 += matData * vecData; 344 matData = *(pInA4)++; 345 sum4 += matData * vecData; 346 347 // Decrement the loop counter 348 colCnt--; 349 } 350 351 /* Saturate and store the result in the destination buffer */ 352 *px++ = sum1; 353 *px++ = sum2; 354 *px++ = sum3; 355 *px++ = sum4; 356 357 i = i + numCols * 4; 358 359 /* Decrement the row loop counter */ 360 row--; 361 } 362 363 /* process any remaining rows */ 364 row = numRows & 3u; 365 while (row > 0) { 366 367 float32_t sum = 0.0f; 368 pInVec = pVec; 369 pInA1 = pSrcA + i; 370 371 colCnt = numCols >> 1; 372 while (colCnt > 0) { 373 vecData = *(pInVec)++; 374 vecData2 = *(pInVec)++; 375 matData = *(pInA1)++; 376 matData2 = *(pInA1)++; 377 sum += matData * vecData; 378 sum += matData2 * vecData2; 379 colCnt--; 380 } 381 // process remainder of row 382 colCnt = numCols & 1u; 383 384 385 while (colCnt > 0) { 386 sum += *pInA1++ * *pInVec++; 387 colCnt--; 388 } 389 390 *px++ = sum; 391 i = i + numCols; 392 row--; 393 } 394 } 395 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 396 397 /** 398 * @} end of MatrixMult group 399 */