diff --git a/my-first-dl4j-project/pom.xml b/my-first-dl4j-project/pom.xml index 88713ebf..18aeae10 100644 --- a/my-first-dl4j-project/pom.xml +++ b/my-first-dl4j-project/pom.xml @@ -35,8 +35,8 @@ UTF-8 - 1.0.0-beta7 - 1.0.0-beta7 + 1.0.0-M1.1 + 1.0.0-M1.1 1.8 2.0.0-alpha1 2.4.3 @@ -55,14 +55,14 @@ - + - + @@ -73,6 +73,12 @@ ${dl4j.version} + + + org.deeplearning4j + deeplearning4j-ui + ${dl4j.version} + diff --git a/my-first-dl4j-project/src/main/java/ai/certifai/MyFirstDL4JProject.java b/my-first-dl4j-project/src/main/java/ai/certifai/MyFirstDL4JProject.java index 597152fb..f496560e 100644 --- a/my-first-dl4j-project/src/main/java/ai/certifai/MyFirstDL4JProject.java +++ b/my-first-dl4j-project/src/main/java/ai/certifai/MyFirstDL4JProject.java @@ -20,6 +20,7 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; +import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -28,6 +29,9 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.ui.api.UIServer; +import org.deeplearning4j.ui.model.stats.StatsListener; +import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -86,7 +90,12 @@ public static void main( String[] args ) throws Exception { MultiLayerNetwork network = new MultiLayerNetwork(config); network.init(); - network.setListeners(new ScoreIterationListener(1)); + + // Setup UI and listeners + UIServer server = UIServer.getInstance(); + StatsStorage storage = new InMemoryStatsStorage(); + server.attach(storage); + network.setListeners(new ScoreIterationListener(1), new StatsListener(storage)); network.fit(iterator, 10);