add expand_as op, test=develop (#20565)
	
		
	
				
					
				
			* add expand_as op, test=develop * add expand_as op,test=develop * add expand_as op,test=develop * add nn.py, test=develop * delele paddle_enforce, test=developrevert-20712-fix_depthwise_conv
							parent
							
								
									5c41805dc9
								
							
						
					
					
						commit
						2ff18e537f
					
				| @ -0,0 +1,127 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/expand_as_op.h" | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using framework::Tensor; | ||||
| 
 | ||||
| class ExpandAsOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|  protected: | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true); | ||||
|     PADDLE_ENFORCE_EQ(ctx->HasInput("target_tensor"), true); | ||||
|     PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true); | ||||
|     auto x_dims = ctx->GetInputDim("X"); | ||||
|     auto target_tensor_dims = ctx->GetInputDim("target_tensor"); | ||||
|     PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), | ||||
|                       target_tensor_dims.size(), | ||||
|                       "The rank of input(target_tensor) must be equal " | ||||
|                       "to the rank of Input(X)."); | ||||
|     PADDLE_ENFORCE_LE(x_dims.size(), 6, | ||||
|                       "The rank of Input(X) must not be greater than 6."); | ||||
|     std::vector<int64_t> out_shape(x_dims.size()); | ||||
|     ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ExpandAsOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   void Make() override { | ||||
|     AddInput("X", | ||||
|              "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." | ||||
|              "X is the input to be expanded."); | ||||
|     AddOutput("Out", | ||||
|               "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." | ||||
|               "The rank of Output(Out) have the same with Input(X). " | ||||
|               "After expanding, size of each dimension of Output(Out) is equal " | ||||
|               "to size of the corresponding dimension of Input(X) multiplying " | ||||
|               "the corresponding value given by Attr(expand_times)."); | ||||
|     AddInput("target_tensor", "Expand tensor's shape for each dimension."); | ||||
|     AddComment(R"DOC( | ||||
| Expand as operator tiles the input by given times number. You should set times | ||||
| number for each dimension by providing tensor 'expend_tensor'. The rank of X | ||||
| should be in [1, 6]. Please note that size of 'expend_tensor' must be the same | ||||
| with X's rank. Following is a using case: | ||||
| Input(X) is a 3-D tensor with shape [2, 3, 1]: | ||||
|         [ | ||||
|            [[1], [2], [3]], | ||||
|            [[4], [5], [6]] | ||||
|         ] | ||||
| target_tensors'shape:  [2, 6, 2] | ||||
| Output(Out) is a 3-D tensor with shape [2, 6, 2]: | ||||
|         [ | ||||
|             [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], | ||||
|             [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] | ||||
|         ] | ||||
| )DOC"); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ExpandAsGradOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
| 
 | ||||
|  protected: | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true); | ||||
|     PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true); | ||||
| 
 | ||||
|     auto x_dims = ctx->GetInputDim("X"); | ||||
|     auto x_grad_name = framework::GradVarName("X"); | ||||
|     if (ctx->HasOutput(x_grad_name)) { | ||||
|       ctx->SetOutputDim(x_grad_name, x_dims); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ExpandAsGradOpDescMaker : public framework::SingleGradOpDescMaker { | ||||
|  public: | ||||
|   using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; | ||||
| 
 | ||||
|  protected: | ||||
|   std::unique_ptr<framework::OpDesc> Apply() const override { | ||||
|     std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); | ||||
|     op->SetType("expand_as_grad"); | ||||
|     op->SetInput("X", Input("X")); | ||||
|     op->SetInput("target_tensor", Input("target_tensor")); | ||||
|     op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); | ||||
|     op->SetOutput(framework::GradVarName("X"), InputGrad("X")); | ||||
|     op->SetAttrMap(Attrs()); | ||||
|     return op; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");
 | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker, | ||||
|                   ops::ExpandAsGradOpDescMaker); | ||||
| REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     expand_as, ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, float>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, double>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, int>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CPUDeviceContext, bool>); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     expand_as_grad, | ||||
|     ops::ExpandAsGradKernel<paddle::platform::CPUDeviceContext, float>, | ||||
|     ops::ExpandAsGradKernel<paddle::platform::CPUDeviceContext, double>); | ||||
| @ -0,0 +1,22 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| #include "paddle/fluid/operators/expand_as_op.h" | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OP_CUDA_KERNEL( | ||||
|     expand_as, ops::ExpandAsKernel<paddle::platform::CUDADeviceContext, float>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CUDADeviceContext, double>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CUDADeviceContext, int>, | ||||
|     ops::ExpandAsKernel<paddle::platform::CUDADeviceContext, bool>); | ||||
| REGISTER_OP_CUDA_KERNEL( | ||||
|     expand_as_grad, | ||||
|     ops::ExpandAsGradKernel<paddle::platform::CUDADeviceContext, float>, | ||||
|     ops::ExpandAsGradKernel<paddle::platform::CUDADeviceContext, double>); | ||||
| @ -0,0 +1,185 @@ | ||||
| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <vector> | ||||
| 
 | ||||
| #include <boost/preprocessor/arithmetic/div.hpp> | ||||
| #include <boost/preprocessor/arithmetic/mod.hpp> | ||||
| #include <boost/preprocessor/comparison/greater.hpp> | ||||
| #include <boost/preprocessor/comparison/greater_equal.hpp> | ||||
| #include <boost/preprocessor/control/if.hpp> | ||||
| #include <boost/preprocessor/repetition/repeat.hpp> | ||||
| #include "paddle/fluid/framework/eigen.h" | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/framework/operator.h" | ||||
| 
 | ||||
| #define MAX_RANK_SUPPORTED 6 | ||||
| 
 | ||||
| #define EXPAND_AS_TEMPLATE(z, n, data) \ | ||||
|   case n + 1: {                        \ | ||||
|     ExpandAs<n + 1>(context);          \ | ||||
|     break;                             \ | ||||
|   } | ||||
| #define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~) | ||||
| #define COND(n)                                               \ | ||||
|   BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \ | ||||
|                          BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) | ||||
| #define EXPAND_AS_GRAD_CASE(n)                                       \ | ||||
|   case n: {                                                          \ | ||||
|     ExpandAsBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \ | ||||
|     break;                                                           \ | ||||
|   } | ||||
| #define EXPAND_AS_GRAD_TEMPLATE(z, n, data) \ | ||||
|   BOOST_PP_IF(COND(n), EXPAND_AS_GRAD_CASE(n), ) | ||||
| #define REP_EXPAND_AS_GRAD_TEMPLATE(n) \ | ||||
|   BOOST_PP_REPEAT(n, EXPAND_AS_GRAD_TEMPLATE, ~) | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| using Tensor = framework::Tensor; | ||||
| template <typename T, int MajorType = Eigen::RowMajor, | ||||
|           typename IndexType = Eigen::DenseIndex> | ||||
| using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||||
| template <typename T, size_t D, int MajorType = Eigen::RowMajor, | ||||
|           typename IndexType = Eigen::DenseIndex> | ||||
| using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class ExpandAsKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& context) const override { | ||||
|     auto rank = context.Input<Tensor>("X")->dims().size(); | ||||
|     switch (rank) { | ||||
|       REP_EXPAND_AS_TEMPLATE(MAX_RANK_SUPPORTED) | ||||
|       default: | ||||
|         PADDLE_THROW("Only support tensor with rank being between 1 and 6."); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   template <int Rank> | ||||
|   void ExpandAs(const framework::ExecutionContext& context) const { | ||||
|     auto* in0 = context.Input<Tensor>("X"); | ||||
|     auto in_dims = in0->dims(); | ||||
|     auto* target_tensor = context.Input<Tensor>("target_tensor"); | ||||
|     auto* out0 = context.Output<Tensor>("Out"); | ||||
|     Eigen::DSizes<int, Rank> bcast_dims; | ||||
|     int bcast_dims_remainder = 0; | ||||
|     auto x_dims = in0->dims(); | ||||
|     auto y_dims = target_tensor->dims(); | ||||
|     for (int i = 0; i < y_dims.size(); ++i) { | ||||
|       PADDLE_ENFORCE_NE(x_dims[i], 0, "X(input) should not have 0 dim"); | ||||
|       bcast_dims[i] = y_dims[i] / x_dims[i]; | ||||
|       bcast_dims_remainder += y_dims[i] % x_dims[i]; | ||||
|     } | ||||
|     PADDLE_ENFORCE_EQ(bcast_dims_remainder, 0, | ||||
|                       "X(input) could not be broadcast together with remapped " | ||||
|                       "shape(expand tensor's shape)"); | ||||
|     framework::DDim out_dims(in_dims); | ||||
|     for (size_t i = 0; i < bcast_dims.size(); ++i) { | ||||
|       out_dims[i] *= bcast_dims[i]; | ||||
|     } | ||||
| 
 | ||||
|     out0->Resize(out_dims); | ||||
|     auto x = EigenTensor<T, Rank>::From(*in0); | ||||
|     out0->mutable_data<T>(context.GetPlace()); | ||||
|     auto y = EigenTensor<T, Rank>::From(*out0); | ||||
|     auto& place = | ||||
|         *context.template device_context<DeviceContext>().eigen_device(); | ||||
|     y.device(place) = x.broadcast(bcast_dims); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class ExpandAsGradKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& context) const override { | ||||
|     auto* in0 = context.Input<Tensor>("X"); | ||||
|     auto* target_tensor = context.Input<Tensor>("target_tensor"); | ||||
|     auto x_dims = in0->dims(); | ||||
|     auto y_dims = target_tensor->dims(); | ||||
|     std::vector<int> bcast_dims; | ||||
|     for (int i = 0; i < y_dims.size(); ++i) { | ||||
|       bcast_dims.push_back(y_dims[i] / x_dims[i]); | ||||
|     } | ||||
|     std::vector<int> reshape_dims_vec; | ||||
|     std::vector<int> reduce_dims_vec; | ||||
|     for (size_t i = 0; i < bcast_dims.size(); ++i) { | ||||
|       if (bcast_dims[i] == 1) { | ||||
|         reshape_dims_vec.push_back(x_dims[i]); | ||||
|       } else { | ||||
|         if (x_dims[i] == 1) { | ||||
|           reduce_dims_vec.push_back(reshape_dims_vec.size()); | ||||
|           reshape_dims_vec.push_back(bcast_dims[i]); | ||||
|         } else { | ||||
|           reduce_dims_vec.push_back(reshape_dims_vec.size()); | ||||
|           reshape_dims_vec.push_back(bcast_dims[i]); | ||||
|           reshape_dims_vec.push_back(x_dims[i]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED + | ||||
|                reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1; | ||||
|     // no need reduce, just copy
 | ||||
|     if (reduce_dims_vec.size() == 0) { | ||||
|       auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); | ||||
|       auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); | ||||
|       out0->mutable_data<T>(context.GetPlace()); | ||||
|       framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), | ||||
|                             out0); | ||||
|     } else { | ||||
|       switch (dims) { | ||||
|         REP_EXPAND_AS_GRAD_TEMPLATE(72) | ||||
|         default: | ||||
|           PADDLE_THROW("Only support tensor with rank being between 1 and 6."); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   template <int Dims> | ||||
|   void ExpandAsBackward(const framework::ExecutionContext& context, | ||||
|                         const std::vector<int>& reshape_dims_vec, | ||||
|                         const std::vector<int>& reduce_dims_vec) const { | ||||
|     size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1; | ||||
|     size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1; | ||||
|     PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), | ||||
|                       "Inconsistent size between template Dims and " | ||||
|                       "reshape dimensions."); | ||||
|     PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(), | ||||
|                       "Inconsistent size between template Dims and " | ||||
|                       "reduce dimensions."); | ||||
|     auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); | ||||
|     auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); | ||||
|     out0->mutable_data<T>(context.GetPlace()); | ||||
|     auto x_grad = EigenVector<T>::Flatten(*out0); | ||||
|     Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims; | ||||
|     for (size_t i = 0; i < reshape_size; ++i) { | ||||
|       reshape_dims[i] = reshape_dims_vec[i]; | ||||
|     } | ||||
|     Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims; | ||||
|     for (size_t i = 0; i < reduce_size; ++i) { | ||||
|       reduce_dims[i] = reduce_dims_vec[i]; | ||||
|     } | ||||
|     auto out_grad = EigenVector<T>::Flatten(*in0); | ||||
|     x_grad.device( | ||||
|         *context.template device_context<DeviceContext>().eigen_device()) = | ||||
|         out_grad.reshape(reshape_dims) | ||||
|             .sum(reduce_dims) | ||||
|             .reshape(x_grad.dimensions()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,130 @@ | ||||
| #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import unittest | ||||
| import numpy as np | ||||
| from op_test import OpTest | ||||
| import paddle.fluid as fluid | ||||
| 
 | ||||
| 
 | ||||
| def bcast(x, target_tensor): | ||||
|     x_dims = x.shape | ||||
|     y_dims = target_tensor.shape | ||||
|     bcast_dims = [] | ||||
|     for i in range(len(x_dims)): | ||||
|         bcast_dims.append(int(y_dims[i] / x_dims[i])) | ||||
|     bcast_dims = np.array(bcast_dims).astype("int64") | ||||
|     return bcast_dims | ||||
| 
 | ||||
| 
 | ||||
| class TestExpandAsOpRank1(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "expand_as" | ||||
|         x = np.random.rand(12).astype("float64") | ||||
|         target_tensor = np.random.rand(24).astype("float64") | ||||
|         self.inputs = {'X': x, 'target_tensor': target_tensor} | ||||
|         self.attrs = {} | ||||
|         bcast_dims = bcast(x, target_tensor) | ||||
|         output = np.tile(self.inputs['X'], bcast_dims) | ||||
|         self.outputs = {'Out': output} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['X'], 'Out') | ||||
| 
 | ||||
| 
 | ||||
| class TestExpandAsOpRank2(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "expand_as" | ||||
|         x = np.random.rand(2, 3).astype("float64") | ||||
|         target_tensor = np.random.rand(4, 6).astype("float64") | ||||
|         self.inputs = {'X': x, 'target_tensor': target_tensor} | ||||
|         self.attrs = {} | ||||
|         bcast_dims = bcast(x, target_tensor) | ||||
|         output = np.tile(self.inputs['X'], bcast_dims) | ||||
|         self.outputs = {'Out': output} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['X'], 'Out') | ||||
| 
 | ||||
| 
 | ||||
| class TestExpandAsOpRank3(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "expand_as" | ||||
|         x = np.random.rand(2, 3, 3).astype("float64") | ||||
|         target_tensor = np.random.rand(4, 6, 6).astype("float64") | ||||
|         self.inputs = {'X': x, 'target_tensor': target_tensor} | ||||
|         self.attrs = {} | ||||
|         bcast_dims = bcast(x, target_tensor) | ||||
|         output = np.tile(self.inputs['X'], bcast_dims) | ||||
|         self.outputs = {'Out': output} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['X'], 'Out') | ||||
| 
 | ||||
| 
 | ||||
| class TestExpandAsOpRank4(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "expand_as" | ||||
|         x = np.random.rand(1, 1, 3, 16).astype("float64") | ||||
|         target_tensor = np.random.rand(4, 6, 6, 32).astype("float64") | ||||
|         self.inputs = {'X': x, 'target_tensor': target_tensor} | ||||
|         self.attrs = {} | ||||
|         bcast_dims = bcast(x, target_tensor) | ||||
|         output = np.tile(self.inputs['X'], bcast_dims) | ||||
|         self.outputs = {'Out': output} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['X'], 'Out') | ||||
| 
 | ||||
| 
 | ||||
| # Test python API | ||||
| class TestExpandAPI(OpTest): | ||||
|     def test_api(self): | ||||
|         input1 = np.random.random([12, 14]).astype("float32") | ||||
|         input2 = np.random.random([48, 14]).astype("float32") | ||||
|         x = fluid.layers.data( | ||||
|             name='x', shape=[12, 14], append_batch_size=False, dtype="float32") | ||||
| 
 | ||||
|         y = fluid.layers.data( | ||||
|             name='target_tensor', | ||||
|             shape=[48, 14], | ||||
|             append_batch_size=False, | ||||
|             dtype="float32") | ||||
| 
 | ||||
|         out_1 = fluid.layers.expand_as(x, target_tensor=y) | ||||
| 
 | ||||
|         exe = fluid.Executor(place=fluid.CPUPlace()) | ||||
|         res_1 = exe.run(fluid.default_main_program(), | ||||
|                         feed={"x": input1, | ||||
|                               "target_tensor": input2}, | ||||
|                         fetch_list=[out_1]) | ||||
|         assert np.array_equal(res_1[0], np.tile(input1, (4, 1))) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
					Loading…
					
					
				
		Reference in new issue