android - OpenCV Error: Bad argument in CvANN_MLP -
i'm using opencv4android , i'm trying make little example of neural network work out arithmetic mean. so, i've decided use cvann_mlp create network. goes when train it, fails next exception:
opencv error: bad argument (output training data should floating-point matrix number of rows equal number of training samples , number of columns equal size of last (output) layer) in cvann_mlp::prepare_to_train
i've checked output training , type cv_32fc1. number of rows , columns correct. don't know error is.
this code , hope can me. thanks!
int train_sample_count = 10; float td[][] = new float[10][3]; //i've created method populate td populatetrainingdata(td); mat traindata = new mat(train_sample_count, 2, cvtype.cv_32fc1); mat trainclasses = new mat(train_sample_count, 1, cvtype.cv_32fc1); mat samplewts = new mat(train_sample_count, 1, cvtype.cv_32fc1); mat neurallayers = new mat(3, 1, cvtype.cv_32sc1); // input layer has 2 cells neurallayers.put(0, 0, 2); // hidden layer has 2 cells neurallayers.put(1, 0, 2); // output layer has 2 cells neurallayers.put(2, 0, 2); // assembles traindata,trainclasses , weights (int = 0; < train_sample_count; i++) { traindata.put(i, 0, td[i][0]); traindata.put(i, 1, td[i][1]); trainclasses.put(i, 0, td[i][2]); samplewts.put(i, 0, 1); } log.d(debug_tag, "assemblage finished"); // creates neural network layers of neurallayers cvann_mlp machinebrain = new cvann_mlp(neurallayers); log.d(debug_tag, "neural network created"); // trains neural network data // parameters neural network cvann_mlp_trainparams trainparams = new cvann_mlp_trainparams(); // backward propagation trainparams.set_train_method(cvann_mlp_trainparams.backprop); // number of iterations , sigmoidal update termcriteria termc = new termcriteria(termcriteria.eps + termcriteria.count, 10000, 1.0); trainparams.set_term_crit(termc); // optional value 0 mat simpleindex = new mat(); // setting neural network log.d(debug_tag, "setting finished"); log.d(debug_tag, "type of trainclasses: " + (trainclasses.type() == cvtype.cv_32fc1)); machinebrain.train(traindata, trainclasses, samplewts);
Comments
Post a Comment