!7529 complex arithmetic_simplify

Merge pull request !7529 from zhuxiaochen/1020_allsimplify_1.0
pull/7529/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8d39a8a4b2

@ -22,8 +22,8 @@
#include <tuple>
#include <vector>
#include "ir/visitor.h"
#include "base/core_ops.h"
#include "ir/visitor.h"
#include "utils/shape_utils.h"
namespace mindspore {
@ -750,9 +750,18 @@ class PConstant : public PBase<PConstant<T> > {
if (value->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
ShapeVector tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
(tensor_type == TypeId::kNumberTypeFloat64)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
float *data2 = reinterpret_cast<float *>(new_tensor_ptr->data_c());
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
return nullptr;
}
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
@ -761,7 +770,11 @@ class PConstant : public PBase<PConstant<T> > {
}
}
if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
int *data = reinterpret_cast<int *>(tensor_ptr->data_c());
int *data2 = reinterpret_cast<int *>(new_tensor_ptr->data_c());
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
return nullptr;
}
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
@ -770,7 +783,11 @@ class PConstant : public PBase<PConstant<T> > {
}
}
if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
double *data = reinterpret_cast<double *>(tensor_ptr->data_c());
double *data2 = reinterpret_cast<double *>(new_tensor_ptr->data_c());
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
return nullptr;
}
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
@ -778,7 +795,9 @@ class PConstant : public PBase<PConstant<T> > {
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
return node;
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(tensor_ptr->ToAbstract());
return new_vnode;
}
return nullptr;
}
@ -1005,6 +1024,14 @@ BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
return rep; \
} \
}
#define MATCH_REPLACE_LAMBDA_FLAG(OrigNode, CaptureNode, Lambda, Flag) \
if ((CaptureNode).TryCapture(OrigNode)) { \
auto rep = (Lambda)(Flag); \
if (rep != nullptr) { \
return rep; \
} \
}
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_

@ -20,7 +20,8 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
@ -33,6 +34,8 @@ class Net(Cell):
self.sqrt = P.Sqrt()
self.pow = P.Pow()
self.neg = P.Neg()
self.reducemin = P.ReduceMin()
self.reshape = P.Reshape()
def construct(self, x, y):
add_res1 = self.add(x, 4)
@ -42,7 +45,9 @@ class Net(Cell):
div_res = self.div(mul_res, self.sqrt(mul_res))
pow_res = self.pow(y, 2)
neg_res = self.neg(self.neg(pow_res))
return self.add(div_res, neg_res)
add_res3 = self.add(neg_res, div_res)
resh_res = self.reshape(add_res3, (2, 12, 3))
return self.reducemin(resh_res, 1)
@pytest.mark.level0
@ -58,10 +63,12 @@ def test_basic():
div_res = np.sqrt(mul_res)
pow_res = input_y * input_y
neg_res = pow_res
expect = div_res + neg_res
add_res3 = neg_res + div_res
expect = np.min(add_res3, (1, 2))
net = Net()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4,
atol=1.e-7, equal_nan=True)
assert res

Loading…
Cancel
Save