Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
dcdb
dcdb
Commits
5cb934b8
Commit
5cb934b8
authored
Mar 18, 2020
by
Alessio Netti
Browse files
Analytics: enforcing categorical variables in classifier plugin
parent
c90a5d90
Changes
3
Hide whitespace changes
Inline
Side-by-side
analytics/operators/regressor/ClassifierOperator.cpp
View file @
5cb934b8
...
...
@@ -60,7 +60,7 @@ void ClassifierOperator::compute(U_Ptr unit) {
_currentClass
=
(
int
)
_currentTarget
;
_responseSet
->
push_back
(
_currentClass
);
if
((
uint64_t
)
_trainingSet
->
size
().
height
>=
_trainingSamples
+
_targetDistance
)
trainRandomForest
();
trainRandomForest
(
true
);
}
if
(
_rForest
.
empty
()
||
!
(
_rForest
->
isTrained
()
||
(
_trainingPending
&&
_streaming
)))
throw
std
::
runtime_error
(
"Operator "
+
_name
+
": cannot perform prediction, the model is untrained!"
);
...
...
analytics/operators/regressor/RegressorOperator.cpp
View file @
5cb934b8
...
...
@@ -120,7 +120,7 @@ void RegressorOperator::compute(U_Ptr unit) {
_trainingSet
->
push_back
(
*
_currentfVector
);
_responseSet
->
push_back
(
_currentTarget
);
if
((
uint64_t
)
_trainingSet
->
size
().
height
>=
_trainingSamples
+
_targetDistance
)
trainRandomForest
();
trainRandomForest
(
false
);
}
if
(
_rForest
.
empty
()
||
!
(
_rForest
->
isTrained
()
||
(
_trainingPending
&&
_streaming
)))
throw
std
::
runtime_error
(
"Operator "
+
_name
+
": cannot perform prediction, the model is untrained!"
);
...
...
@@ -132,7 +132,7 @@ void RegressorOperator::compute(U_Ptr unit) {
}
}
void
RegressorOperator
::
trainRandomForest
()
{
void
RegressorOperator
::
trainRandomForest
(
bool
categorical
)
{
if
(
!
_trainingSet
||
_rForest
.
empty
())
throw
std
::
runtime_error
(
"Operator "
+
_name
+
": cannot perform training, missing model!"
);
if
((
uint64_t
)
_responseSet
->
size
().
height
<=
_targetDistance
)
...
...
@@ -141,8 +141,16 @@ void RegressorOperator::trainRandomForest() {
*
_responseSet
=
_responseSet
->
rowRange
(
_targetDistance
,
_responseSet
->
size
().
height
-
1
);
*
_trainingSet
=
_trainingSet
->
rowRange
(
0
,
_trainingSet
->
size
().
height
-
1
-
_targetDistance
);
shuffleTrainingSet
();
if
(
!
_rForest
->
train
(
*
_trainingSet
,
cv
::
ml
::
ROW_SAMPLE
,
*
_responseSet
))
cv
::
Mat
varType
=
cv
::
Mat
(
_trainingSet
->
size
().
width
+
1
,
1
,
CV_8U
);
varType
.
setTo
(
cv
::
Scalar
(
cv
::
ml
::
VAR_NUMERICAL
));
varType
.
at
<
unsigned
char
>
(
_trainingSet
->
size
().
width
,
0
)
=
categorical
?
cv
::
ml
::
VAR_CATEGORICAL
:
cv
::
ml
::
VAR_NUMERICAL
;
cv
::
Ptr
<
cv
::
ml
::
TrainData
>
td
=
cv
::
ml
::
TrainData
::
create
(
*
_trainingSet
,
cv
::
ml
::
ROW_SAMPLE
,
*
_responseSet
,
cv
::
noArray
(),
cv
::
noArray
(),
cv
::
noArray
(),
varType
);
if
(
!
_rForest
->
train
(
td
))
throw
std
::
runtime_error
(
"Operator "
+
_name
+
": model training failed!"
);
td
.
release
();
delete
_trainingSet
;
_trainingSet
=
nullptr
;
delete
_responseSet
;
...
...
analytics/operators/regressor/RegressorOperator.h
View file @
5cb934b8
...
...
@@ -77,7 +77,7 @@ protected:
virtual
void
compute
(
U_Ptr
unit
)
override
;
void
computeFeatureVector
(
U_Ptr
unit
);
void
trainRandomForest
();
void
trainRandomForest
(
bool
categorical
=
false
);
void
shuffleTrainingSet
();
std
::
string
getImportances
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment