supports collective communicated training (#18175)

* fix prepare context redundant code problem, optimize executor by caching create_varaiables
test=develop

* supports collective training in executor

* make fetch_list runable with variables, add more unittest for use_program_cache
test=develop

* fix comment
test=develop

* use unique name for nccl_id

* supports output to stream in program_to_code

* insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code

* set op role in collective training

* add collective op role

* remove orig file

* add build optimizer by strategy

* add collective strategy

* refine collective strategy

* add multi-process role maker

* refine strategy building factory so that we can easily plugin more strategy

* scale loss grad in collective sgd transpiler

* add support for distributed fc

* code format

* revert some features for dist fc

* add support for distributed fc training

* fix prepare context redundant code problem, optimize executor by caching create_varaiables
test=develop

* supports collective training in executor

* make fetch_list runable with variables, add more unittest for use_program_cache
test=develop

* use unique name for nccl_id

* supports output to stream in program_to_code

* insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code

* set op role in collective training

* add collective op role

* fix comment
test=develop

* remove orig file

* add build optimizer by strategy

* add collective strategy

* refine collective strategy

* add multi-process role maker

* refine strategy building factory so that we can easily plugin more strategy

* scale loss grad in collective sgd transpiler

* add support for distributed fc

* code format

* revert some features for dist fc

* add support for distributed fc training

* test=develop
add collective op unittest standard

* test=develop
remove the test_collective directory

* test=develop
remove the test_collective directory

* remove slicegather test

* code format for reducescatter

* update attr of shard_index_op

* Modify macro nccl_helper

* remove test without distribute

* macro collective_helper

* marcro update

* test=develop
update support python3.5

* test=develop change gpu memory use to 0.1 when test

* test=develop
update ut equal func

* test=develop
set flags to 1.5

* test=develop fix pickle dumple  py35

* test=develop
fix divide in slice and add sync_comm_stream
update atol and rtol to 1e-05
rm shard_index op and test
modify read input from file to read from memory
remove origin_program in framework and add i/o in c_sync_calc_stream

* test=develop update unittest sync operator I/O
nan-debug-tool
HaoRen 6 years ago committed by Yi Liu
parent 9252e8fa08
commit b7128bac5f

@ -179,7 +179,7 @@ if(WITH_DISTRIBUTE)
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
@ -73,6 +74,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kCollective),
static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),

@ -34,6 +34,9 @@ enum class OpRole {
kDist = 0x0008,
// Tag all learning rate scheduler operators.
kLRSched = 0x0010,
// Collective role is for all collective operators and other operators used
// for collective training
kCollective = 0x0020,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and

@ -21,6 +21,7 @@ add_subdirectory(jit)
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
add_subdirectory(distributed_ops)
add_subdirectory(collective)
endif()
add_subdirectory(reader)

@ -0,0 +1,39 @@
include(operators)
set(COLLECTIVE_DEPS "")
if(WITH_GRPC)
set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} ibverbs rdmacm)
endif()
endif()
set(COLLECTIVE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)
foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_GPU AND NOT WIN32)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common)
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")

@ -0,0 +1,89 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include <future> // NOLINT
#include <memory>
#include <ostream>
namespace paddle {
namespace operators {
class CAllGatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SyncFCGather op should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks");
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2");
framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks;
ctx->SetOutputDim("Out", dim);
}
};
class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be allgather");
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job");
AddComment(R"DOC(
***CAllGather Operator***
each rank receives the aggregation of data from all ranks in the order of the ranks
Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather
)DOC");
}
};
class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("c_reducescatter");
retv->SetInput("X", OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker,
ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CPUDeviceContext, float>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, double>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, plat::float16>);

@ -0,0 +1,25 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CUDADeviceContext, float>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, double>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, plat::float16>);

@ -0,0 +1,75 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CAllGatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllGatherOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,83 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
class CAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the allreduced result.");
AddAttr<int>("reduce_type", "(int default 0) determin the reduce type.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CAllReduce Operator***
Call NCCL collective AllReduce internally. Note that this op must be used when one
thread is managing one GPU device.
For speed reasons, reduce_type should be an integer:
0: sum
1: prod
2: max
3: min
If input and output are the same variable, in-place allreduce will be used.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce, ops::CAllReduceOp,
ops::CAllReduceOpMaker);
REGISTER_OP_CPU_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CPUDeviceContext, float>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, double>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, plat::float16>);

@ -0,0 +1,25 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);

@ -0,0 +1,86 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CAllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, red_type, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,74 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
namespace paddle {
namespace operators {
class CBroadcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be broadcasted.");
AddOutput("Out", "(Tensor) the result of broadcast.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("root", "(int default 0) root id for broadcasting.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CBroadcast Operator***
Call ncclBcast internally.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp,
ops::CBroadcastOpMaker);
REGISTER_OP_CPU_KERNEL(
c_broadcast, ops::CBroadcastOpKernel<plat::CPUDeviceContext, float>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, double>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, plat::float16>);

@ -0,0 +1,25 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_broadcast, ops::CBroadcastOpKernel<plat::CUDADeviceContext, float>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, double>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, plat::float16>);

@ -0,0 +1,92 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CBroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CBroadcastOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
int nranks = comm->nranks();
PADDLE_ENFORCE(root >= 0 && root < nranks,
"Expected root in range of [0,%d),but get %d", nranks, root);
if (root == comm->rank()) {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
// TODO(liuyi05): check inplace
framework::TensorCopy(
*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel, dtype, root,
comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,86 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CCommInitOp : public framework::OperatorBase {
public:
CCommInitOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"CCommInitOp can run on gpu place only.");
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(var);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, boost::get<platform::CUDAPlace>(place).device,
rid);
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CCommInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CCommInit operator
Initialize collective communicatoin context within this trainer
)DOC");
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training.");
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init, ops::CCommInitOp, ops::CCommInitOpMaker);

@ -0,0 +1,146 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CGenNCCLIdOp : public framework::OperatorBase {
public:
CGenNCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
if (rank == 0) {
GenerateAndSend(&local_scope, dev_ctx);
} else {
GetIdByServer(&local_scope, dev_ctx);
}
scope.DeleteScope(&local_scope);
}
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string var_name = Output("Out");
auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, var_name);
}
client->Wait();
for (auto& ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
}
};
class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_nccl_id, ops::CGenNCCLIdOp, ops::CGenNCCLIdOpMaker);

@ -0,0 +1,93 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include <future> // NOLINT
#include <memory>
#include <ostream>
namespace paddle {
namespace operators {
class CReduceScatterOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks");
framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE(dim[0] % nranks == 0,
"dim[0] (%d) is not divisible by nranks(%d)", dim[0],
nranks);
dim[0] /= nranks;
}
ctx->SetOutputDim("Out", dim);
}
};
class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be allgather");
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job")
.SetDefault(1);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CReduceScatter Operator***
Call NCCL collective ReduceScatter internally.
)DOC");
}
};
class CReduceScatterOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("c_allgather");
retv->SetInput("X", OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp,
ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL(
c_reducescatter, ops::CReduceScatterOpKernel<plat::CPUDeviceContext, float>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, double>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, plat::float16>);

@ -0,0 +1,26 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_reducescatter,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, float>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, double>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, plat::float16>);

@ -0,0 +1,75 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CReduceScatterOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
auto out_dims = in->dims();
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks;
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
int dtype = platform::ToNCCLDataType(in->type());
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclReduceScatter(
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,76 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace operators {
class CSyncCalcStreamOp : public framework::OperatorBase {
public:
CSyncCalcStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
cudaError_t e_sync = cudaStreamSynchronize(dev_ctx->stream());
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync cuda stream: " << cudaGetErrorString(e_sync);
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync");
AddComment(R"DOC(
***Sync Operator***
Call cuda stream synchronize.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);

@ -0,0 +1,77 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CSyncCommStreamOp : public framework::OperatorBase {
public:
CSyncCommStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ring_id = Attr<int>("ring_id");
auto stream = platform::NCCLCommContext::Instance().Get(ring_id)->stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync nccl stream: " << cudaGetErrorString(e_sync);
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync");
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC(
***Sync Operator***
Call nccl stream synchronize.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);

@ -21,7 +21,6 @@ endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)

@ -86,6 +86,11 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::VarType::FP32);
AddAttr<bool>("allow_out_of_range",
"If it is set true and the input data is out of range, "
"the output tensor will be filled zeros. The default value "
"is false.")
.SetDefault(false);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this

@ -24,7 +24,7 @@ template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
const int64_t numel, const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}

@ -25,10 +25,16 @@ struct OneHotOpFunctor {
framework::LoDTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;
OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
int depth, const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
int depth, const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
template <typename OutT>
void apply() const {
@ -37,13 +43,21 @@ struct OneHotOpFunctor {
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
math::set_constant(ctx_, out_, 0.0);
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(p_in_data[i], 0,
"Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(p_in_data[i], depth_,
"Illegal index value, should be less than depth (%d).",
depth_);
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) {
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(p_in_data[i], 0,
"Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(
p_in_data[i], depth_,
"Illegal index value, should be less than depth (%d).", depth_);
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
}
};
@ -57,6 +71,7 @@ class OneHotKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
bool allow_out_of_range = context.Attr<bool>("allow_out_of_range");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
@ -71,7 +86,8 @@ class OneHotKernel : public framework::OpKernel<T> {
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
in, out, depth, context.template device_context<DeviceContext>(),
allow_out_of_range));
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save