Commit c5a7f5a7 authored by Alessio Netti's avatar Alessio Netti
Browse files

Analytics: added data shuffling to Regressor plugin

- Added shuffling of data in the training set in an attempt to improve
regression accuracy
parent 6a330f21
......@@ -144,6 +144,7 @@ void RegressorAnalyzer::trainRandomForest() {
// Shifting the training and response sets so as to obtain the desired prediction distance
*_responseSet = _responseSet->rowRange(_targetDistance, _responseSet->size().height-1);
*_trainingSet = _trainingSet->rowRange(0, _trainingSet->size().height-1-_targetDistance);
shuffleTrainingSet();
if(!_rForest->train(*_trainingSet, cv::ml::ROW_SAMPLE, *_responseSet))
throw std::runtime_error("Analyzer " + _name + ": model training failed!");
delete _trainingSet;
......@@ -161,6 +162,30 @@ void RegressorAnalyzer::trainRandomForest() {
}
}
void RegressorAnalyzer::shuffleTrainingSet() {
if(!_trainingSet || !_responseSet)
return;
size_t idx1, idx2, swaps = _trainingSet->size().height / 2;
std::random_device dev;
std::mt19937 rng(dev());
rng.seed((unsigned)getTimestamp());
std::uniform_int_distribution<std::mt19937::result_type> dist(0, _trainingSet->size().height-1);
cv::Mat swapBuf;
for(uint64_t nS=0; nS<swaps; nS++) {
idx1 = dist(rng);
idx2 = dist(rng);
_trainingSet->row(idx1).copyTo(swapBuf);
_trainingSet->row(idx2).copyTo(_trainingSet->row(idx1));
swapBuf.copyTo(_trainingSet->row(idx2));
_responseSet->row(idx1).copyTo(swapBuf);
_responseSet->row(idx2).copyTo(_responseSet->row(idx1));
swapBuf.copyTo(_responseSet->row(idx2));
//LOG(debug) << "Swapping " << idx1 << " and " << idx2;
}
}
void RegressorAnalyzer::computeFeatureVector(U_Ptr unit) {
if(!_currentfVector)
_currentfVector = new cv::Mat(1, unit->getInputs().size()*REG_NUMFEATURES, CV_32F);
......
......@@ -35,6 +35,7 @@
#include "opencv4/opencv2/core/cvstd.hpp"
#include "opencv4/opencv2/ml.hpp"
#include <math.h>
#include <random>
#define REG_NUMFEATURES 6
......@@ -77,6 +78,7 @@ protected:
virtual void compute(U_Ptr unit) override;
void computeFeatureVector(U_Ptr unit);
void trainRandomForest();
void shuffleTrainingSet();
std::string getImportances();
std::string _modelOut;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment