1
+ package com .baeldung .matrices .benchmark ;
2
+
3
+ import cern .colt .matrix .DoubleFactory2D ;
4
+ import cern .colt .matrix .DoubleMatrix2D ;
5
+ import cern .colt .matrix .linalg .Algebra ;
6
+ import com .baeldung .matrices .HomemadeMatrix ;
7
+ import org .apache .commons .math3 .linear .Array2DRowRealMatrix ;
8
+ import org .apache .commons .math3 .linear .RealMatrix ;
9
+ import org .ejml .simple .SimpleMatrix ;
10
+ import org .la4j .Matrix ;
11
+ import org .la4j .matrix .dense .Basic2DMatrix ;
12
+ import org .nd4j .linalg .api .ndarray .INDArray ;
13
+ import org .nd4j .linalg .factory .Nd4j ;
14
+ import org .openjdk .jmh .annotations .Benchmark ;
15
+ import org .openjdk .jmh .annotations .Mode ;
16
+ import org .openjdk .jmh .runner .Runner ;
17
+ import org .openjdk .jmh .runner .options .ChainedOptionsBuilder ;
18
+ import org .openjdk .jmh .runner .options .OptionsBuilder ;
19
+
20
+ import java .util .Arrays ;
21
+ import java .util .Map ;
22
+ import java .util .concurrent .TimeUnit ;
23
+ import java .util .stream .Collectors ;
24
+
25
+ public class BigMatrixMultiplicationBenchmarking {
26
+ private static final int DEFAULT_FORKS = 2 ;
27
+ private static final int DEFAULT_WARMUP_ITERATIONS = 5 ;
28
+ private static final int DEFAULT_MEASUREMENT_ITERATIONS = 10 ;
29
+
30
+ public static void main (String [] args ) throws Exception {
31
+ Map <String , String > parameters = parseParameters (args );
32
+
33
+ ChainedOptionsBuilder builder = new OptionsBuilder ()
34
+ .include (BigMatrixMultiplicationBenchmarking .class .getSimpleName ())
35
+ .mode (Mode .AverageTime )
36
+ .forks (forks (parameters ))
37
+ .warmupIterations (warmupIterations (parameters ))
38
+ .measurementIterations (measurementIterations (parameters ))
39
+ .timeUnit (TimeUnit .SECONDS );
40
+
41
+ parameters .forEach (builder ::param );
42
+
43
+ new Runner (builder .build ()).run ();
44
+ }
45
+
46
+ private static Map <String , String > parseParameters (String [] args ) {
47
+ return Arrays .stream (args )
48
+ .map (arg -> arg .split ("=" ))
49
+ .collect (Collectors .toMap (
50
+ arg -> arg [0 ],
51
+ arg -> arg [1 ]
52
+ ));
53
+ }
54
+
55
+ private static int forks (Map <String , String > parameters ) {
56
+ String forks = parameters .remove ("forks" );
57
+ return parseOrDefault (forks , DEFAULT_FORKS );
58
+ }
59
+
60
+ private static int warmupIterations (Map <String , String > parameters ) {
61
+ String warmups = parameters .remove ("warmupIterations" );
62
+ return parseOrDefault (warmups , DEFAULT_WARMUP_ITERATIONS );
63
+ }
64
+
65
+ private static int measurementIterations (Map <String , String > parameters ) {
66
+ String measurements = parameters .remove ("measurementIterations" );
67
+ return parseOrDefault (measurements , DEFAULT_MEASUREMENT_ITERATIONS );
68
+ }
69
+
70
+ private static int parseOrDefault (String parameter , int defaultValue ) {
71
+ return parameter != null ? Integer .parseInt (parameter ) : defaultValue ;
72
+ }
73
+
74
+ @ Benchmark
75
+ public Object homemadeMatrixMultiplication (BigMatrixProvider matrixProvider ) {
76
+ return HomemadeMatrix .multiplyMatrices (matrixProvider .getFirstMatrix (), matrixProvider .getSecondMatrix ());
77
+ }
78
+
79
+ @ Benchmark
80
+ public Object ejmlMatrixMultiplication (BigMatrixProvider matrixProvider ) {
81
+ SimpleMatrix firstMatrix = new SimpleMatrix (matrixProvider .getFirstMatrix ());
82
+ SimpleMatrix secondMatrix = new SimpleMatrix (matrixProvider .getSecondMatrix ());
83
+
84
+ return firstMatrix .mult (secondMatrix );
85
+ }
86
+
87
+ @ Benchmark
88
+ public Object apacheCommonsMatrixMultiplication (BigMatrixProvider matrixProvider ) {
89
+ RealMatrix firstMatrix = new Array2DRowRealMatrix (matrixProvider .getFirstMatrix ());
90
+ RealMatrix secondMatrix = new Array2DRowRealMatrix (matrixProvider .getSecondMatrix ());
91
+
92
+ return firstMatrix .multiply (secondMatrix );
93
+ }
94
+
95
+ @ Benchmark
96
+ public Object la4jMatrixMultiplication (BigMatrixProvider matrixProvider ) {
97
+ Matrix firstMatrix = new Basic2DMatrix (matrixProvider .getFirstMatrix ());
98
+ Matrix secondMatrix = new Basic2DMatrix (matrixProvider .getSecondMatrix ());
99
+
100
+ return firstMatrix .multiply (secondMatrix );
101
+ }
102
+
103
+ @ Benchmark
104
+ public Object nd4jMatrixMultiplication (BigMatrixProvider matrixProvider ) {
105
+ INDArray firstMatrix = Nd4j .create (matrixProvider .getFirstMatrix ());
106
+ INDArray secondMatrix = Nd4j .create (matrixProvider .getSecondMatrix ());
107
+
108
+ return firstMatrix .mmul (secondMatrix );
109
+ }
110
+
111
+ @ Benchmark
112
+ public Object coltMatrixMultiplication (BigMatrixProvider matrixProvider ) {
113
+ DoubleFactory2D doubleFactory2D = DoubleFactory2D .dense ;
114
+
115
+ DoubleMatrix2D firstMatrix = doubleFactory2D .make (matrixProvider .getFirstMatrix ());
116
+ DoubleMatrix2D secondMatrix = doubleFactory2D .make (matrixProvider .getSecondMatrix ());
117
+
118
+ Algebra algebra = new Algebra ();
119
+ return algebra .mult (firstMatrix , secondMatrix );
120
+ }
121
+ }
0 commit comments