|
|
|
@ -38,13 +38,11 @@ struct ScaleLossGradFunctor {
|
|
|
|
|
float coeff_;
|
|
|
|
|
Tensor *out_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
OpHandleBase *op_handle_;
|
|
|
|
|
proto::VarType::Type out_dtype_;
|
|
|
|
|
platform::DeviceContext *ctx_;
|
|
|
|
|
|
|
|
|
|
ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place,
|
|
|
|
|
OpHandleBase *op_handle, proto::VarType::Type dtype,
|
|
|
|
|
platform::DeviceContext *ctx)
|
|
|
|
|
proto::VarType::Type dtype, platform::DeviceContext *ctx)
|
|
|
|
|
: coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {}
|
|
|
|
|
|
|
|
|
|
template <typename OutT>
|
|
|
|
@ -76,11 +74,11 @@ void ScaleLossGradOpHandle::RunImpl() {
|
|
|
|
|
tensor->Resize(make_ddim({1}));
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_,
|
|
|
|
|
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_,
|
|
|
|
|
this->dev_ctxes_.at(place_));
|
|
|
|
|
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
|
|
|
|
|
#else
|
|
|
|
|
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr);
|
|
|
|
|
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, nullptr);
|
|
|
|
|
framework::VisitDataType(out_dtype_, func);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|