/ include / libm / typehelper-vec.h
typehelper-vec.h
  1  /*
  2   * Copyright (C) 2018-2020, Advanced Micro Devices, Inc. All rights reserved.
  3   *
  4   * Redistribution and use in source and binary forms, with or without modification,
  5   * are permitted provided that the following conditions are met:
  6   * 1. Redistributions of source code must retain the above copyright notice,
  7   *    this list of conditions and the following disclaimer.
  8   * 2. Redistributions in binary form must reproduce the above copyright notice,
  9   *    this list of conditions and the following disclaimer in the documentation
 10   *    and/or other materials provided with the distribution.
 11   * 3. Neither the name of the copyright holder nor the names of its contributors
 12   *    may be used to endorse or promote products derived from this software without
 13   *    specific prior written permission.
 14   *
 15   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 16   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 17   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 18   * IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 19   * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 20   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
 21   * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 22   * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 23   * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 24   * POSSIBILITY OF SUCH DAMAGE.
 25   *
 26   */
 27  
 28  #ifndef __LIBM_TYPEHELPER_VEC_H__
 29  #define __LIBM_TYPEHELPER_VEC_H__
 30  
 31  #include <libm/types.h>
 32  
 33  #include <emmintrin.h>
 34  
 35  #define _MM_SET1_PS4(x)                           \
 36      _Generic((x),                                 \
 37               float: (__m128){(x), (x), (x), (x)})
 38  
 39  #define _MM_SET1_PS8(x)                           \
 40      _Generic((x),                                 \
 41               float: (__m256){(x), (x), (x), (x),  \
 42                       (x), (x), (x), (x)})
 43  
 44  #define _MM_SET1_PD2(x)                         \
 45      _Generic((x),                               \
 46               double: (__m128d){(x), (x)})
 47  
 48  #define _MM_SET1_PD4(x)                             \
 49      _Generic((x),                                   \
 50               double: (__m256d){(x), (x), (x), (x)})
 51  
 52  /* TODO: check if _MM_SET1_I64x2 is used */
 53  #define _MM_SET1_I64x2(x) {(x), (x)}
 54  
 55  #define _MM_SET1_I32(x) {(x), (x), (x), (x)}
 56  
 57  #define _MM_SET1_I64(x) {(x), (x), (x), (x)}
 58  
 59  #define _MM256_SET1_I32(x) {(x), (x), (x), (x), (x), (x), (x), (x) }
 60  
 61  #define _MM256_SET1_PS8(x) {(x), (x), (x), (x), (x), (x), (x), (x) }
 62  
 63  
 64  /*
 65   * Naming convention
 66   *  1. Access as different data
 67   *      ``as_v<v>_<to>_<from>``
 68   *        eg: as_v2_u32_f32  - access an given vector f32x2 element as u32x2
 69   *                             (without changing memory contents)
 70   *  2. Cast - Portable C style
 71   *      - cast_v<v>_<to>_<from>
 72   *  3. Check
 73   *      - any_v<v>_<type>
 74   *
 75   *      <v> - values: 2, 4, 8 - shows number of elements in vector
 76   *      <to>/<from>/<type>
 77   *              - u32, f32, u64, f64 shows underlying type
 78   *
 79   *  4. Convert - Non-portable x86 style
 80   *      - cvt_v<v>_<to>_<from>
 81   */
 82  
 83  
 84  /*
 85   * 1. Low-level-access          as_vN_A_B()
 86   * 2. Cast functions            cast_vN_A_to_B()
 87   * 3. Converters                cvt_vN_A_to_B()
 88   * 4. Condition Checkers        any_vN_A()
 89   * 5. Function callers          call_vN_A()
 90   *
 91   * where: N is a number of vector elements 2,4,8 etc,
 92   *        A is a type like f32, f64, u32, u64, i32, i64 etc,
 93   *        B is the other type just like 'A'
 94   */
 95  
 96  /* v4 - single precision */
 97  
 98  /* Access a f32x4 as u32x4 */
 99  inline v_u32x4_t
100  as_v4_u32_f32(v_f32x4_t x)
101  {
102      union {
103          v_f32x4_t f; v_u32x4_t u;
104      } r = {.f = x};
105  
106      return r.u;
107  }
108  
109  /* Access a u32x4 as f32x4 */
110  inline v_f32x4_t
111  as_v4_f32_u32(v_u32x4_t x)
112  {
113      union {
114          v_f32x4_t f; v_u32x4_t u;
115      } r = {.u = x};
116  
117      return r.f;
118  }
119  
120  /* v2 double precision */
121  
122  static inline v_f64x2_t
123  as_v2_f64_u64(v_u64x2_t x)
124  {
125      union {
126          v_u64x2_t _xi; v_f64x2_t _xf;
127      } val = { ._xi = x };
128  
129      return val._xf;
130  }
131  
132  static inline v_u64x2_t
133  as_v2_u64_f64 (v_f64x2_t x)
134  {
135      union {
136          v_f64x2_t f; v_u64x2_t u;
137      } r = {.f = x};
138  
139      return r.u;
140  }
141  
142  /* v4 double precision */
143  
144  /* Access a u64x4 as f64x4 */
145  static inline v_f64x4_t
146  as_v4_f64_u64(v_u64x4_t x)
147  {
148      union {
149          v_f64x4_t f; v_u64x4_t u;
150      } r = {.u = x};
151  
152      return r.f;
153  }
154  
155  /* Access a u64x4 as f64x4 */
156  static inline v_u64x4_t
157  as_v4_u64_f64(v_f64x4_t x)
158  {
159      union {
160          v_f64x4_t f; v_u64x4_t u;
161      } r = {.f = x};
162      return r.u;
163  }
164  
165  /*
166   * v8 single precision
167   */
168  static inline v_f32x8_t
169  as_v8_f32_u32(v_u32x8_t x)
170  {
171      union {
172          v_u32x8_t _xi; v_f32x8_t _xf;
173      } val = { ._xi = x};
174  
175      return val._xf;
176  }
177  
178  static inline v_u32x8_t
179  as_v8_u32_f32(v_f32x8_t x)
180  {
181      union {
182          v_u32x8_t _xi; v_f32x8_t _xf;
183      } val = { ._xf = x};
184  
185      return val._xi;
186  }
187  
188  
189  /*
190   * Casting
191   */
192  
193  /* v4 unsigned int 64 -> 32 */
194  static inline v_u32x4_t
195  cast_v4_u64_to_u32(v_u64x4_t _xu64)
196  {
197      return (v_u32x4_t){_xu64[0], _xu64[1], _xu64[2], _xu64[3]};
198  }
199  
200  /* v4 signed int -> float */
201  static inline v_f32x4_t
202  cast_v4_s32_to_f32(v_i32x4_t _xi32)
203  {
204      return (v_f32x4_t){_xi32[0], _xi32[1], _xi32[2], _xi32[3]};
205  }
206  
207  /* v4 float -> double */
208  inline v_f64x4_t
209  cast_v4_f32_to_f64(v_f32x4_t _x)
210  {
211      return (v_f64x4_t){_x[0], _x[1], _x[2], _x[3]};
212  }
213  
214  /* v4 double -> float */
215  inline v_f32x4_t
216  cast_v4_f64_to_f32(v_f64x4_t _x)
217  {
218      return (v_f32x4_t){_x[0], _x[1], _x[2], _x[3]};
219  }
220  
221  // v4 double -> int64
222  static inline v_i64x4_t
223  cast_v4_f64_to_i64(v_f64x4_t _xf64)
224  {
225      return (v_i64x4_t){_xf64[0], _xf64[1], _xf64[2], _xf64[3]};
226  }
227  
228  // v2 double -> int64
229  static inline v_i64x2_t
230  cast_v2_f64_to_i64(v_f64x2_t _xf64)
231  {
232      return (v_i64x2_t){_xf64[0], _xf64[1]};
233  }
234  
235  /*
236  static inline v_u32x8_t
237  cast_v8_u64_to_u32(v_u32x8_t _xf64)
238  {
239      return (v_u32x8_t){
240          _xf64[0], _xf64[1], _xf64[2], _xf64[3],
241              _xf64[4], _xf64[5], _xf64[6], _xf64[7]
242              };
243  }
244  */
245  
246  static inline v_f64x4_t
247  cvt_v4_f32_to_f64(v_f32x4_t _xf32 /* cond */)
248  {
249      return _mm256_cvtps_pd(_xf32);
250  }
251  
252  static inline v_f32x4_t
253  cvt_v4_f64_to_f32(v_f64x4_t _xf64 /* cond */)
254  {
255      return _mm256_cvtpd_ps(_xf64);
256  }
257  
258  /*
259   * Condition Check
260   *    check if any of the vector elements are set
261   */
262  
263  /*
264   * On x86, 'cond' contains all 0's for false, and all 1's for true
265   * IOW, 0=>false, -1=>true
266   */
267  
268  static inline int
269  any_v4_u32(v_i32x4_t cond)
270  {
271      const v_i32x4_t zero = _MM_SET1_I32(0);
272      return ! _mm_testz_si128(cond, zero);
273  }
274  
275  static inline int
276  any_v8_u32(v_i32x8_t cond)
277  {
278      const v_i32x8_t zero = {0,};
279      return ! _mm256_testz_si256(cond, zero);
280  }
281  
282  static inline int
283  any_v4_u64(v_i64x4_t cond)
284  {
285      const v_i64x4_t zero = _MM_SET1_I64(0);
286      return ! _mm256_testz_si256(cond, zero);
287  }
288  
289  static inline int
290  any_v2_u64(v_i64x2_t cond)
291  {
292      const v_i64x2_t zero = _MM_SET1_I64x2(0);
293      return ! _mm_testz_si128(cond, zero);
294  }
295  
296  // Condition check with for loop for better performance
297  static inline int
298  any_v4_u32_loop(v_i32x4_t cond)
299  {
300      int ret = 0;
301  
302      for (int i = 0; i < 4; i++) {
303          if (cond[i] !=0) {
304              ret= 1;
305              break;
306          }
307      }
308  
309      return ret;
310  }
311  
312  // Condition check with for loop for better performance
313  static inline int
314  any_v2_u64_loop(v_i64x2_t cond)
315  {
316      int ret = 0;
317  
318      for (int i = 0; i < 2; i++) {
319          if (cond[i] !=0) {
320              ret= 1;
321              break;
322          }
323      }
324  
325      return ret;
326  }
327  
328  // Condition check with for loop for better performance
329  static inline int
330  any_v4_u64_loop(v_i64x4_t cond)
331  {
332      int ret = 0;
333      for (int i = 0; i < 4; i++) {
334          if (cond[i] != 0) {
335              ret = 1;
336              break;
337          }
338      }
339  
340      return ret;
341  }
342  
343  
344  #ifndef ALM_HAS_V8_CALL_F32
345  #define ALM_HAS_V8_CALL_F32
346  
347  static inline v_f32x8_t
348  call_v8_f32(float (*fn)(float),
349              v_f32x8_t x,
350              v_f32x8_t result,
351              v_i32x8_t cond)
352  {
353      return (v_f32x8_t) {
354          cond[0] ? fn(x[0]) : result[0],
355              cond[1] ? fn(x[1]) : result[1],
356              cond[2] ? fn(x[2]) : result[2],
357              cond[3] ? fn(x[3]) : result[3],
358              cond[4] ? fn(x[4]) : result[4],
359              cond[5] ? fn(x[5]) : result[5],
360              cond[6] ? fn(x[6]) : result[6],
361              cond[7] ? fn(x[7]) : result[7]
362              };
363  }
364  #endif
365  
366  /*
367   * TODO: Convert all following to format
368   *    call_vN_A
369   *    call2_vN_A  - call a function with 2 args
370   * Where
371   *     N - number of vector elements 2, 4, 8
372   *     A - vector element type f32/f64
373   */
374  #ifndef ALM_HAS_V8_CALL2_F32
375  #define ALM_HAS_V8_CALL2_F32
376  
377  static inline v_f32x8_t
378  call2_v8_f32(float (*fn)(float, float),
379          v_f32x8_t x,
380          v_f32x8_t y,
381          v_f32x8_t result,
382          v_i32x8_t cond)
383  {
384      return (v_f32x8_t) {
385          cond[0] ? fn(x[0], y[0]) : result[0],
386              cond[1] ? fn(x[1], y[1]) : result[1],
387              cond[2] ? fn(x[2], y[2]) : result[2],
388              cond[3] ? fn(x[3], y[3]) : result[3],
389              cond[4] ? fn(x[4], y[4]) : result[4],
390              cond[5] ? fn(x[5], y[5]) : result[5],
391              cond[6] ? fn(x[6], y[6]) : result[6],
392              cond[7] ? fn(x[7], y[7]) : result[7]
393              };
394  }
395  
396  #endif
397  
398  #ifndef ALM_HAS_V4_CALL_F32
399  #define ALM_HAS_V4_CALL_F32
400  
401  static inline v_f32x4_t
402  call_v4_f32(float (*fn)(float),
403             v_f32x4_t orig,
404             v_f32x4_t result,
405             v_i32x4_t cond)
406  {
407      return (v_f32x4_t){cond[0] ? fn(orig[0]) : result[0],
408              cond[1] ? fn(orig[1]) : result[1],
409              cond[2] ? fn(orig[2]) : result[2],
410              cond[3] ? fn(orig[3]) : result[3]};
411  }
412  #endif
413  
414  #ifndef ALM_HAS_V4_CALL_2_F32
415  #define ALM_HAS_V4_CALL_2_F32
416  static inline v_f32x4_t
417  v_call2_f32(float (*fn)(float, float),
418              v_f32x4_t x,
419              v_f32x4_t y,
420              v_f32x4_t result,
421              v_i32x4_t cond)
422  {
423      return (v_f32x4_t){cond[0] ? fn(x[0], y[0]) : result[0],
424              cond[1] ? fn(x[1], y[1]) : result[1],
425              cond[2] ? fn(x[2], y[2]) : result[2],
426              cond[3] ? fn(x[3], y[3]) : result[3]};
427  }
428  
429  #endif
430  
431  #ifndef ALM_HAS_V4_CALL_F64
432  #define ALM_HAS_V4_CALL_F64
433  static inline v_f64x4_t
434  v_call_f64(double (*fn)(double),
435             v_f64x4_t orig,
436             v_f64x4_t result,
437             v_i64x4_t cond)
438  {
439      return (v_f64x4_t){cond[0] ? fn(orig[0]) : result[0],
440              cond[1] ? fn(orig[1]) : result[1],
441              cond[2] ? fn(orig[2]) : result[2],
442              cond[3] ? fn(orig[3]) : result[3]};
443  }
444  #endif
445  
446  #ifndef ALM_HAS_V4_CALL_2_F64
447  #define ALM_HAS_V4_CALL_2_F64
448  static inline v_f64x4_t
449  call2_v4_f64(double (*fn)(double, double),
450         v_f64x4_t x,
451         v_f64x4_t y,
452         v_f64x4_t result,
453         v_i64x4_t cond)
454  {
455      return (v_f64x4_t){cond[0] ? fn(x[0], y[0]) : result[0],
456              cond[1] ? fn(x[1], y[1]) : result[1],
457              cond[2] ? fn(x[2], y[2]) : result[2],
458              cond[3] ? fn(x[3], y[3]) : result[3]};
459  }
460  #endif
461  
462  #ifndef ALM_HAS_V2_CALL_2_F64
463  #define ALM_HAS_V2_CALL_2_F64
464  static inline v_f64x2_t
465  call2_v2_f64(double (*fn)(double, double),
466         v_f64x2_t x,
467         v_f64x2_t y,
468         v_f64x2_t result,
469         v_i64x2_t cond)
470  {
471      return (v_f64x2_t){cond[0] ? fn(x[0], y[0]) : result[0],
472              cond[1] ? fn(x[1], y[1]) : result[1]};
473  }
474  #endif
475  
476  
477  
478  // v_f32x8_t to v_i32x8_t
479  static inline v_i32x8_t
480  cast_v8_f32_to_i32(v_f32x8_t _xf32)
481  {
482      return (v_i32x8_t){_xf32[0], _xf32[1], _xf32[2], _xf32[3],
483              _xf32[4], _xf32[5], _xf32[6], _xf32[7]};
484  }
485  
486  // v_i32x8_t to v_f32x8_t
487  static inline v_f32x8_t
488  cast_v8_f32_to_s32(v_i32x8_t _xi32)
489  {
490      return (v_f32x8_t){_xi32[0], _xi32[1], _xi32[2], _xi32[3],
491              _xi32[4], _xi32[5], _xi32[6], _xi32[7] };
492  }
493  
494  // Condition check with for loop for better performance
495  static inline int
496  any_v8_u32_loop(v_i32x8_t cond)
497  {
498      int ret = 0;
499  
500      for (int i = 0; i < 8; i++) {
501          if (cond[i] !=0) {
502              ret= 1;
503              break;
504          }
505      }
506  
507      return ret;
508  }
509  
510  #endif