Skip to content

Commit 468c8e0

Browse files
authored
Merge pull request eugenp#7988 from alessiostalla/BAEL-18260
#BAEL-18260 Restructure ml and deeplearning4j modules
2 parents 2a746e6 + 8956aee commit 468c8e0

File tree

10 files changed

+42
-95
lines changed

10 files changed

+42
-95
lines changed

deeplearning4j/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## Deeplearning4j
22

3-
This module contains articles about Deeplearning4j
3+
This module contains articles about Deeplearning4j.
44

55
### Relevant Articles:
66
- [A Guide to Deeplearning4j](https://www.baeldung.com/deeplearning4j)
7+
- [Logistic Regression in Java](https://www.baeldung.com/java-logistic-regression)

deeplearning4j/pom.xml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
</parent>
1515

1616
<dependencies>
17+
<dependency>
18+
<groupId>org.nd4j</groupId>
19+
<artifactId>nd4j-api</artifactId>
20+
<version>${dl4j.version}</version>
21+
</dependency>
1722
<dependency>
1823
<groupId>org.nd4j</groupId>
1924
<artifactId>nd4j-native-platform</artifactId>
@@ -24,10 +29,26 @@
2429
<artifactId>deeplearning4j-core</artifactId>
2530
<version>${dl4j.version}</version>
2631
</dependency>
32+
<dependency>
33+
<groupId>org.deeplearning4j</groupId>
34+
<artifactId>deeplearning4j-nn</artifactId>
35+
<version>${dl4j.version}</version>
36+
</dependency>
37+
<!-- https://mvnrepository.com/artifact/org.datavec/datavec-api -->
38+
<dependency>
39+
<groupId>org.datavec</groupId>
40+
<artifactId>datavec-api</artifactId>
41+
<version>${dl4j.version}</version>
42+
</dependency>
43+
<dependency>
44+
<groupId>org.apache.httpcomponents</groupId>
45+
<artifactId>httpclient</artifactId>
46+
<version>4.3.5</version>
47+
</dependency>
2748
</dependencies>
2849

2950
<properties>
30-
<dl4j.version>0.9.1</dl4j.version>
51+
<dl4j.version>0.9.1</dl4j.version> <!-- Latest non beta version -->
3152
</properties>
3253

33-
</project>
54+
</project>

deeplearning4j/src/main/java/com/baeldung/deeplearning4j/IrisClassifier.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import org.datavec.api.records.reader.RecordReader;
44
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
55
import org.datavec.api.split.FileSplit;
6-
import org.datavec.api.util.ClassPathResource;
76
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
87
import org.deeplearning4j.eval.Evaluation;
8+
import org.deeplearning4j.nn.conf.BackpropType;
99
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
1010
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
1111
import org.deeplearning4j.nn.conf.layers.DenseLayer;
@@ -19,6 +19,7 @@
1919
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
2020
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
2121
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
22+
import org.nd4j.linalg.io.ClassPathResource;
2223
import org.nd4j.linalg.lossfunctions.LossFunctions;
2324

2425
import java.io.IOException;
@@ -52,8 +53,8 @@ public static void main(String[] args) throws IOException, InterruptedException
5253
.iterations(1000)
5354
.activation(Activation.TANH)
5455
.weightInit(WeightInit.XAVIER)
55-
.learningRate(0.1)
56-
.regularization(true).l2(0.0001)
56+
.regularization(true)
57+
.learningRate(0.1).l2(0.0001)
5758
.list()
5859
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3)
5960
.build())
@@ -62,14 +63,14 @@ public static void main(String[] args) throws IOException, InterruptedException
6263
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
6364
.activation(Activation.SOFTMAX)
6465
.nIn(3).nOut(CLASSES_COUNT).build())
65-
.backprop(true).pretrain(false)
66+
.backpropType(BackpropType.Standard).pretrain(false)
6667
.build();
6768

6869
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
6970
model.init();
7071
model.fit(trainingData);
7172

72-
INDArray output = model.output(testData.getFeatureMatrix());
73+
INDArray output = model.output(testData.getFeatures());
7374

7475
Evaluation eval = new Evaluation(CLASSES_COUNT);
7576
eval.eval(testData.getLabels(), output);

ml/src/main/java/com/baeldung/logreg/MnistClassifier.java renamed to deeplearning4j/src/main/java/com/baeldung/logreg/MnistClassifier.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.datavec.image.loader.NativeImageLoader;
1111
import org.datavec.image.recordreader.ImageRecordReader;
1212
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
13+
import org.deeplearning4j.eval.Evaluation;
1314
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
1415
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
1516
import org.deeplearning4j.nn.conf.inputs.InputType;
@@ -21,15 +22,12 @@
2122
import org.deeplearning4j.nn.weights.WeightInit;
2223
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
2324
import org.deeplearning4j.util.ModelSerializer;
24-
import org.nd4j.evaluation.classification.Evaluation;
2525
import org.nd4j.linalg.activations.Activation;
2626
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
2727
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
2828
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
2929
import org.nd4j.linalg.learning.config.Nesterovs;
3030
import org.nd4j.linalg.lossfunctions.LossFunctions;
31-
import org.nd4j.linalg.schedule.MapSchedule;
32-
import org.nd4j.linalg.schedule.ScheduleType;
3331
import org.slf4j.Logger;
3432
import org.slf4j.LoggerFactory;
3533

@@ -44,7 +42,7 @@
4442

4543
public class MnistClassifier {
4644
private static final Logger logger = LoggerFactory.getLogger(MnistClassifier.class);
47-
private static final String basePath = System.getProperty("java.io.tmpdir") + "mnist" + File.separator;
45+
private static final String basePath = System.getProperty("java.io.tmpdir") + File.separator + "mnist" + File.separator;
4846
private static final File modelPath = new File(basePath + "mnist-model.zip");
4947
private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
5048

@@ -71,8 +69,7 @@ public static void main(String[] args) throws Exception {
7169
String localFilePath = basePath + "mnist_png.tar.gz";
7270
File file = new File(localFilePath);
7371
if (!file.exists()) {
74-
file.getParentFile()
75-
.mkdirs();
72+
file.getParentFile().mkdirs();
7673
Utils.downloadAndSave(dataUrl, file);
7774
Utils.extractTarArchive(file, basePath);
7875
}
@@ -135,15 +132,15 @@ public static void main(String[] args) throws Exception {
135132
.build();
136133
final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(seed)
137134
.l2(0.0005) // ridge regression value
138-
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
135+
.updater(new Nesterovs()) //TODO new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)
139136
.weightInit(WeightInit.XAVIER)
140137
.list()
141-
.layer(layer1)
142-
.layer(layer2)
143-
.layer(layer3)
144-
.layer(layer2)
145-
.layer(layer4)
146-
.layer(layer5)
138+
.layer(0, layer1)
139+
.layer(1, layer2)
140+
.layer(2, layer3)
141+
.layer(3, layer2)
142+
.layer(4, layer4)
143+
.layer(5, layer5)
147144
.setInputType(InputType.convolutionalFlat(height, width, channels))
148145
.build();
149146

@@ -165,4 +162,4 @@ public static void main(String[] args) throws Exception {
165162
ModelSerializer.writeModel(model, modelPath, true);
166163
logger.info("The MINIST model has been saved in {}", modelPath.getPath());
167164
}
168-
}
165+
}

ml/README.md

Lines changed: 0 additions & 6 deletions
This file was deleted.

ml/pom.xml

Lines changed: 0 additions & 52 deletions
This file was deleted.

ml/src/main/resources/logback.xml

Lines changed: 0 additions & 13 deletions
This file was deleted.

pom.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,6 @@
563563
<module>metrics</module>
564564
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
565565
<module>microprofile</module>
566-
<module>ml</module>
567566
<module>msf4j</module>
568567
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
569568
<module>mustache</module>
@@ -1322,7 +1321,6 @@
13221321
<module>metrics</module>
13231322
<!-- <module>micronaut</module> --> <!-- Fixing in BAEL-10877 -->
13241323
<module>microprofile</module>
1325-
<module>ml</module>
13261324
<module>msf4j</module>
13271325
<!-- <module>muleesb</module> --> <!-- Fixing in BAEL-10878 -->
13281326
<module>mustache</module>

0 commit comments

Comments
 (0)