|
|
|
@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_Transpose;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
int TransposeCPUKernel::Init() {
|
|
|
|
|
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
|
|
|
|
|
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H]));
|
|
|
|
|
thread_h_num_ = MSMIN(thread_num_, num_unit_);
|
|
|
|
|
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -41,9 +37,13 @@ int TransposeCPUKernel::Init() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int TransposeCPUKernel::ReSize() {
|
|
|
|
|
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_);
|
|
|
|
|
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H]));
|
|
|
|
|
thread_h_num_ = MSMIN(thread_num_, num_unit_);
|
|
|
|
|
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
|
|
|
|
|
|
|
|
|
|
auto &inTensor = in_tensors_.front();
|
|
|
|
|
auto &outTensor = out_tensors_.front();
|
|
|
|
|
auto param = reinterpret_cast<TransposeParameter *>(op_parameter_);
|
|
|
|
|
auto in_shape = inTensor->shape();
|
|
|
|
|
auto out_shape = outTensor->shape();
|
|
|
|
|
param->strides_[param->num_axes_ - 1] = 1;
|
|
|
|
|