I want to divide data by taking into account the number of classes in Scikit-learn

Asked 1 years ago, Updated 1 years ago, 76 views

Using the scikit-learn function train_test_split, as shown below,

dataset_train,dataset_test=train_test_split(dataset,train_size=0.8)

They will divide the data set into training and test data, but
If there are many classes (for example, 100 classes),
Training and test data may have different classes.
For example, the number of classes included in the training data is 100, while
The test data can be 98.

Train_test_split is just randomly shuffling and splitting, so
This is likely to happen if the number of data in the class is unbalanced.

To split data to keep the number of classes intact, use the
What should I do?

Thank you for your cooperation.

python scikit-learn

2022-09-30 11:10

2 Answers

Use StratifiedShuffleSplit to keep class proportional and split.


2022-09-30 11:10

Just to guess, is this what you mean?

dataset_train,dataset_test=train_test_split(dataset,stratify=dataset,train_size=0.8)


2022-09-30 11:10

If you have any answers or tips


© 2024 OneMinuteCode. All rights reserved.