|  |  |  | @ -37,12 +37,38 @@ class UnsqueezeOp : public framework::OperatorWithKernel { | 
			
		
	
		
			
				
					|  |  |  |  |     PADDLE_ENFORCE_LE(x_dims.size(), 6, | 
			
		
	
		
			
				
					|  |  |  |  |                       "Invalid dimensions, the rank of Input(X) " | 
			
		
	
		
			
				
					|  |  |  |  |                       "should be in the range of [1, 6] (Eigen limit)"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto out_dims = GetOutputShape(axes, x_dims); | 
			
		
	
		
			
				
					|  |  |  |  |     ctx->SetOutputDim("Out", out_dims); | 
			
		
	
		
			
				
					|  |  |  |  |     if (x_dims[0] == out_dims[0]) { | 
			
		
	
		
			
				
					|  |  |  |  |       // Only pass LoD when the first dimension of output and Input(X)
 | 
			
		
	
		
			
				
					|  |  |  |  |       // are the same.
 | 
			
		
	
		
			
				
					|  |  |  |  |       ctx->ShareLoD("X", "Out"); | 
			
		
	
		
			
				
					|  |  |  |  |     if (!axes.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |       auto out_dims = GetOutputShape(axes, x_dims); | 
			
		
	
		
			
				
					|  |  |  |  |       ctx->SetOutputDim("Out", out_dims); | 
			
		
	
		
			
				
					|  |  |  |  |       if (x_dims[0] == out_dims[0]) { | 
			
		
	
		
			
				
					|  |  |  |  |         // Only pass LoD when the first dimension of output and Input(X)
 | 
			
		
	
		
			
				
					|  |  |  |  |         // are the same.
 | 
			
		
	
		
			
				
					|  |  |  |  |         ctx->ShareLoD("X", "Out"); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } else if (ctx->HasInputs("AxesTensorList")) { | 
			
		
	
		
			
				
					|  |  |  |  |       auto AxesTensorList = ctx->Inputs("AxesTensorList"); | 
			
		
	
		
			
				
					|  |  |  |  |       int output_size = x_dims.size() + static_cast<int>(AxesTensorList.size()); | 
			
		
	
		
			
				
					|  |  |  |  |       PADDLE_ENFORCE_LE(output_size, 6, | 
			
		
	
		
			
				
					|  |  |  |  |                         "The output tensor's rank should be less than 6."); | 
			
		
	
		
			
				
					|  |  |  |  |       std::vector<int> vec_out_dims(output_size, -1); | 
			
		
	
		
			
				
					|  |  |  |  |       ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); | 
			
		
	
		
			
				
					|  |  |  |  |     } else if (ctx->HasInput("AxesTensor")) { | 
			
		
	
		
			
				
					|  |  |  |  |       auto axes_dims = ctx->GetInputDim("AxesTensor"); | 
			
		
	
		
			
				
					|  |  |  |  |       PADDLE_ENFORCE_EQ( | 
			
		
	
		
			
				
					|  |  |  |  |           axes_dims.size(), 1, | 
			
		
	
		
			
				
					|  |  |  |  |           "Input(AxesTensor)'s dimension of Op(unsqueeze) must be 1. " | 
			
		
	
		
			
				
					|  |  |  |  |           "But received AxesTensor's shape = [%s], " | 
			
		
	
		
			
				
					|  |  |  |  |           "AxesTensor's dimension = %d.", | 
			
		
	
		
			
				
					|  |  |  |  |           axes_dims, axes_dims.size()); | 
			
		
	
		
			
				
					|  |  |  |  |       PADDLE_ENFORCE_GE(axes_dims[0], 0, | 
			
		
	
		
			
				
					|  |  |  |  |                         "Input(AxesTensor)'s shape must be known. But received " | 
			
		
	
		
			
				
					|  |  |  |  |                         "AxesTensor's shape = [%s]", | 
			
		
	
		
			
				
					|  |  |  |  |                         axes_dims); | 
			
		
	
		
			
				
					|  |  |  |  |       int output_size = x_dims.size() + static_cast<int>(axes_dims[0]); | 
			
		
	
		
			
				
					|  |  |  |  |       PADDLE_ENFORCE_LE(output_size, 6, | 
			
		
	
		
			
				
					|  |  |  |  |                         "The output tensor's rank should be less than 6."); | 
			
		
	
		
			
				
					|  |  |  |  |       std::vector<int> vec_out_dims(output_size, -1); | 
			
		
	
		
			
				
					|  |  |  |  |       ctx->SetOutputDim("Out", framework::make_ddim(vec_out_dims)); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -83,19 +109,46 @@ class UnsqueezeOp : public framework::OperatorWithKernel { | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     return framework::make_ddim(output_shape); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |  protected: | 
			
		
	
		
			
				
					|  |  |  |  |   framework::OpKernelType GetExpectedKernelType( | 
			
		
	
		
			
				
					|  |  |  |  |       const framework::ExecutionContext &ctx) const override { | 
			
		
	
		
			
				
					|  |  |  |  |     return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), | 
			
		
	
		
			
				
					|  |  |  |  |                                    ctx.device_context()); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   framework::OpKernelType GetKernelTypeForVar( | 
			
		
	
		
			
				
					|  |  |  |  |       const std::string &var_name, const framework::Tensor &tensor, | 
			
		
	
		
			
				
					|  |  |  |  |       const framework::OpKernelType &expected_kernel_type) const override { | 
			
		
	
		
			
				
					|  |  |  |  |     if (var_name == "AxesTensor" || var_name == "AxesTensorList") { | 
			
		
	
		
			
				
					|  |  |  |  |       return expected_kernel_type; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     return framework::OpKernelType(expected_kernel_type.data_type_, | 
			
		
	
		
			
				
					|  |  |  |  |                                    tensor.place(), tensor.layout()); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | }; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { | 
			
		
	
		
			
				
					|  |  |  |  |  public: | 
			
		
	
		
			
				
					|  |  |  |  |   void Make() override { | 
			
		
	
		
			
				
					|  |  |  |  |     AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); | 
			
		
	
		
			
				
					|  |  |  |  |     AddInput("AxesTensor", | 
			
		
	
		
			
				
					|  |  |  |  |              "(Tensor<int32>, optional). The dimensions to be inserted. " | 
			
		
	
		
			
				
					|  |  |  |  |              "If it exists, it will replace Attr(axes).") | 
			
		
	
		
			
				
					|  |  |  |  |         .AsDispensable(); | 
			
		
	
		
			
				
					|  |  |  |  |     AddInput( | 
			
		
	
		
			
				
					|  |  |  |  |         "AxesTensorList", | 
			
		
	
		
			
				
					|  |  |  |  |         "(vector<Tensor<int32>>, optional). The dimensions to be inserted. " | 
			
		
	
		
			
				
					|  |  |  |  |         "If it exists, it will replace Attr(axes)." | 
			
		
	
		
			
				
					|  |  |  |  |         "The shape of the element in vector must be [1].") | 
			
		
	
		
			
				
					|  |  |  |  |         .AsDuplicable() | 
			
		
	
		
			
				
					|  |  |  |  |         .AsDispensable(); | 
			
		
	
		
			
				
					|  |  |  |  |     AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); | 
			
		
	
		
			
				
					|  |  |  |  |     AddAttr<std::vector<int>>("axes", | 
			
		
	
		
			
				
					|  |  |  |  |                               "(std::vector<int>). List of integers," | 
			
		
	
		
			
				
					|  |  |  |  |                               " indicating the dimensions to be inserted") | 
			
		
	
		
			
				
					|  |  |  |  |         .SetDefault({}) | 
			
		
	
		
			
				
					|  |  |  |  |         .AddCustomChecker([](const std::vector<int> &axes) { | 
			
		
	
		
			
				
					|  |  |  |  |           PADDLE_ENFORCE_EQ(!axes.empty(), true, | 
			
		
	
		
			
				
					|  |  |  |  |                             "Invalid axes, The unsqueeze axes is empty."); | 
			
		
	
		
			
				
					|  |  |  |  |           // Validity Check: axes dims (<6).
 | 
			
		
	
		
			
				
					|  |  |  |  |           PADDLE_ENFORCE_LT(static_cast<int>(axes.size()), 6, | 
			
		
	
		
			
				
					|  |  |  |  |                             "Invalid dimensions, dynamic dimensions should be " | 
			
		
	
	
		
			
				
					|  |  |  | @ -136,28 +189,12 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { | 
			
		
	
		
			
				
					|  |  |  |  | // will be used in unsqueeze_grad, in this way, the framework can reuse
 | 
			
		
	
		
			
				
					|  |  |  |  | // the memory of X immediately the unsqueeze2_op is finished.
 | 
			
		
	
		
			
				
					|  |  |  |  | // Considering compatibility issues, we could not fix unsqueeze2_op
 | 
			
		
	
		
			
				
					|  |  |  |  | class Unsqueeze2Op : public framework::OperatorWithKernel { | 
			
		
	
		
			
				
					|  |  |  |  | class Unsqueeze2Op : public UnsqueezeOp { | 
			
		
	
		
			
				
					|  |  |  |  |  public: | 
			
		
	
		
			
				
					|  |  |  |  |   using framework::OperatorWithKernel::OperatorWithKernel; | 
			
		
	
		
			
				
					|  |  |  |  |   using UnsqueezeOp::UnsqueezeOp; | 
			
		
	
		
			
				
					|  |  |  |  |   void InferShape(framework::InferShapeContext *ctx) const override { | 
			
		
	
		
			
				
					|  |  |  |  |     PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, | 
			
		
	
		
			
				
					|  |  |  |  |                       "Input(X) of Unsqueeze operator should not be null."); | 
			
		
	
		
			
				
					|  |  |  |  |     PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, | 
			
		
	
		
			
				
					|  |  |  |  |                       "Output(Out) of Unsqueeze operator should not be null."); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); | 
			
		
	
		
			
				
					|  |  |  |  |     UnsqueezeOp::InferShape(ctx); | 
			
		
	
		
			
				
					|  |  |  |  |     const auto &x_dims = ctx->GetInputDim("X"); | 
			
		
	
		
			
				
					|  |  |  |  |     // Validity Check: input tensor dims (<6).
 | 
			
		
	
		
			
				
					|  |  |  |  |     PADDLE_ENFORCE_LE(x_dims.size(), 6, | 
			
		
	
		
			
				
					|  |  |  |  |                       "Invalid dimensions, the rank of Input(X) " | 
			
		
	
		
			
				
					|  |  |  |  |                       "should be in the range of [1, 6] (Eigen limit)"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto out_dims = UnsqueezeOp::GetOutputShape(axes, x_dims); | 
			
		
	
		
			
				
					|  |  |  |  |     ctx->SetOutputDim("Out", out_dims); | 
			
		
	
		
			
				
					|  |  |  |  |     if (x_dims[0] == out_dims[0]) { | 
			
		
	
		
			
				
					|  |  |  |  |       // Only pass LoD when the first dimension of output and Input(X)
 | 
			
		
	
		
			
				
					|  |  |  |  |       // are the same.
 | 
			
		
	
		
			
				
					|  |  |  |  |       ctx->ShareLoD("X", "Out"); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     PADDLE_ENFORCE_EQ( | 
			
		
	
		
			
				
					|  |  |  |  |         ctx->HasOutput("XShape"), true, | 
			
		
	
	
		
			
				
					|  |  |  | @ -252,12 +289,11 @@ REGISTER_OP_CPU_KERNEL( | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); | 
			
		
	
		
			
				
					|  |  |  |  | REGISTER_OP_CPU_KERNEL( | 
			
		
	
		
			
				
					|  |  |  |  |     unsqueeze2, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, float>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, double>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>); | 
			
		
	
		
			
				
					|  |  |  |  |     unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); | 
			
		
	
		
			
				
					|  |  |  |  | REGISTER_OP_CPU_KERNEL( | 
			
		
	
		
			
				
					|  |  |  |  |     unsqueeze2_grad, | 
			
		
	
		
			
				
					|  |  |  |  |     ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>, | 
			
		
	
	
		
			
				
					|  |  |  | 
 |