2016-07-15 129 views
6

Quando ho calcolato terzo ordine momenti di una matrice X con N righe e n colonne, io di solito uso einsum:Alternative a NumPy einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N 

Questo di solito funziona bene, ma ora sto lavorando con i valori più grandi, vale a dire n = 120 e N = 100000, e einsum restituisce il seguente errore:

ValueError: iterator is too large

l'alternativa di fare 3 cicli annidati è unfeasable, così ho mi sto chiedendo se ci sia qualche tipo di alternativa.

risposta

4

Si noti che il calcolo di tale dovrà fare almeno ~ n × N = 173 miliardi operazioni (non considerando la simmetria), quindi sarà lenta a meno che NumPy ha accesso alla GPU o qualcosa del genere. Su un computer moderno con una CPU da ~ 3 GHz, l'intero calcolo dovrebbe richiedere circa 60 secondi, supponendo che SIMD/parallelo non accelerino.


Per il test, cominciamo con N = 1000. Useremo questo per controllare la correttezza e le prestazioni:

#!/usr/bin/env python3 

import numpy 
import time 

numpy.random.seed(0) 

n = 120 
N = 1000 
X = numpy.random.random((N, n)) 

start_time = time.time() 

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X) 

end_time = time.time() 

print('check:', M3[2,4,6], '= 125.401852515?') 
print('check:', M3[4,2,6], '= 125.401852515?') 
print('check:', M3[6,4,2], '= 125.401852515?') 
print('check:', numpy.sum(M3), '= 218028826.631?') 
print('total time =', end_time - start_time) 

Questa operazione richiede circa 8 secondi. Questa è la linea di base.

Cominciamo con il ciclo nidificato 3 come l'alternativa:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l]) 
# ~27 seconds 

Questo richiede circa mezzo minuto, non va bene! Uno dei motivi è dato dal fatto che si tratta in realtà di quattro cicli annidati: numpy.sum può anche essere considerato un ciclo.

Notiamo che la somma può essere trasformato in un prodotto scalare per rimuovere questo 4 ° ciclo:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l] 
# 14 seconds 

molto meglio ora, ma ancora lento. Ma notiamo che il prodotto il punto può essere cambiata in una moltiplicazione di matrici per rimuovere un loop:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     M3[j,k] = X[:,j] * X[:,k] @ X 
# ~0.5 seconds 

Eh? Ora questo è persino molto più efficiente di einsum! Potremmo anche verificare che la risposta sia effettivamente corretta.

Possiamo andare oltre? Sì! Potremmo eliminare l'anello k da:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = numpy.repeat(X[:,j], n).reshape((N, n)) 
    M3[j] = (Y * X).T @ X 
# ~0.3 seconds 

ci potrebbe anche usare radiodiffusione (cioè a * [b,c] == [a*b, a*c] per ogni fila di X) per evitare di fare il numpy.repeat (grazie @Divakar):

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = X[:,j].reshape((N, 1)) 
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j] 
    M3[j] = (Y * X).T @ X 
# ~0.16 seconds 

Se si scala questo a N = 100000 ci si aspetta che il programma impieghi 16 secondi, che è entro il limite teorico, quindi eliminare lo j potrebbe non essere di grande aiuto (ma questo potrebbe rendere il codice davvero difficile da capire). Potremmo accettare come soluzione finale.


Nota: Se si sta utilizzando Python 2, a @ b è equivalente a a.dot(b).

+0

ottima risposta, grazie! –

+0

Ottima idea davvero. Se posso aggiungere un po 'di trasmissione qui, potremmo evitare di creare 'Y' e ottenere direttamente l'output iterativo:' (X [:, None, j] * X) .T @ X'. Questo dovrebbe darci un ulteriore incremento delle prestazioni. – Divakar

+0

@Divakar: Grazie! Aggiornato. – kennytm