11.08., 9:00 - 11:00: Due to updates GitLab will be unavailable for some minutes between 09:00 and 11:00.

Commit 5cb934b8 authored by Alessio Netti's avatar Alessio Netti

Analytics: enforcing categorical variables in classifier plugin

parent c90a5d90
......@@ -60,7 +60,7 @@ void ClassifierOperator::compute(U_Ptr unit) {
_currentClass = (int)_currentTarget;
_responseSet->push_back(_currentClass);
if ((uint64_t)_trainingSet->size().height >= _trainingSamples + _targetDistance)
trainRandomForest();
trainRandomForest(true);
}
if(_rForest.empty() || !(_rForest->isTrained() || (_trainingPending && _streaming)))
throw std::runtime_error("Operator " + _name + ": cannot perform prediction, the model is untrained!");
......
......@@ -120,7 +120,7 @@ void RegressorOperator::compute(U_Ptr unit) {
_trainingSet->push_back(*_currentfVector);
_responseSet->push_back(_currentTarget);
if ((uint64_t)_trainingSet->size().height >= _trainingSamples + _targetDistance)
trainRandomForest();
trainRandomForest(false);
}
if(_rForest.empty() || !(_rForest->isTrained() || (_trainingPending && _streaming)))
throw std::runtime_error("Operator " + _name + ": cannot perform prediction, the model is untrained!");
......@@ -132,7 +132,7 @@ void RegressorOperator::compute(U_Ptr unit) {
}
}
void RegressorOperator::trainRandomForest() {
void RegressorOperator::trainRandomForest(bool categorical) {
if(!_trainingSet || _rForest.empty())
throw std::runtime_error("Operator " + _name + ": cannot perform training, missing model!");
if((uint64_t)_responseSet->size().height <= _targetDistance)
......@@ -141,8 +141,16 @@ void RegressorOperator::trainRandomForest() {
*_responseSet = _responseSet->rowRange(_targetDistance, _responseSet->size().height-1);
*_trainingSet = _trainingSet->rowRange(0, _trainingSet->size().height-1-_targetDistance);
shuffleTrainingSet();
if(!_rForest->train(*_trainingSet, cv::ml::ROW_SAMPLE, *_responseSet))
cv::Mat varType = cv::Mat(_trainingSet->size().width + 1, 1, CV_8U);
varType.setTo(cv::Scalar(cv::ml::VAR_NUMERICAL));
varType.at<unsigned char>(_trainingSet->size().width, 0) = categorical ? cv::ml::VAR_CATEGORICAL : cv::ml::VAR_NUMERICAL;
cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(*_trainingSet, cv::ml::ROW_SAMPLE, *_responseSet, cv::noArray(), cv::noArray(), cv::noArray(), varType);
if(!_rForest->train(td))
throw std::runtime_error("Operator " + _name + ": model training failed!");
td.release();
delete _trainingSet;
_trainingSet = nullptr;
delete _responseSet;
......
......@@ -77,7 +77,7 @@ protected:
virtual void compute(U_Ptr unit) override;
void computeFeatureVector(U_Ptr unit);
void trainRandomForest();
void trainRandomForest(bool categorical=false);
void shuffleTrainingSet();
std::string getImportances();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment