在实际训练模型时,我们需要将样本数据随机分割为训练集和测试集,其中训练集用来训练模型,测试集用来验证模型的实际效果。


下面介绍使用sklearn开发时如何划分数据集。


为了演示分割效果,我们需要生成一些数据集,生成数据集的代码如下:

import numpy as np

sample_number = 10
feature_number = 6

X = np.random.rand(sample_number, feature_number)
y = np.random.randint(0, 2, sample_number)

print(X)
print(y)


输出结果如下:

[[ 0.66130409  0.97676746  0.71803508  0.99730168  0.08229832  0.86526473]
 [ 0.30021108  0.06706006  0.98098814  0.68709696  0.48207104  0.71118232]
 [ 0.65834414  0.28948205  0.26636723  0.10397637  0.30376824  0.08447934]
 [ 0.61166203  0.76414791  0.53293265  0.72645187  0.83650127  0.80324535]
 [ 0.3168741   0.20964209  0.1050409   0.96667613  0.76779378  0.11599263]
 [ 0.72638776  0.26922676  0.34266498  0.24431192  0.79987585  0.9393169 ]
 [ 0.15487034  0.35483264  0.83128334  0.48502717  0.12895195  0.82433822]
 [ 0.72285669  0.54464814  0.75142994  0.96835823  0.87040141  0.97489327]
 [ 0.81949806  0.39009266  0.57812486  0.96307663  0.39236456  0.20011817]
 [ 0.93098734  0.0461967   0.27630984  0.34199953  0.21883858  0.07402595]]
[1 1 0 1 0 1 1 1 1 1]


下面介绍如何分割数据集,分割数据集的代码如下,其中random_state是随机种子。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42)

print(X_train)
print(X_test)
print(y_train)
print(y_test)


输出结果如下:

[[ 0.66130409  0.97676746  0.71803508  0.99730168  0.08229832  0.86526473]
 [ 0.72285669  0.54464814  0.75142994  0.96835823  0.87040141  0.97489327]
 [ 0.65834414  0.28948205  0.26636723  0.10397637  0.30376824  0.08447934]
 [ 0.93098734  0.0461967   0.27630984  0.34199953  0.21883858  0.07402595]
 [ 0.3168741   0.20964209  0.1050409   0.96667613  0.76779378  0.11599263]
 [ 0.61166203  0.76414791  0.53293265  0.72645187  0.83650127  0.80324535]
 [ 0.15487034  0.35483264  0.83128334  0.48502717  0.12895195  0.82433822]]
[[ 0.81949806  0.39009266  0.57812486  0.96307663  0.39236456  0.20011817]
 [ 0.30021108  0.06706006  0.98098814  0.68709696  0.48207104  0.71118232]
 [ 0.72638776  0.26922676  0.34266498  0.24431192  0.79987585  0.9393169 ]]
[1 1 0 1 0 1 1]
[1 1 1]