|  |  |  | @ -14,6 +14,9 @@ limitations under the License. */ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #pragma once | 
			
		
	
		
			
				
					|  |  |  |  | #include <set> | 
			
		
	
		
			
				
					|  |  |  |  | #include <string> | 
			
		
	
		
			
				
					|  |  |  |  | #include <vector> | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/fluid/framework/eigen.h" | 
			
		
	
		
			
				
					|  |  |  |  | #include "paddle/fluid/framework/op_registry.h" | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -36,11 +39,11 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |   }; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   void GetSegments(const int64_t* label, int length, | 
			
		
	
		
			
				
					|  |  |  |  |                    std::vector<Segment>& segments, int num_chunk_types, | 
			
		
	
		
			
				
					|  |  |  |  |                    std::vector<Segment>* segments, int num_chunk_types, | 
			
		
	
		
			
				
					|  |  |  |  |                    int num_tag_types, int other_chunk_type, int tag_begin, | 
			
		
	
		
			
				
					|  |  |  |  |                    int tag_inside, int tag_end, int tag_single) const { | 
			
		
	
		
			
				
					|  |  |  |  |     segments.clear(); | 
			
		
	
		
			
				
					|  |  |  |  |     segments.reserve(length); | 
			
		
	
		
			
				
					|  |  |  |  |     segments->clear(); | 
			
		
	
		
			
				
					|  |  |  |  |     segments->reserve(length); | 
			
		
	
		
			
				
					|  |  |  |  |     int chunk_start = 0; | 
			
		
	
		
			
				
					|  |  |  |  |     bool in_chunk = false; | 
			
		
	
		
			
				
					|  |  |  |  |     int tag = -1; | 
			
		
	
	
		
			
				
					|  |  |  | @ -58,7 +61,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |             i - 1,        // end
 | 
			
		
	
		
			
				
					|  |  |  |  |             prev_type, | 
			
		
	
		
			
				
					|  |  |  |  |         }; | 
			
		
	
		
			
				
					|  |  |  |  |         segments.push_back(segment); | 
			
		
	
		
			
				
					|  |  |  |  |         segments->push_back(segment); | 
			
		
	
		
			
				
					|  |  |  |  |         in_chunk = false; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type, | 
			
		
	
	
		
			
				
					|  |  |  | @ -73,7 +76,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |           length - 1,   // end
 | 
			
		
	
		
			
				
					|  |  |  |  |           type, | 
			
		
	
		
			
				
					|  |  |  |  |       }; | 
			
		
	
		
			
				
					|  |  |  |  |       segments.push_back(segment); | 
			
		
	
		
			
				
					|  |  |  |  |       segments->push_back(segment); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -177,8 +180,8 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |     for (int i = 0; i < num_sequences; ++i) { | 
			
		
	
		
			
				
					|  |  |  |  |       int seq_length = lod[0][i + 1] - lod[0][i]; | 
			
		
	
		
			
				
					|  |  |  |  |       EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length, | 
			
		
	
		
			
				
					|  |  |  |  |                  output_segments, label_segments, *num_infer_chunks_data, | 
			
		
	
		
			
				
					|  |  |  |  |                  *num_label_chunks_data, *num_correct_chunks_data, | 
			
		
	
		
			
				
					|  |  |  |  |                  &output_segments, &label_segments, num_infer_chunks_data, | 
			
		
	
		
			
				
					|  |  |  |  |                  num_label_chunks_data, num_correct_chunks_data, | 
			
		
	
		
			
				
					|  |  |  |  |                  num_chunk_types, num_tag_types, other_chunk_type, tag_begin, | 
			
		
	
		
			
				
					|  |  |  |  |                  tag_inside, tag_end, tag_single, excluded_chunk_types); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
	
		
			
				
					|  |  |  | @ -197,10 +200,10 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   void EvalOneSeq(const int64_t* output, const int64_t* label, int length, | 
			
		
	
		
			
				
					|  |  |  |  |                   std::vector<Segment>& output_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   std::vector<Segment>& label_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   int64_t& num_output_segments, int64_t& num_label_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   int64_t& num_correct, int num_chunk_types, int num_tag_types, | 
			
		
	
		
			
				
					|  |  |  |  |                   std::vector<Segment>* output_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   std::vector<Segment>* label_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   int64_t* num_output_segments, int64_t* num_label_segments, | 
			
		
	
		
			
				
					|  |  |  |  |                   int64_t* num_correct, int num_chunk_types, int num_tag_types, | 
			
		
	
		
			
				
					|  |  |  |  |                   int other_chunk_type, int tag_begin, int tag_inside, | 
			
		
	
		
			
				
					|  |  |  |  |                   int tag_end, int tag_single, | 
			
		
	
		
			
				
					|  |  |  |  |                   const std::set<int>& excluded_chunk_types) const { | 
			
		
	
	
		
			
				
					|  |  |  | @ -209,25 +212,29 @@ class ChunkEvalKernel : public framework::OpKernel<T> { | 
			
		
	
		
			
				
					|  |  |  |  |     GetSegments(label, length, label_segments, num_chunk_types, num_tag_types, | 
			
		
	
		
			
				
					|  |  |  |  |                 other_chunk_type, tag_begin, tag_inside, tag_end, tag_single); | 
			
		
	
		
			
				
					|  |  |  |  |     size_t i = 0, j = 0; | 
			
		
	
		
			
				
					|  |  |  |  |     while (i < output_segments.size() && j < label_segments.size()) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_segments[i] == label_segments[j] && | 
			
		
	
		
			
				
					|  |  |  |  |           excluded_chunk_types.count(output_segments[i].type) != 1) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++num_correct; | 
			
		
	
		
			
				
					|  |  |  |  |     while (i < output_segments->size() && j < label_segments->size()) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_segments->at(i) == label_segments->at(j) && | 
			
		
	
		
			
				
					|  |  |  |  |           excluded_chunk_types.count(output_segments->at(i).type) != 1) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++(*num_correct); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_segments[i].end < label_segments[j].end) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_segments->at(i).end < label_segments->at(j).end) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++i; | 
			
		
	
		
			
				
					|  |  |  |  |       } else if (output_segments[i].end > label_segments[j].end) { | 
			
		
	
		
			
				
					|  |  |  |  |       } else if (output_segments->at(i).end > label_segments->at(j).end) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++j; | 
			
		
	
		
			
				
					|  |  |  |  |       } else { | 
			
		
	
		
			
				
					|  |  |  |  |         ++i; | 
			
		
	
		
			
				
					|  |  |  |  |         ++j; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto& segment : label_segments) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (excluded_chunk_types.count(segment.type) != 1) ++num_label_segments; | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto& segment : (*label_segments)) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (excluded_chunk_types.count(segment.type) != 1) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++(*num_label_segments); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto& segment : output_segments) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (excluded_chunk_types.count(segment.type) != 1) ++num_output_segments; | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto& segment : (*output_segments)) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (excluded_chunk_types.count(segment.type) != 1) { | 
			
		
	
		
			
				
					|  |  |  |  |         ++(*num_output_segments); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | }; | 
			
		
	
	
		
			
				
					|  |  |  | 
 |