2010-03-12 4 views
5

Ho un array Nx5 contenente N vettori di forma 'id', 'x', 'y', 'z' ed 'energia'. Ho bisogno di rimuovere i punti duplicati (vale a dire dove x, y, z corrispondono tutti) entro una tolleranza di dire 0.1. Idealmente potrei creare una funzione in cui passo nell'array, colonne che devono corrispondere e una tolleranza sulla partita.Rimozione di duplicati (entro una tolleranza data) da una schiera di vettori Numpy

Seguendo this thread on Scipy-user, posso rimuovere i duplicati in base a un array completo utilizzando gli array di record, ma devo semplicemente abbinare una parte di un array. Inoltre questo non corrisponderà entro una certa tolleranza.

Potrei faticosamente scorrere con un ciclo for in Python ma esiste un modo migliore per Numponic?

+1

C'è un problema intrinseco w/le specifiche che si danno, motivo per cui è improbabile di trovare una soluzione precotti: dire per chiarezza la tolleranza è in realtà 0,11, yez sempre identico, e il ' x's sono 0, 0.1, 0.2, 0.3, 0.4, ... - ora quali sono i "duplicati"? Con la def, 0.1 è "un duplicato" sia di 0 che di 0.2, ma questi due NON sono duplicati l'uno dell'altro - quindi la relazione "duplicata" non è transitiva e quindi non può indurre una partizione! Dovrai definire tu stesso delle euristiche, poiché non esiste una soluzione veramente "matematicamente corretta" (non può essere: nessuna partizione!). –

+1

Vedo il tuo punto. Nel dominio del problema sto lavorando all'interno anche se prevedo il clustering, cioè la spaziatura media tra i punti all'interno della tolleranza ~ dei cluster mentre la spaziatura media tra i cluster >> spaziatura media tra i punti all'interno di un cluster. La dimensione della tolleranza dovrebbe essere tale che per i tuoi scopi qualsiasi punto nel cluster potrebbe essere il punto "canonico". – Brendan

risposta

2

Si potrebbe guardare scipy.spatial.KDTree. Quanto è grande N?
Aggiunto: oops, tree.query_pairs non è in scipy 0.7.1.

In caso di dubbio, usare la forza bruta: dividere lo spazio (qui lato^3) in piccole celle, un punto per ogni cella:

""" scatter points to little cells, 1 per cell """ 
from __future__ import division   
import sys        
import numpy as np      

side = 100        
npercell = 1 # 1: ~ 1/e empty   
exec "\n".join(sys.argv[1:]) # side= ... 
N = side**3 * npercell     
print "side: %d npercell: %d N: %d" % (side, npercell, N) 
np.random.seed(1)      
points = np.random.uniform(0, side, size=(N,3)) 

cells = np.zeros((side,side,side), dtype=np.uint) 
id = 1 
for p in points.astype(int): 
    cells[tuple(p)] = id     
    id += 1        

cells = cells.flatten() 
    # A C, an E-flat, and a G walk into a bar. 
    # The bartender says, "Sorry, but we don't serve minors." 
nz = np.nonzero(cells)[0]    
print "%d cells have points" % len(nz) 
print "first few ids:", cells[nz][:10] 
+0

Usare KDTree è una grande idea, potrei implementarla più tardi – Brendan

0

non ho ancora testato questo, ma se si ordina l'array lungo x allora y questo dovrebbe darti la lista dei duplicati. È quindi necessario scegliere quale mantenere.

def find_dup_xyz(anarray, x, y, z): #for example in an data = array([id,x,y,z,energy]) x=1 y=2 z=3 
    dup_xyz=[] 
    for i, row in enumerated(sortedArray): 
     nx=1 
     while (abs(row[x] - sortedArray[i+nx[x])<0.1) and (abs(row[z] and sortedArray[i+nx[y])<0.1) and (abs(row[z] - sortedArray[i+nx[z])<0.1): 
       nx=+1 
       dup_xyz.append(row) 
return dup_xyz 

anche appena trovato questo http://mail.scipy.org/pipermail/scipy-user/2008-April/016504.html

0

ho finalmente ottenuto una soluzione che io sono felice con, questo è un taglio leggermente ripulito e incolla dal mio codice. Potrebbero esserci ancora alcuni bug.

Nota: che utilizza ancora un ciclo "for". Potrei usare l'idea di Denis di KDTree sopra accoppiato con l'arrotondamento per ottenere la soluzione completa.

import numpy as np 

def remove_duplicates(data, dp_tol=None, cols=None, sort_by=None): 
    ''' 
    Removes duplicate vectors from a list of data points 
    Parameters: 
     data  An MxN array of N vectors of dimension M 
     cols  An iterable of the columns that must match 
        in order to constitute a duplicate 
        (default: [1,2,3] for typical Klist data array) 
     dp_tol  An iterable of three tolerances or a single 
        tolerance for all dimensions. Uses this to round 
        the values to specified number of decimal places 
        before performing the removal. 
        (default: None) 
     sort_by  An iterable of columns to sort by (default: [0]) 

    Returns: 
     MxI Array An array of I vectors (minus the 
        duplicates) 

    EXAMPLES: 

    Remove a duplicate 

    >>> import wien2k.utils 
    >>> import numpy as np 
    >>> vecs1 = np.array([[1, 0, 0, 0], 
    ...  [2, 0, 0, 0], 
    ...  [3, 0, 0, 1]]) 
    >>> remove_duplicates(vecs1) 
    array([[1, 0, 0, 0], 
      [3, 0, 0, 1]]) 

    Remove duplicates with a tolerance 

    >>> vecs2 = np.array([[1, 0, 0, 0 ], 
    ...  [2, 0, 0, 0.001 ], 
    ...  [3, 0, 0, 0.02 ], 
    ...  [4, 0, 0, 1  ]]) 
    >>> remove_duplicates(vecs2, dp_tol=2) 
    array([[ 1. , 0. , 0. , 0. ], 
      [ 3. , 0. , 0. , 0.02], 
      [ 4. , 0. , 0. , 1. ]]) 

    Remove duplicates and sort by k values 

    >>> vecs3 = np.array([[1, 0, 0, 0], 
    ...  [2, 0, 0, 2], 
    ...  [3, 0, 0, 0], 
    ...  [4, 0, 0, 1]]) 
    >>> remove_duplicates(vecs3, sort_by=[3]) 
    array([[1, 0, 0, 0], 
      [4, 0, 0, 1], 
      [2, 0, 0, 2]]) 

    Change the columns that constitute a duplicate 

    >>> vecs4 = np.array([[1, 0, 0, 0], 
    ...  [2, 0, 0, 2], 
    ...  [1, 0, 0, 0], 
    ...  [4, 0, 0, 1]]) 
    >>> remove_duplicates(vecs4, cols=[0]) 
    array([[1, 0, 0, 0], 
      [2, 0, 0, 2], 
      [4, 0, 0, 1]]) 

    ''' 
    # Deal with the parameters 
    if sort_by is None: 
     sort_by = [0] 
    if cols is None: 
     cols = [1,2,3] 
    if dp_tol is not None: 
     # test to see if already an iterable 
     try: 
      null = iter(dp_tol) 
      tols = np.array(dp_tol) 
     except TypeError: 
      tols = np.ones_like(cols) * dp_tol 
     # Convert to numbers of decimal places 
     # Find the 'order' of the axes 
    else: 
     tols = None 

    rnd_data = data.copy() 
    # set the tolerances 
    if tols is not None: 
     for col,tol in zip(cols, tols): 
      rnd_data[:,col] = np.around(rnd_data[:,col], decimals=tol) 

    # TODO: For now, use a slow Python 'for' loop, try to find a more 
    # numponic way later - see: http://stackoverflow.com/questions/2433882/ 
    sorted_indexes = np.lexsort(tuple([rnd_data[:,col] for col in cols])) 
    rnd_data = rnd_data[sorted_indexes] 
    unique_kpts = [] 
    for i in xrange(len(rnd_data)): 
     if i == 0: 
      unique_kpts.append(i)  
     else: 
      if (rnd_data[i, cols] == rnd_data[i-1, cols]).all(): 
       continue 
      else: 
       unique_kpts.append(i)  

    rnd_data = rnd_data[unique_kpts] 
    # Now sort 
    sorted_indexes = np.lexsort(tuple([rnd_data[:,col] for col in sort_by])) 
    rnd_data = rnd_data[sorted_indexes] 
    return rnd_data 



if __name__ == '__main__': 
    import doctest 
    doctest.testmod()