В качестве данных был взят датасет iris из библиотеки scikit-learn, в обучении и валидации выборки формируются следующим образом:
X, y = datasets.load_iris(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
X_test, X_val, y_test, y_val = train_test_split(
X, y, test_size=0.5, random_state=42
)
Фиксируются random_state, что позволяет гарантировать, что в скриптах train_model.py и validate_model.py сплиты данных идентичны.
Пайплайн обучения по шагам:
- Ждем пока появится файл кофига в директории
configs
(один конфиг - один ран - одна модель) - Запускаем скрипт обучения
- Для каждого конфига выполняем
- Считываем параметры модели и метрики
- Строим и обучаем модель
- Вычисляем метрики на X_val сплите
- Логируем модель, параметры и значения метрик в MLFlow
- Для каждого конфига выполняем
Airflow model training pipline
Эксперимент MLFlow
Проблемы с которыми столкнулся на этом этапе:
- MLFlow server не мог приконектиться к S3 из-за чего не показывал артифакты эксперимента. Через python API также не получалось их забрать. Решилось через установку boto3 в контейнере с MLFlow server.
Алгоритм валидации (запуск по таймеру):
- Для каждего конфига в папке
configs
- Считываем название модели, параметры эксперимента, метрики
- Находим эксперимент по имени (полагаем что все модели тестируются в рамках решения одной задачи)
- Получаем все раны по experiment-id
- Среди них находим ран с параметрами как в конфиге
- Получаем артифакт
- Вычисляем метрики на X_test сплите и логируем их в тот же ран с постфиксом _test
- Находим в списке ранов лучший по таргет метрики с постфиксом _test (если его нет, то выбираем текущий ран)
- Модель соответсвующую выбранному рану продвигаем в stage: Production
Пайплайн валидации Airflow
Staging MLFlow
Проблемы с которыми столкнулся на этом этапе:
- При попыте через mlflow python API забрать артифакт падало с ошибкой bad credentials (как решил не помню)
- Параметр
artifact_path
. Еслиartifact_path = ""
, то артифакты рана логируются в{run_id}\artifacts
. Если указан, то в{run_id}\artifacts
. Из-за этого проблемы с получением моделей