"fix create output variable bug"

fix-typo
Dong Zhihong 8 years ago
parent 61c1b0469a
commit 4e165f4ea3

@ -114,6 +114,9 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != -1, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");

@ -54,12 +54,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << " invoke allreduce. send " << ins[i]->numel() << " recv "
VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
@ -68,7 +68,8 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << " finished allreduce. send " << ins[i]->numel() << " recv "
VLOG(1) << "gpu : "
<< " finished allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
}
}
@ -91,9 +92,8 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
@ -102,20 +102,20 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
root = hasher(ins_names[i]) % comm->comms_.size();
}
T* recvbuffer = nullptr;
if (root == device_id) {
if (root == gpu_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
VLOG(1) << " invoke reduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, ncclSum, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << " finished reduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
}
}
};
@ -135,33 +135,37 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
if (idx == root) {
auto ins = ctx.MultiInput<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << " invoke Bcast. send " << ins[i]->numel();
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
<< ins[i]->numel();
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << " finished Bcast.";
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
}
} else {
auto outs = ctx.MultiOutput<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
VLOG(1) << " invoke Bcast. recv. ";
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
<< framework::product(outs[i]->dims());
PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << " finished Bcast. recv " << outs[i]->numel();
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
<< outs[i]->numel();
}
}
}

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save