// // test.metal // test // // Created by Teguh Hofstee on 12/18/24. // #include using namespace metal; constant os_log logger(/*subsystem=*/"com.metal.xyz", /*category=*/"abc"); [[kernel]] // clang-format off void init(device int& threadgroup_order) { // clang-format on threadgroup_order = 0; } [[kernel]] // clang-format off void sum(device const int& size, device const int* __restrict__ in, device int* __restrict__ out, device atomic_int* visited, device atomic_int* threadgroup_order, uint i [[thread_position_in_grid]], uint k [[thread_position_in_threadgroup]], uint j [[threadgroup_position_in_grid]]) { // clang-format on if (k == 0) { int order = atomic_fetch_add_explicit(threadgroup_order, 1, memory_order_relaxed); logger.log_info("threadgroup %d executed %d", j, order); } int val = in[i]; uint cur = (size + i - 1); out[cur] = val; atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst); int MAX_ITER = 1000; cur = (cur - 1) / 2; int proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed); if (i == 8191 || i == 8192) { logger.log_info("[%d] got %d", i, proceed); } while (proceed == 1 && MAX_ITER-- > 0) { uint left = 2 * cur + 1; uint right = 2 * cur + 2; uint val_left = out[left]; uint val_right = out[right]; uint val_cur = val_left + val_right; out[cur] = val_cur; if (cur == 0) { break; } cur = (cur - 1) / 2; atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst); proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed); } }