parent
d4142d682d
commit
bf5d21770a
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr size_t UNSORTEDSEGMENTOP_INPUTS_SIZE = 2;
|
||||
constexpr size_t UNSORTEDSEGMENTOP_OUTPUTS_SIZE = 1;
|
||||
class UnsortedSegmentOpInfo : public OperatorInfo {
|
||||
public:
|
||||
UnsortedSegmentOpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~UnsortedSegmentOpInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
||||
Status GenerateStrategies(int32_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferMirrorOps() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
|
||||
private:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
};
|
||||
|
||||
class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {
|
||||
public:
|
||||
UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {}
|
||||
~UnsortedSegmentSumInfo() override = default;
|
||||
};
|
||||
|
||||
class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
|
||||
public:
|
||||
UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
|
||||
~UnsortedSegmentMinInfo() override = default;
|
||||
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
|
||||
protected:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
|
@ -0,0 +1,71 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.ops as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, vectors, index):
|
||||
predict = self.network(vectors, index)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, vectors, index):
|
||||
return grad_all(self.network)(vectors, index)
|
||||
|
||||
|
||||
def test_auto_parallel_unsortedsegmentmin():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super().__init__()
|
||||
self.merge_op = P.UnsortedSegmentMin()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, vectors, index):
|
||||
out = self.merge_op(vectors, index, self.num_segments)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32)
|
||||
indices = Tensor(np.random.randint(16, size=(16,)), ms.int32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net(16)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, x, indices)
|
@ -0,0 +1,71 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.ops as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, vectors, index):
|
||||
predict = self.network(vectors, index)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, vectors, index):
|
||||
return grad_all(self.network)(vectors, index)
|
||||
|
||||
|
||||
def test_auto_parallel_unsortedsegmentsum():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super().__init__()
|
||||
self.merge_op = P.UnsortedSegmentSum()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, vectors, index):
|
||||
out = self.merge_op(vectors, index, self.num_segments)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32)
|
||||
indices = Tensor(np.random.randint(16, size=(16, 16)))
|
||||
|
||||
net = GradWrap(NetWithLoss(Net(16)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, x, indices)
|
@ -0,0 +1,161 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, num_segments):
|
||||
super(Net, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.merge_op = P.UnsortedSegmentMin().shard((strategy1, strategy2))
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, vectors, segment_ids):
|
||||
predict = self.merge_op(vectors, segment_ids, self.num_segments)
|
||||
return predict
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
else:
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_slice_1d():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (8,)
|
||||
strategy2 = (8,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_no_slice_1d():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (1,)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_index_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.arange(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (4, 1)
|
||||
strategy2 = (4,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_vector_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 4)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_vector_slice_3d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 2, 2)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_index_vector_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (2, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_index_vector_slice_3d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_float16():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float16)
|
||||
y = Tensor(np.ones((4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
def test_unsortedsegmentmin_model_parallel_int32():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.int32)
|
||||
y = Tensor(np.ones((4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
@ -0,0 +1,153 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, num_segments):
|
||||
super(Net, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, vectors, segment_ids):
|
||||
predict = self.merge_op(vectors, segment_ids, self.num_segments)
|
||||
return predict
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
else:
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_slice_1d():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (8,)
|
||||
strategy2 = (8,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_no_slice_1d():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (1,)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_index_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.arange(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (4, 1)
|
||||
strategy2 = (4,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_index_slice_3d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4, 4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 2, 1)
|
||||
strategy2 = (2, 2)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_vector_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 4)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_vector_slice_3d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 2, 2)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_index_vector_slice_2d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (2, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4, 4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2, 1)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
Loading…
Reference in new issue