|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/framework/data_type.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -233,6 +234,52 @@ void gemv<platform::CPUPlace, double>(const platform::DeviceContext& context,
|
|
|
|
|
|
|
|
|
|
template struct SetConstant<platform::CPUPlace, float>;
|
|
|
|
|
|
|
|
|
|
struct TensorSetConstant {
|
|
|
|
|
TensorSetConstant(framework::Tensor* tensor, float value)
|
|
|
|
|
: tensor_(tensor), value_(value) {}
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
|
auto cpu = platform::CPUPlace();
|
|
|
|
|
auto* begin = tensor_->mutable_data<T>(cpu);
|
|
|
|
|
std::fill(begin, begin + tensor_->numel(), static_cast<T>(value_));
|
|
|
|
|
}
|
|
|
|
|
framework::Tensor* tensor_;
|
|
|
|
|
float value_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void set_constant_with_place<platform::CPUPlace>(
|
|
|
|
|
const platform::DeviceContext& context, framework::Tensor* tensor,
|
|
|
|
|
float value) {
|
|
|
|
|
framework::VisitDataType(framework::ToDataType(tensor->type()),
|
|
|
|
|
TensorSetConstant(tensor, value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
|
|
|
|
|
TensorSetConstantWithPlace(const platform::DeviceContext& context,
|
|
|
|
|
framework::Tensor* tensor, float value)
|
|
|
|
|
: context_(context), tensor_(tensor), value_(value) {}
|
|
|
|
|
|
|
|
|
|
template <typename Place>
|
|
|
|
|
void operator()(Place place) const {
|
|
|
|
|
set_constant_with_place<Place>(context_, tensor_, value_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const platform::DeviceContext& context_;
|
|
|
|
|
framework::Tensor* tensor_;
|
|
|
|
|
float value_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void set_constant(const platform::DeviceContext& context,
|
|
|
|
|
framework::Tensor* tensor, float value) {
|
|
|
|
|
TensorSetConstantWithPlace func(context, tensor, value);
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
tensor->place().apply_visitor(func);
|
|
|
|
|
#else
|
|
|
|
|
func(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|