/FleetX

Paddle Distributed Training Extended

Primary LanguageCApache License 2.0Apache-2.0

FleetX

FleetX is an extension package for Paddle's High-Level Distributed Training API paddle.distributed.fleet. As cloud service grows rapidly, distributed training of deep learning model will be a user-facing approach for daily applications and research. FleetX aims to help Paddle users do distributed training on cloud like running on notebooks.

Main Features



Installation

pip install fleet-x

A Distributed Resnet50 Training Example

import os
import paddle
import paddle.distributed.fleet as fleet
import fleetx as X

# fleet-x
configs = X.parse_train_configs()
model = X.applications.Resnet50()
loader = model.load_imagenet_from_file("/pathto/imagenet/train.txt")

# paddle optimizer definition
optimizer = paddle.optimizer.Momentum(learning_rate=configs.lr, momentum=configs.momentum)

# paddle distributed training code here
fleet.init(is_collective=True)
optimizer = fleet.distributed_optimizer(optimizer)

optimizer.minimize(model.loss)

epoch = 10
for e in range(epoch):
    for data in loader():
        cost_val = exe.run(paddle.default_main_program(), feed=data, fetch_list=[model.loss.name])

How to launch your task

  • Multiple cards
fleetrun --gpus 0,1,2,3,4,5,6,7 resnet50_app.py
  • Multiple cards on Multiple Nodes
fleetrun --gpus 0,1,2,3,4,5,6,7 --endpoints="xx.xx.xx.xx:8585,yy.yy.yy.yy:9696" resnet50_app.py
  • Run on Baidu Cloud
fleetrun --conf config.yml resnet50_app.py

Multi-slot DNN CTR model

import os
import paddle
import paddle.distributed.fleet as fleet
import fleetx as X

# fleet-x
configs = X.parse_train_configs()
model = X.applications.MultiSlotCTR()
loader = model.load_multislot_from_file("/pathto/imagenet/train.txt")

# paddle optimizer definition
optimizer = paddle.optimizer.SGD(learning_rate=configs.lr)

# paddle distributed training code here
fleet.init()
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(model.loss)

if fleet.is_server():
    fleet.init_server()
    fleet.run_server()
else:
    fleet.init_worker()
    exe = paddle.Executor(paddle.CPUPlace())
    exe.run(paddle.default_startup_program())
    epoch = 10
    for e in range(epoch):
        for data in loader():
            cost_val = exe.run(paddle.default_main_program(), feed=data, fetch_list=[model.loss.name])
    fleet.stop_worker()