示例代码
# coding: utf8
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
# a = np.random.rand(4,3)
a = np.array(matrix)
fig, ax = plt.subplots(figsize=(9, 9))
sns.heatmap(pd.DataFrame(np.round(a, 2)), annot=False, vmax=1, vmin=0, xticklabels=False, yticklabels=False, square=True, cmap="YlGnBu")
ax.set_title('title', fontsize=18)
ax.set_ylabel('ylabel', fontsize=18)
ax.set_xlabel('xlabel', fontsize=18)
plt.show()
更多参考:https://www.jianshu.com/p/e195a09a8ca9