How to extract best parameters from a CrossValidatorModel
Solution 1
One method to get a proper ParamMap
object is to use CrossValidatorModel.avgMetrics: Array[Double]
to find the argmax ParamMap
:
implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {
def bestEstimatorParamMap: ParamMap = {
cvModel.getEstimatorParamMaps
.zip(cvModel.avgMetrics)
.maxBy(_._2)
._1
}
}
When run on the CrossValidatorModel
trained in the Pipeline Example you cited gives:
scala> println(cvModel.bestEstimatorParamMap)
{
hashingTF_2b0b8ccaeeec-numFeatures: 100,
logreg_950a13184247-regParam: 0.1
}
Solution 2
val bestPipelineModel = cvModel.bestModel.asInstanceOf[PipelineModel]
val stages = bestPipelineModel.stages
val hashingStage = stages(1).asInstanceOf[HashingTF]
println("numFeatures = " + hashingStage.getNumFeatures)
val lrStage = stages(2).asInstanceOf[LogisticRegressionModel]
println("regParam = " + lrStage.getRegParam)
Solution 3
To print everything in paramMap
, you actually don't have to call parent:
cvModel.bestModel.extractParamMap()
To answer OP's question, to get a single best parameter, for example regParam
:
cvModel.bestModel.extractParamMap().apply(cvModel.bestModel.getParam("regParam"))
Solution 4
This is how you get the chosen parameters
println(cvModel.bestModel.getMaxIter)
println(cvModel.bestModel.getRegParam)
Solution 5
this java code should work:
cvModel.bestModel().parent().extractParamMap()
.you can translate it to scala code
parent()
method will return an estimator, you can get the best params then.
Mohammad
Updated on June 10, 2022Comments
-
Mohammad almost 2 years
I want to find the parameters of
ParamGridBuilder
that make the best model in CrossValidator in Spark 1.4.x,In Pipeline Example in Spark documentation, they add different parameters (
numFeatures
,regParam
) by usingParamGridBuilder
in the Pipeline. Then by the following line of code they make the best model:val cvModel = crossval.fit(training.toDF)
Now, I want to know what are the parameters (
numFeatures
,regParam
) fromParamGridBuilder
that produces the best model.I already used the following commands without success:
cvModel.bestModel.extractParamMap().toString() cvModel.params.toList.mkString("(", ",", ")") cvModel.estimatorParamMaps.toString() cvModel.explainParams() cvModel.getEstimatorParamMaps.mkString("(", ",", ")") cvModel.toString()
Any help?
Thanks in advance,