You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							98 lines
						
					
					
						
							3.0 KiB
						
					
					
				
			
		
		
	
	
							98 lines
						
					
					
						
							3.0 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "paddle/math/Matrix.h"
 | |
| 
 | |
| namespace paddle {
 | |
| 
 | |
| class LinearChainCRF {
 | |
| public:
 | |
|   /**
 | |
|    * The size of para must be \f$(numClasses + 2) * numClasses\f$.
 | |
|    * The first numClasses values of para are for starting weights (\f$a\f$).
 | |
|    * The next numClasses values of para are for ending weights (\f$b\f$),
 | |
|    * The remaning values are for transition weights (\f$w\f$).
 | |
|    *
 | |
|    * The probability of a state sequence s of length \f$L\f$ is defined as:
 | |
|    * \f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
 | |
|    *                  + \sum_{l=1}^L x_{s_l}
 | |
|    *                  + \sum_{l=2}^L w_{s_{l-1},s_l})\f$
 | |
|    * where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over
 | |
|    * all possible
 | |
|    * sequences is \f$1\f$, and \f$x\f$ is the input feature to the CRF.
 | |
|    */
 | |
|   LinearChainCRF(int numClasses, real* para);
 | |
| 
 | |
|   /**
 | |
|    * Calculate the negative log likelihood of s given x.
 | |
|    * The size of x must be length * numClasses. Each consecutive numClasses
 | |
|    * values are the features for one time step.
 | |
|    */
 | |
|   real forward(real* x, int* s, int length);
 | |
| 
 | |
|   /**
 | |
|    * Calculate the gradient with respect to x, a, b, and w.
 | |
|    * backward() can only be called after a corresponding call to forward() with
 | |
|    * the same x, s and length.
 | |
|    * The gradient with respect to a, b, and w will not be calculated if
 | |
|    * needWGrad is false.
 | |
|    * @note Please call getWGrad() and getXGrad() to get the gradient with
 | |
|    * respect to (a, b, w) and x respectively.
 | |
|    */
 | |
|   void backward(real* x, int* s, int length, bool needWGrad);
 | |
| 
 | |
|   /**
 | |
|    * Find the most probable sequence given x. The result will be stored in s.
 | |
|    */
 | |
|   void decode(real* x, int* s, int length);
 | |
| 
 | |
|   /*
 | |
|    * Return the gradient with respect to (a, b, w). It can only be called after
 | |
|    * a corresponding call to backward().
 | |
|    */
 | |
|   MatrixPtr getWGrad() { return matWGrad_; }
 | |
| 
 | |
|   /*
 | |
|    * Return the gradient with respect to x. It can only be called after a
 | |
|    * corresponding call to backward().
 | |
|    */
 | |
|   MatrixPtr getXGrad() { return matGrad_; }
 | |
| 
 | |
| protected:
 | |
|   int numClasses_;
 | |
|   MatrixPtr a_;
 | |
|   MatrixPtr b_;
 | |
|   MatrixPtr w_;
 | |
|   MatrixPtr matWGrad_;
 | |
|   MatrixPtr da_;
 | |
|   MatrixPtr db_;
 | |
|   MatrixPtr dw_;
 | |
|   MatrixPtr ones_;
 | |
| 
 | |
|   MatrixPtr expX_;
 | |
|   MatrixPtr matGrad_;
 | |
|   MatrixPtr alpha_;
 | |
|   MatrixPtr beta_;
 | |
|   MatrixPtr maxX_;
 | |
|   MatrixPtr expW_;
 | |
| 
 | |
|   // track_(k,i) = j means that the best sequence at time k for class i comes
 | |
|   // from the sequence at time k-1 for class j
 | |
|   IVectorPtr track_;
 | |
| };
 | |
| 
 | |
| }  // namespace paddle
 |