|
|
@ -127,12 +127,6 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
"shape", "(std::vector<int>) Target shape of reshape operator.");
|
|
|
|
"shape", "(std::vector<int>) Target shape of reshape operator.");
|
|
|
|
AddAttr<bool>("inplace",
|
|
|
|
|
|
|
|
"(default: false) Change the source tensor's shape without "
|
|
|
|
|
|
|
|
"memory copy. When Attr(inplace) is set true, the output "
|
|
|
|
|
|
|
|
"tensor shares memory with Input(X), otherwise, a new output "
|
|
|
|
|
|
|
|
"tensor is created, and its data are copied from Input(x).")
|
|
|
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
Reshape Operator.
|
|
|
|
Reshape Operator.
|
|
|
|
|
|
|
|
|
|
|
@ -233,16 +227,9 @@ class ReshapeKernel {
|
|
|
|
"sequence_reshape op.");
|
|
|
|
"sequence_reshape op.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool inplace = ctx.Attr<bool>("inplace");
|
|
|
|
out->mutable_data(ctx.GetPlace(), in->type());
|
|
|
|
|
|
|
|
framework::TensorCopySync(*in, ctx.GetPlace(), out);
|
|
|
|
out->Resize(out_dims);
|
|
|
|
out->Resize(out_dims);
|
|
|
|
if (!inplace) {
|
|
|
|
|
|
|
|
out->mutable_data(ctx.GetPlace(), in->type());
|
|
|
|
|
|
|
|
framework::TensorCopySync(*in, ctx.GetPlace(), out);
|
|
|
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
out->ShareDataWith(*in);
|
|
|
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -251,19 +238,11 @@ class ReshapeGradKernel {
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
|
|
|
|
|
|
|
d_x->mutable_data(ctx.GetPlace(), d_out->type());
|
|
|
|
d_x->mutable_data(ctx.GetPlace(), d_out->type());
|
|
|
|
bool inplace = ctx.Attr<bool>("inplace");
|
|
|
|
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
|
|
|
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
|
|
|
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
|
|
|
|
|
|
|
|
ctx.device_context().Wait();
|
|
|
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
d_x->ShareDataWith(*d_out);
|
|
|
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|