2013-08-07 9 views
16

Per gli array NumPy 1-D, queste due espressioni dovrebbero produrre lo stesso risultato (teoricamente):Numpy: Differenza tra punto (a, b) e (a * b) .sum()

(a*b).sum()/a.sum() 
dot(a, b)/a.sum() 

L' quest'ultimo utilizza dot() ed è più veloce. Ma quale è più preciso? Perché?

Segue un contesto.

Volevo calcolare la varianza ponderata di un campione utilizzando numpy. Ho trovato l'espressione dot() in another answer, con un commento che indica che dovrebbe essere più preciso. Tuttavia nessuna spiegazione è data lì.

+1

Questa è una media ponderata, giusto? Potresti voler semplicemente usare ['np.average'] (http://docs.scipy.org/doc/numpy/reference/generated/numpy.average.html). – user2357112

+0

Penso che la parte "numericamente precisa" si riferisse a sottrarre la media dai valori, piuttosto che usare 'punto'. – user2357112

risposta

9

Numpy dot è una delle routine che chiama la libreria BLAS che si collega su Compile (o crea il proprio). L'importanza di ciò è che la libreria BLAS può fare uso di operazioni Multiply-accumulate (in genere Fused-Multiply Add) che limitano il numero di arrotondamenti eseguiti dal calcolo.

adottare le seguenti:

>>> a=np.ones(1000,dtype=np.float128)+1E-14 
>>> (a*a).sum() 
1000.0000000000199948 
>>> np.dot(a,a) 
1000.0000000000199948 

non è esatto, ma abbastanza vicino.

>>> a=np.ones(1000,dtype=np.float64)+1E-14 
>>> np.dot(a,a) 
1000.0000000000176 #off by 2.3948e-12 
>>> (a*a).sum() 
1000.0000000000059 #off by 1.40948e-11 

Il np.dot(a, a) sarà più precisa dei due come si usa circa la metà del numero di arrotondamenti virgola mobile che il naif (a*a).sum() fa.

Un libro di Nvidia ha il seguente esempio per 4 cifre di precisione. rn sta per 4 intorno alla prossima 4 cifre:

x = 1.0008 
x2 = 1.00160064     # true value 
rn(x2 − 1) = 1.6006 × 10−4   # fused multiply-add 
rn(rn(x2) − 1) = 1.6000 × 10−4  # multiply, then add 

di numeri in virgola mobile dei corsi non sono arrotondati al decimale 16 in base 10, ma si ottiene l'idea.

Posizionamento np.dot(a,a) nella notazione sopra con alcuni pseudo codice addizionale:

out=0 
for x in a: 
    out=rn(x*x+out) #Fused multiply add 

Mentre (a*a).sum() è:

arr=np.zeros(a.shape[0]) 
for x in range(len(arr)): 
    arr[x]=rn(a[x]*a[x]) 

out=0 
for x in arr: 
    out=rn(x+out) 

Da questa sua facile vedere che il numero viene arrotondato doppio delle volte usando (a*a).sum() rispetto a np.dot(a,a). Queste piccole differenze sommate possono cambiare la risposta in modo preciso. Ulteriori exmaples possono essere trovati here.

+5

Se numpy sta usando un blas ottimizzato per la macchina dell'utente, e ha un processore che ha fma. Uno non dovrebbe fare troppe ipotesi basate su "il blas * può * farlo ..." –

+0

Sì, anche per Intel Ivy Bridge 'a + b * c' compila a' mulss' seguito da 'addss'. –

+0

Hai aggiunto opzioni di compilazione adeguate? (come -march = core-avx2 -mavx -mfma ... con gcc) –