|
|
@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|
|
|
<< MakeValue(slice_shape)->ToString();
|
|
|
|
<< MakeValue(slice_shape)->ToString();
|
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape);
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape);
|
|
|
|
abstract->set_shape(parallel_shape);
|
|
|
|
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
|
|
|
|
|
|
|
|
auto cloned_abstract = abstract->Clone();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
|
|
|
|
|
|
|
cloned_abstract->set_shape(parallel_shape);
|
|
|
|
|
|
|
|
parameter->set_abstract(cloned_abstract);
|
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
|
|
@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|
|
|
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
|
|
|
|
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
|
|
|
cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
|
|
|
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
|
|
|
|
|
|
|
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
|
|
|
|
|
|
|
cloned_parameter_node->set_abstract(cloned_abstract);
|
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
|
|
|
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
|
|
|
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
|
|
|
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
|
|
|
<< ", clone index is: " << cloned_index;
|
|
|
|
<< ", clone index is: " << cloned_index;
|
|
|
@ -1743,7 +1750,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
|
|
|
|
auto slice_shape = loss_grad_layout.slice_shape().array();
|
|
|
|
auto slice_shape = loss_grad_layout.slice_shape().array();
|
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape);
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_shape);
|
|
|
|
abstract->set_shape(parallel_shape);
|
|
|
|
auto cloned_abstract = abstract->Clone();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
|
|
|
|
|
|
|
cloned_abstract->set_shape(parallel_shape);
|
|
|
|
|
|
|
|
sens_tensor_node->set_abstract(cloned_abstract);
|
|
|
|
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
|
|
|
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
|
|
|
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|