// // main.swift // test // // Created by Teguh Hofstee on 12/18/23. // import Metal func makeBuffer(_ items: [T]) -> MTLBuffer? { if items.count > 0 { items.withUnsafeBytes { ptr in device.makeBuffer(bytes: ptr.baseAddress!, length: ptr.count) } } else { device.makeBuffer(length: MemoryLayout.stride) } } let device = MTLCreateSystemDefaultDevice()! let captureManager = MTLCaptureManager.shared() let CAPTURE = true if CAPTURE { let captureDescriptor = MTLCaptureDescriptor() captureDescriptor.captureObject = device do { try captureManager.startCapture(with: captureDescriptor) } catch { fatalError("error when trying to capture: \(error)") } } let library = device.makeDefaultLibrary()! print(library.functionNames) let cmdQueue = device.makeCommandQueue()! let cmdBuf = cmdQueue.makeCommandBuffer()! let cmdEnc = cmdBuf.makeComputeCommandEncoder()! let data: [Int32] = Array(1...8193) var size: Int32 = Int32(data.count) let input = makeBuffer(data)! let output = device.makeBuffer(length: MemoryLayout.stride * (2 * data.count - 1))! let visited = device.makeBuffer(length: MemoryLayout.stride * (data.count - 1))! var order = device.makeBuffer(length: MemoryLayout.stride) do { let fn = library.makeFunction(name: "init")! let ps = try! device.makeComputePipelineState(function: fn) cmdEnc.setComputePipelineState(ps) cmdEnc.setBuffer(order, offset: 0, index: 0) let gridSize = MTLSizeMake(1, 1, 1) var threadGroupSize = ps.maxTotalThreadsPerThreadgroup if threadGroupSize > 1 { threadGroupSize = 1 } let threadgroupSize = MTLSizeMake(1, 1, 1) cmdEnc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) } do { let fn = library.makeFunction(name: "sum")! let ps = try! device.makeComputePipelineState(function: fn) cmdEnc.setComputePipelineState(ps) cmdEnc.setBytes(&size, length: MemoryLayout.stride, index: 0) cmdEnc.setBuffer(input, offset: 0, index: 1) cmdEnc.setBuffer(output, offset: 0, index: 2) cmdEnc.setBuffer(visited, offset: 0, index: 3) cmdEnc.setBuffer(order, offset: 0, index: 4) let gridSize = MTLSizeMake(data.count, 1, 1) var threadGroupSize = ps.maxTotalThreadsPerThreadgroup if threadGroupSize > data.count { threadGroupSize = data.count } let threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1) cmdEnc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) print(gridSize) } cmdEnc.endEncoding() cmdBuf.commit() cmdBuf.waitUntilCompleted() print("done!") let result = output.contents().bindMemory(to: Int32.self, capacity: 2 * data.count - 1) for k in 0 ..< data.count - 1 { if result[k] != result[2*k+1] + result[2*k+2] { print("\(k) -> (\(2*k+1), \(2*k+2)): expected: \(result[2*k+1] + result[2*k+2]), actual: \(result[k])") } } print("expected: \(data.reduce(0, +))") print("actual: \(result[0])") if CAPTURE { captureManager.stopCapture() }