import java.io.FileNotFoundException; /** * */ /** * @author nieminen * */ public class DemoMNIST { /** * @param args * @throws FileNotFoundException */ public static void main(String[] args) throws FileNotFoundException { DataSet mnist = DataSet.createFromTextFile("mnist_reduced.dat",1); // In our reduced set, we took 750/250 first samples from the real MNIST. mnist.splitToTrainAndTest(.75); mnist.normalize(-1.0,1.0); int[] layerSizes = new int[]{mnist.getVecSize(),10,mnist.getNumOfClasses()}; SimpleMLP mlp = new SimpleMLP(layerSizes); for(int i=0;i<100;i++){ mlp.trainGD(mnist.trainInputs(), mnist.trainTargets(), 0.1, 100); System.out.println("Training set:"); int[] outclasses = mlp.classifyMatrix(mnist.trainInputs()); DataSet.printConfusion(outclasses,mnist.trainTargets()); System.out.println("Test set:"); outclasses = mlp.classifyMatrix(mnist.testInputs()); DataSet.printConfusion(outclasses,mnist.testTargets()); } } }