Optimizing Matrix Multiplication AlphaTensor for faster matrix multiplication, explained

Published

Matrix multiplication is executed so often in deep learning, video games, and scientific computing that even a slight acceleration can save substantial amounts of processing time. New work finds ways to speed up this crucial operation.

Whatâ€™s new: Alhussein Fawzi and colleagues at DeepMind developed AlphaTensor. This reinforcement learning agent discovers algorithms that multiply matrices faster than those previously developed by humans.

Composition and decomposition: Computers need more time to multiply than to add or subtract. Developers often take advantage of algebraic properties â€” for instance, (a^2 - b^2) = (a+b)(a-b) â€” to manually find matrix multiplication algorithms that require fewer multiplications. To minimize the number of multiplications systematically, we can take advantage of the fact that a tensor (a high-dimensional matrix) can represent a matrix multiplication algorithm. Itâ€™s easy to compose a tensor from three matrices. However, to decompose a tensor (the reverse operation) is not straightforward; the procedure could result in any of thousands of potential sets of matrices. Any valid decomposition of the tensor into three matrices represents a valid algorithm for matrix multiplication. The number of columns equals the number of multiplications required.

Key insight: Just as DeepMindâ€™s AlphaZero learned via reinforcement learning to play Go by simulating future game-board states and, based on those states, predicting the likelihood that it would win, a reinforcement learning model can learn to win a game of decomposing tensors by predicting the columns of three matrices.

How it works: Given a tensor that represents a matrix multiplication algorithm, AlphaTensor played a game in which it decomposed the tensor into three matrices with as few columns â€” and thus as few multiplications â€” as possible. (The values in the predicted columns were limited to {-2,-1,0,1,2} to avoid precision issues that could have occurred with floating-point values.) At each turn, it predicted the entries in one column of each of the three matrices. The game updated the tensorâ€™s state by subtracting the outer product of the predicted columns. It ended when all entries in the tensor equalled 0. AlphaTensor received a negative reward after predicting each set of columns, which encouraged it to decompose the tensor into matrices that had few columns. It received a positive reward for predicting all columns of the three matrices.

• The authors constructed the training dataset of tensor decompositions by randomly generating three matrices and composing them into a tensor.
• Given a tensorâ€™s state (starting with the tensor to be decomposed), AlphaTensor embedded the tensor using a series of axial attention layers.
• Given the tensor embedding, AlphaTensor predicted columns using two components: a transformer that predicted likely next columns and a vanilla neural network that predicted the future total reward for those columns.
• Of the predicted columns, AlphaTensor chose a set that wasnâ€™t often previously predicted and had a high probability and high predicted reward.

Results: AlphaTensor rediscovered known matrix multiplication algorithms for matrices as large as five rows and columns (5x5). Notably, to multiply two 4x4 matrices that contain binary numbers, AlphaTensor discovered an algorithm that requires 47 multiplications, compared to Strassenâ€™s algorithm, which requires 49 and had not been improved upon since its creation in 1969. To multiply 4x5 and 5x5 matrices that contain real numbers, AlphaTensor found an algorithm that requires 76 multiplications; the previous best takes 80. After training AlphaTensor with an additional reward that reduced hardware-specific compute time, the authors found algorithms for an Nvidia V100 GPU that are, on median, 8.5 percent faster than the usual implementation. Optimized for TPUs, AlphaTensor sped up matrix multiplication by 10.3 percent.

Why it matters: Neural networks learn from data how to perform a particular task reasonably well (for instance, they may be correct 95 percent of the time). But is reasonably well sufficient for a field such as mathematics, in which results are provably true or false? This paper stands alongside achievements such as a neural theorem finder and neural theorem prover, showing that deep learning can advance even the most exacting fields.

Weâ€™re thinking: This work shows deep learningâ€™s potential for synergy between humans and machines: People supply an algorithm (such as matrix multiplication) and AI accelerates its runtime.

Share