101 Logo
onenoughtone

Strassen's Matrix Multiplication

What is Strassen's Algorithm?

Strassen's algorithm is a divide and conquer approach for matrix multiplication that reduces the number of recursive calls from 8 to 7, resulting in a more efficient algorithm.

The standard matrix multiplication algorithm has a time complexity of O(n³), while Strassen's algorithm achieves O(n^2.81).

Key Insight

The key insight of Strassen's algorithm is that it's possible to multiply two 2×2 matrices using only 7 multiplications instead of the 8 required by the standard algorithm.

This reduction in the number of multiplications leads to a more efficient algorithm when applied recursively to larger matrices.

The Seven Products

Strassen's algorithm computes the following seven products:

  1. P1 = (A11 + A22) × (B11 + B22)
  2. P2 = (A21 + A22) × B11
  3. P3 = A11 × (B12 - B22)
  4. P4 = A22 × (B21 - B11)
  5. P5 = (A11 + A12) × B22
  6. P6 = (A21 - A11) × (B11 + B12)
  7. P7 = (A12 - A22) × (B21 + B22)

These products are then combined to form the four quadrants of the result matrix:

  • C11 = P1 + P4 - P5 + P7
  • C12 = P3 + P5
  • C21 = P2 + P4
  • C22 = P1 + P3 - P2 + P6

Implementation

Here's an implementation of Strassen's matrix multiplication algorithm:

function strassenMultiply(A, B) { const n = A.length; // Base case: 1x1 matrices if (n === 1) { return [[A[0][0] * B[0][0]]]; } // Ensure n is a power of 2 (for simplicity) // In a real implementation, you would pad the matrices with zeros // Divide matrices into quadrants const mid = n / 2; // Create submatrices const a11 = submatrix(A, 0, 0, mid); const a12 = submatrix(A, 0, mid, mid); const a21 = submatrix(A, mid, 0, mid); const a22 = submatrix(A, mid, mid, mid); const b11 = submatrix(B, 0, 0, mid); const b12 = submatrix(B, 0, mid, mid); const b21 = submatrix(B, mid, 0, mid); const b22 = submatrix(B, mid, mid, mid); // Compute the seven products (recursively) const p1 = strassenMultiply(add(a11, a22), add(b11, b22)); const p2 = strassenMultiply(add(a21, a22), b11); const p3 = strassenMultiply(a11, subtract(b12, b22)); const p4 = strassenMultiply(a22, subtract(b21, b11)); const p5 = strassenMultiply(add(a11, a12), b22); const p6 = strassenMultiply(subtract(a21, a11), add(b11, b12)); const p7 = strassenMultiply(subtract(a12, a22), add(b21, b22)); // Compute the quadrants of the result const c11 = add(subtract(add(p1, p4), p5), p7); const c12 = add(p3, p5); const c21 = add(p2, p4); const c22 = add(subtract(add(p1, p3), p2), p6); // Combine the quadrants into a single matrix return combineMatrices(c11, c12, c21, c22); } // Helper functions function submatrix(matrix, startRow, startCol, size) { const result = []; for (let i = 0; i < size; i++) { result[i] = []; for (let j = 0; j < size; j++) { result[i][j] = matrix[startRow + i][startCol + j]; } } return result; } function add(A, B) { const n = A.length; const result = []; for (let i = 0; i < n; i++) { result[i] = []; for (let j = 0; j < n; j++) { result[i][j] = A[i][j] + B[i][j]; } } return result; } function subtract(A, B) { const n = A.length; const result = []; for (let i = 0; i < n; i++) { result[i] = []; for (let j = 0; j < n; j++) { result[i][j] = A[i][j] - B[i][j]; } } return result; } function combineMatrices(c11, c12, c21, c22) { const n = c11.length; const result = []; for (let i = 0; i < n * 2; i++) { result[i] = []; for (let j = 0; j < n * 2; j++) { if (i < n && j < n) { result[i][j] = c11[i][j]; } else if (i < n && j >= n) { result[i][j] = c12[i][j - n]; } else if (i >= n && j < n) { result[i][j] = c21[i - n][j]; } else { result[i][j] = c22[i - n][j - n]; } } } return result; }

Time and Space Complexity

Time Complexity:

O(n^log₂7) ≈ O(n^2.81), where n is the size of the matrices.

Space Complexity:

O(n²) for storing the matrices and intermediate results.

IntroVisualizePatternsPractice
101 Logo
onenoughtone