StratifiedKFold是一种将数据集中每一类样本数据按均等方式拆分的方法。


下面以几个示例代码来介绍。


实例代码如下:

from sklearn.model_selection import StratifiedKFold
import numpy as np

X = np.ones(10)

y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]

skf = StratifiedKFold(n_splits=2)

for train, test in skf.split(X, y):
    print("%s %s" % (train, test))


输出如下:

[2 3 7 8 9] [0 1 4 5 6]
[0 1 4 5 6] [2 3 7 8 9]


其中0,1,2...9被均匀地分配到了测试集中。


n_splits=3时,输出如下:

[2 3 6 7 8 9] [0 1 4 5]
[0 1 3 4 5 8 9] [2 6 7]
[0 1 2 4 5 6 7] [3 8 9]


n_splits=4时,输出如下:

[1 2 3 6 7 8 9] [0 4 5]
[0 2 3 4 5 8 9] [1 6 7]
[0 1 3 4 5 6 7 9] [2 8]
[0 1 2 4 5 6 7 8] [3 9]



当n小于2或者大于最少样本数(此处为4)时,会报错。


参考:

http://scikit-learn.org/stable/modules/cross_validation.html