arm_mat_mult_f16.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_mult_f16.c 4 * Description: Floating-point matrix multiplication 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/matrix_functions_f16.h" 30 31 #if defined(ARM_FLOAT16_SUPPORTED) 32 33 34 /** 35 * @ingroup groupMatrix 36 */ 37 38 39 /** 40 * @addtogroup MatrixMult 41 * @{ 42 */ 43 44 /** 45 * @brief Floating-point matrix multiplication. 46 * @param[in] *pSrcA points to the first input matrix structure 47 * @param[in] *pSrcB points to the second input matrix structure 48 * @param[out] *pDst points to output matrix structure 49 * @return The function returns either 50 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking. 51 */ 52 53 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) 54 55 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_2x2_mve( 56 const arm_matrix_instance_f16 *pSrcA, 57 const arm_matrix_instance_f16 *pSrcB, 58 arm_matrix_instance_f16 *pDst) 59 { 60 static const uint16_t offsetA[8] = { 0, 0, 2, 2, 0, 0, 2, 2 }; 61 /* offsetB allows to read and duplicate 1 row of B */ 62 static const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 }; 63 uint16x8_t vecOffsA, vecOffsB; 64 f16x8_t vecInA, vecInB, vecDst; 65 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 66 67 /* 68 * load initial offsets 69 */ 70 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 71 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 72 /* 73 * load {a00 a00 a10 a10 x x x x } 74 */ 75 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 76 /* 77 * load {b00 b01 b00 b01 x x x x } 78 */ 79 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 80 /* 81 * { a00 b00 a00 b01 82 * a10 b00 a10 b01 83 * x x 84 * x x } 85 */ 86 vecDst = vmulq(vecInA, vecInB); 87 /* 88 * move to 2nd column of matrix A 89 */ 90 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 91 /* 92 * load {a01 a01 a11 a11 x x x x} 93 */ 94 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 95 /* 96 * move to next B row 97 */ 98 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 2); 99 /* 100 * load {b10, b11, b10, b11, x x x x } 101 */ 102 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 103 /* 104 * { a00 b00 + a01 b10 a00 b01 + a01 b11 105 * a10 b00 + a11 b10 a10 b01 + a11 b11 106 * x x 107 * x x } 108 */ 109 vecDst = vfmaq(vecDst, vecInA, vecInB); 110 111 mve_pred16_t p0 = vctp16q(2*2); 112 /* 113 * Store the result in the destination buffer 114 * (lower half of the vector) 115 */ 116 vstrhq_p(pOut, vecDst, p0); 117 118 return (ARM_MATH_SUCCESS); 119 } 120 121 122 123 124 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_3x3_mve( 125 const arm_matrix_instance_f16 *pSrcA, 126 const arm_matrix_instance_f16 *pSrcB, 127 arm_matrix_instance_f16 *pDst) 128 { 129 static const uint16_t offsetA[8] = { 0, 0, 0, 3, 3, 3, 6, 6 }; 130 /* offsetB allows to read and duplicate 1 row of B */ 131 static const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 }; 132 uint16x8_t vecOffsA, vecOffsB; 133 f16x8_t vecInA, vecInB, vecDst; 134 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 135 136 /* 137 * load initial offsets 138 */ 139 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 140 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 141 142 /* 143 * load {a00 a00 a00 a10 a10 a10 a20 a20} 144 */ 145 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 146 /* 147 * load {b00 b01 b02 b00 b01 b02 b00 b01} 148 */ 149 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 150 /* 151 * { a00 b00 a00 b01 a00 b02 152 * a10 b00 a10 b01 a10 b02 153 * a20 b00 a20 b01} 154 */ 155 vecDst = vmulq(vecInA, vecInB); 156 157 /* 158 * move to 2nd column of matrix A 159 */ 160 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 161 /* 162 * load {a01 a01 a01 a11 a11 a11 a21 a21} 163 */ 164 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 165 /* 166 * move to next B row 167 */ 168 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3); 169 /* 170 * load {b10, b11, b12, b10, b11, b12, b10, b11} 171 */ 172 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 173 /* 174 * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12 175 * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12 176 * a20 b00 + a21 b10 a20 b01 + a21 b11 } 177 */ 178 vecDst = vfmaq(vecDst, vecInA, vecInB); 179 /* 180 * move to 3rd column of matrix A 181 */ 182 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1); 183 /* 184 * load {a02 a02 a02 a12 a12 a12 a22 a22} 185 */ 186 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 187 /* 188 * move to next B row 189 */ 190 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3); 191 /* 192 * load {b20, b21, b22, b20, b21, b22, b20, b21} 193 */ 194 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 195 /* 196 * {a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22}, 197 * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22}, 198 * a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 } 199 */ 200 vecDst = vfmaq(vecDst, vecInA, vecInB); 201 202 /* 203 * Store the result in the destination buffer 204 */ 205 vst1q(pOut, vecDst); pOut += 8; 206 207 /* last element computed in scalar mode 208 * a20 b02 + a21 b12 + a22 b22 209 */ 210 _Float16 * pA = (_Float16 *)pSrcA->pData; 211 _Float16 * pB = (_Float16 *)pSrcB->pData; 212 *pOut = pA[2*3] * pB[2] + pA[2*3+1] * pB[3+2] + pA[2*3+2] * pB[2*3+2]; 213 214 return (ARM_MATH_SUCCESS); 215 } 216 217 218 219 220 221 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_4x4_mve( 222 const arm_matrix_instance_f16 *pSrcA, 223 const arm_matrix_instance_f16 *pSrcB, 224 arm_matrix_instance_f16 *pDst) 225 { 226 /* offsetA allows to read and duplicate 2 successive column elements of A */ 227 static const uint16_t offsetA[8] = { 0, 0, 0, 0, 4, 4, 4, 4 }; 228 /* offsetB allows to read and duplicate 1 row of B */ 229 static const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 }; 230 uint16x8_t vecOffsA, vecOffsB; 231 f16x8_t vecInA, vecInB, vecDst0, vecDst1; 232 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 233 234 /* 235 * load initial offsets 236 */ 237 vecOffsA = vldrhq_u16((uint16_t const *) offsetA); 238 vecOffsB = vldrhq_u16((uint16_t const *) offsetB); 239 240 /* 241 * load {a00 a00 a00 a00 a10 a10 a10 a10} 242 */ 243 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 244 /* 245 * load {b00 b01 b02 b03 b00 b01 b02 b03} 246 */ 247 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 248 /* 249 * { a00 b00 a00 b01 a00 b02 a00 b03 250 * a10 b00 a10 b01 a10 b02 a10 b03 } 251 */ 252 vecDst0 = vmulq(vecInA, vecInB); 253 /* 254 * jump 2 x A rows (2nd half of matrix) 255 */ 256 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8); 257 /* 258 * load {a20 a20 a20 a20 a30 a30 a30 a30} 259 */ 260 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 261 /* 262 * { a20 b00 a20 b01 a20 b02 a20 b03 263 * a30 b00 a30 b01 a30 b02 + a31 b12 } 264 */ 265 vecDst1 = vmulq(vecInA, vecInB); 266 /* 267 * rewind back to top half of the A matrix (2nd column) 268 */ 269 vecOffsA = vsubq(vecOffsA, (uint16_t) 7); 270 /* 271 * load {a01 a01 a01 a01 a11 a11 a11 a11} 272 */ 273 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 274 /* 275 * move to next B row 276 */ 277 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4); 278 /* 279 * load {b10, b11, b12, b13, b10, b11, b12, b13} 280 */ 281 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 282 /* 283 * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12 a00 b03 + a01 b13 284 * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12 a10 b03 + a11 b13 } 285 */ 286 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 287 /* 288 * jump 2 x A rows (2nd half of matrix) 289 */ 290 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8); 291 /* 292 * load {a21 a21 a21 a21 a31 a31 a31 a31} 293 */ 294 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 295 /* 296 * {a20 b00 + a21 b10 a20 b01 + a21 b11 a20 b02 + a21 b12 a20 b03 + a21 b13 297 * a30 b00 + a31 b10 a30 b01 + a31 b11 a30 b02 + a31 b12 a30 b03 + a31 b13 } 298 */ 299 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 300 301 /* 302 * rewind back to top half of the A matrix (3rd column) 303 */ 304 vecOffsA = vsubq(vecOffsA, (uint16_t) 7); 305 /* 306 * load {a02 a02 a02 a02 a12 a12 a12 a12} 307 */ 308 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 309 /* 310 * move to next B row 311 */ 312 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4); 313 /* 314 * load {b20, b21, b22, b23, b20, b21, b22, b23} 315 */ 316 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 317 /* 318 * { a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22 a00 b03 + a01 b13 + a02 b23 319 * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22 a10 b03 + a11 b13 + a12 b23 } 320 */ 321 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 322 /* 323 * jump 2 x A rows 324 */ 325 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8); 326 327 /* 328 * load {a22 a22 a22 a22 a32 a32 a32 a32} 329 */ 330 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 331 /* 332 * {a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 a20 b02 + a21 b12 + a22 b22 a20 b03 + a21 b13 + a22 b23 333 * a30 b00 + a31 b10 + a32 b20 a30 b01 + a31 b11 + a32 b21 a30 b02 + a31 b12 + a32 b22 a30 b03 + a31 b13 + a32 b23 } 334 */ 335 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 336 337 /* 338 * rewind back to top half of the A matrix (4th column) 339 */ 340 vecOffsA = vsubq(vecOffsA, (uint16_t) 7); 341 /* 342 * load {a03 a03 a03 a03 a13 a13 a13 a13} 343 */ 344 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 345 /* 346 * move to next B row 347 */ 348 vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4); 349 /* 350 * load {b30, b31, b32, b33, b30, b31, b32, b33} 351 */ 352 vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB); 353 /* 354 * { a00 b00 +...+ a03 b30, a00 b01 +...+ a03 b31, a00 b02 +...+ a03 b32, a00 b03 +...+ a03 b33 355 * a10 b00 +...+ a13 b30, a10 b01 +...+ a13 b31, a10 b02 +...+ a13 b32, a10 b03 +...+ a13 b33 } 356 */ 357 vecDst0 = vfmaq(vecDst0, vecInA, vecInB); 358 /* 359 * jump 2 x A rows 360 */ 361 vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8); 362 /* 363 * load {a23 a23 a23 a23 a33 a33 a33 a33} 364 */ 365 vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA); 366 /* 367 * {a20 b00 +...+ a23 b30, a20 b01 +...+ a23 b31, a20 b02 +...+ a23 b32, a20 b03 +...+ a23 b33 368 * a30 b00 +...+ a33 b30, a30 b01 +...+ a33 b31, a30 b02 +...+ a33 b32, a30 b03 +...+ a33 b33 } 369 */ 370 vecDst1 = vfmaq(vecDst1, vecInA, vecInB); 371 372 /* 373 * Store the result in the destination buffer 374 */ 375 vst1q(pOut, vecDst0); pOut += 8; 376 vst1q(pOut, vecDst1); 377 378 return (ARM_MATH_SUCCESS); 379 } 380 381 382 arm_status arm_mat_mult_f16( 383 const arm_matrix_instance_f16 * pSrcA, 384 const arm_matrix_instance_f16 * pSrcB, 385 arm_matrix_instance_f16 * pDst) 386 { 387 float16_t *pInB = pSrcB->pData; /* input data matrix pointer B */ 388 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */ 389 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 390 int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ 391 int numColsB = pSrcB->numCols; /* number of columns of input matrix B */ 392 int numColsA = pSrcA->numCols; /* number of columns of input matrix A */ 393 uint32_t blkCnt; /* loop counters */ 394 int i; 395 396 397 #ifdef ARM_MATH_MATRIX_CHECK 398 399 /* Check for matrix mismatch condition */ 400 if ((pSrcA->numCols != pSrcB->numRows) || 401 (pSrcA->numRows != pDst->numRows) || 402 (pSrcB->numCols != pDst->numCols) ) 403 { 404 /* Set status as ARM_MATH_SIZE_MISMATCH */ 405 return(ARM_MATH_SIZE_MISMATCH); 406 } 407 else 408 409 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */ 410 { 411 /* small squared matrix specialized routines */ 412 if(numRowsA == numColsB && numColsB == numColsA) { 413 if(numRowsA == 2) 414 return arm_mat_mult_f16_2x2_mve(pSrcA, pSrcB, pDst); 415 else if(numRowsA == 3) 416 return arm_mat_mult_f16_3x3_mve(pSrcA, pSrcB, pDst); 417 else if(numRowsA == 4) 418 return arm_mat_mult_f16_4x4_mve(pSrcA, pSrcB, pDst); 419 } 420 421 /* main loop process 4 rows */ 422 i = numRowsA / 4; 423 while(i > 0) 424 { 425 float16_t *pInA0, *pInA1, *pInA2, *pInA3; 426 float16_t *pInB0; 427 float16_t *pOut0, *pOut1, *pOut2, *pOut3; 428 f16x8_t vecMac0, vecMac1, vecMac2, vecMac3; 429 f16x8_t vecInB; 430 431 /* pointers to 4 consecutive output rows */ 432 pOut0 = pOut; 433 pOut1 = pOut0 + numColsB; 434 pOut2 = pOut1 + numColsB; 435 pOut3 = pOut2 + numColsB; 436 pInB0 = pInB; 437 438 int k = numColsB >> 3; 439 while(k > 0) 440 { 441 /* pointers to 4 consecutive Matrix A rows */ 442 pInA0 = pInA; 443 pInA1 = pInA0 + numColsA; 444 pInA2 = pInA1 + numColsA; 445 pInA3 = pInA2 + numColsA; 446 447 vecMac0 = vdupq_n_f16(0.0f16); 448 vecMac1 = vdupq_n_f16(0.0f16); 449 vecMac2 = vdupq_n_f16(0.0f16); 450 vecMac3 = vdupq_n_f16(0.0f16); 451 452 blkCnt = numColsA; 453 454 while (blkCnt > 0U) 455 { 456 /* 457 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3..., bi,4n+7} 458 */ 459 vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */ 460 461 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++); 462 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++); 463 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++); 464 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++); 465 466 pInB0 = pInB0 + numColsB; 467 /* 468 * Decrement the blockSize loop counter 469 */ 470 blkCnt--; 471 } 472 473 /* Store the results (4 x 8 block) in the destination buffer */ 474 vst1q(pOut0, vecMac0); pOut0 += 8; 475 vst1q(pOut1, vecMac1); pOut1 += 8; 476 vst1q(pOut2, vecMac2); pOut2 += 8; 477 vst1q(pOut3, vecMac3); pOut3 += 8; 478 /* 479 * rewind 480 */ 481 pInB0 -= (numColsB * numColsA) - 8; 482 k--; 483 } 484 485 int colBLeft = numColsB & 7; 486 if (colBLeft) 487 { 488 pInA0 = pInA; 489 pInA1 = pInA0 + numColsA; 490 pInA2 = pInA1 + numColsA; 491 pInA3 = pInA2 + numColsA; 492 mve_pred16_t p0 = vctp16q(colBLeft); 493 494 vecMac0 = vdupq_n_f16(0.0f16); 495 vecMac1 = vdupq_n_f16(0.0f16); 496 vecMac2 = vdupq_n_f16(0.0f16); 497 vecMac3 = vdupq_n_f16(0.0f16); 498 499 blkCnt = numColsA; 500 501 while (blkCnt > 0U) 502 { 503 /* 504 * load {bi,4n+0, bi,4n+1, bi,4n+2, ..bi,4n+colBLeft-1, 0, ..} 505 */ 506 vecInB = vldrhq_z_f16(pInB0, p0); 507 508 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++); 509 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++); 510 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++); 511 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++); 512 513 pInB0 = pInB0 + numColsB; 514 /* 515 * Decrement the blockSize loop counter 516 */ 517 blkCnt--; 518 } 519 520 /* Store the results (4 x colBLeft block) in the destination buffer */ 521 vstrhq_p_f16(pOut0, vecMac0, p0); 522 vstrhq_p_f16(pOut1, vecMac1, p0); 523 vstrhq_p_f16(pOut2, vecMac2, p0); 524 vstrhq_p_f16(pOut3, vecMac3, p0); 525 } 526 527 pInA += 4 * numColsA; 528 pOut += 4 * numColsB; 529 i--; 530 } 531 532 /* 533 * non multiple of 4 rows for Matrix A 534 * process single row 535 */ 536 if (numRowsA & 3) 537 { 538 i = numRowsA & 3; 539 do 540 { 541 float16_t *pInA0; 542 float16_t *pInB0; 543 float16_t *pOut0; 544 f16x8_t vecInB; 545 f16x8_t vecMac0; 546 547 pOut0 = pOut; 548 pInB0 = pInB; 549 550 int k = numColsB >> 3; 551 while(k > 0) 552 { 553 pInA0 = pInA; 554 555 vecMac0 = vdupq_n_f16(0.0f16); 556 blkCnt = numColsA; 557 558 while (blkCnt > 0U) 559 { 560 /* 561 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3, ...bi,4n+7} 562 */ 563 vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */ 564 565 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++); 566 567 pInB0 = pInB0 + numColsB; 568 /* 569 * Decrement the blockSize loop counter 570 */ 571 blkCnt--; 572 } 573 /* Store the results (1 x 8 block) in the destination buffer */ 574 vst1q(pOut0, vecMac0); pOut0 += 8; 575 /* 576 * rewind 577 */ 578 pInB0 -= (numColsB * numColsA) - 8; 579 k--; 580 } 581 582 int colBLeft = numColsB & 7; 583 if (colBLeft) 584 { 585 pInA0 = pInA; 586 mve_pred16_t p0 = vctp16q(colBLeft); 587 588 vecMac0 = vdupq_n_f16(0.0f16); 589 blkCnt = numColsA; 590 591 while (blkCnt > 0U) 592 { 593 /* 594 * load {bi,4n+0, bi,4n+1, bi,4n+2, ..., bi,4n+colBLeft, 0, ...} 595 */ 596 vecInB = vldrhq_z_f16(pInB0, p0); 597 598 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++); 599 600 pInB0 = pInB0 + numColsB; 601 /* 602 * Decrement the blockSize loop counter 603 */ 604 blkCnt--; 605 } 606 /* Store the results (1 x colBLeft block) in the destination buffer */ 607 vstrhq_p_f16(pOut0, vecMac0, p0); 608 } 609 610 pInA += 1 * numColsA; 611 pOut += 1 * numColsB; 612 } 613 while (--i); 614 } 615 /* 616 * Return to application 617 */ 618 return (ARM_MATH_SUCCESS); 619 } 620 } 621 #else 622 623 624 arm_status arm_mat_mult_f16( 625 const arm_matrix_instance_f16 * pSrcA, 626 const arm_matrix_instance_f16 * pSrcB, 627 arm_matrix_instance_f16 * pDst) 628 { 629 float16_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */ 630 float16_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */ 631 float16_t *pInA = pSrcA->pData; /* Input data matrix pointer A */ 632 float16_t *pInB = pSrcB->pData; /* Input data matrix pointer B */ 633 float16_t *pOut = pDst->pData; /* Output data matrix pointer */ 634 float16_t *px; /* Temporary output data matrix pointer */ 635 _Float16 sum; /* Accumulator */ 636 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */ 637 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */ 638 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */ 639 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */ 640 arm_status status; /* Status of matrix multiplication */ 641 642 #ifdef ARM_MATH_MATRIX_CHECK 643 644 /* Check for matrix mismatch condition */ 645 if ((pSrcA->numCols != pSrcB->numRows) || 646 (pSrcA->numRows != pDst->numRows) || 647 (pSrcB->numCols != pDst->numCols) ) 648 { 649 /* Set status as ARM_MATH_SIZE_MISMATCH */ 650 status = ARM_MATH_SIZE_MISMATCH; 651 } 652 else 653 654 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */ 655 656 { 657 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */ 658 /* row loop */ 659 do 660 { 661 /* Output pointer is set to starting address of row being processed */ 662 px = pOut + i; 663 664 /* For every row wise process, column loop counter is to be initiated */ 665 col = numColsB; 666 667 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */ 668 pIn2 = pSrcB->pData; 669 670 /* column loop */ 671 do 672 { 673 /* Set the variable sum, that acts as accumulator, to zero */ 674 sum = 0.0f16; 675 676 /* Initialize pointer pIn1 to point to starting address of column being processed */ 677 pIn1 = pInA; 678 679 #if defined (ARM_MATH_LOOPUNROLL) 680 681 /* Loop unrolling: Compute 4 MACs at a time. */ 682 colCnt = numColsA >> 2U; 683 684 /* matrix multiplication */ 685 while (colCnt > 0U) 686 { 687 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */ 688 689 /* Perform the multiply-accumulates */ 690 sum += (_Float16)*pIn1++ * (_Float16)*pIn2; 691 pIn2 += numColsB; 692 693 sum += (_Float16)*pIn1++ * (_Float16)*pIn2; 694 pIn2 += numColsB; 695 696 sum += (_Float16)*pIn1++ * (_Float16)*pIn2; 697 pIn2 += numColsB; 698 699 sum += (_Float16)*pIn1++ * (_Float16)*pIn2; 700 pIn2 += numColsB; 701 702 /* Decrement loop counter */ 703 colCnt--; 704 } 705 706 /* Loop unrolling: Compute remaining MACs */ 707 colCnt = numColsA % 0x4U; 708 709 #else 710 711 /* Initialize cntCnt with number of columns */ 712 colCnt = numColsA; 713 714 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */ 715 716 while (colCnt > 0U) 717 { 718 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */ 719 720 /* Perform the multiply-accumulates */ 721 sum += (_Float16)*pIn1++ * (_Float16)*pIn2; 722 pIn2 += numColsB; 723 724 /* Decrement loop counter */ 725 colCnt--; 726 } 727 728 /* Store result in destination buffer */ 729 *px++ = sum; 730 731 /* Decrement column loop counter */ 732 col--; 733 734 /* Update pointer pIn2 to point to starting address of next column */ 735 pIn2 = pInB + (numColsB - col); 736 737 } while (col > 0U); 738 739 /* Update pointer pInA to point to starting address of next row */ 740 i = i + numColsB; 741 pInA = pInA + numColsA; 742 743 /* Decrement row loop counter */ 744 row--; 745 746 } while (row > 0U); 747 748 /* Set status as ARM_MATH_SUCCESS */ 749 status = ARM_MATH_SUCCESS; 750 } 751 752 /* Return to application */ 753 return (status); 754 } 755 756 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 757 758 /** 759 * @} end of MatrixMult group 760 */ 761 762 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 763