boudinfl/pke

Divide by zero when training Kea model

Closed this issue · 4 comments

I am following the example in pke/examples butt get divide by zero error

C:\Users\ADMIN\Anaconda3\lib\site-packages\sklearn\naive_bayes.py:465: RuntimeWarning: divide by zero encountered in log self.class_log_prior_ = (np.log(self.class_count_) -

`

path to the collection of documents

input_dir = 'input_dir/'

path to the reference file

reference_file = 'gold-annotation.txt'

path to the df file

df_file = "df.tsv.gz"
logging.info('Loading df counts from {}'.format(df_file))
df_counts = pke.load_document_frequency_file(input_file=df_file,
delimiter='\t')

path to the model, saved as a pickle

output_mdl = "model.pickle"

pke.train_supervised_model(input_dir=input_dir,
reference_file=reference_file,
model_file=output_mdl,
extension='xml',
language='en',
normalization=None,
df=df_counts,
model=pke.supervised.Kea(),
sep_doc_id=' : ',
sep_ref_keyphrases=',',
normalize_reference=False
)`

Looking at the error message, it seems that there is no positive (or negative) examples when training the naive bayes model. Can you check the reference_file and verify that at least some canidates from the document appear in the references ?

Following the gold-annotation.txt, ex: C-41 : adapt resourc manag,distribut real-time embed system, ...

why some tokens like "resourc", "distribut" occur? I think that make the error above, I think it should be resource or distribute?

I catch the problem because of the extract doc_id:
the doc_id will return in format "input_dir\file_name" in WINDOW

#train\C-42.txt => Window
#train/C-42.txt => other

doc_id1 = '.'.join('train\C-42.txt'.split('/')[-1].split('.')[0:-1])
doc_id2 = '.'.join('train/C-42.txt'.split('/')[-1].split('.')[0:-1])

Indeed, that was the problem I was facing when running this in my Windows machine.
Line 202 of /site-packages/pke/utils.py needs to be replaced by the following code, thus guaranteeing that it will run properly regardless the machine where its executed:

doc_id = '.'.join(os.path.basename(input_file).split('.')[0:-1])