Merge remote-tracking branch 'ups/develop' into feature/op/fusion_expand_concat_fc

createGenDocLib
tensor-tang 7 years ago
commit a481c5e98c

File diff suppressed because one or more lines are too long

@ -64,6 +64,7 @@ static DataTypeMap* InitDataTypeMap() {
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType
return retv;

@ -54,6 +54,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
case proto::VarType::INT16:
visitor.template operator()<int16_t>();
break;
case proto::VarType::INT8:
visitor.template operator()<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}

@ -107,6 +107,7 @@ message VarType {
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
// Other types that may need additional descriptions
LOD_TENSOR = 7;

@ -177,6 +177,9 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>());
}
// Wait() must be called because `inputs_data` may be destructed before
// kernel ends
context.Wait();
}
};
@ -252,6 +255,9 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
}
// Wait() must be called because `outputs_data` may be destructed before
// kernel ends
context.Wait();
}
};

@ -41,7 +41,8 @@ template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>;
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);

@ -33,10 +33,11 @@ template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);

@ -0,0 +1,124 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D>
void PadFunction(const framework::ExecutionContext& context,
const std::vector<int>& pads, const framework::Tensor& src,
T pad_value, framework::Tensor* out) {
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = pads[i * 2];
paddings[i].second = pads[i * 2 + 1];
}
auto src_tensor = EigenTensor<T, D>::From(src);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_tensor.device(place) = src_tensor.pad(paddings, pad_value);
}
template <typename DeviceContext, typename T, size_t D>
void PadGradFunction(const framework::ExecutionContext& context,
const std::vector<int>& pads, const framework::Tensor& src,
framework::Tensor* d_out) {
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = -pads[i * 2];
paddings[i].second = -pads[i * 2 + 1];
}
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
auto src_tensor = EigenTensor<T, D>::From(src);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
d_out_tensor.device(place) = src_tensor.pad(paddings, 0);
}
template <typename DeviceContext, typename T>
void PaddingFunctor(int rank, const framework::ExecutionContext& context,
const std::vector<int>& pads, T pad_value,
const framework::Tensor& src, framework::Tensor* out) {
switch (rank) {
case 1:
PadFunction<DeviceContext, T, 1>(context, pads, src, pad_value, out);
break;
case 2:
PadFunction<DeviceContext, T, 2>(context, pads, src, pad_value, out);
break;
case 3:
PadFunction<DeviceContext, T, 3>(context, pads, src, pad_value, out);
break;
case 4:
PadFunction<DeviceContext, T, 4>(context, pads, src, pad_value, out);
break;
case 5:
PadFunction<DeviceContext, T, 5>(context, pads, src, pad_value, out);
break;
case 6:
PadFunction<DeviceContext, T, 6>(context, pads, src, pad_value, out);
break;
default:
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
}
template <typename DeviceContext, typename T>
void PaddingGradFunctor(int rank, const framework::ExecutionContext& context,
const std::vector<int>& pads,
const framework::Tensor& src, framework::Tensor* out) {
switch (rank) {
case 1:
PadGradFunction<DeviceContext, T, 1>(context, pads, src, out);
break;
case 2:
PadGradFunction<DeviceContext, T, 2>(context, pads, src, out);
break;
case 3:
PadGradFunction<DeviceContext, T, 3>(context, pads, src, out);
break;
case 4:
PadGradFunction<DeviceContext, T, 4>(context, pads, src, out);
break;
case 5:
PadGradFunction<DeviceContext, T, 5>(context, pads, src, out);
break;
case 6:
PadGradFunction<DeviceContext, T, 6>(context, pads, src, out);
break;
default:
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
}
} // namespace math
} // namespace operators
} // namespace paddle

@ -0,0 +1,196 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pad_constant_like_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class PadConstantLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PadConstantLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of PadConstantLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PadConstantLikeOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dim.size(), y_dim.size(),
"The dimention of X and Y should be the same.");
for (int i = 0; i < x_dim.size(); ++i) {
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
}
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class PadConstantLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input of pad_constant_like op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
AddInput("Y",
"The input of pad_constant_like op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out",
"The output of pad_constant_like op. "
"A tensor with the same shape as X.");
AddAttr<float>("pad_value",
"(float, default 0.0) "
"The value to fill the padded areas.")
.SetDefault(0.0f);
AddComment(R"DOC(
PadConstantLikeOp Operator.
Pad input(Y) with a pad_value, the number of values padded to the edges of each
axis is specified by the difference of the shape of X and Y.
((0, shape_x_0 - shape_y_0), (0, shape_x_n - shape_y_n)) unique pad widths for
each axis.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
case1:
Given:
X = [[1, 2],
[3, 4],
[1, 2],
[3, 4]]],
X.shape = (4, 2)
Y = [[5, 6],
[7, 8]],
Y.shape = (2, 2)
And
pad_value = 0,
Return:
Out = [[5, 6],
[7, 8],
[0, 0],
[0, 0]]
Out.shape = (4, 2)
case2:
Given:
X = [[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]],
[[[18, 19, 20],
[21, 22, 23]],
[[24, 25, 26],
[27, 28, 29]],
[[30, 31, 32],
[33, 34, 35]]]]
X.shape = (2, 3, 2, 3)
Y = [[[[35, 36, 37]],
[[38, 39, 40]],
[[41, 42, 43]]]]
Y.shape = (1, 3, 1, 3)
And
pad_value = -1,
Return:
Out = [[[[35, 36, 37],
[-1, -1, -1]],
[[38, 39, 40],
[-1, -1, -1]],
[[41, 42, 43],
[-1, -1, -1]]],
[[[-1, -1, -1],
[-1, -1, -1]],
[[-1, -1, -1],
[-1, -1, -1]],
[[-1, -1, -1],
[-1, -1, -1]]]]
Out.shape = (2, 3, 2, 3)
)DOC");
}
};
class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto y_dim = ctx->GetInputDim("Y");
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(dout_dim.size(), y_dim.size(),
"The dimention of X and Y should be the same.");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dim);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
for (int i = 0; i < y_dim.size(); ++i) {
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
}
}
}
};
class PadConstantLikeOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *bind = new framework::OpDesc();
bind->SetType("pad_constant_like_grad");
bind->SetInput("Y", Input("Y"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
bind->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(bind);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(pad_constant_like, ops::PadConstantLikeOp,
ops::PadConstantLikeOpMaker, ops::PadConstantLikeOpGradMaker);
REGISTER_OPERATOR(pad_constant_like_grad, ops::PadConstantLikeOpGrad);
REGISTER_OP_CPU_KERNEL(
pad_constant_like,
ops::PadConstantLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadConstantLikeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
pad_constant_like_grad,
ops::PadConstantLikeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadConstantLikeGradKernel<paddle::platform::CPUDeviceContext, double>);

@ -0,0 +1,27 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/pad_constant_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
pad_constant_like,
ops::PadConstantLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadConstantLikeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
pad_constant_like_grad,
ops::PadConstantLikeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadConstantLikeGradKernel<paddle::platform::CUDADeviceContext,
double>);

@ -0,0 +1,93 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/padding.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class PadConstantLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto in_x = context.Input<framework::Tensor>("X");
auto in_y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
if (in_x->dims() == in_y->dims()) {
// TensorCopy(in_y, context.GetPlace(), context, out);
out->ShareDataWith(*in_y);
return;
}
T pad_value = context.Attr<T>("pad_value");
out->mutable_data<T>(context.GetPlace());
int rank = context.Input<framework::Tensor>("X")->dims().size();
std::vector<int> pads(rank * 2, 0);
for (int j = 0; j < rank; ++j) {
pads[j * 2] = 0;
pads[j * 2 + 1] = static_cast<int>(in_x->dims()[j] - in_y->dims()[j]);
}
math::PaddingFunctor<DeviceContext, T>(rank, context, pads, pad_value,
*in_y, out);
}
};
template <typename DeviceContext, typename T>
class PadConstantLikeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto in_y = context.Input<framework::Tensor>("Y");
auto in_dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_y = context.Output<framework::Tensor>(framework::GradVarName("Y"));
if (d_y == nullptr) {
return;
}
if (in_dout->dims() == in_y->dims()) {
// TensorCopy(in_dout, context.GetPlace(), context, d_y);
d_y->ShareDataWith(*in_dout);
return;
}
d_y->mutable_data<T>(context.GetPlace());
int rank = in_dout->dims().size();
std::vector<int> pads(static_cast<size_t>(rank) * 2, 0);
for (int j = 0; j < rank; ++j) {
pads[j * 2] = 0;
pads[j * 2 + 1] = static_cast<int>(in_dout->dims()[j] - in_y->dims()[j]);
}
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *in_dout,
d_y);
}
};
} // namespace operators
} // namespace paddle

@ -18,117 +18,44 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/padding.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename DeviceContext, typename T, size_t D>
void PadFunction(const framework::ExecutionContext& context) {
auto pads = context.Attr<std::vector<int>>("paddings");
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = pads[i * 2];
paddings[i].second = pads[i * 2 + 1];
}
T pad_value = context.Attr<T>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto x_tensor = EigenTensor<T, D>::From(*x);
auto out_tensor = EigenTensor<T, D>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_tensor.device(place) = x_tensor.pad(paddings, pad_value);
}
template <typename DeviceContext, typename T>
class PadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
switch (rank) {
case 1:
PadFunction<DeviceContext, T, 1>(context);
break;
case 2:
PadFunction<DeviceContext, T, 2>(context);
break;
case 3:
PadFunction<DeviceContext, T, 3>(context);
break;
case 4:
PadFunction<DeviceContext, T, 4>(context);
break;
case 5:
PadFunction<DeviceContext, T, 5>(context);
break;
case 6:
PadFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
}
auto pads = context.Attr<std::vector<int>>("paddings");
T pad_value = context.Attr<T>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
math::PaddingFunctor<DeviceContext, T>(rank, context, pads, pad_value, *x,
out);
}
};
template <typename DeviceContext, typename T, size_t D>
void PadGradFunction(const framework::ExecutionContext& context) {
auto pads = context.Attr<std::vector<int>>("paddings");
Eigen::array<std::pair<int, int>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = -pads[i * 2];
paddings[i].second = -pads[i * 2 + 1];
}
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0);
}
}
template <typename DeviceContext, typename T>
class PadGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
PadGradFunction<DeviceContext, T, 1>(context);
break;
case 2:
PadGradFunction<DeviceContext, T, 2>(context);
break;
case 3:
PadGradFunction<DeviceContext, T, 3>(context);
break;
case 4:
PadGradFunction<DeviceContext, T, 4>(context);
break;
case 5:
PadGradFunction<DeviceContext, T, 5>(context);
break;
case 6:
PadGradFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
"PadOp only support tensors with no more than 6 dimensions.");
auto pads = context.Attr<std::vector<int>>("paddings");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x == nullptr) {
return;
}
d_x->mutable_data<T>(context.GetPlace());
int rank = d_out->dims().size();
math::PaddingGradFunctor<DeviceContext, T>(rank, context, pads, *d_out,
d_x);
}
};

@ -234,6 +234,7 @@ void BindVarDsec(pybind11::module *m) {
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8)
.value("INT8", pd::proto::VarType::INT8)
.value("INT16", pd::proto::VarType::INT16)
.value("INT32", pd::proto::VarType::INT32)
.value("INT64", pd::proto::VarType::INT64)

@ -130,6 +130,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<bool>)
.def("set", PyCPUTensorSetFromArray<uint16_t>)
.def("set", PyCPUTensorSetFromArray<uint8_t>)
.def("set", PyCPUTensorSetFromArray<int8_t>)
#ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>)
@ -138,6 +139,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>)
.def("set", PyCUDATensorSetFromArray<uint8_t>)
.def("set", PyCUDATensorSetFromArray<int8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<float>)
.def("set", PyCUDAPinnedTensorSetFromArray<int>)
.def("set", PyCUDAPinnedTensorSetFromArray<double>)
@ -145,6 +147,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCUDAPinnedTensorSetFromArray<bool>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<int8_t>)
#endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("_set_float_element", TensorSetElement<float>)

@ -97,7 +97,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) {
auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
uint8_t, platform::float16>()(tensor);
uint8_t, int8_t, platform::float16>()(tensor);
return buffer_info;
}

@ -95,6 +95,8 @@ def convert_np_dtype_to_dtype_(np_dtype):
return core.VarDesc.VarType.INT16
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
elif dtype == np.int8:
return core.VarDesc.VarType.INT8
else:
raise ValueError("Not supported numpy dtype %s" % dtype)

@ -0,0 +1,69 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestPadOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "pad_constant_like"
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Y': np.random.random(self.y_shape).astype("float32")
}
self.attrs = {}
self.attrs['pad_value'] = self.pad_value
self.outputs = {
'Out': np.pad(self.inputs['Y'],
self.paddings,
mode='constant',
constant_values=self.pad_value)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Y'], 'Out', max_relative_error=0.006)
def initTestCase(self):
self.x_shape = (16, 16)
self.y_shape = (3, 16)
self.pad_value = 0.1
self.paddings = [(0, 13), (0, 0)]
class TestCase1(TestPadOp):
def initTestCase(self):
self.x_shape = (4, 3, 4, 4)
self.y_shape = (2, 3, 4, 4)
self.paddings = [(0, 2), (0, 0), (0, 0), (0, 0)]
self.pad_value = 0.5
class TestCase2(TestPadOp):
def initTestCase(self):
self.x_shape = (4, 3, 4, 4)
self.y_shape = (2, 3, 2, 4)
self.paddings = [(0, 2), (0, 0), (0, 2), (0, 0)]
self.pad_value = 0.5
if __name__ == '__main__':
unittest.main()

@ -59,6 +59,27 @@ class TestTensor(unittest.TestCase):
self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
self.assertAlmostEqual(2.0, tensor_array_2[19, 11])
def test_int8_tensor(self):
scope = core.Scope()
var = scope.var("int8_tensor")
cpu_tensor = var.get_tensor()
tensor_array = numpy.random.randint(
-127, high=128, size=[100, 200], dtype=numpy.int8)
place = core.CPUPlace()
cpu_tensor.set(tensor_array, place)
cpu_tensor_array_2 = numpy.array(cpu_tensor)
self.assertAlmostEqual(cpu_tensor_array_2.all(), tensor_array.all())
if core.is_compiled_with_cuda():
cuda_tensor = var.get_tensor()
tensor_array = numpy.random.randint(
-127, high=128, size=[100, 200], dtype=numpy.int8)
place = core.CUDAPlace(0)
cuda_tensor.set(tensor_array, place)
cuda_tensor_array_2 = numpy.array(cuda_tensor)
self.assertAlmostEqual(cuda_tensor_array_2.all(),
tensor_array.all())
def test_int_lod_tensor(self):
place = core.CPUPlace()
scope = core.Scope()

@ -31,7 +31,8 @@ class TestVariable(unittest.TestCase):
self.assertEqual(DT.INT16, convert("int16"))
self.assertEqual(DT.INT64, convert("int64"))
self.assertEqual(DT.BOOL, convert("bool"))
self.assertRaises(ValueError, lambda: convert("int8"))
self.assertEqual(DT.INT8, convert("int8"))
self.assertEqual(DT.UINT8, convert("uint8"))
def test_var(self):
b = default_main_program().current_block()

Loading…
Cancel
Save