Examples

Matrix multiplication

First, import the modules needed for this example:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
1from numba import cuda, float32
2import numpy as np
3import math

Here is a naïve implementation of matrix multiplication using a CUDA kernel:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
1@cuda.jit
2def matmul(A, B, C):
3    """Perform square matrix multiplication of C = A * B."""
4    i, j = cuda.grid(2)
5    if i < C.shape[0] and j < C.shape[1]:
6        tmp = 0.
7        for k in range(A.shape[1]):
8            tmp += A[i, k] * B[k, j]
9        C[i, j] = tmp

An example usage of this function is as follows:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
 1x_h = np.arange(16).reshape([4, 4])
 2y_h = np.ones([4, 4])
 3z_h = np.zeros([4, 4])
 4
 5x_d = cuda.to_device(x_h)
 6y_d = cuda.to_device(y_h)
 7z_d = cuda.to_device(z_h)
 8
 9threadsperblock = (16, 16)
10blockspergrid_x = math.ceil(z_h.shape[0] / threadsperblock[0])
11blockspergrid_y = math.ceil(z_h.shape[1] / threadsperblock[1])
12blockspergrid = (blockspergrid_x, blockspergrid_y)
13
14matmul[blockspergrid, threadsperblock](x_d, y_d, z_d)
15z_h = z_d.copy_to_host()
16print(z_h)
17print(x_h @ y_h)

This implementation is straightforward and intuitive but performs poorly, because the same matrix elements will be loaded multiple times from device memory, which is slow (some devices may have transparent data caches, but they may not be large enough to hold the entire inputs at once).

It will be faster if we use a blocked algorithm to reduce accesses to the device memory. CUDA provides a fast shared memory for threads in a block to cooperatively compute on a task. The following implements a faster version of the square matrix multiplication using shared memory:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
 1# Controls threads per block and shared memory usage.
 2# The computation will be done on blocks of TPBxTPB elements.
 3# TPB should not be larger than 32 in this example
 4TPB = 16
 5
 6@cuda.jit
 7def fast_matmul(A, B, C):
 8    """
 9    Perform matrix multiplication of C = A * B using CUDA shared memory.
10
11    Reference: https://stackoverflow.com/a/64198479/13697228 by @RobertCrovella
12    """
13    # Define an array in the shared memory
14    # The size and type of the arrays must be known at compile time
15    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
16    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
17
18    x, y = cuda.grid(2)
19
20    tx = cuda.threadIdx.x
21    ty = cuda.threadIdx.y
22    bpg = cuda.gridDim.x    # blocks per grid
23
24    # Each thread computes one element in the result matrix.
25    # The dot product is chunked into dot products of TPB-long vectors.
26    tmp = float32(0.)
27    for i in range(bpg):
28        # Preload data into shared memory
29        sA[ty, tx] = 0
30        sB[ty, tx] = 0
31        if y < A.shape[0] and (tx + i * TPB) < A.shape[1]:
32            sA[ty, tx] = A[y, tx + i * TPB]
33        if x < B.shape[1] and (ty + i * TPB) < B.shape[0]:
34            sB[ty, tx] = B[ty + i * TPB, x]
35
36        # Wait until all threads finish preloading
37        cuda.syncthreads()
38
39        # Computes partial product on the shared memory
40        for j in range(TPB):
41            tmp += sA[ty, j] * sB[j, tx]
42
43        # Wait until all threads finish computing
44        cuda.syncthreads()
45    if y < C.shape[0] and x < C.shape[1]:
46        C[y, x] = tmp

Because the shared memory is a limited resource, the code preloads a small block at a time from the input arrays. Then, it calls syncthreads() to wait until all threads have finished preloading and before doing the computation on the shared memory. It synchronizes again after the computation to ensure all threads have finished with the data in shared memory before overwriting it in the next loop iteration.

An example usage of the fast_matmul function is as follows:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
 1x_h = np.arange(16).reshape([4, 4])
 2y_h = np.ones([4, 4])
 3z_h = np.zeros([4, 4])
 4
 5x_d = cuda.to_device(x_h)
 6y_d = cuda.to_device(y_h)
 7z_d = cuda.to_device(z_h)
 8
 9threadsperblock = (TPB, TPB)
10blockspergrid_x = math.ceil(z_h.shape[0] / threadsperblock[0])
11blockspergrid_y = math.ceil(z_h.shape[1] / threadsperblock[1])
12blockspergrid = (blockspergrid_x, blockspergrid_y)
13
14fast_matmul[blockspergrid, threadsperblock](x_d, y_d, z_d)
15z_h = z_d.copy_to_host()
16print(z_h)
17print(x_h @ y_h)

This passes a CUDA memory check test, which can help with debugging. Running the code above produces the following output:

$ python fast_matmul.py
[[ 6.  6.  6.  6.]
[22. 22. 22. 22.]
[38. 38. 38. 38.]
[54. 54. 54. 54.]]
[[ 6.  6.  6.  6.]
[22. 22. 22. 22.]
[38. 38. 38. 38.]
[54. 54. 54. 54.]]

Note

For high performance matrix multiplication in CUDA, see also the CuPy implementation.

The approach outlined here generalizes to non-square matrix multiplication as follows by adjusting the blockspergrid variable:

Again, here is an example usage:

from test_ex_matmul in numba/cuda/tests/doc_examples/test_matmul.py
 1x_h = np.arange(115).reshape([5, 23])
 2y_h = np.ones([23, 7])
 3z_h = np.zeros([5, 7])
 4
 5x_d = cuda.to_device(x_h)
 6y_d = cuda.to_device(y_h)
 7z_d = cuda.to_device(z_h)
 8
 9threadsperblock = (TPB, TPB)
10grid_y_max = max(x_h.shape[0], y_h.shape[0])
11grid_x_max = max(x_h.shape[1], y_h.shape[1])
12blockspergrid_x = math.ceil(grid_x_max / threadsperblock[0])
13blockspergrid_y = math.ceil(grid_y_max / threadsperblock[1])
14blockspergrid = (blockspergrid_x, blockspergrid_y)
15
16fast_matmul[blockspergrid, threadsperblock](x_d, y_d, z_d)
17z_h = z_d.copy_to_host()
18print(z_h)
19print(x_h @ y_h)

and the corresponding output:

$ python nonsquare_matmul.py
[[ 253.  253.  253.  253.  253.  253.  253.]
[ 782.  782.  782.  782.  782.  782.  782.]
[1311. 1311. 1311. 1311. 1311. 1311. 1311.]
[1840. 1840. 1840. 1840. 1840. 1840. 1840.]
[2369. 2369. 2369. 2369. 2369. 2369. 2369.]]
[[ 253.  253.  253.  253.  253.  253.  253.]
[ 782.  782.  782.  782.  782.  782.  782.]
[1311. 1311. 1311. 1311. 1311. 1311. 1311.]
[1840. 1840. 1840. 1840. 1840. 1840. 1840.]
[2369. 2369. 2369. 2369. 2369. 2369. 2369.]]