/CIL-Project

ETHz CIL 2023 Collaborative Filtering

Primary LanguagePython

CIL-Project

ETHz CIL 2023 Collaborative Filtering

1. Environment setup

Create environment and install dependencies

conda create --name cil python=3.9
conda activate cil
pip install -r requirements.txt

Download dataset

Create a directory /data. Then put data_train.csv and sampleSubmission.csv inside.

2. Kaggle result reproduction

Run the following commands:

python cross_validation.py config_submission_cv1.yaml
python cross_validation.py config_submission_cv2.yaml
python train.py config_submission_ensemble.yaml

The result will be in directory /output/submission with name ensemble_GradientBoost_['bfm_op_rk32_iter1000_cv10', 'bfm_reg_rk16_iter1000_cv10']_10.csv.
As running cross validation can cost a long time, we provide 10-fold full prediction results for BFM+reg+rank16+iters1000 and BFM+oprobit+rank32+iters1000. You can find that in this link. Download them and move all txt files to the directory /output/data_ensemble. Then submission results can be generated by using command:

python train.py config_submission_ensemble.yaml

3. Details implementation

3.1 Train a single model

Check the settings in config.yaml:

experiment_args/model_name: Model name.
experiment_args/generate_submissions: True: Use the entire dataset to generate submissions; False: Split the dataset for validation.

Run the following command:

python train.py config.yaml

Please note that when training an NCF model, an unfortunate crash might occur randomly due to either a zero-shaped tensor or segmentation fault. If that happens, please simply rerun the training.

3.2 Apply cross validation to a single model

Check the settings in config_cv.yaml:

experiment_args/model_name: Model name.
experiment_args/generate_submissions: False.
experiment_args/save_full_pred: False.
ensemble_args/fold_number: Fold number

Run the following command for cross validation:

python cross_validation.py config_cv.yaml

3.3 Apply grid search to a single model

Check the settings in config_cv.yaml:

experiment_args/model_name: Model name.
experiment_args/generate_submissions: False.
experiment_args/save_full_pred: False.
ensemble_args/fold_number: Fold number

Check the settings in grid_search.py:

Modify the parameters in function grid_search.

Run the following command for grid search:

python grid_search.py config_cv.yaml

3.4 Ensemble

3.4.1 Save cross validation results

Check the settings in config_cv.yaml:

experiment_args/model_name: Model name.
experiment_args/model_instance_name: A prefix for saving prediction filenames.
experiment_args/save_full_pred: True: The prediction values of fold x will be saved in path ensemble_args/data_ensemble + experiment_args/model_instance_name + "_fold_{fold number}_train/test".txt. The train/test in the file name means the prediction results of ids provided in the data_train.csv and sampleSubmission respectively. ensemble_args/fold_number: Fold number

Run the following command for cross validation:

python cross_validation.py config_cv.yaml

3.4.2 Run ensemble

Check the settings in config.yaml:

experiment_args/model_name: "ensemble".
ensemble_args/fold_number: Fold number, should be the same as that used for cross validation results generation ensemble_args/regressor: "linear", "SGD", "BayesianRidge", "GradientBoost". Regressor type for blending. ensemble_args/models: List of model instances used for blending. The K-fold prediction results are save in format "[prefix]_fold_x_train/test.txt", enter prefix string here.

If you want to apply weighted sampling, in addition, please make sure:
  1. cv_args/weight_entries: "True"
  2. cv_args/sample_proportion: Proportion of training data sampled for each fold
  3. ensemble_args/fold_number (recommended) a large number of folds
  4. (recommended) if possible, set the model you are running into a simple structure. e.g., low-rank

Run the following command for ensemble:

python train.py config.yaml

Code References

  1. MyFM library: https://github.com/tohtsky/myFM
  2. Surprise library: https://surpriselib.com/
  3. Microsoft Recommenders library: https://github.com/microsoft/recommenders/