|
|
@ -54,9 +54,13 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
|
|
|
|
float average_window = ctx.Attr<float>("average_window");
|
|
|
|
float average_window = ctx.Attr<float>("average_window");
|
|
|
|
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
|
|
|
|
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
|
|
|
|
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
|
|
|
|
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
|
|
|
|
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
"min_average_window shouldn't be larger than "
|
|
|
|
min_average_window, max_average_window,
|
|
|
|
"max_average_window");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The min_average_window > "
|
|
|
|
|
|
|
|
"max_average_window is not right, min_average_window is %ld, "
|
|
|
|
|
|
|
|
"max_average_window is %ld.",
|
|
|
|
|
|
|
|
min_average_window, max_average_window));
|
|
|
|
|
|
|
|
|
|
|
|
// Get inputs
|
|
|
|
// Get inputs
|
|
|
|
auto* param = ctx.Input<Tensor>("param");
|
|
|
|
auto* param = ctx.Input<Tensor>("param");
|
|
|
|