/ vulkan-shaders / argsort.comp
argsort.comp
 1  #version 450
 2  
 3  #include "types.comp"
 4  
 5  #define BLOCK_SIZE 1024
 6  #define ASC 0
 7  
 8  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 9  
10  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
11  layout (binding = 1)          buffer D {int data_d[];};
12  
13  layout (push_constant) uniform parameter {
14      uint ncols;
15      uint ncols_pad;
16      uint order;
17  } p;
18  
19  shared int dst_row[BLOCK_SIZE];
20  
21  void swap(uint idx0, uint idx1) {
22      int tmp = dst_row[idx0];
23      dst_row[idx0] = dst_row[idx1];
24      dst_row[idx1] = tmp;
25  }
26  
27  void main() {
28      // bitonic sort
29      const int col = int(gl_LocalInvocationID.x);
30      const uint row = gl_WorkGroupID.y;
31  
32      if (col >= p.ncols_pad) {
33          return;
34      }
35  
36      const uint row_offset = row * p.ncols;
37  
38      // initialize indices
39      dst_row[col] = col;
40      barrier();
41  
42      for (uint k = 2; k <= p.ncols_pad; k *= 2) {
43          for (uint j = k / 2; j > 0; j /= 2) {
44              const uint ixj = col ^ j;
45              if (ixj > col) {
46                  if ((col & k) == 0) {
47                      if (dst_row[col] >= p.ncols ||
48                          (dst_row[ixj] < p.ncols && (p.order == ASC ?
49                              data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
50                              data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
51                      ) {
52                          swap(col, ixj);
53                      }
54                  } else {
55                      if (dst_row[ixj] >= p.ncols ||
56                          (dst_row[col] < p.ncols && (p.order == ASC ?
57                              data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
58                              data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
59                      ) {
60                          swap(col, ixj);
61                      }
62                  }
63              }
64              barrier();
65          }
66      }
67  
68      if (col < p.ncols) {
69          data_d[row_offset + col] = dst_row[col];
70      }
71  }