2016-03-27 42 views
5

Sono nel mezzo del refactoring del mio codice per sfruttare il DataFrames, Estimators, and Pipelines. Inizialmente utilizzavo MLlib Multiclass LogisticRegressionWithLBFGS su RDD[LabeledPoint]. Mi piace imparare e utilizzare la nuova API, ma non sono sicuro di come salvare il mio nuovo modello e applicarlo su nuovi dati.Spark ML - Save OneVsRestModel

Attualmente l'implementazione ML di LogisticRegression supporta solo la classificazione binaria. Sto, invece usando OneVsRest in questo modo:

val lr = new LogisticRegression().setFitIntercept(true) 
val ovr = new OneVsRest() 
ovr.setClassifier(lr) 
val ovrModel = ovr.fit(training) 

Vorrei ora per salvare la mia OneVsRestModel, ma questo non sembra essere supportato da API. Ho provato:

ovrModel.save("my-ovr") // Cannot resolve symbol save 
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save 

C'è un modo per salvare questo, in modo da poter caricare in una nuova applicazione per fare nuove previsioni?

risposta

5

Spark 2.0.0

OneVsRestModel implementa MLWritable quindi dovrebbe essere possibile salvare direttamente. Il metodo mostrato di seguito può essere ancora utile per salvare separatamente i singoli modelli.

Spark < 2.0.0

Il problema qui è che models restituisce un Array di ClassificationModel[_, _]] non un Array di LogisticRegressionModel (o MLWritable). Per farlo funzionare dovrete essere specifici sui tipi:

import org.apache.spark.ml.classification.LogisticRegressionModel 

ovrModel.models.zipWithIndex.foreach { 
    case (model: LogisticRegressionModel, i: Int) => 
    model.save(s"model-${model.uid}-$i") 
} 

o per essere più generico:

import org.apache.spark.ml.util.MLWritable 

ovrModel.models.zipWithIndex.foreach { 
    case (model: MLWritable, i: Int) => 
    model.save(s"model-${model.uid}-$i") 
} 

Purtroppo come per ora (scintille 1.6) OneVsRestModel non implementa MLWritable così non può essere salvato da solo.

Nota:

Tutti i modelli int il OneVsRest sembra usare la stessa uid quindi abbiamo bisogno di un indice esplicito. Sarà anche utile identificare il modello più tardi.

+1

Vorrei poter +2 questo. Non solo è esattamente ciò di cui avevo bisogno, ma rende molto più facile il calcolo delle probabilità non elaborate. Pensavo di dover personalizzare la src. Grazie! –

+0

@ zero323 c'è una versione pyspark della tua risposta? Cercando di trovare un modo per salvare i modelli pyspark.ml – ajkl

+0

@AjinkyaKale In 1.6? – zero323