# L-22 MCS 572 Wed 16 Oct 2024 : firstsum.jl # Illustration using shared memory with CUDA.jl, # to sum the first N numbers. using CUDA """ function sum(x, y, N, threadsPerBlock, blocksPerGrid) computes the sum product of N numbers in x and places the results in y, using shared memory. """ function sum(x, y, N, threadsPerBlock, blocksPerGrid) # set up shared memory cache for this current block cache = @cuDynamicSharedMem(Int64, threadsPerBlock) # initialise the indices tid = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x totalThreads = blockDim().x * gridDim().x cacheIndex = threadIdx().x - 1 # run over the vector temp = 0 while tid < N temp += x[tid + 1] tid += totalThreads end # set cache values cache[cacheIndex + 1] = temp # synchronise threads sync_threads() # we add up all of the values stored in the cache i::Int = blockDim().x ÷ 2 while i!=0 if cacheIndex < i cache[cacheIndex + 1] += cache[cacheIndex + i + 1] end sync_threads() i = i ÷ 2 end # cache[1] now contains the sum of the numbers in the block if cacheIndex == 0 y[blockIdx().x] = cache[1] end return nothing end """ Tests the kernel on the first N natural numbers. """ function main() N::Int64 = 33 * 1024 threadsPerBlock::Int64 = 256 blocksPerGrid::Int64 = min(32, (N + threadsPerBlock - 1) / threadsPerBlock) println("size of the vector : ", N) println(" number of blocks : ", blocksPerGrid) println(" threads per block : ", threadsPerBlock) println(" number of threads : ", blocksPerGrid*threadsPerBlock) # input arrays on the host x_h = [i for i=1:N] # make the arrays on the device x_d = CuArray(x_h) y_d = CuArray(fill(0, blocksPerGrid)) # execute the kernel. Note the shmem argument - this is necessary to allocate # space for the cache we allocate on the gpu with @cuDynamicSharedMem @cuda blocks = blocksPerGrid threads = threadsPerBlock shmem = (threadsPerBlock * sizeof(Int64)) sum(x_d, y_d, N, threadsPerBlock, blocksPerGrid) # copy the result from device to the host y_h = Array(y_d) local result = 0 for i in 1:blocksPerGrid result += y_h[i] end # check whether output is correct print("Does GPU value ", result, " = ", N*(N+1) ÷ 2, " ? ") println(result == N*(N+1) ÷ 2) end main()