delete useless code for allreduce

pull/3288/head
yangzhenzhang 5 years ago
parent b13c7a3d48
commit e6cef98e95

@ -100,11 +100,6 @@ class AllReduce(PrimitiveWithInfer):
self.add_prim_attr('fusion', 0) self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0) self.add_prim_attr('index', 0)
def vm_impl(self, x):
"""Implement by vm mode."""
x = x.asnumpy()
return Tensor(x)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape

@ -294,9 +294,6 @@ TEST_F(TestStepParallel, CreatOpInstance) {
ASSERT_TRUE(allreduce_ptr); ASSERT_TRUE(allreduce_ptr);
if (nullptr != allreduce_ptr) { if (nullptr != allreduce_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name(); MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
if (!allreduce_ptr->HasComputeFunction()) {
MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
}
std::vector<py::object> arglist; std::vector<py::object> arglist;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist), (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arglist),

Loading…
Cancel
Save