我使用的是pandas
的scatter_matrix
,我想知道如何在每个散点矩阵上绘制2D数组?另外,我如何识别输出的哪个AxesSubplot
是输出图上的哪个矩阵?
发布于 2017-12-30 04:45:49
scatter_matrix
是pandas
的一个便利函数,来自pandas.plotting
子模块。虽然the documentation is scarce (和docstring只是更有帮助一点),但这个例子让我们很容易理解它是如何工作的。考虑文档中的示例:
import numpy as np # only needed for the example input
import pandas as pd
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns=['a', 'b', 'c', 'd'])
axs = scatter_matrix(df, alpha=0.2, figsize=(6,6), diagonal='kde')
axs[0,0].get_figure().show() # or import and call matplotlib.pyplot.show
注意底轴和左轴上的标签:这些标签指示输入数据帧的哪些列在给定行/列中相互对比。在第一列曲线图中,x轴对应于df.a
,在第二行曲线图中,y轴对应于df.b
,依此类推(并且在对角线中绘制相应列的密度或直方图)。因此,绘图矩阵中的转置元素对应于x和y数据的交换,即绘图相对于x=y线的反映。如果你仔细观察上面的图,你会发现确实是这样的。
换句话说,您不需要计算来自各个轴的数据,因为您可以直接控制您的输入数据。在非对角轴axs[i,j]
中,x数据由df[df.columns[j]]
给出,y数据由df[df.columns[i]]
给出。这里有一个快速的技巧来帮助可视化订单:
axs = scatter_matrix(df, alpha=0.2, figsize=(6,6), diagonal='kde')
for i in range(axs.shape[0]):
for j in range(axs.shape[1]):
if i == j:
continue
axs[i,j].set_title('x: {}, y: {}'.format(df.columns[j],df.columns[i]),
position=(0.5,0.5))
因此,虽然可以深入挖掘每个AxesSubplot
对象的内部并从中提取数据,但直接使用df
的各个列要简单得多。一个例外是对角线:在内核密度图的情况下(假设diagonal='kde'
关键字被传递给scatter_matrix
),您不能直接访问底层数据。在这种情况下,您可以从对角线AxesSubplots
中提取直线
import matplotlib.pyplot as plt
index = 0
xdat,ydat = axs[index,index].get_lines()[0].get_data() # example for diagonal [0,0]
plt.figure()
plt.plot(xdat,ydat,'-')
plt.xlabel(df.columns[index])
plt.ylabel('density')
https://stackoverflow.com/questions/47967734
复制相似问题