register fp16 kernel, test=develop (#25630)

fix_copy_if_different
Zhang Ting 5 years ago committed by GitHub
parent c4192a8030
commit a1350744eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,8 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>, REGISTER_OP_CUDA_KERNEL(
paddle::operators::ShapeKernel<int32_t>, shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int64_t>, paddle::operators::ShapeKernel<int32_t>,
paddle::operators::ShapeKernel<float>, paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<double>); paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>,
paddle::operators::ShapeKernel<paddle::platform::float16>);

@ -11101,7 +11101,7 @@ def shape(input):
input.shape = [3, 2] input.shape = [3, 2]
Args: Args:
input (Variable): The input can be N-D Tensor or SelectedRows with data type float32, float64, int32, int64. input (Variable): The input can be N-D Tensor or SelectedRows with data type float16, float32, float64, int32, int64.
If input variable is type of SelectedRows, returns the shape of it's inner tensor. If input variable is type of SelectedRows, returns the shape of it's inner tensor.
Returns: Returns:
@ -11124,8 +11124,9 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)] print(res) # [array([ 3, 100, 100], dtype=int32)]
""" """
check_variable_and_dtype(input, 'input', check_variable_and_dtype(
['float32', 'float64', 'int32', 'int64'], 'shape') input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'shape')
helper = LayerHelper('shape', **locals()) helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32') out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op( helper.append_op(

Loading…
Cancel
Save