/ Drivers / CMSIS / DSP / Source / SVMFunctions / arm_svm_rbf_predict_f32.c
arm_svm_rbf_predict_f32.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_svm_rbf_predict_f32.c
  4   * Description:  SVM Radial Basis Function Classifier
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/svm_functions.h"
 30  #include <limits.h>
 31  #include <math.h>
 32  
 33  
 34  /**
 35   * @addtogroup rbfsvm
 36   * @{
 37   */
 38  
 39  
 40  /**
 41   * @brief SVM rbf prediction
 42   * @param[in]    S         Pointer to an instance of the rbf SVM structure.
 43   * @param[in]    in        Pointer to input vector
 44   * @param[out]   pResult   decision value
 45   * @return none.
 46   *
 47   */
 48  
 49  #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 50  
 51  #include "arm_helium_utils.h"
 52  #include "arm_vec_math.h"
 53  
 54  void arm_svm_rbf_predict_f32(
 55      const arm_svm_rbf_instance_f32 *S,
 56      const float32_t * in,
 57      int32_t * pResult)
 58  {
 59          /* inlined Matrix x Vector function interleaved with dot prod */
 60      uint32_t        numRows = S->nbOfSupportVectors;
 61      uint32_t        numCols = S->vectorDimension;
 62      const float32_t *pSupport = S->supportVectors;
 63      const float32_t *pSrcA = pSupport;
 64      const float32_t *pInA0;
 65      const float32_t *pInA1;
 66      uint32_t         row;
 67      uint32_t         blkCnt;     /* loop counters */
 68      const float32_t *pDualCoef = S->dualCoefficients;
 69      float32_t       sum = S->intercept;
 70      f32x4_t         vSum = vdupq_n_f32(0);
 71  
 72      row = numRows;
 73  
 74      /*
 75       * compute 4 rows in parrallel
 76       */
 77      while (row >= 4) {
 78          const float32_t *pInA2, *pInA3;
 79          float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
 80          f32x4_t         vecIn, acc0, acc1, acc2, acc3;
 81          float32_t const *pSrcVecPtr = in;
 82  
 83          /*
 84           * Initialize the pointers to 4 consecutive MatrixA rows
 85           */
 86          pInA0 = pSrcA;
 87          pInA1 = pInA0 + numCols;
 88          pInA2 = pInA1 + numCols;
 89          pInA3 = pInA2 + numCols;
 90          /*
 91           * Initialize the vector pointer
 92           */
 93          pInVec = pSrcVecPtr;
 94          /*
 95           * reset accumulators
 96           */
 97          acc0 = vdupq_n_f32(0.0f);
 98          acc1 = vdupq_n_f32(0.0f);
 99          acc2 = vdupq_n_f32(0.0f);
100          acc3 = vdupq_n_f32(0.0f);
101  
102          pSrcA0Vec = pInA0;
103          pSrcA1Vec = pInA1;
104          pSrcA2Vec = pInA2;
105          pSrcA3Vec = pInA3;
106  
107          blkCnt = numCols >> 2;
108          while (blkCnt > 0U) {
109              f32x4_t         vecA;
110              f32x4_t         vecDif;
111  
112              vecIn = vld1q(pInVec);
113              pInVec += 4;
114              vecA = vld1q(pSrcA0Vec);
115              pSrcA0Vec += 4;
116              vecDif = vsubq(vecIn, vecA);
117              acc0 = vfmaq(acc0, vecDif, vecDif);
118              vecA = vld1q(pSrcA1Vec);
119              pSrcA1Vec += 4;
120              vecDif = vsubq(vecIn, vecA);
121              acc1 = vfmaq(acc1, vecDif, vecDif);
122              vecA = vld1q(pSrcA2Vec);
123              pSrcA2Vec += 4;
124              vecDif = vsubq(vecIn, vecA);
125              acc2 = vfmaq(acc2, vecDif, vecDif);
126              vecA = vld1q(pSrcA3Vec);
127              pSrcA3Vec += 4;
128              vecDif = vsubq(vecIn, vecA);
129              acc3 = vfmaq(acc3, vecDif, vecDif);
130  
131              blkCnt--;
132          }
133          /*
134           * tail
135           * (will be merged thru tail predication)
136           */
137          blkCnt = numCols & 3;
138          if (blkCnt > 0U) {
139              mve_pred16_t    p0 = vctp32q(blkCnt);
140              f32x4_t         vecA;
141              f32x4_t         vecDif;
142  
143              vecIn = vldrwq_z_f32(pInVec, p0);
144              vecA = vldrwq_z_f32(pSrcA0Vec, p0);
145              vecDif = vsubq(vecIn, vecA);
146              acc0 = vfmaq(acc0, vecDif, vecDif);
147              vecA = vldrwq_z_f32(pSrcA1Vec, p0);
148              vecDif = vsubq(vecIn, vecA);
149              acc1 = vfmaq(acc1, vecDif, vecDif);
150              vecA = vldrwq_z_f32(pSrcA2Vec, p0);;
151              vecDif = vsubq(vecIn, vecA);
152              acc2 = vfmaq(acc2, vecDif, vecDif);
153              vecA = vldrwq_z_f32(pSrcA3Vec, p0);
154              vecDif = vsubq(vecIn, vecA);
155              acc3 = vfmaq(acc3, vecDif, vecDif);
156          }
157          /*
158           * Sum the partial parts
159           */
160  
161          //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF32Mve(acc0));
162          f32x4_t         vtmp = vuninitializedq_f32();
163          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
164          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1);
165          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc2), vtmp, 2);
166          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc3), vtmp, 3);
167  
168          vSum =
169              vfmaq_f32(vSum, vld1q(pDualCoef),
170                        vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)));
171          pDualCoef += 4;
172          pSrcA += numCols * 4;
173          /*
174           * Decrement the row loop counter
175           */
176          row -= 4;
177      }
178  
179      /*
180       * compute 2 rows in parrallel
181       */
182      if (row >= 2) {
183          float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
184          f32x4_t         vecIn, acc0, acc1;
185          float32_t const *pSrcVecPtr = in;
186  
187          /*
188           * Initialize the pointers to 2 consecutive MatrixA rows
189           */
190          pInA0 = pSrcA;
191          pInA1 = pInA0 + numCols;
192          /*
193           * Initialize the vector pointer
194           */
195          pInVec = pSrcVecPtr;
196          /*
197           * reset accumulators
198           */
199          acc0 = vdupq_n_f32(0.0f);
200          acc1 = vdupq_n_f32(0.0f);
201          pSrcA0Vec = pInA0;
202          pSrcA1Vec = pInA1;
203  
204          blkCnt = numCols >> 2;
205          while (blkCnt > 0U) {
206              f32x4_t         vecA;
207              f32x4_t         vecDif;
208  
209              vecIn = vld1q(pInVec);
210              pInVec += 4;
211              vecA = vld1q(pSrcA0Vec);
212              pSrcA0Vec += 4;
213              vecDif = vsubq(vecIn, vecA);
214              acc0 = vfmaq(acc0, vecDif, vecDif);;
215              vecA = vld1q(pSrcA1Vec);
216              pSrcA1Vec += 4;
217              vecDif = vsubq(vecIn, vecA);
218              acc1 = vfmaq(acc1, vecDif, vecDif);
219  
220              blkCnt--;
221          }
222          /*
223           * tail
224           * (will be merged thru tail predication)
225           */
226          blkCnt = numCols & 3;
227          if (blkCnt > 0U) {
228              mve_pred16_t    p0 = vctp32q(blkCnt);
229              f32x4_t         vecA, vecDif;
230  
231              vecIn = vldrwq_z_f32(pInVec, p0);
232              vecA = vldrwq_z_f32(pSrcA0Vec, p0);
233              vecDif = vsubq(vecIn, vecA);
234              acc0 = vfmaq(acc0, vecDif, vecDif);
235              vecA = vldrwq_z_f32(pSrcA1Vec, p0);
236              vecDif = vsubq(vecIn, vecA);
237              acc1 = vfmaq(acc1, vecDif, vecDif);
238          }
239          /*
240           * Sum the partial parts
241           */
242          f32x4_t         vtmp = vuninitializedq_f32();
243          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
244          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1);
245  
246          vSum =
247              vfmaq_m_f32(vSum, vld1q(pDualCoef),
248                          vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(2));
249          pDualCoef += 2;
250  
251          pSrcA += numCols * 2;
252          row -= 2;
253      }
254  
255      if (row >= 1) {
256          f32x4_t         vecIn, acc0;
257          float32_t const *pSrcA0Vec, *pInVec;
258          float32_t const *pSrcVecPtr = in;
259          /*
260           * Initialize the pointers to last MatrixA row
261           */
262          pInA0 = pSrcA;
263          /*
264           * Initialize the vector pointer
265           */
266          pInVec = pSrcVecPtr;
267          /*
268           * reset accumulators
269           */
270          acc0 = vdupq_n_f32(0.0f);
271  
272          pSrcA0Vec = pInA0;
273  
274          blkCnt = numCols >> 2;
275          while (blkCnt > 0U) {
276              f32x4_t         vecA, vecDif;
277  
278              vecIn = vld1q(pInVec);
279              pInVec += 4;
280              vecA = vld1q(pSrcA0Vec);
281              pSrcA0Vec += 4;
282              vecDif = vsubq(vecIn, vecA);
283              acc0 = vfmaq(acc0, vecDif, vecDif);
284  
285              blkCnt--;
286          }
287          /*
288           * tail
289           * (will be merged thru tail predication)
290           */
291          blkCnt = numCols & 3;
292          if (blkCnt > 0U) {
293              mve_pred16_t    p0 = vctp32q(blkCnt);
294              f32x4_t         vecA, vecDif;
295  
296              vecIn = vldrwq_z_f32(pInVec, p0);
297              vecA = vldrwq_z_f32(pSrcA0Vec, p0);
298              vecDif = vsubq(vecIn, vecA);
299              acc0 = vfmaq(acc0, vecDif, vecDif);
300          }
301          /*
302           * Sum the partial parts
303           */
304          f32x4_t         vtmp = vuninitializedq_f32();
305          vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
306  
307          vSum =
308              vfmaq_m_f32(vSum, vld1q(pDualCoef),
309                          vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(1));
310  
311      }
312  
313  
314      sum += vecAddAcrossF32Mve(vSum);
315      *pResult = S->classes[STEP(sum)];
316  }
317  
318  
319  #else
320  #if defined(ARM_MATH_NEON)
321  
322  #include "NEMath.h"
323  
324  void arm_svm_rbf_predict_f32(
325      const arm_svm_rbf_instance_f32 *S,
326      const float32_t * in,
327      int32_t * pResult)
328  {
329      float32_t sum = S->intercept;
330     
331      float32_t dot;
332      float32x4_t dotV; 
333  
334      float32x4_t accuma,accumb,accumc,accumd,accum;
335      float32x2_t accum2;
336      float32x4_t temp;
337      float32x4_t vec1;
338  
339      float32x4_t vec2,vec2a,vec2b,vec2c,vec2d;
340  
341      uint32_t blkCnt;   
342      uint32_t vectorBlkCnt;   
343  
344      const float32_t *pIn = in;
345  
346      const float32_t *pSupport = S->supportVectors;
347  
348      const float32_t *pSupporta = S->supportVectors;
349      const float32_t *pSupportb;
350      const float32_t *pSupportc;
351      const float32_t *pSupportd;
352  
353      pSupportb = pSupporta + S->vectorDimension;
354      pSupportc = pSupportb + S->vectorDimension;
355      pSupportd = pSupportc + S->vectorDimension;
356  
357      const float32_t *pDualCoefs = S->dualCoefficients;
358  
359  
360      vectorBlkCnt = S->nbOfSupportVectors >> 2;
361      while (vectorBlkCnt > 0U)
362      {
363          accuma = vdupq_n_f32(0);
364          accumb = vdupq_n_f32(0);
365          accumc = vdupq_n_f32(0);
366          accumd = vdupq_n_f32(0);
367  
368          pIn = in;
369  
370          blkCnt = S->vectorDimension >> 2;
371          while (blkCnt > 0U)
372          {
373          
374              vec1 = vld1q_f32(pIn);
375              vec2a = vld1q_f32(pSupporta);
376              vec2b = vld1q_f32(pSupportb);
377              vec2c = vld1q_f32(pSupportc);
378              vec2d = vld1q_f32(pSupportd);
379  
380              pIn += 4;
381              pSupporta += 4;
382              pSupportb += 4;
383              pSupportc += 4;
384              pSupportd += 4;
385  
386              temp = vsubq_f32(vec1, vec2a);
387              accuma = vmlaq_f32(accuma, temp, temp);
388  
389              temp = vsubq_f32(vec1, vec2b);
390              accumb = vmlaq_f32(accumb, temp, temp);
391  
392              temp = vsubq_f32(vec1, vec2c);
393              accumc = vmlaq_f32(accumc, temp, temp);
394  
395              temp = vsubq_f32(vec1, vec2d);
396              accumd = vmlaq_f32(accumd, temp, temp);
397  
398              blkCnt -- ;
399          }
400          accum2 = vpadd_f32(vget_low_f32(accuma),vget_high_f32(accuma));
401          dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,0);
402  
403          accum2 = vpadd_f32(vget_low_f32(accumb),vget_high_f32(accumb));
404          dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,1);
405  
406          accum2 = vpadd_f32(vget_low_f32(accumc),vget_high_f32(accumc));
407          dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,2);
408  
409          accum2 = vpadd_f32(vget_low_f32(accumd),vget_high_f32(accumd));
410          dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,3);
411  
412  
413          blkCnt = S->vectorDimension & 3;
414          while (blkCnt > 0U)
415          {
416              dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,0) + SQ(*pIn - *pSupporta), dotV,0);
417              dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,1) + SQ(*pIn - *pSupportb), dotV,1);
418              dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,2) + SQ(*pIn - *pSupportc), dotV,2);
419              dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,3) + SQ(*pIn - *pSupportd), dotV,3);
420  
421              pSupporta++;
422              pSupportb++;
423              pSupportc++;
424              pSupportd++;
425  
426              pIn++;
427  
428              blkCnt -- ;
429          }
430  
431          vec1 = vld1q_f32(pDualCoefs);
432          pDualCoefs += 4; 
433  
434          // To vectorize later
435          dotV = vmulq_n_f32(dotV, -S->gamma);
436          dotV = vexpq_f32(dotV);
437  
438          accum = vmulq_f32(vec1,dotV);
439          accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum));
440          sum += vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1);
441  
442          pSupporta += 3*S->vectorDimension;
443          pSupportb += 3*S->vectorDimension;
444          pSupportc += 3*S->vectorDimension;
445          pSupportd += 3*S->vectorDimension;
446  
447          vectorBlkCnt -- ;
448      }
449  
450      pSupport = pSupporta;
451      vectorBlkCnt = S->nbOfSupportVectors & 3;
452  
453      while (vectorBlkCnt > 0U)
454      {
455          accum = vdupq_n_f32(0);
456          dot = 0.0f;
457          pIn = in;
458  
459          blkCnt = S->vectorDimension >> 2;
460          while (blkCnt > 0U)
461          {
462          
463              vec1 = vld1q_f32(pIn);
464              vec2 = vld1q_f32(pSupport);
465              pIn += 4;
466              pSupport += 4;
467  
468              temp = vsubq_f32(vec1,vec2);
469              accum = vmlaq_f32(accum, temp,temp);
470  
471              blkCnt -- ;
472          }
473          accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum));
474          dot = vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1);
475  
476  
477          blkCnt = S->vectorDimension & 3;
478          while (blkCnt > 0U)
479          {
480  
481              dot = dot + SQ(*pIn - *pSupport);
482              pIn++;
483              pSupport++;
484  
485              blkCnt -- ;
486          }
487  
488          sum += *pDualCoefs++ * expf(-S->gamma * dot);
489          vectorBlkCnt -- ;
490      }
491  
492      *pResult=S->classes[STEP(sum)];
493  }
494  #else
495  void arm_svm_rbf_predict_f32(
496      const arm_svm_rbf_instance_f32 *S,
497      const float32_t * in,
498      int32_t * pResult)
499  {
500      float32_t sum=S->intercept;
501      float32_t dot=0;
502      uint32_t i,j;
503      const float32_t *pSupport = S->supportVectors;
504  
505      for(i=0; i < S->nbOfSupportVectors; i++)
506      {
507          dot=0;
508          for(j=0; j < S->vectorDimension; j++)
509          {
510              dot = dot + SQ(in[j] - *pSupport);
511              pSupport++;
512          }
513          sum += S->dualCoefficients[i] * expf(-S->gamma * dot);
514      }
515      *pResult=S->classes[STEP(sum)];
516  }
517  #endif
518  
519  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
520  
521  /**
522   * @} end of rbfsvm group
523   */