!5432 Mindspore parallel supports all elementary-wise operators

Merge pull request !5432 from yihuaijie/master
pull/5432/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit be606ba8f5

@ -90,6 +90,7 @@ REGISTER(TensorAddInfo);
REGISTER(BiasAddInfo); REGISTER(BiasAddInfo);
REGISTER(MulInfo); REGISTER(MulInfo);
REGISTER(DivInfo); REGISTER(DivInfo);
REGISTER(ModInfo);
REGISTER(RealDivInfo); REGISTER(RealDivInfo);
REGISTER(PowInfo); REGISTER(PowInfo);
REGISTER(ExpInfo); REGISTER(ExpInfo);
@ -117,15 +118,56 @@ REGISTER(MaximumInfo);
REGISTER(MinimumInfo); REGISTER(MinimumInfo);
REGISTER(CastInfo); REGISTER(CastInfo);
REGISTER(GreaterInfo); REGISTER(GreaterInfo);
REGISTER(GreaterEqualInfo);
REGISTER(LessEqualInfo);
REGISTER(LessInfo);
REGISTER(ApproximateEqualInfo);
REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
REGISTER(AssignSubInfo); REGISTER(AssignSubInfo);
REGISTER(FloorModInfo);
REGISTER(AssignInfo);
REGISTER(AssignAddInfo);
REGISTER(Atan2Info);
REGISTER(DivNoNanInfo);
REGISTER(LogicalAndInfo);
REGISTER(LogicalOrInfo);
REGISTER(EluInfo);
REGISTER(ReLUInfo); REGISTER(ReLUInfo);
REGISTER(ReLU6Info);
REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo);
REGISTER(SoftsignInfo);
REGISTER(GatherV2Info); REGISTER(GatherV2Info);
REGISTER(SparseGatherV2Info); REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo); REGISTER(SqrtInfo);
REGISTER(SigmoidInfo); REGISTER(SigmoidInfo);
REGISTER(GetNextInfo); REGISTER(GetNextInfo);
REGISTER(NegInfo); REGISTER(NegInfo);
REGISTER(AbsInfo);
REGISTER(AcoshInfo);
REGISTER(AsinInfo);
REGISTER(AsinhInfo);
REGISTER(AtanInfo);
REGISTER(AtanhInfo);
REGISTER(CeilInfo);
REGISTER(CoshInfo);
REGISTER(Expm1Info);
REGISTER(Log1pInfo);
REGISTER(SinInfo);
REGISTER(SinhInfo);
REGISTER(TanInfo);
REGISTER(RsqrtInfo);
REGISTER(InvInfo);
REGISTER(ReciprocalInfo);
REGISTER(RoundInfo);
REGISTER(FloorInfo);
REGISTER(SignInfo);
REGISTER(ErfInfo);
REGISTER(ErfcInfo);
REGISTER(ZerosLikeInfo);
REGISTER(OnesLikeInfo);
REGISTER(BesselI0eInfo);
REGISTER(BesselI1eInfo);
REGISTER(BatchMatMulInfo); REGISTER(BatchMatMulInfo);
REGISTER(ExpandDimsInfo); REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo); REGISTER(SqueezeInfo);

@ -131,6 +131,13 @@ class LogSoftmaxInfo : public Softmax {
~LogSoftmaxInfo() override = default; ~LogSoftmaxInfo() override = default;
}; };
class EluInfo : public ActivationOther {
public:
EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~EluInfo() override = default;
};
class ReLUInfo : public ActivationOther { class ReLUInfo : public ActivationOther {
public: public:
ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
@ -139,6 +146,38 @@ class ReLUInfo : public ActivationOther {
~ReLUInfo() override = default; ~ReLUInfo() override = default;
}; };
class ReLU6Info : public ActivationOther {
public:
ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReLU6Info() override = default;
};
class ReLUV2Info : public ActivationOther {
public:
ReLUV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReLUV2Info() override = default;
};
class SoftsignInfo : public ActivationOther {
public:
SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SoftsignInfo() override = default;
};
class SoftplusInfo : public ActivationOther {
public:
SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SoftplusInfo() override = default;
};
class CastInfo : public ActivationOther { class CastInfo : public ActivationOther {
public: public:
CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

@ -82,6 +82,13 @@ class DivInfo : public ArithmeticBase {
~DivInfo() override = default; ~DivInfo() override = default;
}; };
class ModInfo : public ArithmeticBase {
public:
ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~ModInfo() override = default;
};
class RealDivInfo : public ArithmeticBase { class RealDivInfo : public ArithmeticBase {
public: public:
RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
@ -98,6 +105,14 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default; ~FloorDivInfo() override = default;
}; };
class FloorModInfo : public ArithmeticBase {
public:
FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~FloorModInfo() override = default;
};
class PowInfo : public ArithmeticBase { class PowInfo : public ArithmeticBase {
public: public:
PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
@ -105,20 +120,28 @@ class PowInfo : public ArithmeticBase {
~PowInfo() override = default; ~PowInfo() override = default;
}; };
class GreaterInfo : public ArithmeticBase { class AssignSubInfo : public ArithmeticBase {
public:
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default;
};
class AssignInfo : public ArithmeticBase {
public: public:
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default; ~AssignInfo() override = default;
}; };
class AssignSubInfo : public ArithmeticBase { class AssignAddInfo : public ArithmeticBase {
public: public:
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default; ~AssignAddInfo() override = default;
}; };
// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label.
@ -129,6 +152,38 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase {
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~SigmoidCrossEntropyWithLogitsInfo() override = default; ~SigmoidCrossEntropyWithLogitsInfo() override = default;
}; };
class Atan2Info : public ArithmeticBase {
public:
Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~Atan2Info() override = default;
};
class DivNoNanInfo : public ArithmeticBase {
public:
DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~DivNoNanInfo() override = default;
};
class LogicalAndInfo : public ArithmeticBase {
public:
LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LogicalAndInfo() override = default;
};
class LogicalOrInfo : public ArithmeticBase {
public:
LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LogicalOrInfo() override = default;
};
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -36,6 +36,14 @@ class EqualInfo : public ArithmeticBase {
~EqualInfo() override = default; ~EqualInfo() override = default;
}; };
class ApproximateEqualInfo : public ArithmeticBase {
public:
ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~ApproximateEqualInfo() override = default;
};
class NotEqualInfo : public ArithmeticBase { class NotEqualInfo : public ArithmeticBase {
public: public:
NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
@ -59,6 +67,38 @@ class MinimumInfo : public ArithmeticBase {
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MinimumInfo() override = default; ~MinimumInfo() override = default;
}; };
class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default;
};
class GreaterEqualInfo : public ArithmeticBase {
public:
GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterEqualInfo() override = default;
};
class LessInfo : public ArithmeticBase {
public:
LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LessInfo() override = default;
};
class LessEqualInfo : public ArithmeticBase {
public:
LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LessEqualInfo() override = default;
};
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -63,6 +63,202 @@ class LogicalNotInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {} : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~LogicalNotInfo() override = default; ~LogicalNotInfo() override = default;
}; };
class AbsInfo : public ActivationOther {
public:
AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AbsInfo() override = default;
};
class SignInfo : public ActivationOther {
public:
SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SignInfo() override = default;
};
class FloorInfo : public ActivationOther {
public:
FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~FloorInfo() override = default;
};
class RoundInfo : public ActivationOther {
public:
RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~RoundInfo() override = default;
};
class ReciprocalInfo : public ActivationOther {
public:
ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReciprocalInfo() override = default;
};
class InvInfo : public ActivationOther {
public:
InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~InvInfo() override = default;
};
class RsqrtInfo : public ActivationOther {
public:
RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~RsqrtInfo() override = default;
};
class TanInfo : public ActivationOther {
public:
TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~TanInfo() override = default;
};
class SinInfo : public ActivationOther {
public:
SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SinInfo() override = default;
};
class SinhInfo : public ActivationOther {
public:
SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SinhInfo() override = default;
};
class Log1pInfo : public ActivationOther {
public:
Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~Log1pInfo() override = default;
};
class Expm1Info : public ActivationOther {
public:
Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~Expm1Info() override = default;
};
class CoshInfo : public ActivationOther {
public:
CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~CoshInfo() override = default;
};
class CeilInfo : public ActivationOther {
public:
CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~CeilInfo() override = default;
};
class AtanhInfo : public ActivationOther {
public:
AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AtanhInfo() override = default;
};
class AtanInfo : public ActivationOther {
public:
AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AtanInfo() override = default;
};
class AsinInfo : public ActivationOther {
public:
AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AsinInfo() override = default;
};
class AsinhInfo : public ActivationOther {
public:
AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AsinhInfo() override = default;
};
class AcoshInfo : public ActivationOther {
public:
AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AcoshInfo() override = default;
};
class ErfInfo : public ActivationOther {
public:
ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ErfInfo() override = default;
};
class ErfcInfo : public ActivationOther {
public:
ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ErfcInfo() override = default;
};
class ZerosLikeInfo : public ActivationOther {
public:
ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ZerosLikeInfo() override = default;
};
class OnesLikeInfo : public ActivationOther {
public:
OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~OnesLikeInfo() override = default;
};
class BesselI0eInfo : public ActivationOther {
public:
BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~BesselI0eInfo() override = default;
};
class BesselI1eInfo : public ActivationOther {
public:
BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~BesselI1eInfo() override = default;
};
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -98,6 +98,126 @@ def test_matmul_not_equal():
compile_net(net, x, y, b) compile_net(net, x, y, b)
def test_matmul_approximateEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.approximateEqual = P.ApproximateEqual(tolerance=0.5).set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.approximateEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_greater():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greater = P.Greater().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.greater(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_greaterEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greaterEqual = P.GreaterEqual().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.greaterEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_less():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.less = P.Less().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.less(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_lessEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.lessEqual = P.LessEqual().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.lessEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_not_equal_repeated_calculation(): def test_matmul_not_equal_repeated_calculation():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy2): def __init__(self, strategy1, strategy2):

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save