|
|
|
@ -34,15 +34,18 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
|
|
|
|
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
|
|
|
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
|
|
|
|
[](int a) { return static_cast<int64_t>(a); });
|
|
|
|
|
auto dims = framework::make_ddim(shape_int64);
|
|
|
|
|
auto output_dim = framework::make_ddim(shape_int64);
|
|
|
|
|
|
|
|
|
|
int dim_idx = ctx->Attrs().Get<int>("dim_idx");
|
|
|
|
|
PADDLE_ENFORCE_GE(dim_idx, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), dim_idx);
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), dim_idx);
|
|
|
|
|
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
|
|
|
|
|
PADDLE_ENFORCE_GE(input_dim_idx, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
|
|
|
|
|
|
|
|
|
|
dims[dim_idx] = ctx->GetInputDim("Input")[dim_idx];
|
|
|
|
|
ctx->SetOutputDim("Out", dims);
|
|
|
|
|
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
|
|
|
|
|
PADDLE_ENFORCE_GE(output_dim_idx, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
|
|
|
|
|
|
|
|
|
|
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
|
|
|
|
|
ctx->SetOutputDim("Out", output_dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -69,8 +72,11 @@ class FillConstantBatchSizeLikeOpMaker
|
|
|
|
|
"(Tensor) Tensor of specified shape will be filled "
|
|
|
|
|
"with the specified value");
|
|
|
|
|
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
|
|
|
|
|
AddAttr<int>("dim_idx",
|
|
|
|
|
"(int, default 0) The index of batch size dimension")
|
|
|
|
|
AddAttr<int>("input_dim_idx",
|
|
|
|
|
"(int, default 0) the index of input's batch size dimension")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<int>("output_dim_idx",
|
|
|
|
|
"(int, default 0) the index of output's batch size dimension")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<float>("value", "(float, default 0) The value to be filled")
|
|
|
|
|
.SetDefault(0.0f);
|
|
|
|
@ -86,9 +92,10 @@ Fill up a variable with specified constant value.
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(fill_constant_batch_size_like,
|
|
|
|
|
ops::FillConstantBatchSizeLikeOp,
|
|
|
|
|
ops::FillConstantBatchSizeLikeOpMaker);
|
|
|
|
|
REGISTER_OPERATOR(fill_constant_batch_size_like,
|
|
|
|
|
ops::FillConstantBatchSizeLikeOp,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::FillConstantBatchSizeLikeOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
fill_constant_batch_size_like,
|
|
|
|
|
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|