Merge branch 'feature/make_paddle_support_double' into stable_elemwise_mul

tonyyang-svail-feed-op-desgin
Yu Yang 8 years ago
commit 9a3efb28c0

@ -21,7 +21,7 @@ namespace operators {
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");
@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
// dX = sigmoid(X) - labels // dX = sigmoid(X) - labels
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");

Loading…
Cancel
Save