|
|
|
@ -27,9 +27,10 @@ template <class T>
|
|
|
|
|
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
platform::CPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_height,
|
|
|
|
|
int padding_width, const platform::DeviceContext& context) {
|
|
|
|
|
int padding_width) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
|
|
|
|
@ -79,9 +80,9 @@ template <class T>
|
|
|
|
|
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
|
|
|
|
|
platform::CPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::Tensor& im, const framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_height,
|
|
|
|
|
int padding_width, const platform::DeviceContext& context) {
|
|
|
|
|
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
|
|
|
|
|
const framework::Tensor& col, int stride_height,
|
|
|
|
|
int stride_width, int padding_height, int padding_width) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
int input_channels = im.dims()[0];
|
|
|
|
@ -137,9 +138,10 @@ template <class T>
|
|
|
|
|
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
platform::CPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_height,
|
|
|
|
|
int padding_width, const platform::DeviceContext& context) {
|
|
|
|
|
int padding_width) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
int input_channels = im.dims()[0];
|
|
|
|
@ -197,9 +199,9 @@ template <class T>
|
|
|
|
|
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
|
|
|
|
|
platform::CPUPlace, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::Tensor& im, const framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_height,
|
|
|
|
|
int padding_width, const platform::DeviceContext& context) {
|
|
|
|
|
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
|
|
|
|
|
const framework::Tensor& col, int stride_height,
|
|
|
|
|
int stride_width, int padding_height, int padding_width) {
|
|
|
|
|
PADDLE_ENFORCE(im.dims().size() == 3);
|
|
|
|
|
PADDLE_ENFORCE(col.dims().size() == 5);
|
|
|
|
|
int input_channels = im.dims()[0];
|
|
|
|
|