conv2d bp filter input primitive stride 2 and adapt to 4

pull/915/head
zhaozhenlong 5 years ago
parent cbb4136b62
commit 3209687b9a

@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter {
return static_cast<int64_t>(GetValue<int>(value));
}
// specialization for int to Vector
// specialization for int or tuple broadcast to Vector
static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
return ConvertAnyUtil(value, name, anyTraitsInt);

@ -35,14 +35,20 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor
std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>>) {
MS_EXCEPTION_IF_NULL(value);
int64_t data = GetValue<int>(value);
std::vector<int64_t> list;
int size = 2; // 2 int in list
if (name == "pad") {
size = 4; // 4 int in list
list = TransformUtil::ConvertIntToList(data, size);
list[0] = 1;
list[1] = 1;
if (!value->isa<ValueSequeue>()) {
MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
}
auto vec = value->cast<ValueSequeuePtr>();
list.push_back(1);
list.push_back(1);
for (auto &it : vec->value()) {
list.push_back(static_cast<int64_t>(GetValue<int>(it)));
}
} else {
list = TransformUtil::ConvertIntToList(data, size);
}

@ -733,7 +733,7 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = {
{3, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropInputD) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int>())},
@ -746,7 +746,7 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = {
{3, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropFilterD) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int>())},

@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.pad = pad
if isinstance(stride, tuple) and len(stride) == 2:
self.stride = stride
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.add_prim_attr('stride', self.stride)
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")

@ -1084,7 +1084,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)

Loading…
Cancel
Save