@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License . */
# include <set>
# include <vector>
# include "paddle/math/Vector.h"
@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std : : vector < Segment > labelSegments_ ;
std : : vector < Segment > outputSegments_ ;
std : : set < int > excludedChunkTypes_ ;
public :
virtual void init ( const EvaluatorConfig & config ) {
@ -105,6 +107,10 @@ public:
}
CHECK ( config . has_num_chunk_types ( ) ) < < " Missing num_chunk_types in config " ;
otherChunkType_ = numChunkTypes_ = config . num_chunk_types ( ) ;
// the chunks of types in excludedChunkTypes_ will not be counted
auto & tmp = config . excluded_chunk_types ( ) ;
excludedChunkTypes_ . insert ( tmp . begin ( ) , tmp . end ( ) ) ;
}
virtual void start ( ) {
@ -157,6 +163,7 @@ public:
size_t i = 0 , j = 0 ;
while ( i < outputSegments_ . size ( ) & & j < labelSegments_ . size ( ) ) {
if ( outputSegments_ [ i ] = = labelSegments_ [ j ] ) {
if ( excludedChunkTypes_ . count ( outputSegments_ [ i ] . type ) ! = 1 )
+ + numCorrect_ ;
}
if ( outputSegments_ [ i ] . end < labelSegments_ [ j ] . end ) {
@ -168,8 +175,12 @@ public:
+ + j ;
}
}
numLabelSegments_ + = labelSegments_ . size ( ) ;
numOutputSegments_ + = outputSegments_ . size ( ) ;
for ( auto & segment : labelSegments_ ) {
if ( excludedChunkTypes_ . count ( segment . type ) ! = 1 ) + + numLabelSegments_ ;
}
for ( auto & segment : outputSegments_ ) {
if ( excludedChunkTypes_ . count ( segment . type ) ! = 1 ) + + numOutputSegments_ ;
}
}
void getSegments ( int * label , int length , std : : vector < Segment > & segments ) {