/ Drivers / CMSIS / DSP / Source / SVMFunctions / arm_svm_rbf_predict_f16.c
arm_svm_rbf_predict_f16.c
  1  /* ----------------------------------------------------------------------
  2   * Project:      CMSIS DSP Library
  3   * Title:        arm_svm_rbf_predict_f16.c
  4   * Description:  SVM Radial Basis Function Classifier
  5   *
  6   * $Date:        23 April 2021
  7   * $Revision:    V1.9.0
  8   *
  9   * Target Processor: Cortex-M and Cortex-A cores
 10   * -------------------------------------------------------------------- */
 11  /*
 12   * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
 13   *
 14   * SPDX-License-Identifier: Apache-2.0
 15   *
 16   * Licensed under the Apache License, Version 2.0 (the License); you may
 17   * not use this file except in compliance with the License.
 18   * You may obtain a copy of the License at
 19   *
 20   * www.apache.org/licenses/LICENSE-2.0
 21   *
 22   * Unless required by applicable law or agreed to in writing, software
 23   * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 24   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 25   * See the License for the specific language governing permissions and
 26   * limitations under the License.
 27   */
 28  
 29  #include "dsp/svm_functions_f16.h"
 30  
 31  #if defined(ARM_FLOAT16_SUPPORTED)
 32  
 33  #include <limits.h>
 34  #include <math.h>
 35  
 36  
 37  /**
 38   * @addtogroup rbfsvm
 39   * @{
 40   */
 41  
 42  
 43  /**
 44   * @brief SVM rbf prediction
 45   * @param[in]    S         Pointer to an instance of the rbf SVM structure.
 46   * @param[in]    in        Pointer to input vector
 47   * @param[out]   pResult   decision value
 48   * @return none.
 49   *
 50   */
 51  
 52  #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
 53  
 54  #include "arm_helium_utils.h"
 55  #include "arm_vec_math_f16.h"
 56  
 57  void arm_svm_rbf_predict_f16(
 58      const arm_svm_rbf_instance_f16 *S,
 59      const float16_t * in,
 60      int32_t * pResult)
 61  {
 62          /* inlined Matrix x Vector function interleaved with dot prod */
 63      uint32_t        numRows = S->nbOfSupportVectors;
 64      uint32_t        numCols = S->vectorDimension;
 65      const float16_t *pSupport = S->supportVectors;
 66      const float16_t *pSrcA = pSupport;
 67      const float16_t *pInA0;
 68      const float16_t *pInA1;
 69      uint32_t         row;
 70      uint32_t         blkCnt;     /* loop counters */
 71      const float16_t *pDualCoef = S->dualCoefficients;
 72      _Float16       sum = S->intercept;
 73      f16x8_t         vSum = vdupq_n_f16(0);
 74  
 75      row = numRows;
 76  
 77      /*
 78       * compute 4 rows in parrallel
 79       */
 80      while (row >= 4) {
 81          const float16_t *pInA2, *pInA3;
 82          float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
 83          f16x8_t         vecIn, acc0, acc1, acc2, acc3;
 84          float16_t const *pSrcVecPtr = in;
 85  
 86          /*
 87           * Initialize the pointers to 4 consecutive MatrixA rows
 88           */
 89          pInA0 = pSrcA;
 90          pInA1 = pInA0 + numCols;
 91          pInA2 = pInA1 + numCols;
 92          pInA3 = pInA2 + numCols;
 93          /*
 94           * Initialize the vector pointer
 95           */
 96          pInVec = pSrcVecPtr;
 97          /*
 98           * reset accumulators
 99           */
100          acc0 = vdupq_n_f16(0.0f);
101          acc1 = vdupq_n_f16(0.0f);
102          acc2 = vdupq_n_f16(0.0f);
103          acc3 = vdupq_n_f16(0.0f);
104  
105          pSrcA0Vec = pInA0;
106          pSrcA1Vec = pInA1;
107          pSrcA2Vec = pInA2;
108          pSrcA3Vec = pInA3;
109  
110          blkCnt = numCols >> 3;
111          while (blkCnt > 0U) {
112              f16x8_t         vecA;
113              f16x8_t         vecDif;
114  
115              vecIn = vld1q(pInVec);
116              pInVec += 8;
117              vecA = vld1q(pSrcA0Vec);
118              pSrcA0Vec += 8;
119              vecDif = vsubq(vecIn, vecA);
120              acc0 = vfmaq(acc0, vecDif, vecDif);
121              vecA = vld1q(pSrcA1Vec);
122              pSrcA1Vec += 8;
123              vecDif = vsubq(vecIn, vecA);
124              acc1 = vfmaq(acc1, vecDif, vecDif);
125              vecA = vld1q(pSrcA2Vec);
126              pSrcA2Vec += 8;
127              vecDif = vsubq(vecIn, vecA);
128              acc2 = vfmaq(acc2, vecDif, vecDif);
129              vecA = vld1q(pSrcA3Vec);
130              pSrcA3Vec += 8;
131              vecDif = vsubq(vecIn, vecA);
132              acc3 = vfmaq(acc3, vecDif, vecDif);
133  
134              blkCnt--;
135          }
136          /*
137           * tail
138           * (will be merged thru tail predication)
139           */
140          blkCnt = numCols & 7;
141          if (blkCnt > 0U) {
142              mve_pred16_t    p0 = vctp16q(blkCnt);
143              f16x8_t         vecA;
144              f16x8_t         vecDif;
145  
146              vecIn = vldrhq_z_f16(pInVec, p0);
147              vecA = vldrhq_z_f16(pSrcA0Vec, p0);
148              vecDif = vsubq(vecIn, vecA);
149              acc0 = vfmaq(acc0, vecDif, vecDif);
150              vecA = vldrhq_z_f16(pSrcA1Vec, p0);
151              vecDif = vsubq(vecIn, vecA);
152              acc1 = vfmaq(acc1, vecDif, vecDif);
153              vecA = vldrhq_z_f16(pSrcA2Vec, p0);;
154              vecDif = vsubq(vecIn, vecA);
155              acc2 = vfmaq(acc2, vecDif, vecDif);
156              vecA = vldrhq_z_f16(pSrcA3Vec, p0);
157              vecDif = vsubq(vecIn, vecA);
158              acc3 = vfmaq(acc3, vecDif, vecDif);
159          }
160          /*
161           * Sum the partial parts
162           */
163  
164          //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF16Mve(acc0));
165          f16x8_t         vtmp = vuninitializedq_f16();
166          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
167          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
168          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc2), vtmp, 2);
169          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc3), vtmp, 3);
170  
171          vSum =
172              vfmaq_m_f16(vSum, vld1q(pDualCoef),
173                        vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)),vctp16q(4));
174          pDualCoef += 4;
175          pSrcA += numCols * 4;
176          /*
177           * Decrement the row loop counter
178           */
179          row -= 4;
180      }
181  
182      /*
183       * compute 2 rows in parrallel
184       */
185      if (row >= 2) {
186          float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
187          f16x8_t         vecIn, acc0, acc1;
188          float16_t const *pSrcVecPtr = in;
189  
190          /*
191           * Initialize the pointers to 2 consecutive MatrixA rows
192           */
193          pInA0 = pSrcA;
194          pInA1 = pInA0 + numCols;
195          /*
196           * Initialize the vector pointer
197           */
198          pInVec = pSrcVecPtr;
199          /*
200           * reset accumulators
201           */
202          acc0 = vdupq_n_f16(0.0f);
203          acc1 = vdupq_n_f16(0.0f);
204          pSrcA0Vec = pInA0;
205          pSrcA1Vec = pInA1;
206  
207          blkCnt = numCols >> 3;
208          while (blkCnt > 0U) {
209              f16x8_t         vecA;
210              f16x8_t         vecDif;
211  
212              vecIn = vld1q(pInVec);
213              pInVec += 8;
214              vecA = vld1q(pSrcA0Vec);
215              pSrcA0Vec += 8;
216              vecDif = vsubq(vecIn, vecA);
217              acc0 = vfmaq(acc0, vecDif, vecDif);;
218              vecA = vld1q(pSrcA1Vec);
219              pSrcA1Vec += 8;
220              vecDif = vsubq(vecIn, vecA);
221              acc1 = vfmaq(acc1, vecDif, vecDif);
222  
223              blkCnt--;
224          }
225          /*
226           * tail
227           * (will be merged thru tail predication)
228           */
229          blkCnt = numCols & 7;
230          if (blkCnt > 0U) {
231              mve_pred16_t    p0 = vctp16q(blkCnt);
232              f16x8_t         vecA, vecDif;
233  
234              vecIn = vldrhq_z_f16(pInVec, p0);
235              vecA = vldrhq_z_f16(pSrcA0Vec, p0);
236              vecDif = vsubq(vecIn, vecA);
237              acc0 = vfmaq(acc0, vecDif, vecDif);
238              vecA = vldrhq_z_f16(pSrcA1Vec, p0);
239              vecDif = vsubq(vecIn, vecA);
240              acc1 = vfmaq(acc1, vecDif, vecDif);
241          }
242          /*
243           * Sum the partial parts
244           */
245          f16x8_t         vtmp = vuninitializedq_f16();
246          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
247          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc1), vtmp, 1);
248  
249          vSum =
250              vfmaq_m_f16(vSum, vld1q(pDualCoef),
251                          vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(2));
252          pDualCoef += 2;
253  
254          pSrcA += numCols * 2;
255          row -= 2;
256      }
257  
258      if (row >= 1) {
259          f16x8_t         vecIn, acc0;
260          float16_t const *pSrcA0Vec, *pInVec;
261          float16_t const *pSrcVecPtr = in;
262          /*
263           * Initialize the pointers to last MatrixA row
264           */
265          pInA0 = pSrcA;
266          /*
267           * Initialize the vector pointer
268           */
269          pInVec = pSrcVecPtr;
270          /*
271           * reset accumulators
272           */
273          acc0 = vdupq_n_f16(0.0f);
274  
275          pSrcA0Vec = pInA0;
276  
277          blkCnt = numCols >> 3;
278          while (blkCnt > 0U) {
279              f16x8_t         vecA, vecDif;
280  
281              vecIn = vld1q(pInVec);
282              pInVec += 8;
283              vecA = vld1q(pSrcA0Vec);
284              pSrcA0Vec += 8;
285              vecDif = vsubq(vecIn, vecA);
286              acc0 = vfmaq(acc0, vecDif, vecDif);
287  
288              blkCnt--;
289          }
290          /*
291           * tail
292           * (will be merged thru tail predication)
293           */
294          blkCnt = numCols & 7;
295          if (blkCnt > 0U) {
296              mve_pred16_t    p0 = vctp16q(blkCnt);
297              f16x8_t         vecA, vecDif;
298  
299              vecIn = vldrhq_z_f16(pInVec, p0);
300              vecA = vldrhq_z_f16(pSrcA0Vec, p0);
301              vecDif = vsubq(vecIn, vecA);
302              acc0 = vfmaq(acc0, vecDif, vecDif);
303          }
304          /*
305           * Sum the partial parts
306           */
307          f16x8_t         vtmp = vuninitializedq_f16();
308          vtmp = vsetq_lane(vecAddAcrossF16Mve(acc0), vtmp, 0);
309  
310          vSum =
311              vfmaq_m_f16(vSum, vld1q(pDualCoef),
312                          vexpq_f16(vmulq_n_f16(vtmp, -S->gamma)), vctp16q(1));
313  
314      }
315  
316  
317      sum += vecAddAcrossF16Mve(vSum);
318      *pResult = S->classes[STEP(sum)];
319  }
320  
321  #else
322  void arm_svm_rbf_predict_f16(
323      const arm_svm_rbf_instance_f16 *S,
324      const float16_t * in,
325      int32_t * pResult)
326  {
327      _Float16 sum=S->intercept;
328      _Float16 dot=00.f16;
329      uint32_t i,j;
330      const float16_t *pSupport = S->supportVectors;
331  
332      for(i=0; i < S->nbOfSupportVectors; i++)
333      {
334          dot=0.0f16;
335          for(j=0; j < S->vectorDimension; j++)
336          {
337              dot = dot + SQ((_Float16)in[j] - (_Float16) *pSupport);
338              pSupport++;
339          }
340          sum += (_Float16)S->dualCoefficients[i] * (_Float16)expf(-(_Float16)S->gamma * dot);
341      }
342      *pResult=S->classes[STEP(sum)];
343  }
344  
345  #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
346  
347  /**
348   * @} end of rbfsvm group
349   */
350  
351  #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */ 
352