Skip to content

【WIP】fleet support dygraph in mnist/resnet/transformer#4811

Open
danleifeng wants to merge 2 commits intoPaddlePaddle:release/1.8from
danleifeng:dygraph_fleet_dev
Open

【WIP】fleet support dygraph in mnist/resnet/transformer#4811
danleifeng wants to merge 2 commits intoPaddlePaddle:release/1.8from
danleifeng:dygraph_fleet_dev

Conversation

@danleifeng
Copy link

@danleifeng danleifeng commented Aug 21, 2020

main change:

from paddle.distributed import fleet

fleet.init(is_collective=True)
adam = fleet.distributed_optimizer(adam)
# call after distributed_optimizer so as to apply dist_strategy
mnist = fleet.build_distributed_model(mnist)

Sample code:

import paddle
import paddle.nn as nn
import paddle.optimizer as opt
from paddle.distributed import fleet

class LinearNet(nn.Layer):
    def __init__(self):
        super(LinearNet, self).__init__()
        self._linear1 = nn.Linear(10, 10)
        self._linear2 = nn.Linear(10, 1)
        
    def forward(self, x):
        return self._linear2(self._linear1(x))

def train():
    # 1. enable dynamic mode
    paddle.disable_static()
    
    # 2. create layer & optimizer
    layer = LinearNet()
    loss_fn = nn.MSELoss()
    adam = opt.Adam(
        learning_rate=0.001, parameters=layer.parameters())

    # 3. get data_parallel model using fleet
    fleet.init(is_collective=True)
    adam = fleet.distributed_optimizer(adam)
    # call after distributed_optimizer so as to apply dist_strategy
    dp_layer = fleet.build_distributed_model(layer)
    
    # 4. run layer
    inputs = paddle.randn([10, 10], 'float32')
    outputs = dp_layer(inputs)
    labels = paddle.randn([10, 1], 'float32')
    loss = loss_fn(outputs, labels)
    
    loss = dp_layer.scale_loss(loss)
    loss.backward()
    dp_layer.apply_collective_grads()

    adam.step()
    adam.clear_grad()

if __name__ == '__main__':
   train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant