test=develop

fix-readmd
sneaxiy 6 years ago
parent 6f748a035d
commit 84d9300365

@ -131,21 +131,21 @@ template <typename DeviceContext, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
auto *grad_var = ctx.InputVar("Grad");
auto *param_out = ctx.Output<Tensor>("ParamOut");
auto *moment_out = ctx.Output<Tensor>("MomentOut");
auto *mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto *param_out = ctx.Output<LoDTensor>("ParamOut");
auto *moment_out = ctx.Output<LoDTensor>("MomentOut");
auto *mean_square_out = ctx.Output<LoDTensor>("MeanSquareOut");
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
auto &p_tensor = *ctx.Input<Tensor>("Param");
auto &ms_tensor = *ctx.Input<Tensor>("MeanSquare");
auto &lr_tensor = *ctx.Input<Tensor>("LearningRate");
auto &mom_tensor = *ctx.Input<Tensor>("Moment");
auto &p_tensor = *ctx.Input<LoDTensor>("Param");
auto &ms_tensor = *ctx.Input<LoDTensor>("MeanSquare");
auto &lr_tensor = *ctx.Input<LoDTensor>("LearningRate");
auto &mom_tensor = *ctx.Input<LoDTensor>("Moment");
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
"Param and ParamOut must be the same Tensor");
@ -157,8 +157,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(ms_tensor.numel());
if (grad_var->IsType<Tensor>()) {
auto &grad_tensor = grad_var->Get<Tensor>();
if (grad_var->IsType<LoDTensor>()) {
auto &grad_tensor = grad_var->Get<LoDTensor>();
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
auto &place =
@ -176,9 +176,9 @@ class RmspropOpKernel : public framework::OpKernel<T> {
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
@ -196,8 +196,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
@ -241,8 +241,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
row_numel, row_count);
if (centered) {
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save