|
|
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <thrust/device_vector.h>
|
|
|
|
|
#include "paddle/fluid/framework/tensor_util.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/slice_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_device_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -94,17 +94,22 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1);
|
|
|
|
|
dim3 threads(PADDLE_CUDA_NUM_THREADS);
|
|
|
|
|
auto stream = ctx.cuda_device_context().stream();
|
|
|
|
|
|
|
|
|
|
auto out_shape = framework::vectorize<int64_t>(out_dims);
|
|
|
|
|
thrust::device_vector<int64_t> out_dims_vec(out_shape.begin(),
|
|
|
|
|
out_shape.end());
|
|
|
|
|
auto in_shape = framework::vectorize<int64_t>(in_dims);
|
|
|
|
|
thrust::device_vector<int64_t> in_dims_vec(in_shape.begin(),
|
|
|
|
|
in_shape.end());
|
|
|
|
|
thrust::device_vector<int64_t> offsets_vec(offsets.begin(), offsets.end());
|
|
|
|
|
const int64_t* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
|
|
|
|
|
const int64_t* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
|
|
|
|
|
const int64_t* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
|
|
|
|
|
const std::vector<int64_t> out_shape =
|
|
|
|
|
framework::vectorize<int64_t>(out_dims);
|
|
|
|
|
const std::vector<int64_t> in_shape =
|
|
|
|
|
framework::vectorize<int64_t>(in_dims);
|
|
|
|
|
|
|
|
|
|
framework::Tensor out_dims_tensor;
|
|
|
|
|
framework::Tensor in_dims_tensor;
|
|
|
|
|
framework::Tensor offsets_tensor;
|
|
|
|
|
framework::TensorFromVector(out_shape, ctx.device_context(),
|
|
|
|
|
&out_dims_tensor);
|
|
|
|
|
framework::TensorFromVector(in_shape, ctx.device_context(),
|
|
|
|
|
&in_dims_tensor);
|
|
|
|
|
framework::TensorFromVector(offsets, ctx.device_context(), &offsets_tensor);
|
|
|
|
|
const int64_t* out_dims_ptr = out_dims_tensor.data<int64_t>();
|
|
|
|
|
const int64_t* in_dims_ptr = in_dims_tensor.data<int64_t>();
|
|
|
|
|
const int64_t* offsets_ptr = offsets_tensor.data<int64_t>();
|
|
|
|
|
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|