Replace functor by function.

revert-3824-remove_grad_op_type
dangqingqing 8 years ago
parent 70285cce32
commit 8f6c8780a5

@ -21,19 +21,18 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct clipping_log {
__host__ __device__ T operator()(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
__host__ __device__ T clipping_log(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
T v = log(x);
if (v == INFINITY) {
return kApproInf;
}
};
if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log<T>()(X[i * D + label[i]]);
Y[i] = -clipping_log(X[i * D + label[i]]);
}
}

@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
T tolerable_value(const T x) {
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");

@ -65,7 +65,7 @@ class OpTestMeta(type):
expect = self.outputs[out_name]
self.assertTrue(
numpy.allclose(
actual, expect, atol=1e-04),
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
obj.test_all = test_all

Loading…
Cancel
Save