|
|
|
@ -80,24 +80,31 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|
CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_x);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_x->shape());
|
|
|
|
|
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x->shape());
|
|
|
|
|
ShapeVector x_shape = x->shape()->shape();
|
|
|
|
|
ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
|
|
|
|
|
ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
|
|
|
|
|
|
|
|
|
|
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(y);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(y->shape());
|
|
|
|
|
ShapeVector y_shape = y->shape()->shape();
|
|
|
|
|
ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
|
|
|
|
|
ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
|
|
|
|
|
|
|
|
|
|
auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_y);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_y->shape());
|
|
|
|
|
|
|
|
|
|
auto x_shape = input_x->shape()->shape();
|
|
|
|
|
auto y_shape = input_y->shape()->shape();
|
|
|
|
|
auto out_shape = BroadcastShape(x_shape, y_shape);
|
|
|
|
|
if (out_shape.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
|
|
|
|
<< args_spec_list[1]->ToString();
|
|
|
|
|
}
|
|
|
|
|
auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
|
|
|
|
|
auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
|
|
|
|
|
|
|
|
|
|
auto output_type = std::make_shared<Bool>();
|
|
|
|
|
auto ret = std::make_shared<AbstractTensor>(output_type, out_shape);
|
|
|
|
|
auto ret =
|
|
|
|
|
std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|