|
|
|
@ -19,16 +19,16 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class FillZerosLike : public framework::OperatorWithKernel {
|
|
|
|
|
class FillZerosLikeOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(
|
|
|
|
|
const std::vector<const framework::Tensor *> &inputs,
|
|
|
|
|
const std::vector<framework::Tensor *> &outputs) const override {
|
|
|
|
|
PADDLE_ENFORCE(inputs.size() == 1,
|
|
|
|
|
"Input size of FillZerosLike must be one.");
|
|
|
|
|
"Input size of FillZerosLikeOp must be one.");
|
|
|
|
|
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one.");
|
|
|
|
|
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
|
|
|
|
|
"Outputs of FillZerosLike must all be set.");
|
|
|
|
|
"Outputs of FillZerosLikeOp must all be set.");
|
|
|
|
|
outputs[0]->Resize(inputs[0]->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -44,7 +44,7 @@ public:
|
|
|
|
|
Fill up a vriable with zeros.
|
|
|
|
|
|
|
|
|
|
The output will have the same size with input.
|
|
|
|
|
)DOC")
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -53,6 +53,6 @@ The output will have the same size with input.
|
|
|
|
|
REGISTER_OP(fill_zeros_like,
|
|
|
|
|
paddle::operators::FillZerosLikeOp,
|
|
|
|
|
paddle::operators::FillZerosLikeOpMaker);
|
|
|
|
|
EGISTER_OP_CPU_KERNEL(
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fill_zeros_like,
|
|
|
|
|
paddle::operators::FillZerosLikeKernal<paddle::platform::CPUPlace, float>);
|
|
|
|
|
paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|