|
|
|
@ -67,19 +67,8 @@ int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride,
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape,
|
|
|
|
|
const std::vector<int64_t> &k_size, const std::vector<int64_t> &stride,
|
|
|
|
|
const PadMode pad_mode, const TypeId x_dtype) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4.";
|
|
|
|
|
}
|
|
|
|
|
int64_t pad_top, pad_bottom, pad_left, pad_right;
|
|
|
|
|
int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom);
|
|
|
|
|
int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right);
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<float>> GetAssistInputMatrix(const std::vector<int64_t> &x_shape, int64_t pad_top,
|
|
|
|
|
int64_t pad_bottom, int64_t pad_left, int64_t pad_right) {
|
|
|
|
|
// `assist_input_matrix` is a 2d matrix with input_shape after padding,
|
|
|
|
|
// the value of element which is padded is 0, else are 1.
|
|
|
|
|
// For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize,
|
|
|
|
@ -102,6 +91,22 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std
|
|
|
|
|
assist_input_matrix.emplace_back(tmp_one_vector);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return assist_input_matrix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape,
|
|
|
|
|
const std::vector<int64_t> &k_size, const std::vector<int64_t> &stride,
|
|
|
|
|
const PadMode pad_mode, const TypeId x_dtype) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4.";
|
|
|
|
|
}
|
|
|
|
|
int64_t pad_top, pad_bottom, pad_left, pad_right;
|
|
|
|
|
int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom);
|
|
|
|
|
int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right);
|
|
|
|
|
auto assist_input_matrix = GetAssistInputMatrix(x_shape, pad_top, pad_bottom, pad_left, pad_right);
|
|
|
|
|
|
|
|
|
|
// calculate output
|
|
|
|
|
std::vector<float> hw_output(h_output * w_output, 0.0);
|
|
|
|
|