!13727 remove parameters of fuction in CheckAndConvertUtils

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @ginfung
pull/13727/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b1043bcf55

@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -36,8 +36,8 @@ PadMode AvgPool::get_pad_mode() const {
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), this->AddAttr(kKernelSize,
false, true))); MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
} }
std::vector<int64_t> AvgPool::get_kernel_size() const { std::vector<int64_t> AvgPool::get_kernel_size() const {
@ -45,8 +45,7 @@ std::vector<int64_t> AvgPool::get_kernel_size() const {
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
void AvgPool::set_strides(const std::vector<int64_t> &strides) { void AvgPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides, this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
} }
std::vector<int64_t> AvgPool::get_strides() const { std::vector<int64_t> AvgPool::get_strides() const {

@ -93,8 +93,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
w_out = floor(w_out); w_out = floor(w_out);
} }
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
primitive->AddAttr(kPadList, primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name)));
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true)));
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
if (format == NHWC) { if (format == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel}; out_shape = {x_shape[0], h_out, w_out, out_channel};
@ -144,11 +143,11 @@ void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
} }
void Conv2D::set_stride(const std::vector<int64_t> &stride) { void Conv2D::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
} }
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
} }
void Conv2D::set_pad_mode(const PadMode &pad_mode) { void Conv2D::set_pad_mode(const PadMode &pad_mode) {
@ -166,7 +165,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
void Conv2D::set_pad(const std::vector<int64_t> &pad) { void Conv2D::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
} }
void Conv2D::set_mode(int64_t mode) { void Conv2D::set_mode(int64_t mode) {

@ -111,7 +111,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) { void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
} }
void Conv2dTranspose::set_mode(int64_t mode) { void Conv2dTranspose::set_mode(int64_t mode) {

@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
if (strides[0] != strides[1]) { if (strides[0] != strides[1]) {
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
<< ", width " << strides[1]; << ", width " << strides[1];
} }
this->set_stride(strides); this->set_stride(strides);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
if (dilations[0] != dilations[1]) { if (dilations[0] != dilations[1]) {
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
<< ", width " << dilations[1]; << ", width " << dilations[1];
@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
} else { } else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
} }
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
this->set_out_channel( this->set_out_channel(
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));

@ -30,13 +30,13 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
if (strides[0] != strides[1]) { if (strides[0] != strides[1]) {
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
<< ", width " << strides[1]; << ", width " << strides[1];
} }
this->set_stride(strides); this->set_stride(strides);
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
if (dilations[0] != dilations[1]) { if (dilations[0] != dilations[1]) {
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
<< ", width " << dilations[1]; << ", width " << dilations[1];
@ -52,7 +52,7 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve
} else { } else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
} }
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
this->set_out_channel( this->set_out_channel(
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));

@ -105,11 +105,11 @@ void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_siz
} }
void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) { void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
} }
void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) { void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
} }
void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
@ -127,7 +127,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) { void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
} }
void Conv2DBackpropInput::set_mode(int64_t mode) { void Conv2DBackpropInput::set_mode(int64_t mode) {

@ -36,8 +36,8 @@ PadMode MaxPool::get_pad_mode() const {
return PadMode(GetValue<int64_t>(value_ptr)); return PadMode(GetValue<int64_t>(value_ptr));
} }
void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), this->AddAttr(kKernelSize,
false, true))); MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
} }
std::vector<int64_t> MaxPool::get_kernel_size() const { std::vector<int64_t> MaxPool::get_kernel_size() const {
@ -45,8 +45,7 @@ std::vector<int64_t> MaxPool::get_kernel_size() const {
return GetValue<std::vector<int64_t>>(value_ptr); return GetValue<std::vector<int64_t>>(value_ptr);
} }
void MaxPool::set_strides(const std::vector<int64_t> &strides) { void MaxPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides, this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
} }
std::vector<int64_t> MaxPool::get_strides() const { std::vector<int64_t> MaxPool::get_strides() const {

@ -330,24 +330,10 @@ bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, cons
std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
const std::vector<int64_t> &arg_value, const std::vector<int64_t> &arg_value,
const std::string &prim_name, bool allow_four, const std::string &prim_name) {
bool ret_four) {
auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void {
std::ostringstream buffer;
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
if (allow_four) {
buffer << "or four ";
}
buffer << " positive int64_t numbers , but got [";
for (auto item : arg_value) {
buffer << item << ",";
}
buffer << "]";
MS_EXCEPTION(ValueError) << buffer.str();
};
for (auto item : arg_value) { for (auto item : arg_value) {
if (item < 0) { if (item < 0) {
raise_message(); MS_EXCEPTION(ValueError) << "For " << prim_name << " attr " << arg_name << " should be a positive vector";
} }
} }
return arg_value; return arg_value;

@ -162,8 +162,7 @@ const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeT
class CheckAndConvertUtils { class CheckAndConvertUtils {
public: public:
static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value, static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value,
const std::string &prim_name, bool allow_four = false, const std::string &prim_name);
bool ret_four = false);
static std::string CheckString(const std::string &arg_name, const std::string &arg_value, static std::string CheckString(const std::string &arg_name, const std::string &arg_value,
const std::set<std::string> &check_list, const std::string &prim_name); const std::set<std::string> &check_list, const std::string &prim_name);

Loading…
Cancel
Save