|  |  |  | @ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { | 
			
		
	
		
			
				
					|  |  |  |  |             typename dX> | 
			
		
	
		
			
				
					|  |  |  |  |   void operator()(Device d, X x, Out out, dOut dout, dX dx) const { | 
			
		
	
		
			
				
					|  |  |  |  |     auto temp1 = | 
			
		
	
		
			
				
					|  |  |  |  |         static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto temp2 = (out >= static_cast<T>(0)).template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |         static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |     auto temp2 = (out > static_cast<T>(0)).template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |     dx.device(d) = dout * (temp1 + temp2).template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -1418,11 +1418,11 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { | 
			
		
	
		
			
				
					|  |  |  |  |       auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); | 
			
		
	
		
			
				
					|  |  |  |  |       auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); | 
			
		
	
		
			
				
					|  |  |  |  |       auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); | 
			
		
	
		
			
				
					|  |  |  |  |       ddout.device(*d) = | 
			
		
	
		
			
				
					|  |  |  |  |           ddx * | 
			
		
	
		
			
				
					|  |  |  |  |           ((out >= static_cast<T>(0)).template cast<T>() + | 
			
		
	
		
			
				
					|  |  |  |  |            static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>()) | 
			
		
	
		
			
				
					|  |  |  |  |               .template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |       ddout.device(*d) = ddx * | 
			
		
	
		
			
				
					|  |  |  |  |                          ((out > static_cast<T>(0)).template cast<T>() + | 
			
		
	
		
			
				
					|  |  |  |  |                           static_cast<T>(alpha) * | 
			
		
	
		
			
				
					|  |  |  |  |                               (out <= static_cast<T>(0)).template cast<T>()) | 
			
		
	
		
			
				
					|  |  |  |  |                              .template cast<T>(); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } | 
			
		
	
	
		
			
				
					|  |  |  | 
 |