Speed/sequence op1 (#9217)

* "add functors"

* "remove old code"

* "fix"

* "fix ci"

* "add details"

* "fix ci"

* "fix ci"

* "fix ci"

* "fix ci"

* "remove unused code"
helinwang-patch-1
dzhwinter 7 years ago committed by GitHub
parent d21ab2e2ba
commit 8425c2c859
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,8 +19,17 @@ namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> {
class MaxSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
@ -60,7 +69,7 @@ class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> {
};
template <typename T>
class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> {
class MaxSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> {
}
};
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, float>;
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, double>;
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, float>;
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, double>;
template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public:
/* max pool has index output */
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output,
framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<T> max_pool;
max_pool(context, input, output, index);
return;
}
auto lod = input.lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = output->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = input.numel() / input.dims()[0];
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template <typename T>
class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<T> max_pool_grad;
max_pool_grad(context, out_grad, *index, in_grad);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0);
}
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
auto out_g_t = out_grad.Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = in_grad->numel() / in_grad->dims()[0];
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template class SequencePoolFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolFunctor<platform::CPUDeviceContext, double>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators

File diff suppressed because it is too large Load Diff

@ -21,23 +21,23 @@ namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
template <typename DeviceContext, typename T>
class MaxSeqPoolFunctor {
class SequencePoolFunctor {
public:
void operator()(const DeviceContext& context,
/* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index);
framework::Tensor* index = nullptr);
};
template <typename DeviceContext, class T>
class MaxSeqPoolGradFunctor {
template <typename DeviceContext, typename T>
class SequencePoolGradFunctor {
public:
void operator()(const DeviceContext& context,
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::Tensor& out_grad,
const framework::Tensor& index,
framework::LoDTensor* in_grad);
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr);
};
} // namespace math

@ -23,12 +23,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class SequencePoolKernel : public framework::OpKernel<T> {
@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* index = nullptr;
if (pooltype == "MAX") {
index = context.Output<Tensor>("MaxIndex");
}
auto dims = in->dims();
auto lod = in->lod();
int64_t w = in->numel() / dims[0];
// InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel<T> {
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
out->Resize({dims});
auto lod_level_0 = lod[0];
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<DeviceContext, T> max_pool;
auto* index = context.Output<Tensor>("MaxIndex");
index->Resize({dims});
index->mutable_data<int>(context.GetPlace());
max_pool(dev_ctx, *in, out, index);
return;
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = in->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = out->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod_level_0[i + 1] - lod_level_0[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *in, out,
index);
}
};
@ -96,58 +61,17 @@ template <typename DeviceContext, typename T>
class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype");
auto dims = in->dims();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
const Tensor* index = nullptr;
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<DeviceContext, T> max_pool_grad;
auto* index = context.Input<Tensor>("MaxIndex");
max_pool_grad(dev_ctx, *out_g, *index, in_g);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<DeviceContext, T> functor;
functor(dev_ctx, in_g, 0);
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t =
in_g->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
auto out_g_t = out_g->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
index = context.Input<Tensor>("MaxIndex");
}
in_g->mutable_data<T>(context.GetPlace());
math::SequencePoolGradFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *out_g,
in_g, index);
}
};

@ -49,6 +49,61 @@ class TestSeqAvgPool(OpTest):
self.check_grad(["X"], "Out")
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqAvgPool2D(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
@ -68,14 +123,6 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
@ -84,15 +131,6 @@ class TestSeqSumPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqSqrtPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
@ -108,28 +146,6 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
self.check_grad(["X"], "Out", max_relative_error=0.06)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqMaxPool2D(TestSeqAvgPool2D):
def set_data(self):
self.op_type = 'sequence_pool'
@ -151,14 +167,6 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
@ -167,14 +175,6 @@ class TestSeqLastPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x[-1, :], (3, 17))
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}

Loading…
Cancel
Save