|
|
@ -26,30 +26,28 @@
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); }
|
|
|
|
void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); }
|
|
|
|
|
|
|
|
|
|
|
|
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); }
|
|
|
|
std::string AvgPool::get_padding() const {
|
|
|
|
|
|
|
|
auto value_ptr = GetAttr("padding");
|
|
|
|
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); }
|
|
|
|
return GetValue<std::string>(value_ptr);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> AvgPool::get_strides() const {
|
|
|
|
|
|
|
|
auto value_ptr = GetAttr("strides");
|
|
|
|
|
|
|
|
return GetValue<std::vector<int>>(value_ptr);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("k_size", MakeValue(kernel_size)); }
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> AvgPool::get_kernel_size() const {
|
|
|
|
std::vector<int> AvgPool::get_kernel_size() const {
|
|
|
|
auto value_ptr = GetAttr("ksize");
|
|
|
|
auto value_ptr = GetAttr("k_size");
|
|
|
|
return GetValue<std::vector<int>>(value_ptr);
|
|
|
|
return GetValue<std::vector<int>>(value_ptr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); }
|
|
|
|
|
|
|
|
|
|
|
|
std::string AvgPool::get_padding() const {
|
|
|
|
std::vector<int> AvgPool::get_strides() const {
|
|
|
|
auto value_ptr = GetAttr("padding");
|
|
|
|
auto value_ptr = GetAttr("strides");
|
|
|
|
return GetValue<std::string>(value_ptr);
|
|
|
|
return GetValue<std::vector<int>>(value_ptr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) {
|
|
|
|
void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) {
|
|
|
|
auto prim_name = this->name();
|
|
|
|
auto prim_name = this->name();
|
|
|
|
this->AddAttr("data_format", MakeValue("NCHW"));
|
|
|
|
this->AddAttr("data_format", MakeValue("NCHW"));
|
|
|
|
this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name));
|
|
|
|
this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name));
|
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true));
|
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("k_size", kernel_size, prim_name, false, true));
|
|
|
|
this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true));
|
|
|
|
this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|