|
|
@ -114,7 +114,12 @@ class ExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expand_times");
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expand_times");
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < expand_times.size(); ++i) {
|
|
|
|
size_t start_pos = 0u;
|
|
|
|
|
|
|
|
if (!ctx->IsRuntime()) {
|
|
|
|
|
|
|
|
start_pos = 1u;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = start_pos; i < expand_times.size(); ++i) {
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
|
|
|
|
"Each dimension size of Input(Out@GRAD) should be "
|
|
|
|
"Each dimension size of Input(Out@GRAD) should be "
|
|
|
|
"equal to multiplication of crroresponding dimension "
|
|
|
|
"equal to multiplication of crroresponding dimension "
|
|
|
|