Make fuse_all_reduce_op_pass support mix_precision (#17652)

sum_op
chengduo 6 years ago committed by gongweibao
parent 55baeceddb
commit 7453857324

@ -58,15 +58,15 @@ constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
typedef std::string FusedOptType;
constexpr char kFusedOptType[] = "fused_opt_type";
typedef std::string FusedGrads;
typedef std::vector<std::string> FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupGradsAndParams;
constexpr char kGroupGradsAndParams[] = "group_grads_params";
GroupParamsAndGrads;
constexpr char kGroupParamsAndGrads[] = "group_params_grads";
} // namespace details
} // namespace framework

@ -101,10 +101,17 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
"this pass.");
}
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
PADDLE_ENFORCE_NE(fused_grad.size(), 0,
"The fused gradient should not be empty.");
PADDLE_ENFORCE_EQ(fused_grad.size(), 1,
"Because the dtype of those gradients "
"is not unified, so the number of fused gradients is "
"more than one, but it is not supported currently.");
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
auto iter =
std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front());
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad;
fused_vars_name[kGrad] = fused_grad.front();
// Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.

@ -30,7 +30,6 @@ class FuseAllReduceOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
@ -38,38 +37,17 @@ class FuseAllReduceOpPass : public ir::Pass {
&Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
#endif
std::unordered_set<std::string> grads;
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size();
std::unordered_set<std::string> grads;
grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) {
grads.insert(p_g.second);
}
size_t num_place = places.size();
std::unordered_map<std::string, ir::Node *> all_reduce_ops;
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
for (size_t i = 1; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
"The input name should be the same.");
}
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
all_reduce_ops.emplace(grad_name, node);
}
}
}
std::unordered_map<std::string, Node *> all_reduce_ops =
GetAllReduceOps(result, places, grads);
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
@ -82,16 +60,16 @@ class FuseAllReduceOpPass : public ir::Pass {
"it is not supported currently.");
VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params =
graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
auto &group_params_grads =
graph->Get<details::GroupParamsAndGrads>(details::kGroupParamsAndGrads);
for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size();
for (auto &group_p_g : group_params_grads) {
size_t group_size = group_p_g.size();
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
std::vector<ir::Node *> group_all_reduce_ops;
group_all_reduce_ops.reserve(group_size);
for (auto &g_p : group_g_p) {
group_all_reduce_ops.emplace_back(all_reduce_ops.at(g_p.first));
for (auto &p_g : group_p_g) {
group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second));
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
InsertFusedAllReduce(places, local_scopes, group_size,
@ -103,6 +81,35 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
std::unordered_map<std::string, Node *> GetAllReduceOps(
const Graph &result, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &grads) const {
size_t num_place = places.size();
std::unordered_map<std::string, Node *> all_reduce_ops;
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
for (size_t i = 1; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name,
"The input name should be the same.");
}
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0));
all_reduce_ops.emplace(grad_name, node);
}
}
}
return all_reduce_ops;
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const size_t num_of_all_reduce,

@ -227,8 +227,11 @@ REGISTER_OPERATOR(alloc_continuous_space,
paddle::operators::AllocContinuousSpaceOp,
paddle::operators::AllocContinuousSpaceOpMaker);
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
alloc_continuous_space,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
plat::float16>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, int>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext, float>,
ops::AllocContinuousSpaceKernel<paddle::platform::CPUDeviceContext,
@ -237,6 +240,8 @@ REGISTER_OP_CPU_KERNEL(
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
alloc_continuous_space,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, int>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext, float>,
ops::AllocContinuousSpaceKernel<paddle::platform::CUDADeviceContext,

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include <string>
namespace paddle {
namespace operators {
@ -46,6 +46,17 @@ class SGDOp : public framework::OperatorWithKernel {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "LearningRate") {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class SGDOpInferVarType : public framework::VarTypeInference {

@ -46,7 +46,7 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(
tensor_out_ptr + index,
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
-static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
}
}
}
@ -122,5 +122,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
ops::SGDOpCUDAKernel<double>);
ops::SGDOpCUDAKernel<double>,
ops::SGDOpCUDAKernel<plat::float16>);

@ -0,0 +1,91 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid.core as core
import math
import os
import sys
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from simple_nets import init_data
from parallel_executor_test_base import TestParallelExecutorBase
batch_size = 12
img_shape = [1, 28, 28]
def loss_net(hidden, label):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
def conv_net(use_feed):
img = fluid.layers.data(name='image', shape=img_shape, dtype='float16')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_1 = fluid.layers.cast(conv_pool_1, np.float32)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
hidden = fluid.layers.cast(conv_pool_2, np.float32)
return loss_net(hidden, label)
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
return optimizer
class TestResnet(TestParallelExecutorBase):
def check_model(self, use_cuda):
img, label = init_data(
batch_size=batch_size, img_shape=img_shape, label_range=9)
img = np.float16(img).view(np.uint16)
feed_dict = {"image": img, "label": label}
TestParallelExecutorBase.check_network_convergence(
conv_net,
feed_dict=feed_dict,
iter=10,
use_cuda=use_cuda,
fuse_all_reduce_ops=True,
optimizer=_optimizer)
def test_model(self):
if core.is_compiled_with_cuda():
self.check_model(True)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save