Sto provando a racchiudere la funzione LAPACK dgtsv
(un risolutore per i sistemi di equazioni tridiagonali) utilizzando Cython.Avvolgere una funzione LAPACKE utilizzando Cython
mi sono imbattuto this previous answer, ma dal momento che dgtsv
non è una delle funzioni LAPACK che sono avvolti in scipy.linalg
non credo che posso usare questo particolare approccio. Invece ho cercato di seguire this example.
Ecco il contenuto del mio file di lapacke.pxd
:
ctypedef int lapack_int
cdef extern from "lapacke.h" nogil:
int LAPACK_ROW_MAJOR
int LAPACK_COL_MAJOR
lapack_int LAPACKE_dgtsv(int matrix_order,
lapack_int n,
lapack_int nrhs,
double * dl,
double * d,
double * du,
double * b,
lapack_int ldb)
... ecco la mia sottile involucro Cython in _solvers.pyx
:
#!python
cimport cython
from lapacke cimport *
cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
double[:, ::1] B):
cdef:
lapack_int n = D.shape[0]
lapack_int nrhs = B.shape[1]
lapack_int ldb = B.shape[0]
double * dl = &DL[0]
double * d = &D[0]
double * du = &DU[0]
double * b = &B[0, 0]
lapack_int info
info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)
return info
... ed ecco uno script Python involucro e di prova:
import numpy as np
from scipy import sparse
from cymodules import _solvers
def trisolve_lapacke(dl, d, du, b, inplace=False):
if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
or b.shape != d.shape):
raise ValueError('Invalid diagonal shapes')
if b.ndim == 1:
# b is (LDB, NRHS)
b = b[:, None]
# be sure to force a copy of d and b if we're not solving in place
if not inplace:
d = d.copy()
b = b.copy()
# this may also force copies if arrays are improperly typed/noncontiguous
dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
for v in (dl, d, du, b))
# b will now be modified in place to contain the solution
info = _solvers.TDMA_lapacke(dl, d, du, b)
print info
return b.ravel()
def test_trisolve(n=20000):
dl = np.random.randn(n - 1)
d = np.random.randn(n)
du = np.random.randn(n - 1)
M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
x = np.random.randn(n)
b = M.dot(x)
x_hat = trisolve_lapacke(dl, d, du, b)
print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)
Sfortunatamente, test_trisolve
solo se gfaults sulla chiamata a _solvers.TDMA_lapacke
. Sono abbastanza sicuro che il mio sia corretto - ldd _solvers.so
mostra che _solvers.so
viene collegato alle librerie condivise corrette in fase di esecuzione.
Non sono davvero sicuro di come procedere da qui - qualche idea?
Un breve aggiornamento:
per i valori minori di n
tendo a non ottenere subito segfaults, ma lo faccio ottenere risultati senza senso (|| x - x_hat || dovrebbe essere molto vicino a 0):
In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 6.23202576396
In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| = 3.88623414288
In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 2.60190676562
In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 3.86631743386
In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault
solito LAPACKE_dgtsv
torna con code 0
(che dovrebbe indicare il successo), ma di tanto in tanto ricevo 01.234.924,691 mila, il che significa che l'argomento 7 (b
) aveva un valore non valido. Quello che sta succedendo è che solo il primo valore di b
viene effettivamente modificato sul posto. Se continuo a chiamare test_trisolve
alla fine avrò un segfault anche quando n
è piccolo.