Matrix Chain Multiplication Algorithm

Matrix Chain Multiplication Algorithm is a dynamic programming approach used to find the optimal way to multiply a chain of matrices to minimize the total number of scalar multiplications. Given a sequence of matrices, the algorithm aims to find the most efficient way to multiply these matrices together. The problem is not actually to perform the multiplications, but merely to decide the optimal sequence of matrix multiplications involved. The algorithm uses a bottom-up approach and computes the optimal cost and parenthesization for smaller subproblems to build the solution for the overall problem. The main idea behind the Matrix Chain Multiplication Algorithm is to leverage the associative property of matrix multiplication, which allows for different parenthesization without changing the result. For example, given three matrices A, B, and C, we can multiply them either as (A * B) * C or as A * (B * C), both yielding the same result. However, the number of scalar multiplications required may differ depending on the order, and the algorithm aims to minimize this. Using dynamic programming, the algorithm computes a cost matrix that stores the minimum number of scalar multiplications needed for each possible chain of matrices. By using the computed costs of subproblems, the optimal solution for larger problems is built, and the optimal parenthesization can be determined by backtracking through this cost matrix.
package DynamicProgramming;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;

public class MatrixChainMultiplication {
    private static Scanner scan = new Scanner(System.in);
    private static ArrayList<Matrix> mArray = new ArrayList<>();
    private static int size;
    private static int[][] m;
    private static int[][] s;
    private static int[] p;

    public static void main(String[] args) {
        int count = 1;
        while (true) {
            String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : ");
            int col = Integer.parseInt(mSize[0]);
            if (col == 0) break;
            int row = Integer.parseInt(mSize[1]);

            Matrix matrix = new Matrix(count, col, row);
            mArray.add(matrix);
            count++;
        }
        for (Matrix m : mArray) {
            System.out.format("A(%d)  =  %2d  x  %2d%n", m.count(), m.col(), m.row());
        }

        size = mArray.size();
        m = new int[size + 1][size + 1];
        s = new int[size + 1][size + 1];
        p = new int[size + 1];

        for (int i = 0; i < size + 1; i++) {
            Arrays.fill(m[i], -1);
            Arrays.fill(s[i], -1);
        }

        for (int i = 0; i < p.length; i++) {
            p[i] = i == 0 ? mArray.get(i).col() : mArray.get(i - 1).row();
        }

        matrixChainOrder();
        for (int i = 0; i < size; i++) {
            System.out.print("-------");
        }
        System.out.println();
        printArray(m);
        for (int i = 0; i < size; i++) {
            System.out.print("-------");
        }
        System.out.println();
        printArray(s);
        for (int i = 0; i < size; i++) {
            System.out.print("-------");
        }
        System.out.println();

        System.out.println("Optimal solution : " + m[1][size]);
        System.out.print("Optimal parens : ");
        printOptimalParens(1, size);
    }

    private static void printOptimalParens(int i, int j) {
        if (i == j) {
            System.out.print("A" + i);
        } else {
            System.out.print("(");
            printOptimalParens(i, s[i][j]);
            printOptimalParens(s[i][j] + 1, j);
            System.out.print(")");
        }
    }

    private static void printArray(int[][] array) {
        for (int i = 1; i < size + 1; i++) {
            for (int j = 1; j < size + 1; j++) {
                System.out.print(String.format("%7d", array[i][j]));
            }
            System.out.println();
        }
    }

    private static void matrixChainOrder() {
        for (int i = 1; i < size + 1; i++) {
            m[i][i] = 0;
        }

        for (int l = 2; l < size + 1; l++) {
            for (int i = 1; i < size - l + 2; i++) {
                int j = i + l - 1;
                m[i][j] = Integer.MAX_VALUE;

                for (int k = i; k < j; k++) {
                    int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
                    if (q < m[i][j]) {
                        m[i][j] = q;
                        s[i][j] = k;
                    }
                }
            }
        }
    }

    private static String[] input(String string) {
        System.out.print(string);
        return (scan.nextLine().split(" "));
    }

}

class Matrix {
    private int count;
    private int col;
    private int row;

    Matrix(int count, int col, int row) {
        this.count = count;
        this.col = col;
        this.row = row;
    }

    int count() {
        return count;
    }

    int col() {
        return col;
    }

    int row() {
        return row;
    }
}

LANGUAGE:

DARK MODE: