/ vulkan-shaders / argsort.comp
argsort.comp
1 #version 450 2 3 #include "types.comp" 4 5 #define BLOCK_SIZE 1024 6 #define ASC 0 7 8 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; 9 10 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; 11 layout (binding = 1) buffer D {int data_d[];}; 12 13 layout (push_constant) uniform parameter { 14 uint ncols; 15 uint ncols_pad; 16 uint order; 17 } p; 18 19 shared int dst_row[BLOCK_SIZE]; 20 21 void swap(uint idx0, uint idx1) { 22 int tmp = dst_row[idx0]; 23 dst_row[idx0] = dst_row[idx1]; 24 dst_row[idx1] = tmp; 25 } 26 27 void main() { 28 // bitonic sort 29 const int col = int(gl_LocalInvocationID.x); 30 const uint row = gl_WorkGroupID.y; 31 32 if (col >= p.ncols_pad) { 33 return; 34 } 35 36 const uint row_offset = row * p.ncols; 37 38 // initialize indices 39 dst_row[col] = col; 40 barrier(); 41 42 for (uint k = 2; k <= p.ncols_pad; k *= 2) { 43 for (uint j = k / 2; j > 0; j /= 2) { 44 const uint ixj = col ^ j; 45 if (ixj > col) { 46 if ((col & k) == 0) { 47 if (dst_row[col] >= p.ncols || 48 (dst_row[ixj] < p.ncols && (p.order == ASC ? 49 data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : 50 data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) 51 ) { 52 swap(col, ixj); 53 } 54 } else { 55 if (dst_row[ixj] >= p.ncols || 56 (dst_row[col] < p.ncols && (p.order == ASC ? 57 data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : 58 data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) 59 ) { 60 swap(col, ixj); 61 } 62 } 63 } 64 barrier(); 65 } 66 } 67 68 if (col < p.ncols) { 69 data_d[row_offset + col] = dst_row[col]; 70 } 71 }