2015-10-10 13 views
10

Tracciare 2 distplots o grafici a dispersione in una sottotrama grandi opere:Come tracciare 2 lmplot di seaborn side-by-side?

import matplotlib.pyplot as plt 
import numpy as np 
import seaborn as sns 
import pandas as pd 
%matplotlib inline 

# create df 
x = np.linspace(0, 2 * np.pi, 400) 
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)}) 

# Two subplots 
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) 
ax1.plot(df.x, df.y) 
ax1.set_title('Sharing Y axis') 
ax2.scatter(df.x, df.y) 

plt.show() 

Subplot example

Ma quando faccio lo stesso con un lmplot al posto di uno dei due altri tipi di grafici ottengo un errore:

AttributeError: 'AxesSubplot' object has no attribute 'lmplot'

C'è un modo per tracciare questi tipi di grafici affiancati?

+0

BTW: il tuo esempio non viene eseguito. La variabile 'x' non è definita nella definizione dataframe della colonna' "y" '. –

+0

Grazie per aver notato @PaulH. Corretto. – samthebrand

risposta

24

Si ottiene questo errore perché matplotlib e i suoi oggetti sono completamente inconsapevoli delle funzioni di Seaborn.

Passare gli oggetti assi (ad esempio, ax1 e ax2) per seaborn.regplot o si può saltare la definizione quelli e utilizzare il col kwarg di seaborn.lmplot

Con le stesse importazioni, pre-definiscono i vostri assi e utilizzando regplot assomiglia a questo :

# create df 
x = np.linspace(0, 2 * np.pi, 400) 
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)}) 
df.index.names = ['obs'] 
df.columns.names = ['vars'] 

idx = np.array(df.index.tolist(), dtype='float') # make an array of x-values 

# call regplot on each axes 
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True) 
sns.regplot(x=idx, y=df['x'], ax=ax1) 
sns.regplot(x=idx, y=df['y'], ax=ax2) 

enter image description here

Utilizzando lmplot richiede la vostra dataframe to be tidy. Continuando dal codice di cui sopra:

tidy = (
    df.stack() # pull the columns into row variables 
     .to_frame() # convert the resulting Series to a DataFrame 
     .reset_index() # pull the resulting MultiIndex into the columns 
     .rename(columns={0: 'val'}) # rename the unnamed column 
) 
sns.lmplot(x='obs', y='val', col='vars', hue='vars', data=tidy) 

enter image description here