fix expand_dims infershape

pull/9075/head
yao_yf 4 years ago
parent 2c40e98070
commit 3be87bf352

@ -481,7 +481,7 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
auto padding_attr = primitive->GetAttr("paddings");
MS_EXCEPTION_IF_NULL(padding_attr);
if (!padding_attr->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "paddings is not a ValueTuple";
MS_LOG(EXCEPTION) << "Paddings is not a ValueTuple";
}
std::vector<ValuePtr> paddings = padding_attr->cast<ValueTuplePtr>()->value();
std::vector<std::vector<int64_t>> paddings_vec;
@ -498,7 +498,7 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
size_t length = paddings_vec.size();
for (size_t i = 0; i < length; ++i) {
if (paddings_vec[i].size() != 2) {
MS_LOG(EXCEPTION) << "paddings 's second dim size is not 2";
MS_LOG(EXCEPTION) << "Paddings 's second dim size is not 2";
}
result_shp.push_back(input_shp[i] + paddings_vec[i][0] + paddings_vec[i][1]);
}

@ -500,19 +500,11 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto axis = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(axis);
std::vector<int64_t> shape;
std::vector<int64_t> x_shape = x->shape()->shape();
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
auto axis_value = axis->BuildValue();
if (!axis_value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name();
}
auto axis_tensor = axis_value->cast<tensor::TensorPtr>();
int value = *(static_cast<int *>(axis_tensor->data_c()));
auto axis = primitive->GetAttr("axis");
auto value = GetValue<int64_t>(axis);
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value
<< " and input_x.dim is" << x_shape.size();

@ -159,7 +159,7 @@ class EmbeddingLookup(Cell):
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
and the exceeding part will be filled with 0 in the output. Input_indices must only be a 2d tensor in
this interface.
this interface when run in semi auto parallel/auto parallel mode.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
@ -310,7 +310,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
Specifies the weights of elements of the input_indices. The lookout vector will multiply with
the input_values. Type is Float32.
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
Specifies the field id of elements of the input_indices. Type is Type is Int16, Int32.
Specifies the field id of elements of the input_indices. Type is Int16, Int32.
Outputs:
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32.

Loading…
Cancel
Save