diff --git a/mnist/main.py b/mnist/main.py index 7d7899d9..c03c35ca 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -22,6 +22,7 @@ def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) + x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) @@ -76,8 +77,8 @@ def main(): help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', - help='number of epochs to train (default: 10)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 14)') parser.add_argument('--lr', type=float, default=1.0, metavar='LR', help='learning rate (default: 1.0)') parser.add_argument('--gamma', type=float, default=0.7, metavar='M', diff --git a/pytorch-example b/pytorch-example new file mode 160000 index 00000000..787e7533 --- /dev/null +++ b/pytorch-example @@ -0,0 +1 @@ +Subproject commit 787e75331fedbd3331a96f4ce64be609ca5df2e7