|
|
|
@ -75,7 +75,8 @@ class Im2ColFunctor {
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& im, framework::Tensor& col,
|
|
|
|
|
int stride_height, int stride_width, int padding_up,
|
|
|
|
|
int padding_down, int padding_left, int padding_right);
|
|
|
|
|
int padding_down, int padding_left = 0,
|
|
|
|
|
int padding_right = 0);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <ColFormat Format, typename Place, typename T>
|
|
|
|
@ -84,7 +85,7 @@ class Col2ImFunctor {
|
|
|
|
|
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
|
|
|
|
|
const framework::Tensor& col, int stride_height,
|
|
|
|
|
int stride_width, int padding_up, int padding_down,
|
|
|
|
|
int padding_left, int padding_right);
|
|
|
|
|
int padding_left = 0, int padding_right = 0);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|