Add `get_repeated_stratified_kfold()` and `check_stratified_proportions()` to `feature_data` class
Closed this issue · 1 comments
kfold_repeated_stratified
is a user-function added to feature_data
class.
The check_stratified_proportions
function is effectively invoked if passing print_out
= True
to kfold_repeated_stratified
Use case:
>>> from research_tools import feature_data
>>> from research_tools import feature_groups
>>> base_dir_data = 'I:/Shared drives/NSF STTR Phase I – Potato Remote Sensing/Historical Data/Rosen Lab/Small Plot Data/Data'
>>> feat_data_cs = feature_data(base_dir_data)
>>> group_feats = feature_groups.cs_test2
>>> feat_data_cs.get_feat_group_X_y(group_feats)
>>> cv_rep_strat = feat_data_cs.kfold_repeated_stratified(print_out=True)
Output:
Number of splits: 4
Number of repetitions: 3
The number of observations in each cross-validation dataset are listed below.
The key represents the <stratify_train> ID, and the value represents the number of observations used from that stratify ID
Total number of observations: 386K-fold train set:
Number of observations: 290
{0: 14, 1: 14, 2: 14, 3: 14, 4: 14, 5: 15, 6: 14, 7: 15, 8: 14, 9: 16, 10: 16, 11: 17, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 14, 2: 14, 3: 13, 4: 14, 5: 14, 6: 14, 7: 14, 8: 15, 9: 17, 10: 16, 11: 17, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 15, 1: 14, 2: 14, 3: 13, 4: 15, 5: 14, 6: 14, 7: 14, 8: 14, 9: 17, 10: 17, 11: 16, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 15, 2: 15, 3: 14, 4: 14, 5: 14, 6: 15, 7: 14, 8: 14, 9: 16, 10: 17, 11: 16, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 14, 2: 14, 3: 14, 4: 14, 5: 15, 6: 14, 7: 15, 8: 14, 9: 16, 10: 16, 11: 17, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 14, 2: 14, 3: 13, 4: 14, 5: 14, 6: 14, 7: 14, 8: 15, 9: 17, 10: 16, 11: 17, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 15, 1: 14, 2: 14, 3: 13, 4: 15, 5: 14, 6: 14, 7: 14, 8: 14, 9: 17, 10: 17, 11: 16, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 15, 2: 15, 3: 14, 4: 14, 5: 14, 6: 15, 7: 14, 8: 14, 9: 16, 10: 17, 11: 16, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 14, 2: 14, 3: 14, 4: 14, 5: 15, 6: 14, 7: 15, 8: 14, 9: 16, 10: 16, 11: 17, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 14, 2: 14, 3: 13, 4: 14, 5: 14, 6: 14, 7: 14, 8: 15, 9: 17, 10: 16, 11: 17, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 15, 1: 14, 2: 14, 3: 13, 4: 15, 5: 14, 6: 14, 7: 14, 8: 14, 9: 17, 10: 17, 11: 16, 12: 17, 13: 24, 14: 24, 15: 24, 16: 24}
{0: 14, 1: 15, 2: 15, 3: 14, 4: 14, 5: 14, 6: 15, 7: 14, 8: 14, 9: 16, 10: 17, 11: 16, 12: 16, 13: 24, 14: 24, 15: 24, 16: 24}K-fold validation set:
Number of observations: 96
{0: 5, 1: 5, 2: 5, 3: 4, 4: 5, 5: 4, 6: 5, 7: 4, 8: 5, 9: 6, 10: 6, 11: 5, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 4, 9: 5, 10: 6, 11: 5, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 4, 1: 5, 2: 5, 3: 5, 4: 4, 5: 5, 6: 5, 7: 5, 8: 5, 9: 5, 10: 5, 11: 6, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 4, 2: 4, 3: 4, 4: 5, 5: 5, 6: 4, 7: 5, 8: 5, 9: 6, 10: 5, 11: 6, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 5, 2: 5, 3: 4, 4: 5, 5: 4, 6: 5, 7: 4, 8: 5, 9: 6, 10: 6, 11: 5, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 4, 9: 5, 10: 6, 11: 5, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 4, 1: 5, 2: 5, 3: 5, 4: 4, 5: 5, 6: 5, 7: 5, 8: 5, 9: 5, 10: 5, 11: 6, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 4, 2: 4, 3: 4, 4: 5, 5: 5, 6: 4, 7: 5, 8: 5, 9: 6, 10: 5, 11: 6, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 5, 2: 5, 3: 4, 4: 5, 5: 4, 6: 5, 7: 4, 8: 5, 9: 6, 10: 6, 11: 5, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 4, 9: 5, 10: 6, 11: 5, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 4, 1: 5, 2: 5, 3: 5, 4: 4, 5: 5, 6: 5, 7: 5, 8: 5, 9: 5, 10: 5, 11: 6, 12: 5, 13: 8, 14: 8, 15: 8, 16: 8}
{0: 5, 1: 4, 2: 4, 3: 4, 4: 5, 5: 5, 6: 4, 7: 5, 8: 5, 9: 6, 10: 5, 11: 6, 12: 6, 13: 8, 14: 8, 15: 8, 16: 8}