Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ cascade_train_debug
xor_test_debug
xor_test_fixed_debug
xor_train_debug
robot_adam
robot_adam.net

4 changes: 2 additions & 2 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ GCC=gcc
CFLAGS=-I ../src/include
LDFLAGS=-L ../src/

TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug
TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train robot_adam
DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug robot_adam_debug

all: $(TARGETS)

Expand Down
72 changes: 72 additions & 0 deletions examples/robot_adam.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Fast Artificial Neural Network Library (fann)
Copyright (C) 2003-2016 Steffen Nissen (steffen.fann@gmail.com)

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/

#include <stdio.h>

#include "fann.h"

int main()
{
const unsigned int num_layers = 3;
const unsigned int num_neurons_hidden = 96;
const float desired_error = (const float) 0.00001;
struct fann *ann;
struct fann_train_data *train_data, *test_data;

unsigned int i = 0;

printf("Creating network.\n");

train_data = fann_read_train_from_file("../datasets/robot.train");

ann = fann_create_standard(num_layers,
train_data->num_input, num_neurons_hidden, train_data->num_output);

printf("Training network.\n");

fann_set_training_algorithm(ann, FANN_TRAIN_ADAM);
fann_set_adam_beta1(ann, 0.9f);
fann_set_adam_beta2(ann, 0.999f);
fann_set_adam_epsilon(ann, 1e-8f);
fann_set_learning_rate(ann, 0.05f);

fann_train_on_data(ann, train_data, 10000, 100, desired_error);

printf("Testing network.\n");

test_data = fann_read_train_from_file("../datasets/robot.test");

fann_reset_MSE(ann);
for(i = 0; i < fann_length_train_data(test_data); i++)
{
fann_test(ann, test_data->input[i], test_data->output[i]);
}
printf("MSE error on test data: %f\n", fann_get_MSE(ann));

printf("Saving network.\n");

fann_save(ann, "robot_adam.net");

printf("Cleaning up.\n");
fann_destroy_train(train_data);
fann_destroy_train(test_data);
fann_destroy(ann);

return 0;
}
44 changes: 44 additions & 0 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ FANN_EXTERNAL void FANN_API fann_destroy(struct fann *ann) {
fann_safe_free(ann->prev_train_slopes);
fann_safe_free(ann->prev_steps);
fann_safe_free(ann->prev_weights_deltas);
fann_safe_free(ann->adam_m);
fann_safe_free(ann->adam_v);
fann_safe_free(ann->errstr);
fann_safe_free(ann->cascade_activation_functions);
fann_safe_free(ann->cascade_activation_steepnesses);
Expand Down Expand Up @@ -888,6 +890,14 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) {
copy->rprop_delta_max = orig->rprop_delta_max;
copy->rprop_delta_zero = orig->rprop_delta_zero;

/* Copy Adam optimizer parameters */
copy->adam_beta1 = orig->adam_beta1;
copy->adam_beta2 = orig->adam_beta2;
copy->adam_epsilon = orig->adam_epsilon;
copy->adam_timestep = orig->adam_timestep;
copy->adam_m = NULL;
copy->adam_v = NULL;

/* user_data is not deep copied. user should use fann_copy_with_user_data() for that */
copy->user_data = orig->user_data;

Expand Down Expand Up @@ -1008,6 +1018,29 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) {
copy->total_connections_allocated * sizeof(fann_type));
}

/* Copy Adam optimizer moment vectors if they exist */
if (orig->adam_m) {
copy->adam_m = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type));
if (copy->adam_m == NULL) {
fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy(copy);
return NULL;
}
memcpy(copy->adam_m, orig->adam_m,
copy->total_connections_allocated * sizeof(fann_type));
}

if (orig->adam_v) {
copy->adam_v = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type));
if (copy->adam_v == NULL) {
fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM);
fann_destroy(copy);
return NULL;
}
memcpy(copy->adam_v, orig->adam_v,
copy->total_connections_allocated * sizeof(fann_type));
}

return copy;
}

Expand Down Expand Up @@ -1171,6 +1204,9 @@ FANN_EXTERNAL void FANN_API fann_print_parameters(struct fann *ann) {
printf("RPROP decrease factor :%8.3f\n", ann->rprop_decrease_factor);
printf("RPROP delta min :%8.3f\n", ann->rprop_delta_min);
printf("RPROP delta max :%8.3f\n", ann->rprop_delta_max);
printf("Adam beta1 :%f\n", ann->adam_beta1);
printf("Adam beta2 :%f\n", ann->adam_beta2);
printf("Adam epsilon :%.8f\n", ann->adam_epsilon);
printf("Cascade output change fraction :%11.6f\n", ann->cascade_output_change_fraction);
printf("Cascade candidate change fraction :%11.6f\n", ann->cascade_candidate_change_fraction);
printf("Cascade output stagnation epochs :%4d\n", ann->cascade_output_stagnation_epochs);
Expand Down Expand Up @@ -1552,6 +1588,14 @@ struct fann *fann_allocate_structure(unsigned int num_layers) {
ann->sarprop_temperature = 0.015f;
ann->sarprop_epoch = 0;

/* Variables for use with Adam training (reasonable defaults) */
ann->adam_m = NULL;
ann->adam_v = NULL;
ann->adam_beta1 = 0.9f;
ann->adam_beta2 = 0.999f;
ann->adam_epsilon = 1e-8f;
ann->adam_timestep = 0;

fann_init_error_data((struct fann_error *)ann);

#ifdef FIXEDFANN
Expand Down
19 changes: 19 additions & 0 deletions src/fann_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ int fann_save_internal_fd(struct fann *ann, FILE *conf, const char *configuratio
fprintf(conf, "rprop_delta_min=%f\n", ann->rprop_delta_min);
fprintf(conf, "rprop_delta_max=%f\n", ann->rprop_delta_max);
fprintf(conf, "rprop_delta_zero=%f\n", ann->rprop_delta_zero);
fprintf(conf, "adam_beta1=%f\n", ann->adam_beta1);
fprintf(conf, "adam_beta2=%f\n", ann->adam_beta2);
fprintf(conf, "adam_epsilon=%.8f\n", ann->adam_epsilon);
fprintf(conf, "cascade_output_stagnation_epochs=%u\n", ann->cascade_output_stagnation_epochs);
fprintf(conf, "cascade_candidate_change_fraction=%f\n", ann->cascade_candidate_change_fraction);
fprintf(conf, "cascade_candidate_stagnation_epochs=%u\n",
Expand Down Expand Up @@ -322,6 +325,18 @@ struct fann *fann_create_from_fd_1_1(FILE *conf, const char *configuration_file)
} \
}

/* Optional scanf that sets a default value if the field is not present in the file.
* This is used for new parameters to maintain backward compatibility with older saved networks.
*/
#define fann_scanf_optional(type, name, val, default_val) \
{ \
long pos = ftell(conf); \
if (fscanf(conf, name "=" type "\n", val) != 1) { \
fseek(conf, pos, SEEK_SET); \
*(val) = (default_val); \
} \
}

#define fann_skip(name) \
{ \
if (fscanf(conf, name) != 0) { \
Expand Down Expand Up @@ -420,6 +435,10 @@ struct fann *fann_create_from_fd(FILE *conf, const char *configuration_file) {
fann_scanf("%f", "rprop_delta_min", &ann->rprop_delta_min);
fann_scanf("%f", "rprop_delta_max", &ann->rprop_delta_max);
fann_scanf("%f", "rprop_delta_zero", &ann->rprop_delta_zero);
/* Adam parameters are optional for backward compatibility with older saved networks */
fann_scanf_optional("%f", "adam_beta1", &ann->adam_beta1, 0.9f);
fann_scanf_optional("%f", "adam_beta2", &ann->adam_beta2, 0.999f);
fann_scanf_optional("%f", "adam_epsilon", &ann->adam_epsilon, 1e-8f);
fann_scanf("%u", "cascade_output_stagnation_epochs", &ann->cascade_output_stagnation_epochs);
fann_scanf("%f", "cascade_candidate_change_fraction", &ann->cascade_candidate_change_fraction);
fann_scanf("%u", "cascade_candidate_stagnation_epochs",
Expand Down
97 changes: 96 additions & 1 deletion src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,34 @@ void fann_clear_train_arrays(struct fann *ann) {
} else {
memset(ann->prev_train_slopes, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate and initialize Adam optimizer arrays if using Adam */
if (ann->training_algorithm == FANN_TRAIN_ADAM) {
/* Allocate first moment vector (m) */
if (ann->adam_m == NULL) {
ann->adam_m = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_m == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_m, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Allocate second moment vector (v) */
if (ann->adam_v == NULL) {
ann->adam_v = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type));
if (ann->adam_v == NULL) {
fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM);
return;
}
} else {
memset(ann->adam_v, 0, (ann->total_connections_allocated) * sizeof(fann_type));
}

/* Reset timestep */
ann->adam_timestep = 0;
}
}

/* INTERNAL FUNCTION
Expand Down Expand Up @@ -682,9 +710,70 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight,
}
}

/* INTERNAL FUNCTION
The Adam (Adaptive Moment Estimation) algorithm

Adam combines ideas from momentum and RMSProp:
- Maintains exponential moving averages of gradients (first moment, m)
- Maintains exponential moving averages of squared gradients (second moment, v)
- Uses bias correction to account for initialization at zero

Parameters:
- beta1: exponential decay rate for first moment (default 0.9)
- beta2: exponential decay rate for second moment (default 0.999)
- epsilon: small constant for numerical stability (default 1e-8)
*/
void fann_update_weights_adam(struct fann *ann, unsigned int num_data, unsigned int first_weight,
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
fann_type *weights = ann->weights;
fann_type *m = ann->adam_m;
fann_type *v = ann->adam_v;

const float learning_rate = ann->learning_rate;
const float beta1 = ann->adam_beta1;
const float beta2 = ann->adam_beta2;
const float epsilon = ann->adam_epsilon;
const float gradient_scale = 1.0f / num_data;

unsigned int i;
float gradient, m_hat, v_hat;
float beta1_t, beta2_t;

/* Increment timestep */
ann->adam_timestep++;

/* Compute bias correction terms: 1 - beta^t */
beta1_t = 1.0f - powf(beta1, (float)ann->adam_timestep);
beta2_t = 1.0f - powf(beta2, (float)ann->adam_timestep);

for (i = first_weight; i != past_end; i++) {
/* Compute gradient (average over batch) */
gradient = train_slopes[i] * gradient_scale;

/* Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t */
m[i] = beta1 * m[i] + (1.0f - beta1) * gradient;

/* Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 */
v[i] = beta2 * v[i] + (1.0f - beta2) * gradient * gradient;

/* Compute bias-corrected first moment: m_hat = m_t / (1 - beta1^t) */
m_hat = m[i] / beta1_t;

/* Compute bias-corrected second moment: v_hat = v_t / (1 - beta2^t) */
v_hat = v[i] / beta2_t;

/* Update weights: w_t = w_{t-1} + learning_rate * m_hat / (sqrt(v_hat) + epsilon) */
weights[i] += learning_rate * m_hat / (sqrt(v_hat) + epsilon);

/* Clear slope for next iteration */
train_slopes[i] = 0.0f;
}
}

/* INTERNAL FUNCTION
The SARprop- algorithm
*/
*/
void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight,
unsigned int past_end) {
fann_type *train_slopes = ann->train_slopes;
Expand Down Expand Up @@ -919,3 +1008,9 @@ FANN_GET_SET(float, sarprop_temperature)
FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
FANN_GET_SET(fann_type, bit_fail_limit)
FANN_GET_SET(float, learning_momentum)

FANN_GET_SET(float, adam_beta1)

FANN_GET_SET(float, adam_beta2)

FANN_GET_SET(float, adam_epsilon)
26 changes: 26 additions & 0 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,30 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data) {
return fann_get_MSE(ann);
}

/*
* Internal train function
*/
float fann_train_epoch_adam(struct fann *ann, struct fann_train_data *data) {
unsigned int i;

if (ann->adam_m == NULL) {
fann_clear_train_arrays(ann);
}

fann_reset_MSE(ann);

for (i = 0; i < data->num_data; i++) {
fann_run(ann, data->input[i]);
fann_compute_MSE(ann, data->output[i]);
fann_backpropagate_MSE(ann);
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
}

fann_update_weights_adam(ann, data->num_data, 0, ann->total_connections);

return fann_get_MSE(ann);
}

/*
* Internal train function
*/
Expand Down Expand Up @@ -211,6 +235,8 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai
return fann_train_epoch_irpropm(ann, data);
case FANN_TRAIN_SARPROP:
return fann_train_epoch_sarprop(ann, data);
case FANN_TRAIN_ADAM:
return fann_train_epoch_adam(ann, data);
case FANN_TRAIN_BATCH:
return fann_train_epoch_batch(ann, data);
case FANN_TRAIN_INCREMENTAL:
Expand Down
Loading