fix gpu shape bug

pull/14118/head
yeyunpeng2020 4 years ago
parent 6e997ad3fc
commit 12ea873ee0

@ -55,13 +55,12 @@ int FillOpenCLKernel::RunShape() {
auto tensor_shape = in_tensors_[0]->shape();
void *tensor_shape_data = tensor_shape.data();
for (int i = 0; i < tensor_shape.size(); ++i) {
fill_value.s[0] = reinterpret_cast<float *>(tensor_shape_data)[i];
size_t index = static_cast<size_t>(i);
auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
fill_value.s[i] = reinterpret_cast<float *>(tensor_shape_data)[i];
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;
}

@ -23,3 +23,5 @@ landmark
PoseNet_dla_17_x512
age_new
plat_isface
Q_hand_0812.pb
Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb

@ -190,6 +190,7 @@ generateOpsList
getCommonFile
# get src/ops
getOpsFile "Registry\(schema::PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/ops" "prototype" &
getOpsFile "REG_POPULATE\(PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/ops" "prototype" &
getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_HOME}/mindspore/lite/nnacl/infer" "prototype" &
getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat32, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat32" &
getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat16, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat16" &

Loading…
Cancel
Save