small AverageOptimizer enhance. (#11761)

* small AverageOptimizer enhance.

* clean

* clean
ce-debug
Xin Pan 7 years ago committed by whs
parent 19e877ffdb
commit 2ecc56226d

@ -19,28 +19,28 @@ namespace operators {
template <>
void GetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t* num_updates_,
int64_t* num_accumulates_, int64_t* old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t* num_updates,
int64_t* num_accumulates, int64_t* old_num_accumulates) {
auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates");
auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates");
auto* in_num_updates = ctx.Input<Tensor>("in_num_updates");
*old_num_accumulates_ = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates_ = in_num_accumulates->data<int64_t>()[0];
*num_updates_ = in_num_updates->data<int64_t>()[0];
*old_num_accumulates = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates = in_num_accumulates->data<int64_t>()[0];
*num_updates = in_num_updates->data<int64_t>()[0];
}
template <>
void SetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t num_updates_,
int64_t num_accumulates_, int64_t old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t num_updates,
int64_t num_accumulates, int64_t old_num_accumulates) {
auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates");
auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates");
auto* out_num_updates = ctx.Output<Tensor>("out_num_updates");
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates_;
out_num_accumulates->data<int64_t>()[0] = num_accumulates_;
out_num_updates->data<int64_t>()[0] = num_updates_;
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates;
out_num_accumulates->data<int64_t>()[0] = num_accumulates;
out_num_updates->data<int64_t>()[0] = num_updates;
}
class AverageAccumulatesOp : public framework::OperatorWithKernel {
@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
AverageAccumulates Operator.
Accumulate the sum of parameter whtin sliding window. The size of sliding window is
Accumulate the sum of parameter within sliding window. The size of sliding window is
determined by 'average_window', 'max_average_window' and 'min_average_window'.
Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'.
'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'.

@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float average_window = ctx.Attr<float>("average_window");
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
min_average_window =
std::min<int64_t>(min_average_window, max_average_window);
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
"min_average_window shouldn't be larger than "
"max_average_window");
// Get inputs
auto* param = ctx.Input<Tensor>("param");

Loading…
Cancel
Save