Skip to content

Commit 148e421

Browse files
YassinHajajmaibin
authored andcommitted
BAEL-3204 (eugenp#8085)
* BAEL-3204 * BAEL-3204
1 parent a602f92 commit 148e421

File tree

8 files changed

+384
-0
lines changed

8 files changed

+384
-0
lines changed

machine-learning/pom.xml

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
2+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
5+
<artifactId>machine-learning</artifactId>
6+
<version>1.0-SNAPSHOT</version>
7+
<name>Supervised Learning</name>
8+
<packaging>jar</packaging>
9+
10+
<parent>
11+
<groupId>com.baeldung</groupId>
12+
<artifactId>parent-modules</artifactId>
13+
<version>1.0.0-SNAPSHOT</version>
14+
</parent>
15+
16+
<properties>
17+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
18+
<maven.compiler.source>1.7</maven.compiler.source>
19+
<maven.compiler.target>1.7</maven.compiler.target>
20+
<kotlin.version>1.3.50</kotlin.version>
21+
<dl4j.version>0.9.1</dl4j.version>
22+
</properties>
23+
24+
<dependencies>
25+
<dependency>
26+
<groupId>org.jetbrains.kotlin</groupId>
27+
<artifactId>kotlin-stdlib-jdk8</artifactId>
28+
<version>${kotlin.version}</version>
29+
</dependency>
30+
<dependency>
31+
<groupId>org.nd4j</groupId>
32+
<artifactId>nd4j-native-platform</artifactId>
33+
<version>${dl4j.version}</version>
34+
</dependency>
35+
<dependency>
36+
<groupId>org.deeplearning4j</groupId>
37+
<artifactId>deeplearning4j-core</artifactId>
38+
<version>${dl4j.version}</version>
39+
</dependency>
40+
<dependency>
41+
<groupId>org.jetbrains.kotlin</groupId>
42+
<artifactId>kotlin-stdlib-jdk8</artifactId>
43+
<version>${kotlin.version}</version>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.jetbrains.kotlin</groupId>
47+
<artifactId>kotlin-test</artifactId>
48+
<version>${kotlin.version}</version>
49+
<scope>test</scope>
50+
</dependency>
51+
<dependency>
52+
<groupId>org.jetbrains.kotlin</groupId>
53+
<artifactId>kotlin-stdlib-jdk8</artifactId>
54+
<version>${kotlin.version}</version>
55+
</dependency>
56+
</dependencies>
57+
<build>
58+
<sourceDirectory>src/main/kotlin</sourceDirectory>
59+
<testSourceDirectory>src/test</testSourceDirectory>
60+
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
61+
<plugins>
62+
<!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
63+
<plugin>
64+
<artifactId>maven-clean-plugin</artifactId>
65+
<version>3.1.0</version>
66+
</plugin>
67+
<!-- default lifecycle, jar packaging: see https://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
68+
<plugin>
69+
<artifactId>maven-resources-plugin</artifactId>
70+
<version>3.0.2</version>
71+
</plugin>
72+
<plugin>
73+
<artifactId>maven-compiler-plugin</artifactId>
74+
<version>3.8.0</version>
75+
</plugin>
76+
<plugin>
77+
<artifactId>maven-surefire-plugin</artifactId>
78+
<version>2.22.1</version>
79+
</plugin>
80+
<plugin>
81+
<artifactId>maven-jar-plugin</artifactId>
82+
<version>3.0.2</version>
83+
</plugin>
84+
<plugin>
85+
<artifactId>maven-install-plugin</artifactId>
86+
<version>2.5.2</version>
87+
</plugin>
88+
<plugin>
89+
<artifactId>maven-deploy-plugin</artifactId>
90+
<version>2.8.2</version>
91+
</plugin>
92+
<!-- site lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#site_Lifecycle -->
93+
<plugin>
94+
<artifactId>maven-site-plugin</artifactId>
95+
<version>3.7.1</version>
96+
</plugin>
97+
<plugin>
98+
<artifactId>maven-project-info-reports-plugin</artifactId>
99+
<version>3.0.0</version>
100+
</plugin>
101+
</plugins>
102+
</pluginManagement>
103+
<plugins>
104+
<plugin>
105+
<groupId>org.jetbrains.kotlin</groupId>
106+
<artifactId>kotlin-maven-plugin</artifactId>
107+
<version>${kotlin.version}</version>
108+
<executions>
109+
<execution>
110+
<id>compile</id>
111+
<phase>compile</phase>
112+
<goals>
113+
<goal>compile</goal>
114+
</goals>
115+
</execution>
116+
<execution>
117+
<id>test-compile</id>
118+
<phase>test-compile</phase>
119+
<goals>
120+
<goal>test-compile</goal>
121+
</goals>
122+
</execution>
123+
</executions>
124+
<configuration>
125+
<jvmTarget>1.8</jvmTarget>
126+
</configuration>
127+
</plugin>
128+
<plugin>
129+
<groupId>org.apache.maven.plugins</groupId>
130+
<artifactId>maven-compiler-plugin</artifactId>
131+
<executions>
132+
<execution>
133+
<id>compile</id>
134+
<phase>compile</phase>
135+
<goals>
136+
<goal>compile</goal>
137+
</goals>
138+
</execution>
139+
<execution>
140+
<id>testCompile</id>
141+
<phase>test-compile</phase>
142+
<goals>
143+
<goal>testCompile</goal>
144+
</goals>
145+
</execution>
146+
</executions>
147+
</plugin>
148+
</plugins>
149+
</build>
150+
</project>
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package com.baeldung.cnn
2+
3+
import org.datavec.api.records.reader.impl.collection.ListStringRecordReader
4+
import org.datavec.api.split.ListStringSplit
5+
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator
6+
import org.deeplearning4j.eval.Evaluation
7+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
8+
import org.deeplearning4j.nn.conf.inputs.InputType
9+
import org.deeplearning4j.nn.conf.layers.*
10+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
11+
import org.deeplearning4j.nn.weights.WeightInit
12+
import org.nd4j.linalg.activations.Activation
13+
import org.nd4j.linalg.learning.config.Adam
14+
import org.nd4j.linalg.lossfunctions.LossFunctions
15+
16+
object ConvolutionalNeuralNetwork {
17+
18+
@JvmStatic
19+
fun main(args: Array<String>) {
20+
val dataset = ZalandoMNISTDataSet().load()
21+
dataset.shuffle()
22+
val trainDatasetIterator = createDatasetIterator(dataset.subList(0, 50_000))
23+
val testDatasetIterator = createDatasetIterator(dataset.subList(50_000, 60_000))
24+
25+
val cnn = buildCNN()
26+
learning(cnn, trainDatasetIterator)
27+
testing(cnn, testDatasetIterator)
28+
}
29+
30+
private fun createDatasetIterator(dataset: MutableList<List<String>>): RecordReaderDataSetIterator {
31+
val listStringRecordReader = ListStringRecordReader()
32+
listStringRecordReader.initialize(ListStringSplit(dataset))
33+
return RecordReaderDataSetIterator(listStringRecordReader, 128, 28 * 28, 10)
34+
}
35+
36+
private fun buildCNN(): MultiLayerNetwork {
37+
val multiLayerNetwork = MultiLayerNetwork(NeuralNetConfiguration.Builder()
38+
.seed(123)
39+
.l2(0.0005)
40+
.updater(Adam())
41+
.weightInit(WeightInit.XAVIER)
42+
.list()
43+
.layer(0, buildInitialConvolutionLayer())
44+
.layer(1, buildBatchNormalizationLayer())
45+
.layer(2, buildPoolingLayer())
46+
.layer(3, buildConvolutionLayer())
47+
.layer(4, buildBatchNormalizationLayer())
48+
.layer(5, buildPoolingLayer())
49+
.layer(6, buildDenseLayer())
50+
.layer(7, buildBatchNormalizationLayer())
51+
.layer(8, buildDenseLayer())
52+
.layer(9, buildOutputLayer())
53+
.setInputType(InputType.convolutionalFlat(28, 28, 1))
54+
.backprop(true)
55+
.build())
56+
multiLayerNetwork.init()
57+
return multiLayerNetwork
58+
}
59+
60+
private fun buildOutputLayer(): OutputLayer? {
61+
return OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
62+
.nOut(10)
63+
.activation(Activation.SOFTMAX)
64+
.build()
65+
}
66+
67+
private fun buildDenseLayer(): DenseLayer? {
68+
return DenseLayer.Builder().activation(Activation.RELU)
69+
.nOut(500)
70+
.dropOut(0.5)
71+
.build()
72+
}
73+
74+
private fun buildPoolingLayer(): SubsamplingLayer? {
75+
return SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
76+
.kernelSize(2, 2)
77+
.stride(2, 2)
78+
.build()
79+
}
80+
81+
private fun buildBatchNormalizationLayer() = BatchNormalization.Builder().build()
82+
83+
private fun buildConvolutionLayer(): ConvolutionLayer? {
84+
return ConvolutionLayer.Builder(5, 5)
85+
.stride(1, 1) // nIn need not specified in later layers
86+
.nOut(50)
87+
.activation(Activation.IDENTITY)
88+
.build()
89+
}
90+
91+
private fun buildInitialConvolutionLayer(): ConvolutionLayer? {
92+
return ConvolutionLayer.Builder(5, 5)
93+
.nIn(1)
94+
.stride(1, 1)
95+
.nOut(20)
96+
.activation(Activation.IDENTITY)
97+
.build()
98+
}
99+
100+
private fun learning(cnn: MultiLayerNetwork, trainSet: RecordReaderDataSetIterator) {
101+
for (i in 0 until 10) {
102+
cnn.fit(trainSet)
103+
}
104+
}
105+
106+
private fun testing(cnn: MultiLayerNetwork, testSet: RecordReaderDataSetIterator) {
107+
val evaluation = Evaluation(10)
108+
while (testSet.hasNext()) {
109+
val next = testSet.next()
110+
val output = cnn.output(next.features)
111+
evaluation.eval(next.labels, output)
112+
}
113+
114+
println(evaluation.stats())
115+
println(evaluation.confusionToString())
116+
}
117+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.baeldung.cnn
2+
3+
import java.io.File
4+
import java.nio.ByteBuffer
5+
import java.util.*
6+
import java.util.stream.Collectors
7+
import kotlin.streams.asStream
8+
9+
class ZalandoMNISTDataSet {
10+
private val OFFSET_SIZE = 4 //in bytes
11+
private val NUM_ITEMS_OFFSET = 4
12+
private val ITEMS_SIZE = 4
13+
private val ROWS = 28
14+
private val COLUMNS = 28
15+
private val IMAGE_OFFSET = 16
16+
private val IMAGE_SIZE = ROWS * COLUMNS
17+
18+
fun load(): MutableList<List<String>> {
19+
val labelsFile = File("machine-learning/src/main/resources/train-labels-idx1-ubyte")
20+
val imagesFile = File("machine-learning/src/main/resources/train-images-idx3-ubyte")
21+
22+
val labelBytes = labelsFile.readBytes()
23+
val imageBytes = imagesFile.readBytes()
24+
25+
val byteLabelCount = Arrays.copyOfRange(labelBytes, NUM_ITEMS_OFFSET, NUM_ITEMS_OFFSET + ITEMS_SIZE)
26+
val numberOfLabels = ByteBuffer.wrap(byteLabelCount).int
27+
28+
val list = mutableListOf<List<String>>()
29+
30+
for (i in 0 until numberOfLabels) {
31+
val label = labelBytes[OFFSET_SIZE + ITEMS_SIZE + i]
32+
val startBoundary = i * IMAGE_SIZE + IMAGE_OFFSET
33+
val endBoundary = i * IMAGE_SIZE + IMAGE_OFFSET + IMAGE_SIZE
34+
val imageData = Arrays.copyOfRange(imageBytes, startBoundary, endBoundary)
35+
36+
val imageDataList = imageData.iterator()
37+
.asSequence()
38+
.asStream().map { b -> b.toString() }
39+
.collect(Collectors.toList())
40+
imageDataList.add(label.toString())
41+
list.add(imageDataList)
42+
}
43+
return list
44+
}
45+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.baeldung.simplelinearregression
2+
3+
import kotlin.math.pow
4+
5+
class SimpleLinearRegression(private val xs: List<Int>, private val ys: List<Int>) {
6+
var slope: Double = 0.0
7+
var yIntercept: Double = 0.0
8+
9+
init {
10+
val covariance = calculateCovariance(xs, ys)
11+
val variance = calculateVariance(xs)
12+
slope = calculateSlope(covariance, variance)
13+
yIntercept = calculateYIntercept(ys, slope, xs)
14+
}
15+
16+
fun predict(independentVariable: Double) = slope * independentVariable + yIntercept
17+
18+
fun calculateRSquared(): Double {
19+
val sst = ys.sumByDouble { y -> (y - ys.average()).pow(2) }
20+
val ssr = xs.zip(ys) { x, y -> (y - predict(x.toDouble())).pow(2) }.sum()
21+
return (sst - ssr) / sst
22+
}
23+
24+
private fun calculateYIntercept(ys: List<Int>, slope: Double, xs: List<Int>) = ys.average() - slope * xs.average()
25+
26+
private fun calculateSlope(covariance: Double, variance: Double) = covariance / variance
27+
28+
private fun calculateCovariance(xs: List<Int>, ys: List<Int>) = xs.zip(ys) { x, y -> (x - xs.average()) * (y - ys.average()) }.sum()
29+
30+
private fun calculateVariance(xs: List<Int>) = xs.sumByDouble { x -> (x - xs.average()).pow(2) }
31+
}
Binary file not shown.
Binary file not shown.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.baeldung.simplelinearregression
2+
3+
import org.junit.Assert.assertEquals
4+
import org.junit.jupiter.api.Test
5+
6+
class SimpleLinearRegressionUnitTest {
7+
@Test
8+
fun givenAProperDataSetWhenFedToASimpleLinearRegressionModelThenItPredictsCorrectly() {
9+
val xs = arrayListOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
10+
val ys = arrayListOf(25, 35, 49, 60, 75, 90, 115, 130, 150, 200)
11+
12+
val model = SimpleLinearRegression(xs, ys)
13+
14+
val predictionOne = model.predict(2.5)
15+
assertEquals(38.99, predictionOne, 0.01)
16+
17+
val predictionTwo = model.predict(7.5)
18+
assertEquals(128.84, predictionTwo, 0.01)
19+
}
20+
21+
@Test
22+
fun givenAPredictableDataSetWhenCalculatingTheLossFunctionThenTheModelIsConsideredReliable() {
23+
val xs = arrayListOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
24+
val ys = arrayListOf(25, 35, 49, 60, 75, 90, 115, 130, 150, 200)
25+
26+
val model = SimpleLinearRegression(xs, ys)
27+
28+
assertEquals(0.95, model.calculateRSquared(), 0.01)
29+
}
30+
31+
@Test
32+
fun givenAnUnpredictableDataSetWhenCalculatingTheLossFunctionThenTheModelIsConsideredUnreliable() {
33+
val xs = arrayListOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
34+
val ys = arrayListOf(200, 0, 200, 0, 0, 0, -115, 1000, 0, 1)
35+
36+
val model = SimpleLinearRegression(xs, ys)
37+
38+
assertEquals(0.01, model.calculateRSquared(), 0.01)
39+
}
40+
}

pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@
625625

626626
<module>spring-boot-nashorn</module>
627627
<module>java-blockchain</module>
628+
<module>machine-learning</module>
628629
<module>wildfly</module>
629630
</modules>
630631

0 commit comments

Comments
 (0)