Introduction to fast matrix multiplication

Recall how to multiply n\times n matrices A and B:

\displaystyle{(AB)_{ik}=\sum_{j=1}^n A_{ij}B_{jk}}

For each (i,k), it takes O(n) operations to calculate this sum, and so naive matrix multiplication takes O(n^3) operations total. Let \omega denote the infimum of \alpha for which there exists an algorithm that computes the product of any two n\times n matrices in O(n^{\alpha}) operations.

Conjecture. \omega=2.

A 3-way tensor is a complex function of index triples, i.e., T\colon I\times J\times K\rightarrow\mathbb{C}. A tensor has rank 1 if it enjoys a factorization of the form

\displaystyle{T_{ijk}=x_iy_jz_k \qquad \forall (i,j,k)\in I\times J\times K.}

The rank of a tensor is the size of its smallest decomposition into rank-1 tensors (like matrix rank).

Warning 1. Tensor rank is strange. Consider

\displaystyle{A_\epsilon:=\Bigg[\bigg[\begin{array}{cc} 1 & 0 \\ \epsilon & 1 \end{array}\bigg];\bigg[\begin{array}{cc} 0 & 1 \\ 0 & 0 \end{array}\bigg]\Bigg].}

Then \mathrm{rank}(A_\epsilon)=2 for every \epsilon>0, but \mathrm{rank}(A_0)=3.

Warning 2. Tensor rank is NP-hard to compute [Hastad 90].

Regardless, we use tensors (and their low-rank decompositions) to multiply matrices faster than O(n^3). Define the matrix multiplication tensor as follows:

\displaystyle{M_{(a,b),(c,d),(e,f)}^{(n)} := \left\{ \begin{array}{ll} 1 & \mbox{if}~b=c,~d=e,~f=a \\ 0 & \mbox{otherwise}.\end{array} \right.}

Suppose \mathrm{rank}(M^{(n)})\leq r, i.e.,

\displaystyle{M_{(a,b),(c,d),(e,f)}^{(n)}=\sum_{\ell=1}^r x_{ab}^\ell y_{cd}^\ell z_{ef}^\ell.}

Then we can use this decomposition to re-express matrix multiplication:

\begin{aligned}  (AB)_{ik}  &=\sum_{j=1}^n A_{ij}B_{jk}\\  &=\sum_{j=1}^n M_{(i,j),(j,k),(k,i)}^{(n)}A_{ij}B_{jk}\\  &=\sum_{a=1}^n\sum_{b=1}^n\sum_{c=1}^n\sum_{d=1}^n M_{(a,b),(c,d),(k,i)}^{(n)}A_{ab}B_{cd}\\  &=\sum_{a=1}^n\sum_{b=1}^n\sum_{c=1}^n\sum_{d=1}^n \bigg(\sum_{\ell=1}^r x_{ab}^\ell y_{cd}^\ell z_{ki}^\ell\bigg)A_{ab}B_{cd}\\  &=\sum_{\ell=1}^r \underbrace{\bigg(\sum_{a=1}^n\sum_{b=1}^n x_{ab}^\ell A_{ab}\bigg)}_{C_\ell} \underbrace{\bigg(\sum_{c=1}^n\sum_{d=1}^n y_{cd}^\ell B_{cd}\bigg)}_{D_\ell} z_{ki}^\ell.  \end{aligned}

This can be implemented as follows (here, P=AB; ignore the notes in brackets for now):


Initialize P as the matrix of all zeros      [O(n^2), O(n^{2p})]

for \ell=1:r

Compute C_\ell and D_\ell      [O(n^2), O(n^{2p})]

E_\ell\leftarrow C_\ell D_\ell      [O(1), \#_{p-1}]

for i=1:n, k=1:n

P_{ik}\leftarrow P_{ik}+E_\ell z_{ki}^\ell      [O(1), O(n^{2p-2})]

endfor

endfor


In each line, the first quantity in brackets counts the number of operations performed. Adding these quantities (and accounting for for loops), we get that the total complexity is O(n^2)+r(O(n^2)+O(1)+n^2O(1))=O(rn^2). This is not very good:

Lemma. r\geq n^2.

Proof: Suppose r<n^2. Then pick a nonzero A such that A\perp \overline{x^\ell} for every \ell=1,\ldots,r. Then C_\ell=0 for every \ell=1,\ldots,r, meaning AA^*=0. Contradiction.     \Box

Now we fix n and multiply N\times N matrices with N=n^p. Consider blocks of size n^{p-1}\times n^{p-1}, and note that we can re-interpret the previous derivation of (AB)_{ik} as a formula for the (i,k)th block of AB in terms of block multiplication. In the algorithm described above, the second quantity in brackets reports the number of operations under this new interpretation. Due to the calculation of E_\ell, the total number of operations then follows a recursion relation:

\displaystyle{\#_p =O(n^{2p})+r\Big(O(n^{2p})+\#_{p-1}+n^2O(n^{2p-2})\Big) =r\cdot\#_{p-1}+O(rn^{2p}).}

Note that \#_0=1. Let’s hunt for a pattern:

\begin{aligned} \#_1&\asymp rn^2\\ \#_2&\asymp r^2n^2+rn^4\\ \#_3&\asymp r^3n^2+r^2n^4+rn^6\\ \#_p&\asymp r^pn^2+r^{p-1}n^4+\cdots+r^2n^{2(p-1)}+rn^{2p}. \end{aligned}

Since r\geq n^2, the first term is dominant:

\#_p \ll r^pn^2p =N^{\log_n r} \cdot N^{2/\log_n N} \cdot \log_nN \ll N^{\log_n r+\epsilon} \quad \forall \epsilon>0.

As such, we can multiply two N\times N matrices in only O(N^{\log_nr+\epsilon}) operations (even if N is not a power of n; why?). To summarize:

Theorem. If \mathrm{rank}(M^{(n)})\leq r, then \omega\leq \log_n r.

Note: \mathrm{rank}(M^{(2)})=7 [Strassen 69, Landsberg 05], and so \omega\leq\log_27\leq 2.808. MATLAB’s built-in matrix multiplication does not use Strassen’s algorithm due to local memory constraints, but Strassen provides speedups for large matrices (N\geq1000).

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s