"polish code based on comment"

fix-typo
dzhwinter 8 years ago
parent 6f009cf8ba
commit 71305e5f90

@ -290,12 +290,12 @@ class ExecutionContext {
return device_context_; return device_context_;
} }
//! Get variables vector with same input name. //! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name); return op_.Inputs(name);
} }
//! Get variables vector with same output name. //! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name); return op_.Outputs(name);
} }

@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope.", name); "Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus"); std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
platform::Communicator *comm = platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>(); scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus); comm->InitAll(gpus);

@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <functional> #include <functional>
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") { } else if (reduction == "ncclProd") {
reduction_op_ = ncclProd; reduction_op_ = ncclProd;
} else { } else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") { } else if (reduction == "ncclProd") {
reduction_op_ = ncclProd; reduction_op_ = ncclProd;
} else { } else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");

@ -12,8 +12,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
@ -193,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
} }
} }
// ncclAReduceOp with desc // ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) { TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 0; const int kRoot = 0;
@ -201,7 +199,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"}); op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"}); op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot}); op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes; std::vector<f::Scope *> dev_scopes;
@ -241,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
} }
} }
// // ncclBcastOp with desc // ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) { TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 5; const int kRoot = 5;
@ -249,7 +247,7 @@ TEST_F(NCCLTester, ncclBcastOp) {
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"}); op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"}); op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot}); op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes; std::vector<f::Scope *> dev_scopes;

Loading…
Cancel
Save