class RubyFann::TrainData
Public Class Methods
new(hash) → new ruby-fann training data object (RubyFann::TrainData)
click to toggle source
Initialize in one of the following forms: # This is a flat file with training data as described in FANN docs. RubyFann::TrainData.new(:filename => 'path/to/training_file.train')
OR
# Train with inputs (array of arrays) & desired_outputs (array of arrays) # inputs & desired outputs should be of same length # All sub-arrays on inputs should be of same length # All sub-arrays on desired_outputs should be of same length # Sub-arrays on inputs & desired_outputs can be different sizes from one another RubyFann::TrainData.new(:inputs=>[[0.2, 0.3, 0.4], [0.8, 0.9, 0.7]], :desired_outputs=>[[3.14], [6.33]])
static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
{
struct fann_train_data *train_data;
Check_Type(hash, T_HASH);
VALUE filename = rb_hash_aref(hash, ID2SYM(rb_intern("filename")));
VALUE inputs = rb_hash_aref(hash, ID2SYM(rb_intern("inputs")));
VALUE desired_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("desired_outputs")));
if (TYPE(filename) == T_STRING)
{
train_data = fann_read_train_from_file(StringValuePtr(filename));
DATA_PTR(self) = train_data;
}
else if (TYPE(inputs) == T_ARRAY)
{
if (TYPE(desired_outputs) != T_ARRAY)
{
rb_raise(rb_eRuntimeError, "[desired_outputs] must be present when [inputs] used.");
}
if (RARRAY_LEN(inputs) < 1)
{
rb_raise(rb_eRuntimeError, "[inputs] must contain at least one value.");
}
if (RARRAY_LEN(desired_outputs) < 1)
{
rb_raise(rb_eRuntimeError, "[desired_outputs] must contain at least one value.");
}
// The data is here, start constructing:
if (RARRAY_LEN(inputs) != RARRAY_LEN(desired_outputs))
{
rb_raise(
rb_eRuntimeError,
"Number of inputs must match number of outputs: (%d != %d)",
(int)RARRAY_LEN(inputs),
(int)RARRAY_LEN(desired_outputs));
}
train_data = fann_create_train_from_rb_ary(inputs, desired_outputs);
DATA_PTR(self) = train_data;
}
else
{
rb_raise(rb_eRuntimeError, "Must construct with a filename(string) or inputs/desired_outputs(arrays). All args passed via hash with symbols as keys.");
}
return (VALUE)train_data;
}
Public Instance Methods
length()
click to toggle source
Length of training data
static VALUE length_train_data(VALUE self)
{
struct fann_train_data *t;
Data_Get_Struct(self, struct fann_train_data, t);
return (UINT2NUM(fann_length_train_data(t)));
return self;
}
save(filename)
click to toggle source
Save to given filename
static VALUE training_save(VALUE self, VALUE filename)
{
Check_Type(filename, T_STRING);
struct fann_train_data *t;
Data_Get_Struct(self, struct fann_train_data, t);
fann_save_train(t, StringValuePtr(filename));
return self;
}
shuffle()
click to toggle source
Shuffles training data, randomizing the order. This is recommended for incremental training, while it will have no influence during batch training.
static VALUE shuffle(VALUE self)
{
struct fann_train_data *t;
Data_Get_Struct(self, struct fann_train_data, t);
fann_shuffle_train_data(t);
return self;
}