|
|
|
@ -53,7 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
void Make() final {
|
|
|
|
|
AddInput("Input",
|
|
|
|
|
"(Tensor) Tensor "
|
|
|
|
|
"whose input_dim_idx'th dimension specifies the batch_size");
|
|
|
|
@ -67,7 +67,11 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<int>("output_dim_idx",
|
|
|
|
|
"(int, default 0) The index of output's batch size dimension")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
Apply();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual void Apply() = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|