È necessario modificare np.linalg.det
per ottenere la velocità. L'idea è che det()
è una funzione Python, fa un sacco di controlli prima e chiama la routine fortran e calcola alcuni array per ottenere il risultato.
Ecco il codice da NumPy:
def slogdet(a):
a = asarray(a)
_assertRank2(a)
_assertSquareness(a)
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
a = _to_native_byte_order(a)
n = a.shape[0]
if isComplexType(t):
lapack_routine = lapack_lite.zgetrf
else:
lapack_routine = lapack_lite.dgetrf
pivots = zeros((n,), fortran_int)
results = lapack_routine(n, n, a, n, pivots, 0)
info = results['info']
if (info < 0):
raise TypeError, "Illegal input to Fortran routine"
elif (info > 0):
return (t(0.0), _realType(t)(-Inf))
sign = 1. - 2. * (add.reduce(pivots != arange(1, n + 1)) % 2)
d = diagonal(a)
absd = absolute(d)
sign *= multiply.reduce(d/absd)
log(absd, absd)
logdet = add.reduce(absd, axis=-1)
return sign, logdet
def det(a):
sign, logdet = slogdet(a)
return sign * exp(logdet)
per accelerare questa funzione, è possibile omettere il controllo (diventa vostra responsabilità di mantenere l'ingresso a destra), e raccogliere i risultati FORTRAN in una matrice, e eseguire i calcoli finali per tutti i piccoli array senza ciclo.
Ecco il mio risultato:
import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite
N = 1000
M = np.random.rand(N*10*10).reshape(N, 10, 10)
def dets(a):
length = a.shape[0]
dm = np.zeros(length)
for i in xrange(length):
dm[i] = np.linalg.det(M[i])
return dm
def dets_fast(a):
m = a.shape[0]
n = a.shape[1]
lapack_routine = lapack_lite.dgetrf
pivots = np.zeros((m, n), intc)
flags = np.arange(1, n + 1).reshape(1, -1)
for i in xrange(m):
tmp = a[i]
lapack_routine(n, n, tmp, n, pivots[i], 0)
sign = 1. - 2. * (np.add.reduce(pivots != flags, axis=1) % 2)
idx = np.arange(n)
d = a[:, idx, idx]
absd = np.absolute(d)
sign *= np.multiply.reduce(d/absd, axis=1)
np.log(absd, absd)
logdet = np.add.reduce(absd, axis=-1)
return sign * np.exp(logdet)
print np.allclose(dets(M), dets_fast(M.copy()))
e la velocità è:
timeit dets(M)
10 loops, best of 3: 159 ms per loop
timeit dets_fast(M)
100 loops, best of 3: 10.7 ms per loop
Quindi, in questo modo, è possibile velocizzare di 15 volte. Questo è un buon risultato senza alcun codice compilato.
note: ometto il controllo degli errori per la routine fortran.
La ringrazio molto per il vostro codice di esempio e che persino fatto i tempi. Funziona molto bene per piccole matrici quadratiche (O (MxM)) e non peggiora del numpy.linalg.det implementato per N ~ M. – user1825991