2014-12-15 5 views
13

Sto lavorando alla classificazione di dati semplici usando KNN con distanza euclidea. Ho visto un esempio di ciò desidero farlo è fatto con la funzione MATLAB knnsearch come illustrato di seguito:Trovare i vicini più vicini di K e la sua implementazione

load fisheriris 
x = meas(:,3:4); 
gscatter(x(:,1),x(:,2),species) 
newpoint = [5 1.45]; 
[n,d] = knnsearch(x,newpoint,'k',10); 
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10) 

Il codice precedente tiene un nuovo punto cioè [5 1.45] e trova i 10 valori vicini al nuovo punto . Qualcuno può mostrarmi un algoritmo MATLAB con una spiegazione dettagliata di cosa fa la funzione knnsearch? C'è un altro modo di fare questo?

+0

È piuttosto semplice. Per un punto particolare, troviamo i 10 punti più vicini tra i dati e questo punto e restituiscono quei punti più vicini che fanno parte dei tuoi dati. Di solito, la distanza euclidea viene usata dove i componenti di un punto vengono usati per confrontare tra i componenti di un altro punto. Questo articolo su Wikipedia è particolarmente utile: http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm – rayryeng

+0

Ohhh ... vuoi implementare questa procedura da solo?Posso certamente fornire una risposta per te. In realtà non è così difficile come pensi di implementare l'algoritmo. Si prega di indicare ciò di cui avete bisogno. – rayryeng

+0

Sì sto cercando di implementare la funzione 'knnsearch' da solo, proprio come il mio esempio di codice, Grazie! –

risposta

34

La base del vicino K-vicino (KNN) algoritmo è che si ha una matrice di dati costituita da N righe e M colonne in cui N è il numero di punti di dati che abbiamo, mentre M è la dimensionalità di ciascun punto dati. Ad esempio, se posizioniamo le coordinate cartesiane all'interno di una matrice di dati, di solito è una matrice N x 2 o N x 3. Con questa matrice di dati, si fornisce un punto interrogativo e si cercano i punti k più vicini all'interno di questa matrice di dati che sono i più vicini a questo punto di ricerca.

Di solito usiamo la distanza euclidea tra la query e il resto dei punti nella matrice di dati per calcolare le nostre distanze. Tuttavia, vengono utilizzate anche altre distanze come la L1 o la distanza City-Block/Manhattan. Dopo questa operazione, avrai N distanze Euclide o Manhattan che simboleggiano le distanze tra la query con ciascun punto corrispondente nel set di dati. Una volta individuati, cerca semplicemente i punti più vicini alla query k ordinando le distanze in ordine crescente e recuperando quei punti k che hanno la distanza minima tra il set di dati e la query.

Supponendo matrice i dati sono stati memorizzati in x, e newpoint è un punto di campionamento in cui ha M colonne (cioè 1 x M), questa è la procedura generale che si seguire in forma di punto:

  1. Trova l'euclidea o Manhattan distanza tra newpoint e ogni punto in x.
  2. Ordinare queste distanze in ordine crescente.
  3. Restituire i punti di dati k in x più vicini a newpoint.

Facciamo ogni passo lentamente.


Passo # 1

Un modo che qualcuno possa fare questo è forse in un ciclo for in questo modo:

N = size(x,1); 
dists = zeros(N,1); 
for idx = 1 : N 
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2)); 
end 

Se si volesse implementare la distanza di Manhattan, questo sarebbe semplicemente :

N = size(x,1); 
dists = zeros(N,1); 
for idx = 1 : N 
    dists(idx) = sum(abs(x(idx,:) - newpoint)); 
end 

dists sarebbe un 01.233.603,220982 millionsvettore di elementi che contiene le distanze tra ciascun punto di dati in x e newpoint.Facciamo una sottrazione elemento per elemento tra newpoint e un punto dati in x, piazza le differenze, quindi sum tutte insieme. Questa somma è quindi radicata in quadrato, che completa la distanza euclidea. Per la distanza di Manhattan, dovresti eseguire un elemento per sottrazione di elementi, prendere i valori assoluti, quindi sommare tutti i componenti insieme. Questa è probabilmente la più semplice delle implementazioni da comprendere, ma potrebbe essere probabilmente la più inefficiente ... soprattutto per i set di dati di dimensioni maggiori e una maggiore dimensionalità dei dati.

Un'altra possibile soluzione sarebbe quella di replicare newpoint e rendere questa matrice delle stesse dimensioni x, poi facendo una sottrazione elemento per elemento di questa matrice, quindi sommando su tutte le colonne per ogni riga e facendo la radice quadrata . Pertanto, siamo in grado di fare qualcosa di simile:

N = size(x, 1); 
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2)); 

Per la distanza di Manhattan, si dovrebbe fare:

N = size(x, 1); 
dists = sum(abs(x - repmat(newpoint, N, 1)), 2); 

repmat prende una matrice o vettoriale e ripete loro una certa quantità di volte in una data direzione . Nel nostro caso, vogliamo prendere il nostro vettore newpoint e impilare questo N volte l'uno sopra l'altro per creare una matrice N x M, dove ogni riga è lunga M elementi. Sottragiamo queste due matrici insieme, quindi quadriamo ogni componente. Una volta fatto questo, abbiamo sum su tutte le colonne per ogni riga e alla fine prendiamo la radice quadrata di tutti i risultati. Per la distanza di Manhattan, facciamo la sottrazione, prendiamo il valore assoluto e poi sommiamo.

Tuttavia, il modo più efficiente per farlo a mio parere sarebbe utilizzare bsxfun. Questo essenzialmente fa la replica di cui abbiamo parlato sotto il cofano con una singola chiamata di funzione. Pertanto, il codice sarebbe semplicemente questo:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 

A me questo sembra molto più pulito e al punto. Per la distanza di Manhattan, si dovrebbe fare:

dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); 

Passo # 2

Ora che abbiamo le nostre distanze, abbiamo semplicemente ordinarli. Possiamo usare sort per ordinare le nostre distanze:

[d,ind] = sort(dists); 

d conterrebbe le distanze ordinati in ordine crescente, mentre ind si dice per ogni valore nella indifferenziati serie dove appare nel ordinate risultato. Dobbiamo utilizzare ind, estrarre i primi elementi k di questo vettore, quindi utilizzare ind per indicizzare nella nostra matrice di dati x per restituire quei punti che erano i più vicini a newpoint.

Passo # 3

Il passo finale è quello di tornare ora quelle k punti di dati che sono più vicini a newpoint.Possiamo farlo molto semplicemente:

ind_closest = ind(1:k); 
x_closest = x(ind_closest,:); 

ind_closest dovrebbe contenere gli indici nella matrice dati originali x che sono più vicino a newpoint. Nello specifico, ind_closest contiene le righe da campionare in x per ottenere i punti più vicini a newpoint. x_closest conterrà quei punti dati effettivi.


Per la vostra copia e piacere incollare, questo è ciò che il codice è simile:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 
%// Or do this for Manhattan 
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2); 
[d,ind] = sort(dists); 
ind_closest = ind(1:k); 
x_closest = x(ind_closest,:); 

esecuzione attraverso il vostro esempio, vediamo il nostro codice in azione:

load fisheriris 
x = meas(:,3:4); 
newpoint = [5 1.45]; 
k = 10; 

%// Use Euclidean 
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 
[d,ind] = sort(dists); 
ind_closest = ind(1:k); 
x_closest = x(ind_closest,:); 

Ispezionando ind_closest e x_closest, questo è ciò che otteniamo:

>> ind_closest 

ind_closest = 

    120 
    53 
    73 
    134 
    84 
    77 
    78 
    51 
    64 
    87 

>> x_closest 

x_closest = 

    5.0000 1.5000 
    4.9000 1.5000 
    4.9000 1.5000 
    5.1000 1.5000 
    5.1000 1.6000 
    4.8000 1.4000 
    5.0000 1.7000 
    4.7000 1.4000 
    4.7000 1.4000 
    4.7000 1.5000 

Se è stato eseguito knnsearch, vedrete che la variabile n corrisponde con ind_closest. Tuttavia, la variabile d restituisce le distanze da newpoint a ciascun punto x, non i punti di dati effettivi stessi. Se si desidera che le distanze reali, fare semplicemente il seguente testo dopo il codice che ho scritto:

dist_sorted = d(1:k); 

Nota che la risposta precedente utilizza un solo punto query in un lotto di N esempi. Molto spesso KNN viene utilizzato su più esempi contemporaneamente. Supponendo che abbiamo i punti di query Q che vogliamo testare nella KNN. Ciò risulterebbe in una matrice k x M x Q in cui per ogni esempio o ogni sezione, restituiamo i punti più vicini con una dimensionalità di M. In alternativa, è possibile restituire gli ID dei punti più vicini k risultando così una matrice Q x k. Calcoliamo entrambi.

Un modo ingenuo per fare questo sarebbe applicare il codice sopra in un ciclo e ripetere su ogni esempio.

Qualcosa di simile funzionerebbe dove allochiamo una matrice Q x k e applicare l'approccio basato bsxfun impostare ogni riga della matrice di uscita ai k punti più vicini nell'insieme di dati, dove useremo il set di dati Fisher Iris proprio come quello che aveva prima. Ci sarà anche mantenere la stessa dimensionalità come abbiamo fatto nell'esempio precedente e userò quattro esempi, in modo Q = 4 e M = 2:

%// Load the data and create the query points 
load fisheriris; 
x = meas(:,3:4); 
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; 

%// Define k and the output matrices 
Q = size(newpoints, 1); 
M = size(x, 2); 
k = 10; 
x_closest = zeros(k, M, Q); 
ind_closest = zeros(Q, k); 

%// Loop through each point and do logic as seen above: 
for ii = 1 : Q 
    %// Get the point 
    newpoint = newpoints(ii, :); 

    %// Use Euclidean 
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); 
    [d,ind] = sort(dists); 

    %// New - Output the IDs of the match as well as the points themselves 
    ind_closest(ii, :) = ind(1 : k).'; 
    x_closest(:, :, ii) = x(ind_closest(ii, :), :); 
end 

Anche se questo è molto bello, siamo in grado di fare ancora meglio. C'è un modo per calcolare in modo efficiente la distanza Euclidea quadrata tra due serie di vettori. Lo lascerò come esercizio se vuoi farlo con Manhattan.Consultando this blog, dato che A è una matrice Q1 x M cui ogni riga è un punto di dimensionalità M con Q1 punti e B è una matrice Q2 x M cui ogni riga è anche un punto di dimensionalità M con Q2 punti, si può calcolare in modo efficiente una matrice di distanza D(i, j) dove l'elemento alla riga i e colonna j indica la distanza tra fila i di A e fila j di B utilizzando la formula seguente matrice:

nA = sum(A.^2, 2); %// Sum of squares for each row of A 
nB = sum(B.^2, 2); %// Sum of squares for each row of B 
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix 
D = sqrt(D); %// Compute square root to complete calculation 

Pertanto, se facciamo sì che A sia una matrice di punti interrogazione e B sia il set di dati costituito dai tuoi dati originali, possiamo determinare i punti più vicini ordinando ogni riga individualmente e determinando le posizioni di ciascuna riga che erano le più piccole. Possiamo inoltre utilizzarlo anche per recuperare i punti effettivi stessi.

Pertanto:

%// Load the data and create the query points 
load fisheriris; 
x = meas(:,3:4); 
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; 

%// Define k and other variables 
k = 10; 
Q = size(newpoints, 1); 
M = size(x, 2); 

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A 
nB = sum(x.^2, 2); %// Sum of squares for each row of B 
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix 
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2); 

%// Get the indices of the closest distances 
ind_closest = ind(:, 1:k); 

%// Also get the nearest points 
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]); 

Vediamo che abbiamo usato la logica per il calcolo della matrice di distanza è lo stesso ma alcune variabili hanno cambiato per soddisfare l'esempio. Inoltre, ordiniamo ogni riga in modo indipendente utilizzando la versione a due ingressi di sort e quindi ind conterrà gli ID per riga e d conterrà le distanze corrispondenti. Individuiamo quindi quali indici sono i più vicini a ciascun punto di query semplicemente troncando questa matrice alle colonne k. Quindi utilizziamo permute e reshape per determinare quali sono i punti più vicini associati. Per prima cosa usiamo tutti gli indici più vicini e creiamo una matrice di punti che impila tutti gli ID uno sopra l'altro in modo da ottenere una matrice Q * k x M. L'utilizzo di reshape e permute ci consente di creare la nostra matrice 3D in modo che diventi una matrice k x M x Q come specificato. Se si desidera ottenere le distanze effettive, possiamo indicizzare in d e prendere ciò di cui abbiamo bisogno. Per fare ciò, è necessario utilizzare sub2ind per ottenere gli indici lineari in modo da poter indicizzare in d in un colpo. I valori di ind_closest ci forniscono già le colonne a cui è necessario accedere. Le righe a cui dobbiamo accedere sono semplicemente 1, k volte, 2, k volte, ecc. Fino a Q. k è per il numero di punti abbiamo voluto tornare:

Quando si esegue il codice sopra per i punti di query di cui sopra, questi sono i indici, punti e distanze otteniamo:

>> ind_closest 

ind_closest = 

    120 134 53 73 84 77 78 51 64 87 
    123 119 118 106 132 108 131 136 126 110 
    107 62 86 122 71 127 139 115 60 52 
    99 65 58 94 60 61 80 44 54 72 

>> x_closest 

x_closest(:,:,1) = 

    5.0000 1.5000 
    6.7000 2.0000 
    4.5000 1.7000 
    3.0000 1.1000 
    5.1000 1.5000 
    6.9000 2.3000 
    4.2000 1.5000 
    3.6000 1.3000 
    4.9000 1.5000 
    6.7000 2.2000 


x_closest(:,:,2) = 

    4.5000 1.6000 
    3.3000 1.0000 
    4.9000 1.5000 
    6.6000 2.1000 
    4.9000 2.0000 
    3.3000 1.0000 
    5.1000 1.6000 
    6.4000 2.0000 
    4.8000 1.8000 
    3.9000 1.4000 


x_closest(:,:,3) = 

    4.8000 1.4000 
    6.3000 1.8000 
    4.8000 1.8000 
    3.5000 1.0000 
    5.0000 1.7000 
    6.1000 1.9000 
    4.8000 1.8000 
    3.5000 1.0000 
    4.7000 1.4000 
    6.1000 2.3000 


x_closest(:,:,4) = 

    5.1000 2.4000 
    1.6000 0.6000 
    4.7000 1.4000 
    6.0000 1.8000 
    3.9000 1.4000 
    4.0000 1.3000 
    4.7000 1.5000 
    6.1000 2.5000 
    4.5000 1.5000 
    4.0000 1.3000 

>> dist_sorted 

dist_sorted = 

    0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041 
    0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296 
    0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180 
    2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732 

Per confronta questo con knnsearch, dovresti invece specificare una matrice di punti per il secondo parametro in cui ogni riga è un punto di interrogazione e vedrai che gli indici e le distanze ordinate corrispondono tra questa implementazione e knnsearch.


Spero che questo ti aiuti. In bocca al lupo!

+0

Questo è davvero molto utile! Grazie mille! Ora capisco @rayryeng –

+0

@Young_DataAnalyst - Il mio piacere! Per favore considera di accettare la mia risposta se ti ho aiutato :). In bocca al lupo! – rayryeng

+0

Ho anche imparato molto, grazie! – Rashid