|
|
@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens
|
|
|
|
MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
|
|
|
|
MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
|
|
|
|
return Status::SUCCESS;
|
|
|
|
return Status::SUCCESS;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
|
|
|
|
if (layout_transfer_) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
|
|
|
|
|
|
|
|
}
|
|
|
|
return Status::FAILED;
|
|
|
|
return Status::FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
|
|
|
|
if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
|
|
|
|
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
|
|
|
if (layout_transfer_) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
|
|
|
|
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
@ -214,6 +222,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDevice
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
TensorLayout tensor_layout_new;
|
|
|
|
TensorLayout tensor_layout_new;
|
|
|
|
|
|
|
|
tensor_layout_new.set_layout_transfer(true);
|
|
|
|
Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
|
|
|
|
Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
|
|
|
|
if (status != Status::SUCCESS) {
|
|
|
|
if (status != Status::SUCCESS) {
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TensorLayout TensorLayout::TransferRepeatLayout() const {
|
|
|
|
TensorLayout TensorLayout::TransferRepeatLayout() const {
|
|
|
|
Shape dev_mat(device_arrangement_.array());
|
|
|
|
Shape dev_mat(device_arrangement_origin_.array());
|
|
|
|
Shape tensor_map(tensor_map_.GetDimSize(), -1);
|
|
|
|
Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
|
|
|
|
Shape tensor_shape(tensor_shape_.array());
|
|
|
|
Shape tensor_shape(tensor_shape_origin_.array());
|
|
|
|
TensorLayout repeat;
|
|
|
|
TensorLayout repeat;
|
|
|
|
repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
|
|
|
|
repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
|
|
|
|
return repeat;
|
|
|
|
return repeat;
|
|
|
|