arm_mat_vec_mult_f16.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_vec_mult_f16.c 4 * Description: Floating-point matrix and vector multiplication 5 * 6 * $Date: 23 April 2021 7 * 8 * $Revision: V1.9.0 9 * 10 * Target Processor: Cortex-M and Cortex-A cores 11 * -------------------------------------------------------------------- */ 12 /* 13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 14 * 15 * SPDX-License-Identifier: Apache-2.0 16 * 17 * Licensed under the Apache License, Version 2.0 (the License); you may 18 * not use this file except in compliance with the License. 19 * You may obtain a copy of the License at 20 * 21 * www.apache.org/licenses/LICENSE-2.0 22 * 23 * Unless required by applicable law or agreed to in writing, software 24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 * See the License for the specific language governing permissions and 27 * limitations under the License. 28 */ 29 30 #include "dsp/matrix_functions_f16.h" 31 32 #if defined(ARM_FLOAT16_SUPPORTED) 33 34 35 /** 36 * @ingroup groupMatrix 37 */ 38 39 40 /** 41 * @addtogroup MatrixVectMult 42 * @{ 43 */ 44 45 /** 46 * @brief Floating-point matrix and vector multiplication. 47 * @param[in] *pSrcMat points to the input matrix structure 48 * @param[in] *pVec points to input vector 49 * @param[out] *pDst points to output vector 50 */ 51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) 52 53 #include "arm_helium_utils.h" 54 55 void arm_mat_vec_mult_f16( 56 const arm_matrix_instance_f16 *pSrcMat, 57 const float16_t *pSrcVec, 58 float16_t *pDstVec) 59 { 60 uint32_t numRows = pSrcMat->numRows; 61 uint32_t numCols = pSrcMat->numCols; 62 const float16_t *pSrcA = pSrcMat->pData; 63 const float16_t *pInA0; 64 const float16_t *pInA1; 65 float16_t *px; 66 int32_t row; 67 uint32_t blkCnt; /* loop counters */ 68 69 row = numRows; 70 px = pDstVec; 71 72 /* 73 * compute 4 rows in parallel 74 */ 75 while (row >= 4) 76 { 77 const float16_t *pInA2, *pInA3; 78 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec; 79 f16x8_t vecIn, acc0, acc1, acc2, acc3; 80 float16_t const *pSrcVecPtr = pSrcVec; 81 82 /* 83 * Initialize the pointers to 4 consecutive MatrixA rows 84 */ 85 pInA0 = pSrcA; 86 pInA1 = pInA0 + numCols; 87 pInA2 = pInA1 + numCols; 88 pInA3 = pInA2 + numCols; 89 /* 90 * Initialize the vector pointer 91 */ 92 pInVec = pSrcVecPtr; 93 /* 94 * reset accumulators 95 */ 96 acc0 = vdupq_n_f16(0.0f); 97 acc1 = vdupq_n_f16(0.0f); 98 acc2 = vdupq_n_f16(0.0f); 99 acc3 = vdupq_n_f16(0.0f); 100 101 pSrcA0Vec = pInA0; 102 pSrcA1Vec = pInA1; 103 pSrcA2Vec = pInA2; 104 pSrcA3Vec = pInA3; 105 106 blkCnt = numCols >> 3; 107 while (blkCnt > 0U) 108 { 109 f16x8_t vecA; 110 111 vecIn = vld1q(pInVec); 112 pInVec += 8; 113 vecA = vld1q(pSrcA0Vec); 114 pSrcA0Vec += 8; 115 acc0 = vfmaq(acc0, vecIn, vecA); 116 vecA = vld1q(pSrcA1Vec); 117 pSrcA1Vec += 8; 118 acc1 = vfmaq(acc1, vecIn, vecA); 119 vecA = vld1q(pSrcA2Vec); 120 pSrcA2Vec += 8; 121 acc2 = vfmaq(acc2, vecIn, vecA); 122 vecA = vld1q(pSrcA3Vec); 123 pSrcA3Vec += 8; 124 acc3 = vfmaq(acc3, vecIn, vecA); 125 126 blkCnt--; 127 } 128 /* 129 * tail 130 * (will be merged thru tail predication) 131 */ 132 blkCnt = numCols & 7; 133 if (blkCnt > 0U) 134 { 135 mve_pred16_t p0 = vctp16q(blkCnt); 136 f16x8_t vecA; 137 138 vecIn = vldrhq_z_f16(pInVec, p0); 139 vecA = vld1q(pSrcA0Vec); 140 acc0 = vfmaq(acc0, vecIn, vecA); 141 vecA = vld1q(pSrcA1Vec); 142 acc1 = vfmaq(acc1, vecIn, vecA); 143 vecA = vld1q(pSrcA2Vec); 144 acc2 = vfmaq(acc2, vecIn, vecA); 145 vecA = vld1q(pSrcA3Vec); 146 acc3 = vfmaq(acc3, vecIn, vecA); 147 } 148 /* 149 * Sum the partial parts 150 */ 151 *px++ = vecAddAcrossF16Mve(acc0); 152 *px++ = vecAddAcrossF16Mve(acc1); 153 *px++ = vecAddAcrossF16Mve(acc2); 154 *px++ = vecAddAcrossF16Mve(acc3); 155 156 pSrcA += numCols * 4; 157 /* 158 * Decrement the row loop counter 159 */ 160 row -= 4; 161 } 162 163 /* 164 * compute 2 rows in parrallel 165 */ 166 if (row >= 2) 167 { 168 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec; 169 f16x8_t vecIn, acc0, acc1; 170 float16_t const *pSrcVecPtr = pSrcVec; 171 172 /* 173 * Initialize the pointers to 2 consecutive MatrixA rows 174 */ 175 pInA0 = pSrcA; 176 pInA1 = pInA0 + numCols; 177 /* 178 * Initialize the vector pointer 179 */ 180 pInVec = pSrcVecPtr; 181 /* 182 * reset accumulators 183 */ 184 acc0 = vdupq_n_f16(0.0f); 185 acc1 = vdupq_n_f16(0.0f); 186 pSrcA0Vec = pInA0; 187 pSrcA1Vec = pInA1; 188 189 blkCnt = numCols >> 3; 190 while (blkCnt > 0U) 191 { 192 f16x8_t vecA; 193 194 vecIn = vld1q(pInVec); 195 pInVec += 8; 196 vecA = vld1q(pSrcA0Vec); 197 pSrcA0Vec += 8; 198 acc0 = vfmaq(acc0, vecIn, vecA); 199 vecA = vld1q(pSrcA1Vec); 200 pSrcA1Vec += 8; 201 acc1 = vfmaq(acc1, vecIn, vecA); 202 203 blkCnt--; 204 } 205 /* 206 * tail 207 * (will be merged thru tail predication) 208 */ 209 blkCnt = numCols & 7; 210 if (blkCnt > 0U) 211 { 212 mve_pred16_t p0 = vctp16q(blkCnt); 213 f16x8_t vecA; 214 215 vecIn = vldrhq_z_f16(pInVec, p0); 216 vecA = vld1q(pSrcA0Vec); 217 acc0 = vfmaq(acc0, vecIn, vecA); 218 vecA = vld1q(pSrcA1Vec); 219 acc1 = vfmaq(acc1, vecIn, vecA); 220 } 221 /* 222 * Sum the partial parts 223 */ 224 *px++ = vecAddAcrossF16Mve(acc0); 225 *px++ = vecAddAcrossF16Mve(acc1); 226 227 pSrcA += numCols * 2; 228 row -= 2; 229 } 230 231 if (row >= 1) 232 { 233 f16x8_t vecIn, acc0; 234 float16_t const *pSrcA0Vec, *pInVec; 235 float16_t const *pSrcVecPtr = pSrcVec; 236 /* 237 * Initialize the pointers to last MatrixA row 238 */ 239 pInA0 = pSrcA; 240 /* 241 * Initialize the vector pointer 242 */ 243 pInVec = pSrcVecPtr; 244 /* 245 * reset accumulators 246 */ 247 acc0 = vdupq_n_f16(0.0f); 248 249 pSrcA0Vec = pInA0; 250 251 blkCnt = numCols >> 3; 252 while (blkCnt > 0U) 253 { 254 f16x8_t vecA; 255 256 vecIn = vld1q(pInVec); 257 pInVec += 8; 258 vecA = vld1q(pSrcA0Vec); 259 pSrcA0Vec += 8; 260 acc0 = vfmaq(acc0, vecIn, vecA); 261 262 blkCnt--; 263 } 264 /* 265 * tail 266 * (will be merged thru tail predication) 267 */ 268 blkCnt = numCols & 7; 269 if (blkCnt > 0U) 270 { 271 mve_pred16_t p0 = vctp16q(blkCnt); 272 f16x8_t vecA; 273 274 vecIn = vldrhq_z_f16(pInVec, p0); 275 vecA = vld1q(pSrcA0Vec); 276 acc0 = vfmaq(acc0, vecIn, vecA); 277 } 278 /* 279 * Sum the partial parts 280 */ 281 *px++ = vecAddAcrossF16Mve(acc0); 282 } 283 } 284 #else 285 void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst) 286 { 287 uint32_t numRows = pSrcMat->numRows; 288 uint32_t numCols = pSrcMat->numCols; 289 const float16_t *pSrcA = pSrcMat->pData; 290 const float16_t *pInA1; /* input data matrix pointer A of Q31 type */ 291 const float16_t *pInA2; /* input data matrix pointer A of Q31 type */ 292 const float16_t *pInA3; /* input data matrix pointer A of Q31 type */ 293 const float16_t *pInA4; /* input data matrix pointer A of Q31 type */ 294 const float16_t *pInVec; /* input data matrix pointer B of Q31 type */ 295 float16_t *px; /* Temporary output data matrix pointer */ 296 uint16_t i, row, colCnt; /* loop counters */ 297 float16_t matData, matData2, vecData, vecData2; 298 299 300 /* Process 4 rows at a time */ 301 row = numRows >> 2; 302 i = 0u; 303 px = pDst; 304 305 /* The following loop performs the dot-product of each row in pSrcA with the vector */ 306 /* row loop */ 307 while (row > 0) { 308 /* For every row wise process, the pInVec pointer is set 309 ** to the starting address of the vector */ 310 pInVec = pVec; 311 312 /* Initialize accumulators */ 313 float16_t sum1 = 0.0f16; 314 float16_t sum2 = 0.0f16; 315 float16_t sum3 = 0.0f16; 316 float16_t sum4 = 0.0f16; 317 318 /* Loop unrolling: process 2 columns per iteration */ 319 colCnt = numCols; 320 321 /* Initialize pointers to the starting address of the column being processed */ 322 pInA1 = pSrcA + i; 323 pInA2 = pInA1 + numCols; 324 pInA3 = pInA2 + numCols; 325 pInA4 = pInA3 + numCols; 326 327 328 // Main loop: matrix-vector multiplication 329 while (colCnt > 0u) { 330 // Read 2 values from vector 331 vecData = *(pInVec)++; 332 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate 333 matData = *(pInA1)++; 334 sum1 += (_Float16)matData * (_Float16)vecData; 335 matData = *(pInA2)++; 336 sum2 += (_Float16)matData * (_Float16)vecData; 337 matData = *(pInA3)++; 338 sum3 += (_Float16)matData * (_Float16)vecData; 339 matData = *(pInA4)++; 340 sum4 += (_Float16)matData * (_Float16)vecData; 341 342 // Decrement the loop counter 343 colCnt--; 344 } 345 346 /* Saturate and store the result in the destination buffer */ 347 *px++ = sum1; 348 *px++ = sum2; 349 *px++ = sum3; 350 *px++ = sum4; 351 352 i = i + numCols * 4; 353 354 /* Decrement the row loop counter */ 355 row--; 356 } 357 358 /* process any remaining rows */ 359 row = numRows & 3u; 360 while (row > 0) { 361 362 float16_t sum = 0.0f16; 363 pInVec = pVec; 364 pInA1 = pSrcA + i; 365 366 colCnt = numCols >> 1; 367 368 while (colCnt > 0) { 369 vecData = *(pInVec)++; 370 vecData2 = *(pInVec)++; 371 matData = *(pInA1)++; 372 matData2 = *(pInA1)++; 373 sum += (_Float16)matData * (_Float16)vecData; 374 sum += (_Float16)matData2 * (_Float16)vecData2; 375 colCnt--; 376 } 377 // process remainder of row 378 colCnt = numCols & 1u; 379 while (colCnt > 0) { 380 sum += (_Float16)*pInA1++ * (_Float16)*pInVec++; 381 colCnt--; 382 } 383 384 *px++ = sum; 385 i = i + numCols; 386 row--; 387 } 388 } 389 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 390 391 /** 392 * @} end of MatrixMult group 393 */ 394 395 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 396