arm_mat_cmplx_mult_f16.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_mat_cmplx_mult_f16.c 4 * Description: Floating-point matrix multiplication 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/matrix_functions_f16.h" 30 31 #if defined(ARM_FLOAT16_SUPPORTED) 32 33 34 /** 35 @ingroup groupMatrix 36 */ 37 38 39 /** 40 @addtogroup CmplxMatrixMult 41 @{ 42 */ 43 44 /** 45 @brief Floating-point Complex matrix multiplication. 46 @param[in] pSrcA points to first input complex matrix structure 47 @param[in] pSrcB points to second input complex matrix structure 48 @param[out] pDst points to output complex matrix structure 49 @return execution status 50 - \ref ARM_MATH_SUCCESS : Operation successful 51 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed 52 */ 53 54 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H) 55 #pragma GCC warning "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc." 56 #endif 57 58 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && !defined(__CMSIS_GCC_H) 59 60 #include "arm_helium_utils.h" 61 62 #define DONTCARE 0 /* inactive lane content */ 63 64 65 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve( 66 const arm_matrix_instance_f16 * pSrcA, 67 const arm_matrix_instance_f16 * pSrcB, 68 arm_matrix_instance_f16 * pDst) 69 { 70 #define MATRIX_DIM 2 71 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */ 72 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */ 73 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 74 uint16x8_t vecColBOffs0,vecColAOffs0,vecColAOffs1; 75 float16_t *pInA0 = pInA; 76 f16x8_t acc0, acc1; 77 f16x8_t vecB, vecA0, vecA1; 78 f16x8_t vecTmp; 79 uint16_t tmp; 80 static const uint16_t offsetB0[8] = { 0, 1, 81 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1, 82 2, 3, 83 MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3, 84 }; 85 86 87 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0); 88 89 tmp = 0; 90 vecColAOffs0 = viwdupq_u16(tmp, 4, 1); 91 92 tmp = (CMPLX_DIM * MATRIX_DIM); 93 vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM); 94 95 96 pInB = (float16_t const *)pSrcB->pData; 97 98 vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0); 99 vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1); 100 101 102 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0); 103 104 acc0 = vcmulq(vecA0, vecB); 105 acc0 = vcmlaq_rot90(acc0, vecA0, vecB); 106 107 acc1 = vcmulq(vecA1, vecB); 108 acc1 = vcmlaq_rot90(acc1, vecA1, vecB); 109 110 111 /* 112 * Compute 113 * re0+re1 | im0+im1 | re0+re1 | im0+im1 114 * re2+re3 | im2+im3 | re2+re3 | im2+im3 115 */ 116 117 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0); 118 vecTmp = vaddq(vecTmp, acc0); 119 120 121 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0]; 122 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2]; 123 124 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1); 125 vecTmp = vaddq(vecTmp, acc1); 126 127 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0]; 128 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2]; 129 130 /* 131 * Return to application 132 */ 133 return (ARM_MATH_SUCCESS); 134 #undef MATRIX_DIM 135 } 136 137 138 139 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve( 140 const arm_matrix_instance_f16 * pSrcA, 141 const arm_matrix_instance_f16 * pSrcB, 142 arm_matrix_instance_f16 * pDst) 143 { 144 #define MATRIX_DIM 3 145 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */ 146 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */ 147 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 148 uint16x8_t vecColBOffs0; 149 float16_t *pInA0 = pInA; 150 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM; 151 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM; 152 f16x8_t acc0, acc1, acc2; 153 f16x8_t vecB, vecA0, vecA1, vecA2; 154 static const uint16_t offsetB0[8] = { 0, 1, 155 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1, 156 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1, 157 DONTCARE, DONTCARE 158 }; 159 160 161 /* enable predication to disable upper half complex vector element */ 162 mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM); 163 164 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0); 165 166 pInB = (float16_t const *)pSrcB->pData; 167 168 vecA0 = vldrhq_f16(pInA0); 169 vecA1 = vldrhq_f16(pInA1); 170 vecA2 = vldrhq_f16(pInA2); 171 172 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0); 173 174 acc0 = vcmulq(vecA0, vecB); 175 acc0 = vcmlaq_rot90(acc0, vecA0, vecB); 176 177 acc1 = vcmulq(vecA1, vecB); 178 acc1 = vcmlaq_rot90(acc1, vecA1, vecB); 179 180 acc2 = vcmulq(vecA2, vecB); 181 acc2 = vcmlaq_rot90(acc2, vecA2, vecB); 182 183 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 184 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 185 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 186 pOut += CMPLX_DIM; 187 /* 188 * move to next B column 189 */ 190 pInB = pInB + CMPLX_DIM; 191 192 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0); 193 194 acc0 = vcmulq(vecA0, vecB); 195 acc0 = vcmlaq_rot90(acc0, vecA0, vecB); 196 197 acc1 = vcmulq(vecA1, vecB); 198 acc1 = vcmlaq_rot90(acc1, vecA1, vecB); 199 200 acc2 = vcmulq(vecA2, vecB); 201 acc2 = vcmlaq_rot90(acc2, vecA2, vecB); 202 203 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 204 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 205 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 206 pOut += CMPLX_DIM; 207 /* 208 * move to next B column 209 */ 210 pInB = pInB + CMPLX_DIM; 211 212 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0); 213 214 acc0 = vcmulq(vecA0, vecB); 215 acc0 = vcmlaq_rot90(acc0, vecA0, vecB); 216 217 acc1 = vcmulq(vecA1, vecB); 218 acc1 = vcmlaq_rot90(acc1, vecA1, vecB); 219 220 acc2 = vcmulq(vecA2, vecB); 221 acc2 = vcmlaq_rot90(acc2, vecA2, vecB); 222 223 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 224 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 225 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 226 /* 227 * Return to application 228 */ 229 return (ARM_MATH_SUCCESS); 230 #undef MATRIX_DIM 231 } 232 233 234 235 236 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve( 237 const arm_matrix_instance_f16 * pSrcA, 238 const arm_matrix_instance_f16 * pSrcB, 239 arm_matrix_instance_f16 * pDst) 240 { 241 #define MATRIX_DIM 4 242 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */ 243 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */ 244 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 245 uint16x8_t vecColBOffs0; 246 float16_t *pInA0 = pInA; 247 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM; 248 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM; 249 float16_t *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM; 250 f16x8_t acc0, acc1, acc2, acc3; 251 f16x8_t vecB, vecA; 252 static const uint16_t offsetB0[8] = { 0, 1, 253 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1, 254 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1, 255 3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1 256 }; 257 258 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0); 259 260 pInB = (float16_t const *)pSrcB->pData; 261 262 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0); 263 264 vecA = vldrhq_f16(pInA0); 265 acc0 = vcmulq(vecA, vecB); 266 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 267 268 vecA = vldrhq_f16(pInA1); 269 acc1 = vcmulq(vecA, vecB); 270 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 271 272 vecA = vldrhq_f16(pInA2); 273 acc2 = vcmulq(vecA, vecB); 274 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 275 276 vecA = vldrhq_f16(pInA3); 277 acc3 = vcmulq(vecA, vecB); 278 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 279 280 281 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 282 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 283 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 284 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]); 285 pOut += CMPLX_DIM; 286 /* 287 * move to next B column 288 */ 289 pInB = pInB + CMPLX_DIM; 290 291 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0); 292 293 vecA = vldrhq_f16(pInA0); 294 acc0 = vcmulq(vecA, vecB); 295 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 296 297 vecA = vldrhq_f16(pInA1); 298 acc1 = vcmulq(vecA, vecB); 299 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 300 301 vecA = vldrhq_f16(pInA2); 302 acc2 = vcmulq(vecA, vecB); 303 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 304 305 vecA = vldrhq_f16(pInA3); 306 acc3 = vcmulq(vecA, vecB); 307 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 308 309 310 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 311 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 312 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 313 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]); 314 pOut += CMPLX_DIM; 315 /* 316 * move to next B column 317 */ 318 pInB = pInB + CMPLX_DIM; 319 320 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0); 321 322 vecA = vldrhq_f16(pInA0); 323 acc0 = vcmulq(vecA, vecB); 324 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 325 326 vecA = vldrhq_f16(pInA1); 327 acc1 = vcmulq(vecA, vecB); 328 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 329 330 vecA = vldrhq_f16(pInA2); 331 acc2 = vcmulq(vecA, vecB); 332 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 333 334 vecA = vldrhq_f16(pInA3); 335 acc3 = vcmulq(vecA, vecB); 336 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 337 338 339 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 340 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 341 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 342 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]); 343 pOut += CMPLX_DIM; 344 /* 345 * move to next B column 346 */ 347 pInB = pInB + CMPLX_DIM; 348 349 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0); 350 351 vecA = vldrhq_f16(pInA0); 352 acc0 = vcmulq(vecA, vecB); 353 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 354 355 vecA = vldrhq_f16(pInA1); 356 acc1 = vcmulq(vecA, vecB); 357 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 358 359 vecA = vldrhq_f16(pInA2); 360 acc2 = vcmulq(vecA, vecB); 361 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 362 363 vecA = vldrhq_f16(pInA3); 364 acc3 = vcmulq(vecA, vecB); 365 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 366 367 368 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]); 369 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]); 370 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]); 371 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]); 372 /* 373 * Return to application 374 */ 375 return (ARM_MATH_SUCCESS); 376 #undef MATRIX_DIM 377 } 378 379 380 381 arm_status arm_mat_cmplx_mult_f16( 382 const arm_matrix_instance_f16 * pSrcA, 383 const arm_matrix_instance_f16 * pSrcB, 384 arm_matrix_instance_f16 * pDst) 385 { 386 float16_t const *pInB = (float16_t const *) pSrcB->pData; /* input data matrix pointer B */ 387 float16_t const *pInA = (float16_t const *) pSrcA->pData; /* input data matrix pointer A */ 388 float16_t *pOut = pDst->pData; /* output data matrix pointer */ 389 float16_t *px; /* Temporary output data matrix pointer */ 390 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */ 391 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */ 392 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */ 393 uint16_t col, i = 0U, row = numRowsA; /* loop counters */ 394 arm_status status; /* status of matrix multiplication */ 395 uint16x8_t vecOffs, vecColBOffs; 396 uint32_t blkCnt,rowCnt; /* loop counters */ 397 398 #ifdef ARM_MATH_MATRIX_CHECK 399 400 /* Check for matrix mismatch condition */ 401 if ((pSrcA->numCols != pSrcB->numRows) || 402 (pSrcA->numRows != pDst->numRows) || 403 (pSrcB->numCols != pDst->numCols) ) 404 { 405 /* Set status as ARM_MATH_SIZE_MISMATCH */ 406 status = ARM_MATH_SIZE_MISMATCH; 407 } 408 else 409 410 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */ 411 412 { 413 414 /* 415 * small squared matrix specialized routines 416 */ 417 if (numRowsA == numColsB && numColsB == numColsA) 418 { 419 if (numRowsA == 1) 420 { 421 pOut[0] = (_Float16)pInA[0] * (_Float16)pInB[0] - (_Float16)pInA[1] * (_Float16)pInB[1]; 422 pOut[1] = (_Float16)pInA[0] * (_Float16)pInB[1] + (_Float16)pInA[1] * (_Float16)pInB[0]; 423 return (ARM_MATH_SUCCESS); 424 } 425 else if (numRowsA == 2) 426 return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst); 427 else if (numRowsA == 3) 428 return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst); 429 else if (numRowsA == 4) 430 return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst); 431 } 432 433 vecColBOffs[0] = 0; 434 vecColBOffs[1] = 1; 435 vecColBOffs[2] = numColsB * CMPLX_DIM; 436 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1; 437 vecColBOffs[4] = 2*numColsB * CMPLX_DIM; 438 vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1; 439 vecColBOffs[6] = 3*numColsB * CMPLX_DIM; 440 vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1; 441 442 /* 443 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB 444 */ 445 446 /* 447 * row loop 448 */ 449 rowCnt = row >> 2; 450 while (rowCnt > 0u) 451 { 452 /* 453 * Output pointer is set to starting address of the row being processed 454 */ 455 px = pOut + i * CMPLX_DIM; 456 i = i + 4 * numColsB; 457 /* 458 * For every row wise process, the column loop counter is to be initiated 459 */ 460 col = numColsB; 461 /* 462 * For every row wise process, the pInB pointer is set 463 * to the starting address of the pSrcB data 464 */ 465 pInB = (float16_t const *) pSrcB->pData; 466 /* 467 * column loop 468 */ 469 while (col > 0u) 470 { 471 /* 472 * generate 4 columns elements 473 */ 474 /* 475 * Matrix A columns number of MAC operations are to be performed 476 */ 477 478 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec; 479 float16_t const *pInA0 = pInA; 480 float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM; 481 float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM; 482 float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM; 483 f16x8_t acc0, acc1, acc2, acc3; 484 485 acc0 = vdupq_n_f16(0.0f16); 486 acc1 = vdupq_n_f16(0.0f16); 487 acc2 = vdupq_n_f16(0.0f16); 488 acc3 = vdupq_n_f16(0.0f16); 489 490 pSrcA0Vec = (float16_t const *) pInA0; 491 pSrcA1Vec = (float16_t const *) pInA1; 492 pSrcA2Vec = (float16_t const *) pInA2; 493 pSrcA3Vec = (float16_t const *) pInA3; 494 495 vecOffs = vecColBOffs; 496 497 /* 498 * process 1 x 4 block output 499 */ 500 blkCnt = (numColsA * CMPLX_DIM) >> 3; 501 while (blkCnt > 0U) 502 { 503 f16x8_t vecB, vecA; 504 505 vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs); 506 /* 507 * move Matrix B read offsets, 4 rows down 508 */ 509 vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM)); 510 511 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8; 512 acc0 = vcmlaq(acc0, vecA, vecB); 513 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 514 515 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8; 516 acc1 = vcmlaq(acc1, vecA, vecB); 517 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 518 519 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8; 520 acc2 = vcmlaq(acc2, vecA, vecB); 521 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 522 523 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8; 524 acc3 = vcmlaq(acc3, vecA, vecB); 525 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 526 527 blkCnt--; 528 } 529 /* 530 * Unsupported addressing mode compiler crash 531 */ 532 /* 533 * tail 534 * (will be merged thru tail predication) 535 */ 536 blkCnt = (numColsA * CMPLX_DIM) & 7; 537 if (blkCnt > 0U) 538 { 539 mve_pred16_t p0 = vctp16q(blkCnt); 540 f16x8_t vecB, vecA; 541 542 vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0); 543 /* 544 * move Matrix B read offsets, 4 rows down 545 */ 546 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM)); 547 548 vecA = vld1q(pSrcA0Vec); 549 acc0 = vcmlaq(acc0, vecA, vecB); 550 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 551 552 vecA = vld1q(pSrcA1Vec); 553 acc1 = vcmlaq(acc1, vecA, vecB); 554 acc1 = vcmlaq_rot90(acc1, vecA, vecB); 555 556 vecA = vld1q(pSrcA2Vec); 557 acc2 = vcmlaq(acc2, vecA, vecB); 558 acc2 = vcmlaq_rot90(acc2, vecA, vecB); 559 560 vecA = vld1q(pSrcA3Vec); 561 acc3 = vcmlaq(acc3, vecA, vecB); 562 acc3 = vcmlaq_rot90(acc3, vecA, vecB); 563 564 } 565 566 567 mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]); 568 mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]); 569 mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]); 570 mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]); 571 572 px += CMPLX_DIM; 573 /* 574 * Decrement the column loop counter 575 */ 576 col--; 577 /* 578 * Update the pointer pInB to point to the starting address of the next column 579 */ 580 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM; 581 } 582 583 /* 584 * Update the pointer pInA to point to the starting address of the next row 585 */ 586 pInA += (numColsA * 4) * CMPLX_DIM; 587 /* 588 * Decrement the row loop counter 589 */ 590 rowCnt --; 591 592 } 593 594 rowCnt = row & 3; 595 while (rowCnt > 0u) 596 { 597 /* 598 * Output pointer is set to starting address of the row being processed 599 */ 600 px = pOut + i * CMPLX_DIM; 601 i = i + numColsB; 602 /* 603 * For every row wise process, the column loop counter is to be initiated 604 */ 605 col = numColsB; 606 /* 607 * For every row wise process, the pInB pointer is set 608 * to the starting address of the pSrcB data 609 */ 610 pInB = (float16_t const *) pSrcB->pData; 611 /* 612 * column loop 613 */ 614 while (col > 0u) 615 { 616 /* 617 * generate 4 columns elements 618 */ 619 /* 620 * Matrix A columns number of MAC operations are to be performed 621 */ 622 623 float16_t const *pSrcA0Vec; 624 float16_t const *pInA0 = pInA; 625 f16x8_t acc0; 626 627 acc0 = vdupq_n_f16(0.0f16); 628 629 pSrcA0Vec = (float16_t const *) pInA0; 630 631 vecOffs = vecColBOffs; 632 633 /* 634 * process 1 x 4 block output 635 */ 636 blkCnt = (numColsA * CMPLX_DIM) >> 3; 637 while (blkCnt > 0U) 638 { 639 f16x8_t vecB, vecA; 640 641 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs); 642 /* 643 * move Matrix B read offsets, 4 rows down 644 */ 645 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM)); 646 647 vecA = vld1q(pSrcA0Vec); 648 pSrcA0Vec += 8; 649 acc0 = vcmlaq(acc0, vecA, vecB); 650 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 651 652 653 blkCnt--; 654 } 655 656 657 /* 658 * tail 659 */ 660 blkCnt = (numColsA * CMPLX_DIM) & 7; 661 if (blkCnt > 0U) 662 { 663 mve_pred16_t p0 = vctp16q(blkCnt); 664 f16x8_t vecB, vecA; 665 666 vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0); 667 668 vecA = vld1q(pSrcA0Vec); 669 acc0 = vcmlaq(acc0, vecA, vecB); 670 acc0 = vcmlaq_rot90(acc0, vecA, vecB); 671 672 } 673 674 mve_cmplx_sum_intra_vec_f16(acc0, &px[0]); 675 676 677 px += CMPLX_DIM; 678 /* 679 * Decrement the column loop counter 680 */ 681 col--; 682 /* 683 * Update the pointer pInB to point to the starting address of the next column 684 */ 685 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM; 686 } 687 688 /* 689 * Update the pointer pInA to point to the starting address of the next row 690 */ 691 pInA += numColsA * CMPLX_DIM; 692 rowCnt--; 693 } 694 695 /* 696 * set status as ARM_MATH_SUCCESS 697 */ 698 status = ARM_MATH_SUCCESS; 699 } 700 /* 701 * Return to application 702 */ 703 return (status); 704 } 705 #else 706 707 arm_status arm_mat_cmplx_mult_f16( 708 const arm_matrix_instance_f16 * pSrcA, 709 const arm_matrix_instance_f16 * pSrcB, 710 arm_matrix_instance_f16 * pDst) 711 { 712 float16_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */ 713 float16_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */ 714 float16_t *pInA = pSrcA->pData; /* Input data matrix pointer A */ 715 float16_t *pOut = pDst->pData; /* Output data matrix pointer */ 716 float16_t *px; /* Temporary output data matrix pointer */ 717 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */ 718 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */ 719 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */ 720 _Float16 sumReal, sumImag; /* Accumulator */ 721 _Float16 a1, b1, c1, d1; 722 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */ 723 arm_status status; /* status of matrix multiplication */ 724 725 #if defined (ARM_MATH_LOOPUNROLL) 726 _Float16 a0, b0, c0, d0; 727 #endif 728 729 #ifdef ARM_MATH_MATRIX_CHECK 730 731 /* Check for matrix mismatch condition */ 732 if ((pSrcA->numCols != pSrcB->numRows) || 733 (pSrcA->numRows != pDst->numRows) || 734 (pSrcB->numCols != pDst->numCols) ) 735 { 736 /* Set status as ARM_MATH_SIZE_MISMATCH */ 737 status = ARM_MATH_SIZE_MISMATCH; 738 } 739 else 740 741 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */ 742 743 { 744 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */ 745 /* row loop */ 746 do 747 { 748 /* Output pointer is set to starting address of the row being processed */ 749 px = pOut + 2 * i; 750 751 /* For every row wise process, the column loop counter is to be initiated */ 752 col = numColsB; 753 754 /* For every row wise process, the pIn2 pointer is set 755 ** to the starting address of the pSrcB data */ 756 pIn2 = pSrcB->pData; 757 758 j = 0U; 759 760 /* column loop */ 761 do 762 { 763 /* Set the variable sum, that acts as accumulator, to zero */ 764 sumReal = 0.0f16; 765 sumImag = 0.0f16; 766 767 /* Initiate pointer pIn1 to point to starting address of column being processed */ 768 pIn1 = pInA; 769 770 #if defined (ARM_MATH_LOOPUNROLL) 771 772 /* Apply loop unrolling and compute 4 MACs simultaneously. */ 773 colCnt = numColsA >> 2U; 774 775 /* matrix multiplication */ 776 while (colCnt > 0U) 777 { 778 779 /* Reading real part of complex matrix A */ 780 a0 = *pIn1; 781 782 /* Reading real part of complex matrix B */ 783 c0 = *pIn2; 784 785 /* Reading imaginary part of complex matrix A */ 786 b0 = *(pIn1 + 1U); 787 788 /* Reading imaginary part of complex matrix B */ 789 d0 = *(pIn2 + 1U); 790 791 /* Multiply and Accumlates */ 792 sumReal += a0 * c0; 793 sumImag += b0 * c0; 794 795 /* update pointers */ 796 pIn1 += 2U; 797 pIn2 += 2 * numColsB; 798 799 /* Multiply and Accumlates */ 800 sumReal -= b0 * d0; 801 sumImag += a0 * d0; 802 803 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */ 804 805 /* read real and imag values from pSrcA and pSrcB buffer */ 806 a1 = *(pIn1 ); 807 c1 = *(pIn2 ); 808 b1 = *(pIn1 + 1U); 809 d1 = *(pIn2 + 1U); 810 811 /* Multiply and Accumlates */ 812 sumReal += a1 * c1; 813 sumImag += b1 * c1; 814 815 /* update pointers */ 816 pIn1 += 2U; 817 pIn2 += 2 * numColsB; 818 819 /* Multiply and Accumlates */ 820 sumReal -= b1 * d1; 821 sumImag += a1 * d1; 822 823 a0 = *(pIn1 ); 824 c0 = *(pIn2 ); 825 b0 = *(pIn1 + 1U); 826 d0 = *(pIn2 + 1U); 827 828 /* Multiply and Accumlates */ 829 sumReal += a0 * c0; 830 sumImag += b0 * c0; 831 832 /* update pointers */ 833 pIn1 += 2U; 834 pIn2 += 2 * numColsB; 835 836 /* Multiply and Accumlates */ 837 sumReal -= b0 * d0; 838 sumImag += a0 * d0; 839 840 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */ 841 842 a1 = *(pIn1 ); 843 c1 = *(pIn2 ); 844 b1 = *(pIn1 + 1U); 845 d1 = *(pIn2 + 1U); 846 847 /* Multiply and Accumlates */ 848 sumReal += a1 * c1; 849 sumImag += b1 * c1; 850 851 /* update pointers */ 852 pIn1 += 2U; 853 pIn2 += 2 * numColsB; 854 855 /* Multiply and Accumlates */ 856 sumReal -= b1 * d1; 857 sumImag += a1 * d1; 858 859 /* Decrement loop count */ 860 colCnt--; 861 } 862 863 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here. 864 ** No loop unrolling is used. */ 865 colCnt = numColsA % 0x4U; 866 867 #else 868 869 /* Initialize blkCnt with number of samples */ 870 colCnt = numColsA; 871 872 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */ 873 874 while (colCnt > 0U) 875 { 876 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */ 877 a1 = *(pIn1 ); 878 c1 = *(pIn2 ); 879 b1 = *(pIn1 + 1U); 880 d1 = *(pIn2 + 1U); 881 882 /* Multiply and Accumlates */ 883 sumReal += a1 * c1; 884 sumImag += b1 * c1; 885 886 /* update pointers */ 887 pIn1 += 2U; 888 pIn2 += 2 * numColsB; 889 890 /* Multiply and Accumlates */ 891 sumReal -= b1 * d1; 892 sumImag += a1 * d1; 893 894 /* Decrement loop counter */ 895 colCnt--; 896 } 897 898 /* Store result in destination buffer */ 899 *px++ = sumReal; 900 *px++ = sumImag; 901 902 /* Update pointer pIn2 to point to starting address of next column */ 903 j++; 904 pIn2 = pSrcB->pData + 2U * j; 905 906 /* Decrement column loop counter */ 907 col--; 908 909 } while (col > 0U); 910 911 /* Update pointer pInA to point to starting address of next row */ 912 i = i + numColsB; 913 pInA = pInA + 2 * numColsA; 914 915 /* Decrement row loop counter */ 916 row--; 917 918 } while (row > 0U); 919 920 /* Set status as ARM_MATH_SUCCESS */ 921 status = ARM_MATH_SUCCESS; 922 } 923 924 /* Return to application */ 925 return (status); 926 } 927 928 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 929 930 /** 931 @} end of MatrixMult group 932 */ 933 934 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 935