apostrophe/uberwriter/pressagio/tests/test_predictor.py

144 lines
4.6 KiB
Python

# -*- coding: utf-8 -*-
#
# Poio Tools for Linguists
#
# Copyright (C) 2009-2013 Poio Project
# Author: Peter Bouda <pbouda@cidles.eu>
# URL: <http://media.cidles.eu/poio/>
# For license information, see LICENSE
from __future__ import absolute_import, unicode_literals
import os
try:
import configparser
except ImportError:
import ConfigParser as configparser
import pressagio.predictor
import pressagio.tokenizer
import pressagio.dbconnector
import pressagio.context_tracker
import pressagio.callback
class TestSuggestion():
def setup(self):
self.suggestion = pressagio.predictor.Suggestion("Test", 0.3)
def test_probability(self):
self.suggestion.probability = 0.1
assert self.suggestion.probability == 0.1
class TestPrediction():
def setup(self):
self.prediction = pressagio.predictor.Prediction()
def test_add_suggestion(self):
self.prediction.add_suggestion(pressagio.predictor.Suggestion(
"Test", 0.3))
assert self.prediction[0].word == "Test"
assert self.prediction[0].probability == 0.3
self.prediction.add_suggestion(pressagio.predictor.Suggestion(
"Test2", 0.2))
assert self.prediction[0].word == "Test"
assert self.prediction[0].probability == 0.3
assert self.prediction[1].word == "Test2"
assert self.prediction[1].probability == 0.2
self.prediction.add_suggestion(pressagio.predictor.Suggestion(
"Test3", 0.6))
assert self.prediction[0].word == "Test3"
assert self.prediction[0].probability == 0.6
assert self.prediction[1].word == "Test"
assert self.prediction[1].probability == 0.3
assert self.prediction[2].word == "Test2"
assert self.prediction[2].probability == 0.2
self.prediction[:] = []
def test_suggestion_for_token(self):
self.prediction.add_suggestion(pressagio.predictor.Suggestion(
"Token", 0.8))
assert self.prediction.suggestion_for_token("Token").probability == 0.8
self.prediction[:] = []
class StringStreamCallback(pressagio.callback.Callback):
def __init__(self, stream):
pressagio.callback.Callback.__init__(self)
self.stream = stream
class TestSmoothedNgramPredictor():
def setup(self):
self.dbfilename = os.path.abspath(os.path.join(
os.path.dirname( __file__ ), 'test_data', 'test.db'))
self.infile = os.path.abspath(os.path.join(os.path.dirname( __file__ ),
'test_data', 'der_linksdenker.txt'))
for ngram_size in range(3):
ngram_map = pressagio.tokenizer.forward_tokenize_file(
self.infile, ngram_size + 1, False)
pressagio.dbconnector.insert_ngram_map_sqlite(ngram_map, ngram_size + 1,
self.dbfilename, False)
config_file = os.path.abspath(os.path.join(os.path.dirname( __file__ ),
'test_data', 'profile_smoothedngram.ini'))
config = configparser.ConfigParser()
config.read(config_file)
config.set("Database", "database", self.dbfilename)
self.predictor_registry = pressagio.predictor.PredictorRegistry(config)
self.callback = StringStreamCallback("")
context_tracker = pressagio.context_tracker.ContextTracker(
config, self.predictor_registry, self.callback)
def test_predict(self):
predictor = self.predictor_registry[0]
predictions = predictor.predict(6, None)
assert len(predictions) == 6
words = []
for p in predictions:
words.append(p.word)
assert "er" in words
assert "der" in words
assert "die" in words
assert "und" in words
assert "nicht" in words
self.callback.stream="d"
predictions = predictor.predict(6, None)
assert len(predictions) == 6
words = []
for p in predictions:
words.append(p.word)
assert "der" in words
assert "die" in words
assert "das" in words
assert "da" in words
assert "Der" in words
self.callback.stream="de"
predictions = predictor.predict(6, None)
assert len(predictions) == 6
words = []
for p in predictions:
words.append(p.word)
assert "der" in words
assert "Der" in words
assert "dem" in words
assert "den" in words
assert "des" in words
def teardown(self):
if self.predictor_registry[0].db:
self.predictor_registry[0].db.close_database()
del(self.predictor_registry[0])
if os.path.isfile(self.dbfilename):
os.remove(self.dbfilename)