arm_svm_sigmoid_predict_f32.c
1 /* ---------------------------------------------------------------------- 2 * Project: CMSIS DSP Library 3 * Title: arm_svm_sigmoid_predict_f32.c 4 * Description: SVM Sigmoid Classifier 5 * 6 * $Date: 23 April 2021 7 * $Revision: V1.9.0 8 * 9 * Target Processor: Cortex-M and Cortex-A cores 10 * -------------------------------------------------------------------- */ 11 /* 12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. 13 * 14 * SPDX-License-Identifier: Apache-2.0 15 * 16 * Licensed under the Apache License, Version 2.0 (the License); you may 17 * not use this file except in compliance with the License. 18 * You may obtain a copy of the License at 19 * 20 * www.apache.org/licenses/LICENSE-2.0 21 * 22 * Unless required by applicable law or agreed to in writing, software 23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT 24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 * See the License for the specific language governing permissions and 26 * limitations under the License. 27 */ 28 29 #include "dsp/svm_functions.h" 30 #include <limits.h> 31 #include <math.h> 32 33 /** 34 * @addtogroup sigmoidsvm 35 * @{ 36 */ 37 38 39 40 /** 41 * @brief SVM sigmoid prediction 42 * @param[in] S Pointer to an instance of the rbf SVM structure. 43 * @param[in] in Pointer to input vector 44 * @param[out] pResult Decision value 45 * @return none. 46 * 47 */ 48 49 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) 50 51 #include "arm_helium_utils.h" 52 #include "arm_vec_math.h" 53 54 void arm_svm_sigmoid_predict_f32( 55 const arm_svm_sigmoid_instance_f32 *S, 56 const float32_t * in, 57 int32_t * pResult) 58 { 59 /* inlined Matrix x Vector function interleaved with dot prod */ 60 uint32_t numRows = S->nbOfSupportVectors; 61 uint32_t numCols = S->vectorDimension; 62 const float32_t *pSupport = S->supportVectors; 63 const float32_t *pSrcA = pSupport; 64 const float32_t *pInA0; 65 const float32_t *pInA1; 66 uint32_t row; 67 uint32_t blkCnt; /* loop counters */ 68 const float32_t *pDualCoef = S->dualCoefficients; 69 float32_t sum = S->intercept; 70 f32x4_t vSum = vdupq_n_f32(0.0f); 71 72 row = numRows; 73 74 /* 75 * compute 4 rows in parrallel 76 */ 77 while (row >= 4) { 78 const float32_t *pInA2, *pInA3; 79 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec; 80 f32x4_t vecIn, acc0, acc1, acc2, acc3; 81 float32_t const *pSrcVecPtr = in; 82 83 /* 84 * Initialize the pointers to 4 consecutive MatrixA rows 85 */ 86 pInA0 = pSrcA; 87 pInA1 = pInA0 + numCols; 88 pInA2 = pInA1 + numCols; 89 pInA3 = pInA2 + numCols; 90 /* 91 * Initialize the vector pointer 92 */ 93 pInVec = pSrcVecPtr; 94 /* 95 * reset accumulators 96 */ 97 acc0 = vdupq_n_f32(0.0f); 98 acc1 = vdupq_n_f32(0.0f); 99 acc2 = vdupq_n_f32(0.0f); 100 acc3 = vdupq_n_f32(0.0f); 101 102 pSrcA0Vec = pInA0; 103 pSrcA1Vec = pInA1; 104 pSrcA2Vec = pInA2; 105 pSrcA3Vec = pInA3; 106 107 blkCnt = numCols >> 2; 108 while (blkCnt > 0U) { 109 f32x4_t vecA; 110 111 vecIn = vld1q(pInVec); 112 pInVec += 4; 113 vecA = vld1q(pSrcA0Vec); 114 pSrcA0Vec += 4; 115 acc0 = vfmaq(acc0, vecIn, vecA); 116 vecA = vld1q(pSrcA1Vec); 117 pSrcA1Vec += 4; 118 acc1 = vfmaq(acc1, vecIn, vecA); 119 vecA = vld1q(pSrcA2Vec); 120 pSrcA2Vec += 4; 121 acc2 = vfmaq(acc2, vecIn, vecA); 122 vecA = vld1q(pSrcA3Vec); 123 pSrcA3Vec += 4; 124 acc3 = vfmaq(acc3, vecIn, vecA); 125 126 blkCnt--; 127 } 128 /* 129 * tail 130 * (will be merged thru tail predication) 131 */ 132 blkCnt = numCols & 3; 133 if (blkCnt > 0U) { 134 mve_pred16_t p0 = vctp32q(blkCnt); 135 f32x4_t vecA; 136 137 vecIn = vldrwq_z_f32(pInVec, p0); 138 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 139 acc0 = vfmaq(acc0, vecIn, vecA); 140 vecA = vldrwq_z_f32(pSrcA1Vec, p0); 141 acc1 = vfmaq(acc1, vecIn, vecA); 142 vecA = vldrwq_z_f32(pSrcA2Vec, p0); 143 acc2 = vfmaq(acc2, vecIn, vecA); 144 vecA = vldrwq_z_f32(pSrcA3Vec, p0); 145 acc3 = vfmaq(acc3, vecIn, vecA); 146 } 147 /* 148 * Sum the partial parts 149 */ 150 f32x4_t vtmp = vuninitializedq_f32(); 151 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 152 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1); 153 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc2), vtmp, 2); 154 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc3), vtmp, 3); 155 156 vSum = 157 vfmaq_f32(vSum, vld1q(pDualCoef), 158 vtanhq_f32(vaddq_n_f32(vmulq_n_f32(vtmp, S->gamma), S->coef0))); 159 160 pDualCoef += 4; 161 162 pSrcA += numCols * 4; 163 /* 164 * Decrement the row loop counter 165 */ 166 row -= 4; 167 } 168 169 /* 170 * compute 2 rows in parrallel 171 */ 172 if (row >= 2) { 173 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec; 174 f32x4_t vecIn, acc0, acc1; 175 float32_t const *pSrcVecPtr = in; 176 177 /* 178 * Initialize the pointers to 2 consecutive MatrixA rows 179 */ 180 pInA0 = pSrcA; 181 pInA1 = pInA0 + numCols; 182 /* 183 * Initialize the vector pointer 184 */ 185 pInVec = pSrcVecPtr; 186 /* 187 * reset accumulators 188 */ 189 acc0 = vdupq_n_f32(0.0f); 190 acc1 = vdupq_n_f32(0.0f); 191 pSrcA0Vec = pInA0; 192 pSrcA1Vec = pInA1; 193 194 blkCnt = numCols >> 2; 195 while (blkCnt > 0U) { 196 f32x4_t vecA; 197 198 vecIn = vld1q(pInVec); 199 pInVec += 4; 200 vecA = vld1q(pSrcA0Vec); 201 pSrcA0Vec += 4; 202 acc0 = vfmaq(acc0, vecIn, vecA); 203 vecA = vld1q(pSrcA1Vec); 204 pSrcA1Vec += 4; 205 acc1 = vfmaq(acc1, vecIn, vecA); 206 207 blkCnt--; 208 } 209 /* 210 * tail 211 * (will be merged thru tail predication) 212 */ 213 blkCnt = numCols & 3; 214 if (blkCnt > 0U) { 215 mve_pred16_t p0 = vctp32q(blkCnt); 216 f32x4_t vecA; 217 218 vecIn = vldrwq_z_f32(pInVec, p0); 219 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 220 acc0 = vfmaq(acc0, vecIn, vecA); 221 vecA = vldrwq_z_f32(pSrcA1Vec, p0); 222 acc1 = vfmaq(acc1, vecIn, vecA); 223 } 224 /* 225 * Sum the partial parts 226 */ 227 f32x4_t vtmp = vuninitializedq_f32(); 228 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 229 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1); 230 231 vSum = 232 vfmaq_m_f32(vSum, vld1q(pDualCoef), 233 vtanhq_f32(vaddq_n_f32(vmulq_n_f32(vtmp, S->gamma), S->coef0)), 234 vctp32q(2)); 235 236 pSrcA += numCols * 2; 237 row -= 2; 238 } 239 240 if (row >= 1) { 241 f32x4_t vecIn, acc0; 242 float32_t const *pSrcA0Vec, *pInVec; 243 float32_t const *pSrcVecPtr = in; 244 /* 245 * Initialize the pointers to last MatrixA row 246 */ 247 pInA0 = pSrcA; 248 /* 249 * Initialize the vector pointer 250 */ 251 pInVec = pSrcVecPtr; 252 /* 253 * reset accumulators 254 */ 255 acc0 = vdupq_n_f32(0.0f); 256 257 pSrcA0Vec = pInA0; 258 259 blkCnt = numCols >> 2; 260 while (blkCnt > 0U) { 261 f32x4_t vecA; 262 263 vecIn = vld1q(pInVec); 264 pInVec += 4; 265 vecA = vld1q(pSrcA0Vec); 266 pSrcA0Vec += 4; 267 acc0 = vfmaq(acc0, vecIn, vecA); 268 269 blkCnt--; 270 } 271 /* 272 * tail 273 * (will be merged thru tail predication) 274 */ 275 blkCnt = numCols & 3; 276 if (blkCnt > 0U) { 277 mve_pred16_t p0 = vctp32q(blkCnt); 278 f32x4_t vecA; 279 280 vecIn = vldrwq_z_f32(pInVec, p0); 281 vecA = vldrwq_z_f32(pSrcA0Vec, p0); 282 acc0 = vfmaq(acc0, vecIn, vecA); 283 } 284 /* 285 * Sum the partial parts 286 */ 287 f32x4_t vtmp = vuninitializedq_f32(); 288 vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0); 289 290 vSum = 291 vfmaq_m_f32(vSum, vld1q(pDualCoef), 292 vtanhq_f32(vaddq_n_f32(vmulq_n_f32(vtmp, S->gamma), S->coef0)), 293 vctp32q(1)); 294 } 295 sum += vecAddAcrossF32Mve(vSum); 296 297 *pResult = S->classes[STEP(sum)]; 298 } 299 300 #else 301 #if defined(ARM_MATH_NEON) 302 #include "NEMath.h" 303 304 void arm_svm_sigmoid_predict_f32( 305 const arm_svm_sigmoid_instance_f32 *S, 306 const float32_t * in, 307 int32_t * pResult) 308 { 309 float32_t sum = S->intercept; 310 311 float32_t dot; 312 float32x4_t dotV; 313 314 float32x4_t accuma,accumb,accumc,accumd,accum; 315 float32x2_t accum2; 316 float32x4_t vec1; 317 float32x4_t coef0 = vdupq_n_f32(S->coef0); 318 319 float32x4_t vec2,vec2a,vec2b,vec2c,vec2d; 320 321 uint32_t blkCnt; 322 uint32_t vectorBlkCnt; 323 324 const float32_t *pIn = in; 325 326 const float32_t *pSupport = S->supportVectors; 327 328 const float32_t *pSupporta = S->supportVectors; 329 const float32_t *pSupportb; 330 const float32_t *pSupportc; 331 const float32_t *pSupportd; 332 333 pSupportb = pSupporta + S->vectorDimension; 334 pSupportc = pSupportb + S->vectorDimension; 335 pSupportd = pSupportc + S->vectorDimension; 336 337 const float32_t *pDualCoefs = S->dualCoefficients; 338 339 vectorBlkCnt = S->nbOfSupportVectors >> 2; 340 while (vectorBlkCnt > 0U) 341 { 342 accuma = vdupq_n_f32(0); 343 accumb = vdupq_n_f32(0); 344 accumc = vdupq_n_f32(0); 345 accumd = vdupq_n_f32(0); 346 347 pIn = in; 348 349 blkCnt = S->vectorDimension >> 2; 350 while (blkCnt > 0U) 351 { 352 353 vec1 = vld1q_f32(pIn); 354 vec2a = vld1q_f32(pSupporta); 355 vec2b = vld1q_f32(pSupportb); 356 vec2c = vld1q_f32(pSupportc); 357 vec2d = vld1q_f32(pSupportd); 358 359 pIn += 4; 360 pSupporta += 4; 361 pSupportb += 4; 362 pSupportc += 4; 363 pSupportd += 4; 364 365 accuma = vmlaq_f32(accuma, vec1,vec2a); 366 accumb = vmlaq_f32(accumb, vec1,vec2b); 367 accumc = vmlaq_f32(accumc, vec1,vec2c); 368 accumd = vmlaq_f32(accumd, vec1,vec2d); 369 370 blkCnt -- ; 371 } 372 accum2 = vpadd_f32(vget_low_f32(accuma),vget_high_f32(accuma)); 373 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,0); 374 375 accum2 = vpadd_f32(vget_low_f32(accumb),vget_high_f32(accumb)); 376 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,1); 377 378 accum2 = vpadd_f32(vget_low_f32(accumc),vget_high_f32(accumc)); 379 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,2); 380 381 accum2 = vpadd_f32(vget_low_f32(accumd),vget_high_f32(accumd)); 382 dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,3); 383 384 385 blkCnt = S->vectorDimension & 3; 386 while (blkCnt > 0U) 387 { 388 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,0) + *pIn * *pSupporta++, dotV,0); 389 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,1) + *pIn * *pSupportb++, dotV,1); 390 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,2) + *pIn * *pSupportc++, dotV,2); 391 dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,3) + *pIn * *pSupportd++, dotV,3); 392 393 pIn++; 394 395 blkCnt -- ; 396 } 397 398 vec1 = vld1q_f32(pDualCoefs); 399 pDualCoefs += 4; 400 401 // To vectorize later 402 dotV = vmulq_n_f32(dotV, S->gamma); 403 dotV = vaddq_f32(dotV, coef0); 404 405 dotV = vtanhq_f32(dotV); 406 407 accum = vmulq_f32(vec1,dotV); 408 accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum)); 409 sum += vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1); 410 411 pSupporta += 3*S->vectorDimension; 412 pSupportb += 3*S->vectorDimension; 413 pSupportc += 3*S->vectorDimension; 414 pSupportd += 3*S->vectorDimension; 415 416 vectorBlkCnt -- ; 417 } 418 419 pSupport = pSupporta; 420 vectorBlkCnt = S->nbOfSupportVectors & 3; 421 422 while (vectorBlkCnt > 0U) 423 { 424 accum = vdupq_n_f32(0); 425 dot = 0.0f; 426 pIn = in; 427 428 blkCnt = S->vectorDimension >> 2; 429 while (blkCnt > 0U) 430 { 431 432 vec1 = vld1q_f32(pIn); 433 vec2 = vld1q_f32(pSupport); 434 pIn += 4; 435 pSupport += 4; 436 437 accum = vmlaq_f32(accum, vec1,vec2); 438 439 blkCnt -- ; 440 } 441 accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum)); 442 dot = vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1); 443 444 445 blkCnt = S->vectorDimension & 3; 446 while (blkCnt > 0U) 447 { 448 dot = dot + *pIn++ * *pSupport++; 449 450 blkCnt -- ; 451 } 452 453 sum += *pDualCoefs++ * tanhf(S->gamma * dot + S->coef0); 454 vectorBlkCnt -- ; 455 } 456 457 *pResult=S->classes[STEP(sum)]; 458 } 459 #else 460 void arm_svm_sigmoid_predict_f32( 461 const arm_svm_sigmoid_instance_f32 *S, 462 const float32_t * in, 463 int32_t * pResult) 464 { 465 float32_t sum=S->intercept; 466 float32_t dot=0; 467 uint32_t i,j; 468 const float32_t *pSupport = S->supportVectors; 469 470 for(i=0; i < S->nbOfSupportVectors; i++) 471 { 472 dot=0; 473 for(j=0; j < S->vectorDimension; j++) 474 { 475 dot = dot + in[j]* *pSupport++; 476 } 477 sum += S->dualCoefficients[i] * tanhf(S->gamma * dot + S->coef0); 478 } 479 *pResult=S->classes[STEP(sum)]; 480 } 481 482 #endif 483 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */ 484 485 /** 486 * @} end of sigmoidsvm group 487 */