示例代码

# coding: utf8

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

# a = np.random.rand(4,3)
a = np.array(matrix)
fig, ax = plt.subplots(figsize=(9, 9))

sns.heatmap(pd.DataFrame(np.round(a, 2)), annot=False, vmax=1, vmin=0, xticklabels=False, yticklabels=False, square=True, cmap="YlGnBu")
ax.set_title('title', fontsize=18)
ax.set_ylabel('ylabel', fontsize=18)
ax.set_xlabel('xlabel', fontsize=18)

plt.show()

更多参考:https://www.jianshu.com/p/e195a09a8ca9