|
|
|
@ -31,7 +31,7 @@ inline int get_output_size(int img_size, int block_size, int stride,
|
|
|
|
|
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class BlockExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -71,8 +71,9 @@ class BlockExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
img_channels, block_height,
|
|
|
|
|
block_width});
|
|
|
|
|
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f;
|
|
|
|
|
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
f(dev_ctx, src, dilations, strides, paddings, &dst);
|
|
|
|
|
}
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
|
|
|
|
@ -87,7 +88,7 @@ class BlockExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class BlockExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -98,7 +99,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x_v = framework::EigenVector<T>::Flatten(*d_x);
|
|
|
|
|
x_v.device(ctx.GetEigenDevice<Place>()) = x_v.constant(0.0);
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
x_v.device(place) = x_v.constant(0.0);
|
|
|
|
|
|
|
|
|
|
auto in_dim = in->dims();
|
|
|
|
|
int batch_size = in_dim[0];
|
|
|
|
@ -131,8 +133,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
const Tensor src = d_out->Slice(i, i + 1).Resize(
|
|
|
|
|
{output_height, output_width, img_channels, block_height,
|
|
|
|
|
block_width});
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f;
|
|
|
|
|
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
|
|
|
|
|
math::Col2ImFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
f(dev_ctx, src, dilations, strides, paddings, &dst);
|
|
|
|
|
}
|
|
|
|
|
d_out->Resize(d_out_dims);
|
|
|
|
|
}
|
|
|
|
|