|
|
|
@ -30,7 +30,7 @@ class GetPlacesOp : public framework::OperatorBase {
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
|
std::string device_type = Attr<std::string>("device_type");
|
|
|
|
|
auto trainer_count = Attr<int>("trainer_count");
|
|
|
|
|
auto device_count = Attr<int>("device_count");
|
|
|
|
|
|
|
|
|
|
auto out_var_name = Output("Out");
|
|
|
|
|
auto *out_var = scope.FindVar(out_var_name);
|
|
|
|
@ -38,18 +38,18 @@ class GetPlacesOp : public framework::OperatorBase {
|
|
|
|
|
out_var_name);
|
|
|
|
|
|
|
|
|
|
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
|
|
|
|
|
places.resize(trainer_count);
|
|
|
|
|
places.resize(device_count);
|
|
|
|
|
if (device_type == "CUDA") {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_ENFORCE_LT(trainer_count, platform::GetCUDADeviceCount());
|
|
|
|
|
for (int i = 0; i < trainer_count; i++) {
|
|
|
|
|
PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount());
|
|
|
|
|
for (int i = 0; i < device_count; i++) {
|
|
|
|
|
places.emplace_back(platform::GPUPlace(i));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
|
|
|
|
|
#endif
|
|
|
|
|
} else if (device_type == "CPU") {
|
|
|
|
|
for (int i = 0; i < trainer_count; i++) {
|
|
|
|
|
for (int i = 0; i < device_count; i++) {
|
|
|
|
|
places.emplace_back(platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -62,7 +62,7 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("Out", "vector of Place");
|
|
|
|
|
AddAttr<int>("trainer_count", "(int)trainer count").SetDefault(1);
|
|
|
|
|
AddAttr<int>("device_count", "(int)device count").SetDefault(1);
|
|
|
|
|
AddAttr<std::string>("device_type",
|
|
|
|
|
"(string), deivce type can be \"CPU\" and \"CUDA\"")
|
|
|
|
|
.InEnum({"CPU", "CUDA"});
|
|
|
|
|