fix errors in sequence_slice_op

scopeFix
fengjiayi 7 years ago
parent baa9f50da5
commit bf99396a04

@ -66,13 +66,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu);
framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu);
framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
length_data = length_cpu.data<int64_t>();
}
@ -127,13 +125,11 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(),
&offset_cpu);
framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(),
&length_cpu);
framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
length_data = length_cpu.data<int64_t>();
}

Loading…
Cancel
Save