|
|
|
@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto param = ctx.Input<Tensor>("param");
|
|
|
|
|
auto grad = ctx.Input<Tensor>("grad");
|
|
|
|
|
auto param_out = ctx.Output<Tensor>(0);
|
|
|
|
|
auto param_out = ctx.Output<Tensor>("param_out");
|
|
|
|
|
float lr = ctx.op_.GetAttr<float>("learning_rate");
|
|
|
|
|
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|