Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into port_pybind11

revert-12469-sum_op_dim_fix
minqiyang 7 years ago
commit e398348e7e

2
.gitignore vendored

@ -5,6 +5,7 @@ python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/ python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/ python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
*.DS_Store *.DS_Store
*.vs
build/ build/
build_doc/ build_doc/
*.user *.user
@ -15,6 +16,7 @@ build_doc/
.cproject .cproject
.pydevproject .pydevproject
.settings/ .settings/
CMakeSettings.json
Makefile Makefile
.test_env/ .test_env/
third_party/ third_party/

@ -204,11 +204,12 @@ include(external/snappy) # download snappy
include(external/snappystream) include(external/snappystream)
include(external/threadpool) include(external/threadpool)
set(WITH_ANAKIN OFF CACHE STRING "Disable Anakin first, will add it later." FORCE)
if(WITH_GPU) if(WITH_GPU)
include(cuda) include(cuda)
include(tensorrt) include(tensorrt)
include(external/anakin) include(external/anakin)
elseif()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE)
endif() endif()
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure

@ -56,6 +56,10 @@ if(NOT CMAKE_CROSSCOMPILING)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
endif() endif()
endif() endif()
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif(UNIX AND NOT APPLE)
if(NOT WITH_GOLANG) if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG)
@ -104,6 +108,10 @@ if(WITH_GPU)
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile") message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile")
endif() endif()
set(ENV{CUDNN_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR})
set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY})
message(STATUS "cudnn include header is ${CUDNN_INCLUDE_DIR}/cudnn.h")
message(STATUS "cudnn library is ${CUDNN_LIBRARY}")
endif() endif()
elseif(WITH_AMD_GPU) elseif(WITH_AMD_GPU)
add_definitions(-DPADDLE_WITH_HIP) add_definitions(-DPADDLE_WITH_HIP)

@ -35,9 +35,8 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS
ExternalProject_Add( ExternalProject_Add(
extern_anakin extern_anakin
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(luotao): use PaddlePaddle/Anakin later GIT_REPOSITORY "https://github.com/PaddlePaddle/Anakin"
GIT_REPOSITORY "https://github.com/luotao1/Anakin" GIT_TAG "04256ba78fa3da0beb74e8036c8efd68c12824d6"
GIT_TAG "3957ae9263eaa0b1986758dac60a88852afb09be"
PREFIX ${ANAKIN_SOURCE_DIR} PREFIX ${ANAKIN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DUSE_GPU_PLACE=YES CMAKE_ARGS -DUSE_GPU_PLACE=YES

@ -155,10 +155,11 @@ paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale',
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))

@ -128,7 +128,8 @@ struct ExtractAttribute {
attr_value = &boost::get<T>(attr); attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, typeid(T).name(), attr.type().name()); attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
} }
return attr_value; return attr_value;
} }
@ -160,7 +161,7 @@ struct ExtractAttribute<bool> {
attr_value = &boost::get<bool>(attr); attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, attr.type().name()); attr_name_, paddle::platform::demangle(attr.type().name()));
} }
return attr_value; return attr_value;
} }
@ -186,7 +187,7 @@ struct ExtractAttribute<int64_t> {
attr_value = &boost::get<int64_t>(attr); attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, attr.type().name()); attr_name_, paddle::platform::demangle(attr.type().name()));
} }
return attr_value; return attr_value;
} }

@ -20,6 +20,9 @@
DEFINE_int32(io_threadpool_size, 100, DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100"); "number of threads used for doing IO, default 100");
DEFINE_int32(dist_threadpool_size, 0,
"number of threads used for distributed executed.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -35,6 +38,10 @@ void ThreadPool::Init() {
if (threadpool_.get() == nullptr) { if (threadpool_.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number // TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
if (FLAGS_dist_threadpool_size > 0) {
num_threads = FLAGS_dist_threadpool_size;
VLOG(1) << "set dist_threadpool_size to " << num_threads;
}
PADDLE_ENFORCE_GT(num_threads, 0); PADDLE_ENFORCE_GT(num_threads, 0);
threadpool_.reset(new ThreadPool(num_threads)); threadpool_.reset(new ThreadPool(num_threads));
} }

@ -60,7 +60,7 @@ cc_library(paddle_inference_tensorrt_subgraph_engine
inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec) inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec)
endif() endif()
if (WITH_ANAKIN) # only needed in CI if (WITH_ANAKIN AND WITH_GPU) # only needed in CI
# compile the libinference_anakin_api.a and anakin.so. # compile the libinference_anakin_api.a and anakin.so.
nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber)
#nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin) #nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin)

@ -170,6 +170,9 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op") elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
elseif(${TARGET} STREQUAL "fc")
# HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
@ -300,12 +303,6 @@ op_library(channel_recv_op DEPS concurrency)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF
# Because the fully connected layer has only one MKLDNN's operator
if(NOT WITH_MKLDNN)
list(REMOVE_ITEM GENERAL_OPS fc_op)
endif(NOT WITH_MKLDNN)
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src}) op_library(${src})
endforeach() endforeach()

@ -26,6 +26,8 @@ namespace plat = paddle::platform;
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \ act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \ ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \ ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);

@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out); dx.device(d) = static_cast<T>(0.5) * dout / out;
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
} }
}; };
@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) * dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1))); x.pow(static_cast<T>(factor) - static_cast<T>(1));
} }
}; };
@ -863,10 +862,11 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
T b = static_cast<T>(beta);
auto temp1 = static_cast<T>(1) / auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp()); (static_cast<T>(1) + (static_cast<T>(-b) * x).exp());
auto temp2 = temp1 * (static_cast<T>(1) - (beta * out)); auto temp2 = temp1 * (static_cast<T>(1) - (b * out));
dx.device(d) = dout * ((beta * out) + temp2); dx.device(d) = dout * ((b * out) + temp2);
} }
}; };

@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>, REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>); ops::AssignValueKernel<float>,
ops::AssignValueKernel<plat::float16>);

@ -39,6 +39,27 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024; static_cast<size_t>(1024) * 1024 * 1024;
template <typename T, typename DeviceContext>
// bool EnableFp16(const T& dummy, const DeviceContext& dev_ctx,
bool EnableFp16(const DeviceContext& dev_ctx,
cudnnConvolutionDescriptor_t cudnn_conv_desc) {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
if (dev_ctx.GetComputeCapability() >= 70 &&
std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16))) {
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
return true;
} else {
PADDLE_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
}
#endif
return false;
}
template <typename T> template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> { class CUDNNConvOpKernel : public framework::OpKernel<T> {
public: public:
@ -128,27 +149,14 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
if (EnableFp16<T>(dev_ctx, cudnn_conv_desc)) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
if (dev_ctx.GetComputeCapability() >= 70 &&
std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16))) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else { } else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
} }
#endif
// get workspace size able to allocate // get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
@ -288,6 +296,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else { } else {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} }
if (EnableFp16<T>(dev_ctx, cudnn_conv_desc)) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
@ -307,6 +318,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else { } else {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} }
if (EnableFp16<T>(dev_ctx, cudnn_conv_desc)) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
@ -362,7 +376,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<plat::float16>); paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
@ -370,4 +385,5 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<plat::float16>); paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>)

@ -13,12 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
using CUDACtx = paddle::platform::CUDADeviceContext; using CUDACtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(cross_entropy, REGISTER_OP_CUDA_KERNEL(cross_entropy,
ops::CrossEntropyOpKernel<CUDACtx, float>, ops::CrossEntropyOpKernel<CUDACtx, float>,
ops::CrossEntropyOpKernel<CUDACtx, double>); ops::CrossEntropyOpKernel<CUDACtx, double>,
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, ops::CrossEntropyOpKernel<CUDACtx, plat::float16>);
ops::CrossEntropyGradientOpKernel<CUDACtx, float>, REGISTER_OP_CUDA_KERNEL(
ops::CrossEntropyGradientOpKernel<CUDACtx, double>); cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);

@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField(
#endif #endif
} }
VLOG(7) << "ProcSerializedField:" << meta_.varname()
<< ", type:" << meta_.type() << std::endl;
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
return false; return false;
} }
return true; return true;
} }
@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField(
return true; return true;
} }
return true; PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());
return false;
} }
}; // namespace distributed }; // namespace distributed

@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>); ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);

@ -14,19 +14,24 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>); plat::float16>);

@ -14,19 +14,25 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>);

@ -350,7 +350,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
int j = blockIdx.x; int j = blockIdx.x;
int i = threadIdx.x; int i = threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
T val = 0; T val(0);
do { do {
int x_offset = i * w + j; int x_offset = i * w + j;
@ -418,7 +418,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
int tid = threadIdx.x; int tid = threadIdx.x;
int j = blockIdx.x; int j = blockIdx.x;
T val = 0; T val(0);
int ttid = tid; int ttid = tid;
while (true) { while (true) {

@ -14,19 +14,25 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>);

@ -125,13 +125,16 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW"); "Input must be with 2 or 4 dimensions, i.e. NCHW");
// TODO(intel friends): the native weight format is io,
// but the mkldnn weight format is oihw, which may need be transposed.
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4, PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW"); "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
bool with_bias = ctx.Attr<bool>("bias_attr"); bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias); MKLDNNMD<Tensor> md(input, w, with_bias);
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd = std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd =
@ -154,6 +157,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_memory = mem.dst(output_data); auto dst_memory = mem.dst(output_data);
auto src_memory = mem.src(input_data); auto src_memory = mem.src(input_data);
auto weights_memory = mem.weights(w_data); auto weights_memory = mem.weights(w_data);
// TODO(intel friends): bias memory should also be obtain from bias->data()
auto bias_memory = mem.bias(); auto bias_memory = mem.bias();
auto forward = with_bias ? mkldnn::inner_product_forward( auto forward = with_bias ? mkldnn::inner_product_forward(
@ -216,7 +220,8 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
bool with_bias = ctx.Attr<bool>("bias_attr"); auto bias = ctx.Input<Tensor>("Bias");
bool with_bias = bias != nullptr;
MKLDNNMD<Tensor> md(input, w, with_bias); MKLDNNMD<Tensor> md(input, w, with_bias);
MKLDNNMemory mem(&md, mkldnn_engine); MKLDNNMemory mem(&md, mkldnn_engine);

@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -25,16 +28,24 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"Out(Output) of Fully Connected should not be null."); "Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of Fully Connected should not be null."); "W(Input) of Fully Connected should not be null.");
// NCHW
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
// IO, I=C*H*W
auto w_dims = ctx->GetInputDim("W"); auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
}
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, "Fully Connected input should be 2-D tensor.");
"Fully Connected input should be 2-D or 4-D tensor."); PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
"Fully Connected input and weigth size do not match.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out"); ctx->ShareLoD("Input", "Out");
@ -42,9 +53,12 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FCOp::GetExpectedKernelType( framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout{framework::DataLayout::kMKLDNN}; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library); layout, library);
@ -60,27 +74,39 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasOutput(framework::GradVarName("W"))) { if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), w_dims); ctx->SetOutputDim(framework::GradVarName("W"), w_dims);
} }
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
}
} }
framework::OpKernelType FCOpGrad::GetExpectedKernelType( framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout{framework::DataLayout::kMKLDNN}; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library); layout, library);
} }
void FCOpMaker::Make() { void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("Input",
AddInput("W", "(Tensor), The second input tensor of fc op."); "(Tensor), The input tensor of fully connected operator with format "
"(NCHW). ");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("bias_attr", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Fully Connected Operator. Fully Connected Operator.
@ -94,9 +120,47 @@ void FCOpMaker::Make() {
)DOC"); )DOC");
} }
template <typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out");
auto in_dims = input->dims();
auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
static_cast<T>(1), input_data, w_data, static_cast<T>(0),
output_data);
if (bias) {
const T* bias_data = bias->data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dims[1]);
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);

@ -12,48 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillConstantInferShape : public framework::InferShapeBase { class FillConstantOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null."); "Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
};
class FillConstantOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
if (force_cpu) {
auto cpu = platform::CPUPlace();
out.mutable_data(cpu, framework::ToTypeIndex(data_type));
} else {
out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); framework::OpKernelType GetExpectedKernelType(
auto &dev_ctx = *pool.Get(dev_place); const framework::ExecutionContext& ctx) const override {
math::set_constant(dev_ctx, &out, value); return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context());
} }
}; };
@ -87,6 +67,11 @@ Fill up a variable with specified constant value.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, REGISTER_OPERATOR(fill_constant, ops::FillConstantOp, ops::FillConstantOpMaker,
ops::FillConstantInferShape, ops::FillConstantOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillConstantOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillConstantOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillConstantOpKernel<paddle::platform::CPUDeviceContext, int64_t>)

@ -0,0 +1,26 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_constant,
ops::FillConstantOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillConstantOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillConstantOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillConstantOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillConstantOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>)

@ -0,0 +1,48 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FillConstantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu");
auto* out = ctx.Output<framework::Tensor>("Out");
out->Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape")));
if (force_cpu) {
auto cpu = platform::CPUPlace();
out->mutable_data(cpu, framework::ToTypeIndex(data_type));
} else {
out->mutable_data(ctx.GetPlace(), framework::ToTypeIndex(data_type));
}
math::set_constant(ctx.template device_context<DeviceContext>(), out,
value);
}
};
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save