Support pow's second input could be tensor and fix bug in bprop of pow

pull/188/head
buxue 5 years ago
parent 7cec28526a
commit 5841fe010e

@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default;
};
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
};
class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,

@ -1,47 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace mindspore {
namespace parallel {
Status PowInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
OperatorVector op_for_value;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

@ -27,16 +27,6 @@
namespace mindspore {
namespace parallel {
class PowInfo : public ActivationOther {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
protected:
Status InferMirrorOps() override;
};
class ExpInfo : public ActivationOther {
public:
ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)

@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass
def extend_repr(self):
return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__)
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
class MaxPool2d(_PoolNd):

@ -336,14 +336,13 @@ def get_bprop_log(self):
@bprop_getters.register(P.Pow)
def get_bprop_pow(self):
"""Grad definition for `Pow` operation."""
pow_ = P.Pow()
cast = P.Cast()
dtype = P.DType()
pow_op = P.Pow()
ln = P.Log()
def bprop(x, power, out, dout):
g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0)
dx = g * dout
return dx, 0
dx = power * pow_op(x, power - 1.0) * dout
dpower = pow_op(x, power) * ln(x) * dout
return dx, dpower
return bprop

@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):
@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):

File diff suppressed because it is too large Load Diff

@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self):
"""
Deeply clones the primitive object.

@ -19,7 +19,7 @@
#include <vector>
#include "common/common_test.h"
#include "parallel/strategy.h"
#include "parallel/ops_info/elementary_function_info.h"
#include "parallel/ops_info/arithmetic_info.h"
#include "parallel/device_manager.h"
#include "parallel/step_parallel.h"
@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std::unordered_map<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 128}};
Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}};
Shapes outputs_shape = {{32, 64, 128}};
pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr);
}
TEST_F(TestPowInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
}
TEST_F(TestPowInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8}};
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy);
@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
}
TEST_F(TestPowInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8}};
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy);
@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
}
TEST_F(TestPowInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
}
TEST_F(TestPowInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
}
TEST_F(TestPowInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}};
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy);
@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
}
TEST_F(TestPowInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy);

@ -82,9 +82,10 @@ def test_sqrt():
def test_pow():
""" test_pow """
input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
power = Tensor(np.array(3.0, np.int64))
testpow = P.Pow()
expect = np.array([[8, 8], [27, 27]])
result = testpow(input_tensor, 3.0)
result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect)

@ -224,11 +224,15 @@ test_case_math_ops = [
'block': P.Minimum(),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Pow', {
('Pow_0', {
'block': P.Pow(),
'desc_const': [2.0],
'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Pow_1', {
'block': P.Pow(),
'desc_inputs': [[3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Exp', {
'block': P.Exp(),
'desc_inputs': [[2, 3]],

@ -59,7 +59,7 @@ def test_matmul_pow():
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), )
strategy2 = ((4, 2), ())
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")

@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow."""
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
res = vm.power(x, y)
return Tensor(res)
return vm_impl

Loading…
Cancel
Save