arm_svm_rbf_predict_f32.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_svm_rbf_predict_f32.c 4 * Description: SVM Radial Basis Function Classifier 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/svm_functions.h" 30 #include <limits.h> 31 #include <math.h> 32 33 34 /** 35 * @addtogroup rbfsvm 36 * @{ 37 */ 38 39 40 /** 41 * @brief SVM rbf prediction 42 * @param[in] S Pointer to an instance of the rbf SVM structure. 43 * @param[in] in Pointer to input vector 44 * @param[out] pResult decision value 45 * @return none. 46 * 47 */ 48 49 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) 50 51 #include "arm_helium_utils.h" 52 #include "arm_vec_math.h" 53 54 void arm_svm_rbf_predict_f32( 55 const arm_svm_rbf_instance_f32 *S, 56 const float32_t * in, 57 int32_t * pResult) 58 { 59 /* inlined Matrix x Vector function interleaved with dot prod */ 60 uint32_t numRows = S->nbOfSupportVectors; 61 uint32_t numCols = S->vectorDimension; 62 const float32_t *pSupport = S->supportVectors; 63 const float32_t *pSrcA = pSupport; 64 const float32_t *pInA0; 65 const float32_t *pInA1; 66 uint32_t row; 67 uint32_t blkCnt; /* loop counters */ 68 const float32_t *pDualCoef = S->dualCoefficients; 69 float32_t sum = S->intercept; 70 f32x4_t vSum = vdupq_n_f32(0); 71 72 row = numRows; 73 74 /* 75 * compute 4 rows in parrallel 76 */ 77 while (row >= 4) { 78 const float32_t *pInA2, *pInA3; 79 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec; 80 f32x4_t vecIn, acc0, acc1, acc2, acc3; 81 float32_t const *pSrcVecPtr = in; 82 83 /* 84 * Initialize the pointers to 4 consecutive MatrixA rows 85 */ 86 pInA0 = pSrcA; 87 pInA1 = pInA0 + numCols; 88 pInA2 = pInA1 + numCols; 89 pInA3 = pInA2 + numCols; 90 /* 91 * Initialize the vector pointer 92 */ 93 pInVec = pSrcVecPtr; 94 /* 95 * reset accumulators 96 */ 97 acc0 = vdupq_n_f32(0.0f); 98 acc1 = vdupq_n_f32(0.0f); 99 acc2 = vdupq_n_f32(0.0f); 100 acc3 = vdupq_n_f32(0.0f); 101 102 pSrcA0Vec = pInA0; 103 pSrcA1Vec = pInA1; 104 pSrcA2Vec = pInA2; 105 pSrcA3Vec = pInA3; 106 107 blkCnt = numCols >> 2; 108 while (blkCnt > 0U) { 109 f32x4_t vecA; 110 f32x4_t vecDif; 111 112 vecIn = vld1q(pInVec); 113 pInVec += 4; 114 vecA = vld1q(pSrcA0Vec); 115 pSrcA0Vec += 4; 116 vecDif = vsubq(vecIn, vecA); 117 acc0 = vfmaq(acc0, vecDif, vecDif); 118 vecA = vld1q(pSrcA1Vec); 119 pSrcA1Vec += 4; 120 vecDif = vsubq(vecIn, vecA); 121 acc1 = vfmaq(acc1, vecDif, vecDif); 122 vecA = vld1q(pSrcA2Vec); 123 pSrcA2Vec += 4; 124 vecDif = vsubq(vecIn, vecA); 125 acc2 = vfmaq(acc2, vecDif, vecDif); 126 vecA = vld1q(pSrcA3Vec); 127 pSrcA3Vec += 4; 128 vecDif = vsubq(vecIn, vecA); 129 acc3 = vfmaq(acc3, vecDif, vecDif); 130 131 blkCnt--; 132 } 133 /* 134 * tail 135 * (will be merged thru tail predication) 136 */ 137 blkCnt = numCols & 3; 138 if (blkCnt > 0U) { 139 mve_pred16_t p0 = vctp32q(blkCnt); 140 f32x4_t vecA; 141 f32x4_t vecDif; 142 143 vecIn = vldrwq_z_f32(pInVec, p0); 144 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 145 vecDif = vsubq(vecIn, vecA); 146 acc0 = vfmaq(acc0, vecDif, vecDif); 147 vecA = vldrwq_z_f32(pSrcA1Vec, p0); 148 vecDif = vsubq(vecIn, vecA); 149 acc1 = vfmaq(acc1, vecDif, vecDif); 150 vecA = vldrwq_z_f32(pSrcA2Vec, p0);; 151 vecDif = vsubq(vecIn, vecA); 152 acc2 = vfmaq(acc2, vecDif, vecDif); 153 vecA = vldrwq_z_f32(pSrcA3Vec, p0); 154 vecDif = vsubq(vecIn, vecA); 155 acc3 = vfmaq(acc3, vecDif, vecDif); 156 } 157 /* 158 * Sum the partial parts 159 */ 160 161 //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF32Mve(acc0)); 162 f32x4_t vtmp = vuninitializedq_f32(); 163 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 164 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1); 165 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc2), vtmp, 2); 166 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc3), vtmp, 3); 167 168 vSum = 169 vfmaq_f32(vSum, vld1q(pDualCoef), 170 vexpq_f32(vmulq_n_f32(vtmp, -S->gamma))); 171 pDualCoef += 4; 172 pSrcA += numCols * 4; 173 /* 174 * Decrement the row loop counter 175 */ 176 row -= 4; 177 } 178 179 /* 180 * compute 2 rows in parrallel 181 */ 182 if (row >= 2) { 183 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec; 184 f32x4_t vecIn, acc0, acc1; 185 float32_t const *pSrcVecPtr = in; 186 187 /* 188 * Initialize the pointers to 2 consecutive MatrixA rows 189 */ 190 pInA0 = pSrcA; 191 pInA1 = pInA0 + numCols; 192 /* 193 * Initialize the vector pointer 194 */ 195 pInVec = pSrcVecPtr; 196 /* 197 * reset accumulators 198 */ 199 acc0 = vdupq_n_f32(0.0f); 200 acc1 = vdupq_n_f32(0.0f); 201 pSrcA0Vec = pInA0; 202 pSrcA1Vec = pInA1; 203 204 blkCnt = numCols >> 2; 205 while (blkCnt > 0U) { 206 f32x4_t vecA; 207 f32x4_t vecDif; 208 209 vecIn = vld1q(pInVec); 210 pInVec += 4; 211 vecA = vld1q(pSrcA0Vec); 212 pSrcA0Vec += 4; 213 vecDif = vsubq(vecIn, vecA); 214 acc0 = vfmaq(acc0, vecDif, vecDif);; 215 vecA = vld1q(pSrcA1Vec); 216 pSrcA1Vec += 4; 217 vecDif = vsubq(vecIn, vecA); 218 acc1 = vfmaq(acc1, vecDif, vecDif); 219 220 blkCnt--; 221 } 222 /* 223 * tail 224 * (will be merged thru tail predication) 225 */ 226 blkCnt = numCols & 3; 227 if (blkCnt > 0U) { 228 mve_pred16_t p0 = vctp32q(blkCnt); 229 f32x4_t vecA, vecDif; 230 231 vecIn = vldrwq_z_f32(pInVec, p0); 232 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 233 vecDif = vsubq(vecIn, vecA); 234 acc0 = vfmaq(acc0, vecDif, vecDif); 235 vecA = vldrwq_z_f32(pSrcA1Vec, p0); 236 vecDif = vsubq(vecIn, vecA); 237 acc1 = vfmaq(acc1, vecDif, vecDif); 238 } 239 /* 240 * Sum the partial parts 241 */ 242 f32x4_t vtmp = vuninitializedq_f32(); 243 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 244 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1); 245 246 vSum = 247 vfmaq_m_f32(vSum, vld1q(pDualCoef), 248 vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(2)); 249 pDualCoef += 2; 250 251 pSrcA += numCols * 2; 252 row -= 2; 253 } 254 255 if (row >= 1) { 256 f32x4_t vecIn, acc0; 257 float32_t const *pSrcA0Vec, *pInVec; 258 float32_t const *pSrcVecPtr = in; 259 /* 260 * Initialize the pointers to last MatrixA row 261 */ 262 pInA0 = pSrcA; 263 /* 264 * Initialize the vector pointer 265 */ 266 pInVec = pSrcVecPtr; 267 /* 268 * reset accumulators 269 */ 270 acc0 = vdupq_n_f32(0.0f); 271 272 pSrcA0Vec = pInA0; 273 274 blkCnt = numCols >> 2; 275 while (blkCnt > 0U) { 276 f32x4_t vecA, vecDif; 277 278 vecIn = vld1q(pInVec); 279 pInVec += 4; 280 vecA = vld1q(pSrcA0Vec); 281 pSrcA0Vec += 4; 282 vecDif = vsubq(vecIn, vecA); 283 acc0 = vfmaq(acc0, vecDif, vecDif); 284 285 blkCnt--; 286 } 287 /* 288 * tail 289 * (will be merged thru tail predication) 290 */ 291 blkCnt = numCols & 3; 292 if (blkCnt > 0U) { 293 mve_pred16_t p0 = vctp32q(blkCnt); 294 f32x4_t vecA, vecDif; 295 296 vecIn = vldrwq_z_f32(pInVec, p0); 297 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 298 vecDif = vsubq(vecIn, vecA); 299 acc0 = vfmaq(acc0, vecDif, vecDif); 300 } 301 /* 302 * Sum the partial parts 303 */ 304 f32x4_t vtmp = vuninitializedq_f32(); 305 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 306 307 vSum = 308 vfmaq_m_f32(vSum, vld1q(pDualCoef), 309 vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(1)); 310 311 } 312 313 314 sum += vecAddAcrossF32Mve(vSum); 315 *pResult = S->classes[STEP(sum)]; 316 } 317 318 319 #else 320 #if defined(ARM_MATH_NEON) 321 322 #include "NEMath.h" 323 324 void arm_svm_rbf_predict_f32( 325 const arm_svm_rbf_instance_f32 *S, 326 const float32_t * in, 327 int32_t * pResult) 328 { 329 float32_t sum = S->intercept; 330 331 float32_t dot; 332 float32x4_t dotV; 333 334 float32x4_t accuma,accumb,accumc,accumd,accum; 335 float32x2_t accum2; 336 float32x4_t temp; 337 float32x4_t vec1; 338 339 float32x4_t vec2,vec2a,vec2b,vec2c,vec2d; 340 341 uint32_t blkCnt; 342 uint32_t vectorBlkCnt; 343 344 const float32_t *pIn = in; 345 346 const float32_t *pSupport = S->supportVectors; 347 348 const float32_t *pSupporta = S->supportVectors; 349 const float32_t *pSupportb; 350 const float32_t *pSupportc; 351 const float32_t *pSupportd; 352 353 pSupportb = pSupporta + S->vectorDimension; 354 pSupportc = pSupportb + S->vectorDimension; 355 pSupportd = pSupportc + S->vectorDimension; 356 357 const float32_t *pDualCoefs = S->dualCoefficients; 358 359 360 vectorBlkCnt = S->nbOfSupportVectors >> 2; 361 while (vectorBlkCnt > 0U) 362 { 363 accuma = vdupq_n_f32(0); 364 accumb = vdupq_n_f32(0); 365 accumc = vdupq_n_f32(0); 366 accumd = vdupq_n_f32(0); 367 368 pIn = in; 369 370 blkCnt = S->vectorDimension >> 2; 371 while (blkCnt > 0U) 372 { 373 374 vec1 = vld1q_f32(pIn); 375 vec2a = vld1q_f32(pSupporta); 376 vec2b = vld1q_f32(pSupportb); 377 vec2c = vld1q_f32(pSupportc); 378 vec2d = vld1q_f32(pSupportd); 379 380 pIn += 4; 381 pSupporta += 4; 382 pSupportb += 4; 383 pSupportc += 4; 384 pSupportd += 4; 385 386 temp = vsubq_f32(vec1, vec2a); 387 accuma = vmlaq_f32(accuma, temp, temp); 388 389 temp = vsubq_f32(vec1, vec2b); 390 accumb = vmlaq_f32(accumb, temp, temp); 391 392 temp = vsubq_f32(vec1, vec2c); 393 accumc = vmlaq_f32(accumc, temp, temp); 394 395 temp = vsubq_f32(vec1, vec2d); 396 accumd = vmlaq_f32(accumd, temp, temp); 397 398 blkCnt -- ; 399 } 400 accum2 = vpadd_f32(vget_low_f32(accuma),vget_high_f32(accuma)); 401 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,0); 402 403 accum2 = vpadd_f32(vget_low_f32(accumb),vget_high_f32(accumb)); 404 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,1); 405 406 accum2 = vpadd_f32(vget_low_f32(accumc),vget_high_f32(accumc)); 407 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,2); 408 409 accum2 = vpadd_f32(vget_low_f32(accumd),vget_high_f32(accumd)); 410 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,3); 411 412 413 blkCnt = S->vectorDimension & 3; 414 while (blkCnt > 0U) 415 { 416 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,0) + SQ(*pIn - *pSupporta), dotV,0); 417 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,1) + SQ(*pIn - *pSupportb), dotV,1); 418 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,2) + SQ(*pIn - *pSupportc), dotV,2); 419 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,3) + SQ(*pIn - *pSupportd), dotV,3); 420 421 pSupporta++; 422 pSupportb++; 423 pSupportc++; 424 pSupportd++; 425 426 pIn++; 427 428 blkCnt -- ; 429 } 430 431 vec1 = vld1q_f32(pDualCoefs); 432 pDualCoefs += 4; 433 434 // To vectorize later 435 dotV = vmulq_n_f32(dotV, -S->gamma); 436 dotV = vexpq_f32(dotV); 437 438 accum = vmulq_f32(vec1,dotV); 439 accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum)); 440 sum += vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1); 441 442 pSupporta += 3*S->vectorDimension; 443 pSupportb += 3*S->vectorDimension; 444 pSupportc += 3*S->vectorDimension; 445 pSupportd += 3*S->vectorDimension; 446 447 vectorBlkCnt -- ; 448 } 449 450 pSupport = pSupporta; 451 vectorBlkCnt = S->nbOfSupportVectors & 3; 452 453 while (vectorBlkCnt > 0U) 454 { 455 accum = vdupq_n_f32(0); 456 dot = 0.0f; 457 pIn = in; 458 459 blkCnt = S->vectorDimension >> 2; 460 while (blkCnt > 0U) 461 { 462 463 vec1 = vld1q_f32(pIn); 464 vec2 = vld1q_f32(pSupport); 465 pIn += 4; 466 pSupport += 4; 467 468 temp = vsubq_f32(vec1,vec2); 469 accum = vmlaq_f32(accum, temp,temp); 470 471 blkCnt -- ; 472 } 473 accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum)); 474 dot = vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1); 475 476 477 blkCnt = S->vectorDimension & 3; 478 while (blkCnt > 0U) 479 { 480 481 dot = dot + SQ(*pIn - *pSupport); 482 pIn++; 483 pSupport++; 484 485 blkCnt -- ; 486 } 487 488 sum += *pDualCoefs++ * expf(-S->gamma * dot); 489 vectorBlkCnt -- ; 490 } 491 492 *pResult=S->classes[STEP(sum)]; 493 } 494 #else 495 void arm_svm_rbf_predict_f32( 496 const arm_svm_rbf_instance_f32 *S, 497 const float32_t * in, 498 int32_t * pResult) 499 { 500 float32_t sum=S->intercept; 501 float32_t dot=0; 502 uint32_t i,j; 503 const float32_t *pSupport = S->supportVectors; 504 505 for(i=0; i < S->nbOfSupportVectors; i++) 506 { 507 dot=0; 508 for(j=0; j < S->vectorDimension; j++) 509 { 510 dot = dot + SQ(in[j] - *pSupport); 511 pSupport++; 512 } 513 sum += S->dualCoefficients[i] * expf(-S->gamma * dot); 514 } 515 *pResult=S->classes[STEP(sum)]; 516 } 517 #endif 518 519 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 520 521 /** 522 * @} end of rbfsvm group 523 */