14

Di seguito è la mia pipeline e sembra che non posso passare i parametri ai miei modelli utilizzando la classe ModelTransformer, che prendo dal link (http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html)(Python - sklearn) Come passare parametri alla classe ModelTransformer personalizzata da gridsearchcv

Il messaggio di errore ha senso per me, ma non so come risolvere questo problema. Qualche idea su come risolvere questo problema? Grazie.

# define a pipeline 
pipeline = Pipeline([ 
('vect', DictVectorizer(sparse=False)), 
('scale', preprocessing.MinMaxScaler()), 
('ess', FeatureUnion(n_jobs=-1, 
        transformer_list=[ 
    ('rfc', ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1, n_estimators=100))), 
    ('svc', ModelTransformer(SVC(random_state=1))),], 
        transformer_weights=None)), 
('es', EnsembleClassifier1()), 
]) 

# define the parameters for the pipeline 
parameters = { 
'ess__rfc__n_estimators': (100, 200), 
} 

# ModelTransformer class. It takes it from the link 
(http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html) 
class ModelTransformer(TransformerMixin): 
    def __init__(self, model): 
     self.model = model 
    def fit(self, *args, **kwargs): 
     self.model.fit(*args, **kwargs) 
     return self 
    def transform(self, X, **transform_params): 
     return DataFrame(self.model.predict(X)) 

grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, refit=True) 

messaggio di errore: ValueError: n_estimators parametro non valido per stimatore ModelTransformer.

+0

Grazie per avermelo chiesto - Ho avuto la stessa domanda. Lascia che ti chieda un'altra cosa. Sai perché * self.model.fit (* args, ** kwargs) * funziona? Voglio dire che di solito non si passano iperparametri come n_estimators quando si chiama il metodo fit, ma quando si definisce l'istanza della classe, ad esempio rfc = RandomForestClassifier (n_estimators = 100), rfc.fit (X, y) – drake

+0

@drake, quando si crea un'istanza ModelTransformer, è necessario passare un modello con i relativi parametri. Ad esempio, ModelTransformer (RandomForestClassifier (n_jobs = -1, random_state = 1, n_estimators = 100))). E qui self.model.fit (* args, ** kwargs) significa principalmente self.model.fit (X, y). – nkhuyu

+0

Grazie, @nkhuyu. So che è così che funziona. Stavo chiedendo perché. Poiché self.model = model, self.model = RandomForestClassifier (n_jobs = -1, random_state = 1, n_estimators = 100). Capisco * args è decompressione (X, y), ma non capisco PERCHÉ è necessario ** kwargs nel metodo di adattamento quando self.model conosce già gli iperparametri. – drake

risposta

15

GridSearchCV ha una convenzione di denominazione speciale per oggetti nidificati. Nel tuo caso ess__rfc__n_estimators sta per ess.rfc.n_estimators, e, secondo la definizione del pipeline, che punti alla proprietà n_estimators di

ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1, n_estimators=100))) 

Ovviamente, ModelTransformer casi non hanno tale proprietà.

La correzione è semplice: per accedere all'oggetto sottostante di ModelTransformer è necessario utilizzare il campo model. Così, parametri di rete diventano

parameters = { 
    'ess__rfc__model__n_estimators': (100, 200), 
} 

P.S. non è l'unico problema con il tuo codice. Per utilizzare più lavori in GridSearchCV, è necessario creare tutti gli oggetti che si stanno utilizzando in grado di copiare. Ciò si ottiene implementando i metodi get_params e set_params, è possibile prendere in prestito dal mix BaseEstimator.

+0

puoi espandere un po 'questo PS? Penso di avere lo stesso problema in cui quando provo a usare gridsearchcv con union feature pipeline ottengo l'errore AttributeError: L'oggetto 'SelectColumns' non ha attributo 'get_params' dove SelectColumns è una classe che ho scritto per la pipeline. –

+7

@B_Miner, dovresti ereditare la tua classe 'SelectColumns' dal [' BaseEstimator'] (http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html) che fornisce i 'set_params 'di cui sopra e 'get_params'. In alternativa, puoi implementare i tuoi, ma la maggior parte delle volte non vuoi. –

+2

Stavo cercando BaseEstimatorMixin. Ho ereditato da BaseEstimator e ha funzionato come un incantesimo, grazie! –