diff --git a/examples/.gitignore b/examples/.gitignore index 7cf8e682..bb90fcbf 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -16,3 +16,6 @@ cascade_train_debug xor_test_debug xor_test_fixed_debug xor_train_debug +robot_adam +robot_adam.net + diff --git a/examples/Makefile b/examples/Makefile index b3d9b0f3..db1b45f9 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -4,8 +4,8 @@ GCC=gcc CFLAGS=-I ../src/include LDFLAGS=-L ../src/ -TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train -DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug +TARGETS = xor_train xor_test xor_test_fixed simple_train steepness_train simple_test robot mushroom cascade_train scaling_test scaling_train robot_adam +DEBUG_TARGETS = xor_train_debug xor_test_debug xor_test_fixed_debug cascade_train_debug robot_adam_debug all: $(TARGETS) diff --git a/examples/robot_adam.c b/examples/robot_adam.c new file mode 100644 index 00000000..68a890b1 --- /dev/null +++ b/examples/robot_adam.c @@ -0,0 +1,72 @@ +/* +Fast Artificial Neural Network Library (fann) +Copyright (C) 2003-2016 Steffen Nissen (steffen.fann@gmail.com) + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +*/ + +#include + +#include "fann.h" + +int main() +{ + const unsigned int num_layers = 3; + const unsigned int num_neurons_hidden = 96; + const float desired_error = (const float) 0.00001; + struct fann *ann; + struct fann_train_data *train_data, *test_data; + + unsigned int i = 0; + + printf("Creating network.\n"); + + train_data = fann_read_train_from_file("../datasets/robot.train"); + + ann = fann_create_standard(num_layers, + train_data->num_input, num_neurons_hidden, train_data->num_output); + + printf("Training network.\n"); + + fann_set_training_algorithm(ann, FANN_TRAIN_ADAM); + fann_set_adam_beta1(ann, 0.9f); + fann_set_adam_beta2(ann, 0.999f); + fann_set_adam_epsilon(ann, 1e-8f); + fann_set_learning_rate(ann, 0.05f); + + fann_train_on_data(ann, train_data, 10000, 100, desired_error); + + printf("Testing network.\n"); + + test_data = fann_read_train_from_file("../datasets/robot.test"); + + fann_reset_MSE(ann); + for(i = 0; i < fann_length_train_data(test_data); i++) + { + fann_test(ann, test_data->input[i], test_data->output[i]); + } + printf("MSE error on test data: %f\n", fann_get_MSE(ann)); + + printf("Saving network.\n"); + + fann_save(ann, "robot_adam.net"); + + printf("Cleaning up.\n"); + fann_destroy_train(train_data); + fann_destroy_train(test_data); + fann_destroy(ann); + + return 0; +} diff --git a/src/fann.c b/src/fann.c index b1ee769a..4c0db12d 100644 --- a/src/fann.c +++ b/src/fann.c @@ -748,6 +748,8 @@ FANN_EXTERNAL void FANN_API fann_destroy(struct fann *ann) { fann_safe_free(ann->prev_train_slopes); fann_safe_free(ann->prev_steps); fann_safe_free(ann->prev_weights_deltas); + fann_safe_free(ann->adam_m); + fann_safe_free(ann->adam_v); fann_safe_free(ann->errstr); fann_safe_free(ann->cascade_activation_functions); fann_safe_free(ann->cascade_activation_steepnesses); @@ -888,6 +890,14 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) { copy->rprop_delta_max = orig->rprop_delta_max; copy->rprop_delta_zero = orig->rprop_delta_zero; + /* Copy Adam optimizer parameters */ + copy->adam_beta1 = orig->adam_beta1; + copy->adam_beta2 = orig->adam_beta2; + copy->adam_epsilon = orig->adam_epsilon; + copy->adam_timestep = orig->adam_timestep; + copy->adam_m = NULL; + copy->adam_v = NULL; + /* user_data is not deep copied. user should use fann_copy_with_user_data() for that */ copy->user_data = orig->user_data; @@ -1008,6 +1018,29 @@ FANN_EXTERNAL struct fann *FANN_API fann_copy(struct fann *orig) { copy->total_connections_allocated * sizeof(fann_type)); } + /* Copy Adam optimizer moment vectors if they exist */ + if (orig->adam_m) { + copy->adam_m = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type)); + if (copy->adam_m == NULL) { + fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM); + fann_destroy(copy); + return NULL; + } + memcpy(copy->adam_m, orig->adam_m, + copy->total_connections_allocated * sizeof(fann_type)); + } + + if (orig->adam_v) { + copy->adam_v = (fann_type *)malloc(copy->total_connections_allocated * sizeof(fann_type)); + if (copy->adam_v == NULL) { + fann_error((struct fann_error *)orig, FANN_E_CANT_ALLOCATE_MEM); + fann_destroy(copy); + return NULL; + } + memcpy(copy->adam_v, orig->adam_v, + copy->total_connections_allocated * sizeof(fann_type)); + } + return copy; } @@ -1171,6 +1204,9 @@ FANN_EXTERNAL void FANN_API fann_print_parameters(struct fann *ann) { printf("RPROP decrease factor :%8.3f\n", ann->rprop_decrease_factor); printf("RPROP delta min :%8.3f\n", ann->rprop_delta_min); printf("RPROP delta max :%8.3f\n", ann->rprop_delta_max); + printf("Adam beta1 :%f\n", ann->adam_beta1); + printf("Adam beta2 :%f\n", ann->adam_beta2); + printf("Adam epsilon :%.8f\n", ann->adam_epsilon); printf("Cascade output change fraction :%11.6f\n", ann->cascade_output_change_fraction); printf("Cascade candidate change fraction :%11.6f\n", ann->cascade_candidate_change_fraction); printf("Cascade output stagnation epochs :%4d\n", ann->cascade_output_stagnation_epochs); @@ -1552,6 +1588,14 @@ struct fann *fann_allocate_structure(unsigned int num_layers) { ann->sarprop_temperature = 0.015f; ann->sarprop_epoch = 0; + /* Variables for use with Adam training (reasonable defaults) */ + ann->adam_m = NULL; + ann->adam_v = NULL; + ann->adam_beta1 = 0.9f; + ann->adam_beta2 = 0.999f; + ann->adam_epsilon = 1e-8f; + ann->adam_timestep = 0; + fann_init_error_data((struct fann_error *)ann); #ifdef FIXEDFANN diff --git a/src/fann_io.c b/src/fann_io.c index bafb63eb..0e97386d 100644 --- a/src/fann_io.c +++ b/src/fann_io.c @@ -175,6 +175,9 @@ int fann_save_internal_fd(struct fann *ann, FILE *conf, const char *configuratio fprintf(conf, "rprop_delta_min=%f\n", ann->rprop_delta_min); fprintf(conf, "rprop_delta_max=%f\n", ann->rprop_delta_max); fprintf(conf, "rprop_delta_zero=%f\n", ann->rprop_delta_zero); + fprintf(conf, "adam_beta1=%f\n", ann->adam_beta1); + fprintf(conf, "adam_beta2=%f\n", ann->adam_beta2); + fprintf(conf, "adam_epsilon=%.8f\n", ann->adam_epsilon); fprintf(conf, "cascade_output_stagnation_epochs=%u\n", ann->cascade_output_stagnation_epochs); fprintf(conf, "cascade_candidate_change_fraction=%f\n", ann->cascade_candidate_change_fraction); fprintf(conf, "cascade_candidate_stagnation_epochs=%u\n", @@ -322,6 +325,18 @@ struct fann *fann_create_from_fd_1_1(FILE *conf, const char *configuration_file) } \ } +/* Optional scanf that sets a default value if the field is not present in the file. + * This is used for new parameters to maintain backward compatibility with older saved networks. + */ +#define fann_scanf_optional(type, name, val, default_val) \ + { \ + long pos = ftell(conf); \ + if (fscanf(conf, name "=" type "\n", val) != 1) { \ + fseek(conf, pos, SEEK_SET); \ + *(val) = (default_val); \ + } \ + } + #define fann_skip(name) \ { \ if (fscanf(conf, name) != 0) { \ @@ -420,6 +435,10 @@ struct fann *fann_create_from_fd(FILE *conf, const char *configuration_file) { fann_scanf("%f", "rprop_delta_min", &ann->rprop_delta_min); fann_scanf("%f", "rprop_delta_max", &ann->rprop_delta_max); fann_scanf("%f", "rprop_delta_zero", &ann->rprop_delta_zero); + /* Adam parameters are optional for backward compatibility with older saved networks */ + fann_scanf_optional("%f", "adam_beta1", &ann->adam_beta1, 0.9f); + fann_scanf_optional("%f", "adam_beta2", &ann->adam_beta2, 0.999f); + fann_scanf_optional("%f", "adam_epsilon", &ann->adam_epsilon, 1e-8f); fann_scanf("%u", "cascade_output_stagnation_epochs", &ann->cascade_output_stagnation_epochs); fann_scanf("%f", "cascade_candidate_change_fraction", &ann->cascade_candidate_change_fraction); fann_scanf("%u", "cascade_candidate_stagnation_epochs", diff --git a/src/fann_train.c b/src/fann_train.c index 8f974bdf..db74b2f4 100644 --- a/src/fann_train.c +++ b/src/fann_train.c @@ -527,6 +527,34 @@ void fann_clear_train_arrays(struct fann *ann) { } else { memset(ann->prev_train_slopes, 0, (ann->total_connections_allocated) * sizeof(fann_type)); } + + /* Allocate and initialize Adam optimizer arrays if using Adam */ + if (ann->training_algorithm == FANN_TRAIN_ADAM) { + /* Allocate first moment vector (m) */ + if (ann->adam_m == NULL) { + ann->adam_m = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type)); + if (ann->adam_m == NULL) { + fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM); + return; + } + } else { + memset(ann->adam_m, 0, (ann->total_connections_allocated) * sizeof(fann_type)); + } + + /* Allocate second moment vector (v) */ + if (ann->adam_v == NULL) { + ann->adam_v = (fann_type *)calloc(ann->total_connections_allocated, sizeof(fann_type)); + if (ann->adam_v == NULL) { + fann_error((struct fann_error *)ann, FANN_E_CANT_ALLOCATE_MEM); + return; + } + } else { + memset(ann->adam_v, 0, (ann->total_connections_allocated) * sizeof(fann_type)); + } + + /* Reset timestep */ + ann->adam_timestep = 0; + } } /* INTERNAL FUNCTION @@ -682,9 +710,70 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, } } +/* INTERNAL FUNCTION + The Adam (Adaptive Moment Estimation) algorithm + + Adam combines ideas from momentum and RMSProp: + - Maintains exponential moving averages of gradients (first moment, m) + - Maintains exponential moving averages of squared gradients (second moment, v) + - Uses bias correction to account for initialization at zero + + Parameters: + - beta1: exponential decay rate for first moment (default 0.9) + - beta2: exponential decay rate for second moment (default 0.999) + - epsilon: small constant for numerical stability (default 1e-8) + */ +void fann_update_weights_adam(struct fann *ann, unsigned int num_data, unsigned int first_weight, + unsigned int past_end) { + fann_type *train_slopes = ann->train_slopes; + fann_type *weights = ann->weights; + fann_type *m = ann->adam_m; + fann_type *v = ann->adam_v; + + const float learning_rate = ann->learning_rate; + const float beta1 = ann->adam_beta1; + const float beta2 = ann->adam_beta2; + const float epsilon = ann->adam_epsilon; + const float gradient_scale = 1.0f / num_data; + + unsigned int i; + float gradient, m_hat, v_hat; + float beta1_t, beta2_t; + + /* Increment timestep */ + ann->adam_timestep++; + + /* Compute bias correction terms: 1 - beta^t */ + beta1_t = 1.0f - powf(beta1, (float)ann->adam_timestep); + beta2_t = 1.0f - powf(beta2, (float)ann->adam_timestep); + + for (i = first_weight; i != past_end; i++) { + /* Compute gradient (average over batch) */ + gradient = train_slopes[i] * gradient_scale; + + /* Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t */ + m[i] = beta1 * m[i] + (1.0f - beta1) * gradient; + + /* Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 */ + v[i] = beta2 * v[i] + (1.0f - beta2) * gradient * gradient; + + /* Compute bias-corrected first moment: m_hat = m_t / (1 - beta1^t) */ + m_hat = m[i] / beta1_t; + + /* Compute bias-corrected second moment: v_hat = v_t / (1 - beta2^t) */ + v_hat = v[i] / beta2_t; + + /* Update weights: w_t = w_{t-1} + learning_rate * m_hat / (sqrt(v_hat) + epsilon) */ + weights[i] += learning_rate * m_hat / (sqrt(v_hat) + epsilon); + + /* Clear slope for next iteration */ + train_slopes[i] = 0.0f; + } +} + /* INTERNAL FUNCTION The SARprop- algorithm -*/ + */ void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight, unsigned int past_end) { fann_type *train_slopes = ann->train_slopes; @@ -919,3 +1008,9 @@ FANN_GET_SET(float, sarprop_temperature) FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function) FANN_GET_SET(fann_type, bit_fail_limit) FANN_GET_SET(float, learning_momentum) + +FANN_GET_SET(float, adam_beta1) + +FANN_GET_SET(float, adam_beta2) + +FANN_GET_SET(float, adam_epsilon) diff --git a/src/fann_train_data.c b/src/fann_train_data.c index ba992975..000759af 100644 --- a/src/fann_train_data.c +++ b/src/fann_train_data.c @@ -137,6 +137,30 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data) { return fann_get_MSE(ann); } +/* + * Internal train function + */ +float fann_train_epoch_adam(struct fann *ann, struct fann_train_data *data) { + unsigned int i; + + if (ann->adam_m == NULL) { + fann_clear_train_arrays(ann); + } + + fann_reset_MSE(ann); + + for (i = 0; i < data->num_data; i++) { + fann_run(ann, data->input[i]); + fann_compute_MSE(ann, data->output[i]); + fann_backpropagate_MSE(ann); + fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1); + } + + fann_update_weights_adam(ann, data->num_data, 0, ann->total_connections); + + return fann_get_MSE(ann); +} + /* * Internal train function */ @@ -211,6 +235,8 @@ FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_trai return fann_train_epoch_irpropm(ann, data); case FANN_TRAIN_SARPROP: return fann_train_epoch_sarprop(ann, data); + case FANN_TRAIN_ADAM: + return fann_train_epoch_adam(ann, data); case FANN_TRAIN_BATCH: return fann_train_epoch_batch(ann, data); case FANN_TRAIN_INCREMENTAL: diff --git a/src/include/fann_data.h b/src/include/fann_data.h index f9040c27..7e60534a 100644 --- a/src/include/fann_data.h +++ b/src/include/fann_data.h @@ -79,7 +79,8 @@ enum fann_train_enum { FANN_TRAIN_BATCH, FANN_TRAIN_RPROP, FANN_TRAIN_QUICKPROP, - FANN_TRAIN_SARPROP + FANN_TRAIN_SARPROP, + FANN_TRAIN_ADAM }; /* Constant: FANN_TRAIN_NAMES @@ -95,7 +96,7 @@ enum fann_train_enum { */ static char const *const FANN_TRAIN_NAMES[] = {"FANN_TRAIN_INCREMENTAL", "FANN_TRAIN_BATCH", "FANN_TRAIN_RPROP", "FANN_TRAIN_QUICKPROP", - "FANN_TRAIN_SARPROP"}; + "FANN_TRAIN_SARPROP", "FANN_TRAIN_ADAM"}; /* Enums: fann_activationfunc_enum @@ -754,6 +755,25 @@ struct fann { */ fann_type *prev_weights_deltas; + /* Adam optimizer parameters */ + /* First moment vector (mean of gradients) for Adam optimizer */ + fann_type *adam_m; + + /* Second moment vector (variance of gradients) for Adam optimizer */ + fann_type *adam_v; + + /* Exponential decay rate for the first moment estimates (default 0.9) */ + float adam_beta1; + + /* Exponential decay rate for the second moment estimates (default 0.999) */ + float adam_beta2; + + /* Small constant for numerical stability (default 1e-8) */ + float adam_epsilon; + + /* Current timestep for Adam optimizer */ + unsigned int adam_timestep; + #ifndef FIXEDFANN /* Arithmetic mean used to remove steady component in input data. */ float *scale_mean_in; diff --git a/src/include/fann_internal.h b/src/include/fann_internal.h index 4aac0700..8eedab2e 100644 --- a/src/include/fann_internal.h +++ b/src/include/fann_internal.h @@ -85,6 +85,8 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, unsigned int past_end); void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight, unsigned int past_end); +void fann_update_weights_adam(struct fann *ann, unsigned int num_data, unsigned int first_weight, + unsigned int past_end); void fann_clear_train_arrays(struct fann *ann); diff --git a/src/include/fann_train.h b/src/include/fann_train.h index 18f0486c..f47b9025 100644 --- a/src/include/fann_train.h +++ b/src/include/fann_train.h @@ -817,6 +817,75 @@ FANN_EXTERNAL float FANN_API fann_get_learning_momentum(struct fann *ann); */ FANN_EXTERNAL void FANN_API fann_set_learning_momentum(struct fann *ann, float learning_momentum); +/* Function: fann_get_adam_beta1 + + Get the Adam optimizer beta1 parameter (exponential decay rate for first moment estimates). + + The default beta1 is 0.9. + + See also: + , + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL float FANN_API fann_get_adam_beta1(struct fann *ann); + +/* Function: fann_set_adam_beta1 + + Set the Adam optimizer beta1 parameter (exponential decay rate for first moment estimates). + + Typical values are close to 1.0, with the default being 0.9. + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL void FANN_API fann_set_adam_beta1(struct fann *ann, float adam_beta1); + +/* Function: fann_get_adam_beta2 + + Get the Adam optimizer beta2 parameter (exponential decay rate for second moment estimates). + + The default beta2 is 0.999. + + See also: + , + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL float FANN_API fann_get_adam_beta2(struct fann *ann); + +/* Function: fann_set_adam_beta2 + + Set the Adam optimizer beta2 parameter (exponential decay rate for second moment estimates). + + Typical values are close to 1.0, with the default being 0.999. + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL void FANN_API fann_set_adam_beta2(struct fann *ann, float adam_beta2); + +/* Function: fann_get_adam_epsilon + + Get the Adam optimizer epsilon parameter (small constant for numerical stability). + + The default epsilon is 1e-8. + + See also: + , + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL float FANN_API fann_get_adam_epsilon(struct fann *ann); + +/* Function: fann_set_adam_epsilon + + Set the Adam optimizer epsilon parameter (small constant for numerical stability). + + This is used to prevent division by zero. Typical values are very small, with the default being 1e-8. + + This function appears in FANN >= 2.3.0. + */ +FANN_EXTERNAL void FANN_API fann_set_adam_epsilon(struct fann *ann, float adam_epsilon); + /* Function: fann_get_activation_function Get the activation function for neuron number *neuron* in layer number *layer*,