Improve expand as (#26290)

align expand_as op to expand.
revert-24895-update_cub
lilong12 5 years ago committed by GitHub
parent 5fdec3ed35
commit 638bbb6153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,150 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class ExpandAsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2");
OP_INOUT_CHECK(ctx->HasInput("target_tensor"), "Input", "target_tensor",
"ExpandAsV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2");
auto x_dims = ctx->GetInputDim("X");
auto target_tensor_dims = ctx->GetInputDim("target_tensor");
PADDLE_ENFORCE_GE(
target_tensor_dims.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The rank of Input(target_tensor) must be greater than or equal "
"to the rank of Input(X). But received Input(X): input "
"rank %u, input shape [%s]; received Input(target_tensor): "
"input rank %u, input shape [%s].",
x_dims.size(), x_dims, target_tensor_dims.size(),
target_tensor_dims));
PADDLE_ENFORCE_LE(
target_tensor_dims.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of Input(target_tensor) must not be less than or equal "
"to %d. But received: input rank %u, input shape [%s].",
MAX_RANK_SUPPORTED, x_dims.size(), x_dims));
std::vector<int64_t> out_shape(target_tensor_dims.size());
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
};
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
"After expanding, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
"the corresponding value given by Attr(expand_times).");
AddInput("target_tensor", "Expand tensor's shape for each dimension.");
AddComment(R"DOC(
Expand the input by given times number. You should set times
number for each dimension by providing tensor 'expend_tensor'. The rank of X
should be in [1, 6]. Please note that size of 'expend_tensor' must be the same
with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
[[1], [2], [3]],
[[4], [5], [6]]
]
target_tensors'shape: [2, 6, 2]
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
[
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
]
)DOC");
}
};
class ExpandAsV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "ExpandAsV2Grad");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class ExpandAsV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_as_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("target_tensor", this->Input("target_tensor"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker,
ops::ExpandAsV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp,
ops::ExpandAsV2GradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CPUDeviceContext, double>);

@ -0,0 +1,26 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
expand_as_v2,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2Kernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_as_v2_grad,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);

@ -0,0 +1,214 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
#define EXPAND_AS_TEMPLATE(z, n, data) \
case n + 1: { \
ExpandAs<n + 1>(context); \
break; \
}
#define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~)
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
#define EXPAND_AS_GRAD_CASE(n) \
case n: { \
ExpandAsBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
break; \
}
#define EXPAND_AS_GRAD_TEMPLATE(z, n, data) \
BOOST_PP_IF(COND(n), EXPAND_AS_GRAD_CASE(n), )
#define REP_EXPAND_AS_GRAD_TEMPLATE(n) \
BOOST_PP_REPEAT(n, EXPAND_AS_GRAD_TEMPLATE, ~)
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class ExpandAsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size();
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto target_rank = target_tensor->dims().size();
PADDLE_ENFORCE_GE(target_rank, rank,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be greater than or equal to "
"the rank (%d) of the input 'x'.",
target_rank, rank));
PADDLE_ENFORCE_GE(rank, 1, platform::errors::InvalidArgument(
"The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
PADDLE_ENFORCE_LE(target_rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank (%d) of the input 'target_tensor' for "
"expand_as_v2 op must be less than or equal to %d.",
target_rank, MAX_RANK_SUPPORTED));
switch (target_rank) { REP_EXPAND_AS_TEMPLATE(MAX_RANK_SUPPORTED) }
}
protected:
template <int Rank>
void ExpandAs(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X");
auto in_dims = in0->dims();
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto vec_in_dims = framework::vectorize<int>(in_dims);
auto target_shape = framework::vectorize<int>(target_tensor->dims());
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument(
"The value of target shape cannot be zero."));
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in "
"target tensor for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
framework::DDim new_in_dims = framework::make_ddim(vec_in_dims);
framework::DDim out_dims = framework::make_ddim(target_shape);
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(*in0, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
}
};
template <typename DeviceContext, typename T>
class ExpandAsV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto x_dims = in0->dims();
auto target_shape = target_tensor->dims();
auto vec_in_dims = framework::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
repeat_times[i] = target_shape[i] / vec_in_dims[i];
}
std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < repeat_times.size(); ++i) {
reduce_dims_vec.push_back(reshape_dims_vec.size());
reshape_dims_vec.push_back(repeat_times[i]);
reshape_dims_vec.push_back(vec_in_dims[i]);
}
int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;
break;
}
}
// no need reduce, just copy
if (just_copy) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
out0);
} else {
PADDLE_ENFORCE_GE(dims, 1,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, dims));
switch (dims) { REP_EXPAND_AS_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) }
}
}
protected:
template <int Dims>
void ExpandAsBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = reshape_dims_vec.size();
size_t reduce_size = reduce_dims_vec.size();
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
}
};
} // namespace operators
} // namespace paddle

@ -101,6 +101,7 @@ from .tensor.logic import equal_all #DEFINE_ALIAS
from .tensor.manipulation import cast #DEFINE_ALIAS
from .tensor.manipulation import concat #DEFINE_ALIAS
from .tensor.manipulation import expand #DEFINE_ALIAS
from .tensor.manipulation import broadcast_to #DEFINE_ALIAS
from .tensor.manipulation import expand_as #DEFINE_ALIAS
from .tensor.manipulation import tile #DEFINE_ALIAS
from .tensor.manipulation import flatten #DEFINE_ALIAS

@ -0,0 +1,121 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
class TestExpandAsOpRank1(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(100).astype("float64")
target_tensor = np.random.rand(2, 100).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [2, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank2(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(10, 12).astype("float64")
target_tensor = np.random.rand(10, 12).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank3(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(2, 3, 20).astype("float64")
target_tensor = np.random.rand(2, 3, 20).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [1, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestExpandAsOpRank4(OpTest):
def setUp(self):
self.op_type = "expand_as_v2"
x = np.random.rand(1, 1, 7, 16).astype("float64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
self.inputs = {'X': x, 'target_tensor': target_tensor}
self.attrs = {}
bcast_dims = [4, 6, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
# Test python API
class TestExpandAPI(unittest.TestCase):
def test_api(self):
input1 = np.random.random([12, 14]).astype("float32")
input2 = np.random.random([2, 12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
y = fluid.layers.data(
name='target_tensor',
shape=[2, 12, 14],
append_batch_size=False,
dtype="float32")
out_1 = paddle.expand_as(x, y=y)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1 = exe.run(fluid.default_main_program(),
feed={"x": input1,
"target_tensor": input2},
fetch_list=[out_1])
assert np.array_equal(res_1[0], np.tile(input1, (2, 1, 1)))
if __name__ == "__main__":
unittest.main()

@ -193,7 +193,7 @@ class TestExpandV2Error(unittest.TestCase):
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tensor.expand, x2, shape)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
x3.stop_gradient = False
self.assertRaises(ValueError, paddle.tensor.expand, x3, shape)

@ -22,7 +22,7 @@ import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
# Situation 1: repeat_times is a list(without tensor)
# Situation 1: repeat_times is a list (without tensor)
class TestTileOpRank1(OpTest):
def setUp(self):
self.op_type = "tile"
@ -81,7 +81,7 @@ class TestTileOpRank4(TestTileOpRank1):
self.repeat_times = (3, 2, 1, 2)
# Situation 2: repeat_times is a list(with tensor)
# Situation 2: repeat_times is a list (with tensor)
class TestTileOpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "tile"
@ -162,7 +162,7 @@ class TestTileOpInteger(OpTest):
self.op_type = "tile"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int32")
10, size=(4, 4, 5)).astype("int32")
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
@ -211,38 +211,30 @@ class TestTileError(unittest.TestCase):
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, paddle.tile, x2, repeat_times)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
x3.stop_gradient = False
self.assertRaises(ValueError, paddle.tile, x3, repeat_times)
# Test python API
class TestTileAPI(unittest.TestCase):
def test_api(self):
input = np.random.random([12, 14]).astype("float32")
x = fluid.layers.data(
name='x', shape=[12, 14], append_batch_size=False, dtype="float32")
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
repeat_times = fluid.layers.data(
name="repeat_times", shape=[2], append_batch_size=False)
out_1 = paddle.tile(x, repeat_times=[2, 3])
out_2 = paddle.tile(x, repeat_times=[positive_2, 3])
out_3 = paddle.tile(x, repeat_times=repeat_times)
g0 = fluid.backward.calc_gradient(out_2, x)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
"x": input,
"repeat_times":
np.array([1, 3]).astype("int32")
},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.tile(input, (2, 3)))
assert np.array_equal(res_2, np.tile(input, (2, 3)))
assert np.array_equal(res_3, np.tile(input, (1, 3)))
with fluid.dygraph.guard():
np_x = np.random.random([12, 14]).astype("float32")
x = paddle.to_variable(np_x)
positive_2 = np.array([2]).astype("int32")
positive_2 = paddle.to_variable(positive_2)
repeat_times = np.array([2, 3]).astype("int32")
repeat_times = paddle.to_variable(repeat_times)
out_1 = paddle.tile(x, repeat_times=[2, 3])
out_2 = paddle.tile(x, repeat_times=[positive_2, 3])
out_3 = paddle.tile(x, repeat_times=repeat_times)
assert np.array_equal(out_1.numpy(), np.tile(np_x, (2, 3)))
assert np.array_equal(out_2.numpy(), np.tile(np_x, (2, 3)))
assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3)))
if __name__ == "__main__":

@ -74,6 +74,7 @@ from .logic import equal_all #DEFINE_ALIAS
from .manipulation import cast #DEFINE_ALIAS
from .manipulation import concat #DEFINE_ALIAS
from .manipulation import expand #DEFINE_ALIAS
from .manipulation import broadcast_to #DEFINE_ALIAS
from .manipulation import expand_as #DEFINE_ALIAS
from .manipulation import tile #DEFINE_ALIAS
from .manipulation import flatten #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save