class RubyFann::TrainData

Public Class Methods

new(hash) → new ruby-fann training data object (RubyFann::TrainData) click to toggle source
Initialize in one of the following forms:

# This is a flat file with training data as described in FANN docs.
RubyFann::TrainData.new(:filename => 'path/to/training_file.train')

OR

# Train with inputs (array of arrays) & desired_outputs (array of arrays)
# inputs & desired outputs should be of same length
# All sub-arrays on inputs should be of same length
# All sub-arrays on desired_outputs should be of same length
# Sub-arrays on inputs & desired_outputs can be different sizes from one another
RubyFann::TrainData.new(:inputs=>[[0.2, 0.3, 0.4], [0.8, 0.9, 0.7]], :desired_outputs=>[[3.14], [6.33]])
static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
{
    struct fann_train_data *train_data;
    Check_Type(hash, T_HASH);

    VALUE filename = rb_hash_aref(hash, ID2SYM(rb_intern("filename")));
    VALUE inputs = rb_hash_aref(hash, ID2SYM(rb_intern("inputs")));
    VALUE desired_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("desired_outputs")));

    if (TYPE(filename) == T_STRING)
    {
        train_data = fann_read_train_from_file(StringValuePtr(filename));
        DATA_PTR(self) = train_data;
    }
    else if (TYPE(inputs) == T_ARRAY)
    {
        if (TYPE(desired_outputs) != T_ARRAY)
        {
            rb_raise(rb_eRuntimeError, "[desired_outputs] must be present when [inputs] used.");
        }

        if (RARRAY_LEN(inputs) < 1)
        {
            rb_raise(rb_eRuntimeError, "[inputs] must contain at least one value.");
        }

        if (RARRAY_LEN(desired_outputs) < 1)
        {
            rb_raise(rb_eRuntimeError, "[desired_outputs] must contain at least one value.");
        }

        // The data is here, start constructing:
        if (RARRAY_LEN(inputs) != RARRAY_LEN(desired_outputs))
        {
            rb_raise(
                rb_eRuntimeError,
                "Number of inputs must match number of outputs: (%d != %d)",
                (int)RARRAY_LEN(inputs),
                (int)RARRAY_LEN(desired_outputs));
        }

        train_data = fann_create_train_from_rb_ary(inputs, desired_outputs);
        DATA_PTR(self) = train_data;
    }
    else
    {
        rb_raise(rb_eRuntimeError, "Must construct with a filename(string) or inputs/desired_outputs(arrays).  All args passed via hash with symbols as keys.");
    }

    return (VALUE)train_data;
}

Public Instance Methods

length() click to toggle source

Length of training data

static VALUE length_train_data(VALUE self)
{
    struct fann_train_data *t;
    Data_Get_Struct(self, struct fann_train_data, t);
    return (UINT2NUM(fann_length_train_data(t)));
    return self;
}
save(filename) click to toggle source

Save to given filename

static VALUE training_save(VALUE self, VALUE filename)
{
    Check_Type(filename, T_STRING);
    struct fann_train_data *t;
    Data_Get_Struct(self, struct fann_train_data, t);
    fann_save_train(t, StringValuePtr(filename));
    return self;
}
shuffle() click to toggle source

Shuffles training data, randomizing the order. This is recommended for incremental training, while it will have no influence during batch training.

static VALUE shuffle(VALUE self)
{
    struct fann_train_data *t;
    Data_Get_Struct(self, struct fann_train_data, t);
    fann_shuffle_train_data(t);
    return self;
}