|
|
|
@ -14,8 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
@ -75,13 +76,28 @@ class FillConstantKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto data_type =
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
|
|
|
|
|
auto value = ctx.Attr<float>("value");
|
|
|
|
|
auto str_value = ctx.Attr<std::string>("str_value");
|
|
|
|
|
auto float_value = ctx.Attr<float>("value");
|
|
|
|
|
auto force_cpu = ctx.Attr<bool>("force_cpu");
|
|
|
|
|
|
|
|
|
|
framework::Tensor *tensor = nullptr;
|
|
|
|
|
|
|
|
|
|
framework::Variable *out_var = ctx.OutputVar("Out");
|
|
|
|
|
|
|
|
|
|
T value;
|
|
|
|
|
if (str_value.empty()) {
|
|
|
|
|
value = static_cast<T>(float_value);
|
|
|
|
|
} else {
|
|
|
|
|
std::stringstream convert_stream(str_value);
|
|
|
|
|
if (std::is_same<int64_t, T>::value) {
|
|
|
|
|
int64_t tmp_value;
|
|
|
|
|
convert_stream >> tmp_value;
|
|
|
|
|
value = static_cast<T>(tmp_value);
|
|
|
|
|
} else {
|
|
|
|
|
double tmp_value;
|
|
|
|
|
convert_stream >> tmp_value;
|
|
|
|
|
value = static_cast<T>(tmp_value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto shape = GetShape(ctx);
|
|
|
|
|
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
@ -96,15 +112,23 @@ class FillConstantKernel : public framework::OpKernel<T> {
|
|
|
|
|
"supports SelectedRows and LoDTensor");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (force_cpu) {
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(ctx.GetPlace());
|
|
|
|
|
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
|
|
|
|
|
if (cpu_place) {
|
|
|
|
|
tensor->mutable_data(platform::CPUPlace(), data_type);
|
|
|
|
|
} else {
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> functor;
|
|
|
|
|
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
|
|
|
|
|
tensor, static_cast<T>(value));
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (!cpu_place) {
|
|
|
|
|
tensor->mutable_data(ctx.GetPlace(), data_type);
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> functor;
|
|
|
|
|
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
|
|
|
|
|
tensor, static_cast<T>(value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(ctx.GetPlace());
|
|
|
|
|
math::set_constant(dev_ctx, tensor, value);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|