StratifiedKFold是一种将数据集中每一类样本数据按均等方式拆分的方法。
下面以几个示例代码来介绍。
实例代码如下:
from sklearn.model_selection import StratifiedKFold import numpy as np X = np.ones(10) y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1] skf = StratifiedKFold(n_splits=2) for train, test in skf.split(X, y): print("%s %s" % (train, test))
输出如下:
[2 3 7 8 9] [0 1 4 5 6] [0 1 4 5 6] [2 3 7 8 9]
其中0,1,2...9被均匀地分配到了测试集中。
n_splits=3时,输出如下:
[2 3 6 7 8 9] [0 1 4 5] [0 1 3 4 5 8 9] [2 6 7] [0 1 2 4 5 6 7] [3 8 9]
n_splits=4时,输出如下:
[1 2 3 6 7 8 9] [0 4 5] [0 2 3 4 5 8 9] [1 6 7] [0 1 3 4 5 6 7 9] [2 8] [0 1 2 4 5 6 7 8] [3 9]
当n小于2或者大于最少样本数(此处为4)时,会报错。
参考:
http://scikit-learn.org/stable/modules/cross_validation.html