2011-09-12 7 views
17

Sono interessato a testare le prestazioni SVM per classificare diverse persone in quattro gruppi/classi. Quando si utilizza la funzione svstrain LibSVM di MATLAB, sono in grado di ottenere le tre equazioni utilizzate per classificare quegli individui tra i 4 gruppi, in base ai valori di questa equazione. Uno schema potrebbe essere la seguente:Equazioni SVM dal pacchetto e1071 R?

   All individuals (N)* 
         | 
Group 1 (n1) <--- equation 1 ---> (N-n1) 
             | 
        (N-n1-n2) <--- equation 2 ---> Group 2 (n2) 
         | 
Group 3 (n3) <--- equation 3 ---> Group 4(n4) 

*N = n1+n2+n3+n4 

C'è un modo per ottenere queste equazioni utilizzando la funzione SVM nel pacchetto R E1071?

risposta

39

svm in e1071 utilizza la strategia "uno contro uno" per la classificazione multiclasse (vale a dire la classificazione binaria tra tutte le coppie, seguita dal voto). Quindi, per gestire questa impostazione gerarchica, probabilmente è necessario fare una serie di classificatori binari manualmente, come gruppo 1 vs tutti, allora il gruppo 2 vs tutto ciò che è rimasto, ecc .. Inoltre, la funzione di base svm non si sintonizzare i iperparametri, quindi in genere si desidera utilizzare un wrapper come tune in e1071, oppure train nel pacchetto eccellente caret.

In ogni caso, per classificare nuovi individui in R, non c'è bisogno di inserire i numeri in un'equazione manualmente. Piuttosto, si utilizza la funzione generica predict, che ha metodi per diversi modelli come SVM. Per oggetti modello come questo, puoi anche usare solitamente le funzioni generiche plot e summary. Ecco un esempio dell'idea di base utilizzando uno SVM lineare:

require(e1071) 

# Subset the iris dataset to only 2 labels and 2 features 
iris.part = subset(iris, Species != 'setosa') 
iris.part$Species = factor(iris.part$Species) 
iris.part = iris.part[, c(1,2,5)] 

# Fit svm model 
fit = svm(Species ~ ., data=iris.part, type='C-classification', kernel='linear') 

# Make a plot of the model 
dev.new(width=5, height=5) 
plot(fit, iris.part) 

# Tabulate actual labels vs. fitted labels 
pred = predict(fit, iris.part) 
table(Actual=iris.part$Species, Fitted=pred) 

# Obtain feature weights 
w = t(fit$coefs) %*% fit$SV 

# Calculate decision values manually 
iris.scaled = scale(iris.part[,-3], fit$x.scale[[1]], fit$x.scale[[2]]) 
t(w %*% t(as.matrix(iris.scaled))) - fit$rho 

# Should equal... 
fit$decision.values 

enter image description here

tabulare etichette di classe reali vs previsioni del modello:

> table(Actual=iris.part$Species, Fitted=pred) 
      Fitted 
Actual  versicolor virginica 
    versicolor   38  12 
    virginica   15  35 

includono estratto pesi da svm oggetto del modello (per selezione delle funzioni, ecc.). Qui, Sepal.Length è ovviamente più utile.

> t(fit$coefs) %*% fit$SV 
    Sepal.Length Sepal.Width 
[1,] -1.060146 -0.2664518 

per capire dove i valori decisionali vengono, possiamo calcolare manualmente come il prodotto scalare della funzione di pesi e la caratteristica vettori preprocessati, meno l'intercetta di offset rho. (Pre-elaborato significa forse centrato/scalato e/o kernel trasformato se si utilizza RBF SVM, ecc)

> t(w %*% t(as.matrix(iris.scaled))) - fit$rho 
     [,1] 
51 -1.3997066 
52 -0.4402254 
53 -1.1596819 
54 1.7199970 
55 -0.2796942 
56 0.9996141 
... 

Questo dovrebbe essere uguale ciò che viene calcolato internamente:

> head(fit$decision.values) 
    versicolor/virginica 
51   -1.3997066 
52   -0.4402254 
53   -1.1596819 
54   1.7199970 
55   -0.2796942 
56   0.9996141 
... 
+0

Grazie per voi rispondere, John. Il motivo per cui voglio conoscere queste equazioni è valutare quali parametri del totale hanno più importanza quando classifichi i miei eventi. –

+2

@ ManuelRamón Ahh gotcha. Questi sono chiamati "pesi" per un SVM lineare. Vedere la modifica sopra per come calcolare da un oggetto modello svm. In bocca al lupo! –

+1

L'esempio ha solo due categorie (versicolor e virginica) e si ottiene un vettore con due coefficienti, uno per ciascuna variabile utilizzata per classificare i dati dell'iride. Se ho N categorie ottengo i vettori N-1 da 'con (fit, t (coefs%% *% SV)'. Qual è il significato di ciascun vettore? –