2012-12-06 4 views
8

Per impostazione predefinita, il decollo di una matrice di visualizzazione numpy perde la relazione di visualizzazione, anche se anche la base dell'array viene decapata. La mia situazione è che ho alcuni oggetti contenitori complessi che sono in salamoia. E in alcuni casi, alcuni dati contenuti sono visualizzazioni in altri. Il salvataggio di una matrice indipendente di ciascuna vista non è solo una perdita di spazio, ma anche i dati ricaricati hanno perso la relazione della vista.Conservazione della visualizzazione numpy durante il decapaggio

Un semplice esempio potrebbe essere (ma nel mio caso il contenitore è più complesso di un dizionario):

import numpy as np 
import cPickle 

tmp = np.zeros(2) 
d1 = dict(a=tmp,b=tmp[:]) # d1 to be saved: b is a view on a 

pickled = cPickle.dumps(d1) 
d2 = cPickle.loads(pickled) # d2 reloaded copy of d1 container 

print 'd1 before:', d1 
d1['b'][:] = 1 
print 'd1 after: ', d1 

print 'd2 before:', d2 
d2['b'][:] = 1 
print 'd2 after: ', d2 

che sarebbe stampare:

d1 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d1 after: {'a': array([ 1., 1.]), 'b': array([ 1., 1.])} 
d2 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d2 after: {'a': array([ 0., 0.]), 'b': array([ 1., 1.])} # not a view anymore 

La mia domanda:

(1) C'è un modo per preservarlo? (2) (ancora meglio) c'è un modo per farlo solo se la base è in salamoia

Per la (1) Penso che ci può essere qualche modo cambiando il __setstate__, __reduce_ex_, ecc ... del vedi matrice. Ma per ora non mi confido con questi. Per il (2) non ne ho idea.

risposta

7

Questo non viene eseguito in NumPy, perché non sempre ha senso mettere sottosopra l'array di base, e pickle non espone la possibilità di controllare se un altro oggetto viene decapitato come parte della sua API.

Ma questo tipo di controllo può essere eseguito in un contenitore personalizzato per gli array NumPy. Per esempio:

import numpy as np 
import pickle 

def byte_offset(array, source): 
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0] 

class SharedPickleList(object): 
    def __init__(self, arrays): 
     self.arrays = list(arrays) 

    def __getstate__(self): 
     unique_ids = {id(array) for array in self.arrays} 
     source_arrays = {} 
     view_tuples = {} 
     for array in self.arrays: 
      if array.base is None or id(array.base) not in unique_ids: 
       # only use views if the base is also being pickled 
       source_arrays[id(array)] = array 
      else: 
       view_tuples[id(array)] = (array.shape, 
              array.dtype, 
              id(array.base), 
              byte_offset(array, array.base), 
              array.strides) 
     order = [id(array) for array in self.arrays] 
     return (source_arrays, view_tuples, order) 

    def __setstate__(self, state): 
     source_arrays, view_tuples, order = state 
     view_arrays = {} 
     for k, view_state in view_tuples.items(): 
      (shape, dtype, source_id, offset, strides) = view_state 
      buffer = source_arrays[source_id].data 
      array = np.ndarray(shape, dtype, buffer, offset, strides) 
      view_arrays[k] = array 
     self.arrays = [source_arrays[i] 
         if i in source_arrays 
         else view_arrays[i] 
         for i in order] 

# unit tests 
def check_roundtrip(arrays): 
    unpickled_arrays = pickle.loads(pickle.dumps(
     SharedPickleList(arrays))).arrays 
    assert all(a.shape == b.shape and (a == b).all() 
       for a, b in zip(arrays, unpickled_arrays)) 

indexers = [0, None, slice(None), slice(2), slice(None, -1), 
      slice(None, None, -1), slice(None, 6, 2)] 

source0 = np.random.randint(100, size=10) 
arrays0 = [np.asarray(source0[k1]) for k1 in indexers] 
check_roundtrip([source0] + arrays0) 

source1 = np.random.randint(100, size=(8, 10)) 
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers] 
check_roundtrip([source1] + arrays1) 

Ciò si traduce in un notevole risparmio di spazio:

source = np.random.rand(1000) 
arrays = [source] + [source[n:] for n in range(99)] 
print(len(pickle.dumps(arrays, protocol=-1))) 
# 766372 
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1))) 
# 11833