You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/operators/get_places_op.cc

80 lines
2.8 KiB

7 years ago
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
7 years ago
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/gpu_info.h"
#endif
7 years ago
namespace paddle {
namespace operators {
class GetPlacesOp : public framework::OperatorBase {
public:
GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
7 years ago
std::string device_type = Attr<std::string>("device_type");
auto device_count = Attr<int>("device_count");
7 years ago
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.resize(device_count);
7 years ago
if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount());
for (int i = 0; i < device_count; i++) {
7 years ago
places.emplace_back(platform::GPUPlace(i));
}
7 years ago
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (device_type == "CPU") {
for (int i = 0; i < device_count; i++) {
7 years ago
places.emplace_back(platform::CPUPlace());
}
}
}
};
class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
GetPlacesOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "(int)device count").SetDefault(1);
7 years ago
AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"")
.InEnum({"CPU", "CUDA"});
7 years ago
AddComment(R"DOC(
7 years ago
Returns a list of places based on flags. The list will be used for parallel execution.
7 years ago
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker);