Skip to content

Commit f7c0563

Browse files
François Dupiremaibin
authored andcommitted
dupirefr/[email protected] [BAEL-3606] Matrix Multiplication Libraries Comparison (eugenp#8298)
* Added benchmarking on larger matrices * [BAEL-3606] Moved benchmarking to production code * [BAEL-3606] Added minor fix
1 parent 1c5e524 commit f7c0563

File tree

4 files changed

+172
-3
lines changed

4 files changed

+172
-3
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package com.baeldung.matrices.benchmark;
2+
3+
import cern.colt.matrix.DoubleFactory2D;
4+
import cern.colt.matrix.DoubleMatrix2D;
5+
import cern.colt.matrix.linalg.Algebra;
6+
import com.baeldung.matrices.HomemadeMatrix;
7+
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
8+
import org.apache.commons.math3.linear.RealMatrix;
9+
import org.ejml.simple.SimpleMatrix;
10+
import org.la4j.Matrix;
11+
import org.la4j.matrix.dense.Basic2DMatrix;
12+
import org.nd4j.linalg.api.ndarray.INDArray;
13+
import org.nd4j.linalg.factory.Nd4j;
14+
import org.openjdk.jmh.annotations.Benchmark;
15+
import org.openjdk.jmh.annotations.Mode;
16+
import org.openjdk.jmh.runner.Runner;
17+
import org.openjdk.jmh.runner.options.ChainedOptionsBuilder;
18+
import org.openjdk.jmh.runner.options.OptionsBuilder;
19+
20+
import java.util.Arrays;
21+
import java.util.Map;
22+
import java.util.concurrent.TimeUnit;
23+
import java.util.stream.Collectors;
24+
25+
public class BigMatrixMultiplicationBenchmarking {
26+
private static final int DEFAULT_FORKS = 2;
27+
private static final int DEFAULT_WARMUP_ITERATIONS = 5;
28+
private static final int DEFAULT_MEASUREMENT_ITERATIONS = 10;
29+
30+
public static void main(String[] args) throws Exception {
31+
Map<String, String> parameters = parseParameters(args);
32+
33+
ChainedOptionsBuilder builder = new OptionsBuilder()
34+
.include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
35+
.mode(Mode.AverageTime)
36+
.forks(forks(parameters))
37+
.warmupIterations(warmupIterations(parameters))
38+
.measurementIterations(measurementIterations(parameters))
39+
.timeUnit(TimeUnit.SECONDS);
40+
41+
parameters.forEach(builder::param);
42+
43+
new Runner(builder.build()).run();
44+
}
45+
46+
private static Map<String, String> parseParameters(String[] args) {
47+
return Arrays.stream(args)
48+
.map(arg -> arg.split("="))
49+
.collect(Collectors.toMap(
50+
arg -> arg[0],
51+
arg -> arg[1]
52+
));
53+
}
54+
55+
private static int forks(Map<String, String> parameters) {
56+
String forks = parameters.remove("forks");
57+
return parseOrDefault(forks, DEFAULT_FORKS);
58+
}
59+
60+
private static int warmupIterations(Map<String, String> parameters) {
61+
String warmups = parameters.remove("warmupIterations");
62+
return parseOrDefault(warmups, DEFAULT_WARMUP_ITERATIONS);
63+
}
64+
65+
private static int measurementIterations(Map<String, String> parameters) {
66+
String measurements = parameters.remove("measurementIterations");
67+
return parseOrDefault(measurements, DEFAULT_MEASUREMENT_ITERATIONS);
68+
}
69+
70+
private static int parseOrDefault(String parameter, int defaultValue) {
71+
return parameter != null ? Integer.parseInt(parameter) : defaultValue;
72+
}
73+
74+
@Benchmark
75+
public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
76+
return HomemadeMatrix.multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
77+
}
78+
79+
@Benchmark
80+
public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
81+
SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
82+
SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());
83+
84+
return firstMatrix.mult(secondMatrix);
85+
}
86+
87+
@Benchmark
88+
public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
89+
RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
90+
RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());
91+
92+
return firstMatrix.multiply(secondMatrix);
93+
}
94+
95+
@Benchmark
96+
public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
97+
Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
98+
Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());
99+
100+
return firstMatrix.multiply(secondMatrix);
101+
}
102+
103+
@Benchmark
104+
public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
105+
INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
106+
INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());
107+
108+
return firstMatrix.mmul(secondMatrix);
109+
}
110+
111+
@Benchmark
112+
public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
113+
DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
114+
115+
DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
116+
DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());
117+
118+
Algebra algebra = new Algebra();
119+
return algebra.mult(firstMatrix, secondMatrix);
120+
}
121+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.baeldung.matrices.benchmark;
2+
3+
import org.openjdk.jmh.annotations.Param;
4+
import org.openjdk.jmh.annotations.Scope;
5+
import org.openjdk.jmh.annotations.Setup;
6+
import org.openjdk.jmh.annotations.State;
7+
import org.openjdk.jmh.infra.BenchmarkParams;
8+
9+
import java.util.Random;
10+
import java.util.stream.DoubleStream;
11+
12+
@State(Scope.Benchmark)
13+
public class BigMatrixProvider {
14+
@Param({})
15+
private int matrixSize;
16+
private double[][] firstMatrix;
17+
private double[][] secondMatrix;
18+
19+
public BigMatrixProvider() {}
20+
21+
@Setup
22+
public void setup(BenchmarkParams parameters) {
23+
firstMatrix = createMatrix(matrixSize);
24+
secondMatrix = createMatrix(matrixSize);
25+
}
26+
27+
private double[][] createMatrix(int matrixSize) {
28+
Random random = new Random();
29+
30+
double[][] result = new double[matrixSize][matrixSize];
31+
for (int row = 0; row < result.length; row++) {
32+
for (int col = 0; col < result[row].length; col++) {
33+
result[row][col] = random.nextDouble();
34+
}
35+
}
36+
return result;
37+
}
38+
39+
public double[][] getFirstMatrix() {
40+
return firstMatrix;
41+
}
42+
43+
public double[][] getSecondMatrix() {
44+
return secondMatrix;
45+
}
46+
}

java-math/src/test/java/com/baeldung/matrices/MatrixMultiplicationBenchmarking.java renamed to java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixMultiplicationBenchmarking.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
package com.baeldung.matrices;
1+
package com.baeldung.matrices.benchmark;
22

33
import cern.colt.matrix.DoubleFactory2D;
44
import cern.colt.matrix.DoubleMatrix2D;
55
import cern.colt.matrix.linalg.Algebra;
6+
import com.baeldung.matrices.HomemadeMatrix;
67
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
78
import org.apache.commons.math3.linear.RealMatrix;
89
import org.ejml.simple.SimpleMatrix;
@@ -23,9 +24,10 @@ public class MatrixMultiplicationBenchmarking {
2324
public static void main(String[] args) throws Exception {
2425
Options opt = new OptionsBuilder()
2526
.include(MatrixMultiplicationBenchmarking.class.getSimpleName())
27+
.exclude(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
2628
.mode(Mode.AverageTime)
2729
.forks(2)
28-
.warmupIterations(5)
30+
.warmupIterations(10)
2931
.measurementIterations(10)
3032
.timeUnit(TimeUnit.MICROSECONDS)
3133
.build();

java-math/src/test/java/com/baeldung/matrices/MatrixProvider.java renamed to java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.baeldung.matrices;
1+
package com.baeldung.matrices.benchmark;
22

33
import org.openjdk.jmh.annotations.Scope;
44
import org.openjdk.jmh.annotations.State;

0 commit comments

Comments
 (0)