Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
[package]

name = "nn"
version = "0.1.6"
authors = ["Jack Montgomery <jackm321@gmail.com>"]
repository = "https://github.com/jackm321/RustNN"
documentation = "https://jackm321.github.io/RustNN/doc/nn/"
license = "Apache-2.0"
readme = "README.md"
version = "0.6.0"
authors = ["https://github.com/jackm321/RustNN"]

description = """
A multilayer feedforward backpropagation neural network library
"""
[dependencies]
rand = "0.3.*"
serde = "1.*"
serde_derive = "1.*"
serde_json = "1.*"

keywords = ["nn", "neural-network", "classifier", "backpropagation",
"machine-learning"]

[dependencies]
rand = "0.3.7"
rustc-serialize = "0.3.12"
time = "0.1.24"

[profile.dev]
opt-level = 3
lto = true
panic = "unwind"
debug = true
debug-assertions = true

[profile.test]
opt-level = 0
lto = false
panic = "unwind"
debug = true
debug-assertions = true

[profile.release]
opt-level = 3
lto = true
panic = "unwind"
debug = false
debug-assertions = false
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# RustNN

[![Build Status](https://travis-ci.org/jackm321/RustNN.svg?branch=master)](https://travis-ci.org/jackm321/RustNN)

An easy to use neural network library written in Rust.

[Crate](https://crates.io/crates/nn)

[Documentation](https://jackm321.github.io/RustNN/doc/nn/)
For the documentation take a look at the original library or generate it using "cargo doc".

## Description
RustNN is a [feedforward neural network ](http://en.wikipedia.org/wiki/Feedforward_neural_network)
Expand All @@ -15,6 +11,10 @@ generates fully connected multi-layer artificial neural networks that
are trained via [backpropagation](http://en.wikipedia.org/wiki/Backpropagation).
Networks are trained using an incremental training mode.

## Fork
This fork adds L2 regularization and several activation functions to the original crate. Additionally, there are a few minor improvements.
Lambda can be set just like the learning rate. The activation functions for hidden and output gets set in NN::new as second and third parameter respectively.

## XOR example

This example creates a neural network with `2` nodes in the input layer,
Expand All @@ -27,7 +27,7 @@ given examples. See the documentation for the `NN` and `Trainer` structs
for more details.

```rust
use nn::{NN, HaltCondition};
use nn::{NN, HaltCondition, Activation};

// create examples of the XOR function
// the network is trained on tuples of vectors where the first vector
Expand All @@ -43,7 +43,7 @@ let examples = [
// that specifies the number of layers and the number of nodes in each layer
// in this case we have an input layer with 2 nodes, one hidden layer
// with 3 nodes and the output layer has 1 node
let mut net = NN::new(&[2, 3, 1]);
let mut net = NN::new(&[2, 3, 1], Activation::PELU, Activation::Sigmoid);

// train the network on the examples of the XOR function
// all methods seen here are optional except go() which must be called to begin training
Expand Down
43 changes: 43 additions & 0 deletions examples/selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
extern crate nn;

use nn::{NN, HaltCondition, Activation};

const ACTIONS:u32 = 10;


fn main()
{
// create examples of the xor function
let mut examples = Vec::new();
for i in 0..ACTIONS
{
let mut result = Vec::new();
for j in 0..ACTIONS
{
if j == i { result.push(1.0); }
else { result.push(0.0); }
}
let example = (vec![i as f64], result);
examples.push(example);
}

// create a new neural network
let mut nn = NN::new(&[1, 10, ACTIONS], Activation::PELU, Activation::Sigmoid);

// train the network
nn.train(&examples)
.log_interval(Some(1000))
.halt_condition( HaltCondition::MSE(0.01) )
.rate(0.025)
.momentum(0.5)
.lambda(0.00005)
.go();

// print results of the trained network
for &(ref input, _) in examples.iter()
{
let result = nn.run(input);
let print:Vec<String> = result.iter().map(|x:&f64| { format!("{:4.2}", (*x * 100.0).round() / 100.0) }).collect();
println!("{:1.0} -> {:?}", input[0], print);
}
}
Loading