|
|
|
@ -15,7 +15,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -41,19 +40,33 @@ class FillConstantOp : public framework::OperatorBase {
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
|
|
|
|
|
auto value = Attr<float>("value");
|
|
|
|
|
auto force_cpu = Attr<bool>("force_cpu");
|
|
|
|
|
auto &out =
|
|
|
|
|
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
|
|
|
|
|
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
|
|
|
|
|
framework::Tensor *tensor = nullptr;
|
|
|
|
|
|
|
|
|
|
auto &out_var = *scope.FindVar(Output("Out"));
|
|
|
|
|
|
|
|
|
|
if (out_var.IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
} else if (out_var.IsType<framework::SelectedRows>()) {
|
|
|
|
|
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"fill constant op's output only"
|
|
|
|
|
"supports SelectedRows and LoDTensor");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (force_cpu) {
|
|
|
|
|
auto cpu = platform::CPUPlace();
|
|
|
|
|
out.mutable_data(cpu, framework::ToTypeIndex(data_type));
|
|
|
|
|
tensor->mutable_data(cpu, framework::ToTypeIndex(data_type));
|
|
|
|
|
} else {
|
|
|
|
|
out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
|
|
|
|
|
tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
|
math::set_constant(dev_ctx, &out, value);
|
|
|
|
|
math::set_constant(dev_ctx, tensor, value);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|