# -*- coding: utf-8 -*- # # Poio Tools for Linguists # # Copyright (C) 2009-2013 Poio Project # Author: Peter Bouda # URL: # For license information, see LICENSE """ Classes for predictors and to handle suggestions and predictions. """ from __future__ import absolute_import, unicode_literals import os try: import configparser except ImportError: import ConfigParser as configparser from . import dbconnector #import pressagio.observer MIN_PROBABILITY = 0.0 MAX_PROBABILITY = 1.0 class SuggestionException(Exception): pass class UnknownCombinerException(Exception): pass class PredictorRegistryException(Exception): pass class Suggestion(object): """ Class for a simple suggestion, consists of a string and a probility for that string. """ def __init__(self, word, probability): print("I am a suggetsion") self.word = word self._probability = probability def __eq__(self, other): if self.word == other.word and self.probability == other.probability: return True return False def __lt__(self, other): if self.probability < other.probability: return True if self.probability == other.probability: return self.word < other.word return False def __repr__(self): return "Word: {0} - Probability: {1}".format( self.word, self.probability) def probability(): doc = "The probability property." def fget(self): return self._probability def fset(self, value): if value < MIN_PROBABILITY or value > MAX_PROBABILITY: raise SuggestionException("Probability is too high or too low.") self._probability = value def fdel(self): del self._probability return locals() probability = property(**probability()) class Prediction(list): """ Class for predictions from predictors. """ def __init__(self): pass def __eq__(self, other): if self is other: return True if len(self) != len(other): return False for i, s in enumerate(other): if not s == self[i]: return False return True def suggestion_for_token(self, token): for s in self: if s.word == token: return s def add_suggestion(self, suggestion): if len(self) == 0: self.append(suggestion) else: i = 0 while i < len(self) and suggestion < self[i]: i += 1 self.insert(i, suggestion) class PredictorActivator(object): """ PredictorActivator starts the execution of the active predictors, monitors their execution and collects the predictions returned, or terminates a predictor's execution if it execedes its maximum prediction time. The predictions returned by the individual predictors are combined into a single prediction by the active Combiner. """ def __init__(self, config, registry, context_tracker): self.config = config self.registry = registry self.context_tracker = context_tracker #self.dispatcher = pressagio.observer.Dispatcher(self) self.predictions = [] self.combiner = None self.max_partial_prediction_size = int(config.get( "Selector", "suggestions")) self.predict_time = None self._combination_policy = None def combination_policy(): doc = "The combination_policy property." def fget(self): return self._combination_policy def fset(self, value): self._combination_policy = value if value.lower() == "meritocracy": self.combiner = pressagio.combiner.MeritocracyCombiner() else: raise UnknownCombinerException() def fdel(self): del self._combination_policy return locals() combination_policy = property(**combination_policy()) def predict(self, multiplier = 1, prediction_filter = None): self.predictions[:] = [] for predictor in self.registry: self.predictions.append(predictor.predict( self.max_partial_prediction_size * multiplier, prediction_filter)) result = self.combiner.combine(self.predictions) return result class PredictorRegistry(list): #pressagio.observer.Observer, """ Manages instantiation and iteration through predictors and aids in generating predictions and learning. PredictorRegitry class holds the active predictors and provides the interface required to obtain an iterator to the predictors. The standard use case is: Predictor obtains an iterator from PredictorRegistry and invokes the predict() or learn() method on each Predictor pointed to by the iterator. Predictor registry should eventually just be a simple wrapper around plump. """ def __init__(self, config, dbconnection = None): self.config = config self.dbconnection = dbconnection self._context_tracker = None self.set_predictors() def context_tracker(): doc = "The context_tracker property." def fget(self): return self._context_tracker def fset(self, value): if self._context_tracker is not value: self._context_tracker = value self[:] = [] self.set_predictors() def fdel(self): del self._context_tracker return locals() context_tracker = property(**context_tracker()) def set_predictors(self): if (self.context_tracker): self[:] = [] for predictor in self.config.get("PredictorRegistry", "predictors")\ .split(): self.add_predictor(predictor) def add_predictor(self, predictor_name): predictor = None if self.config.get(predictor_name, "predictor_class") == \ "SmoothedNgramPredictor": predictor = SmoothedNgramPredictor(self.config, self.context_tracker, predictor_name, dbconnection = self.dbconnection) if predictor: self.append(predictor) def close_database(self): for predictor in self: predictor.close_database() class Predictor(object): """ Base class for predictors. """ def __init__(self, config, context_tracker, predictor_name, short_desc = None, long_desc = None): self.short_description = short_desc self.long_description = long_desc self.context_tracker = context_tracker self.name = predictor_name self.config = config def token_satifies_filter(token, prefix, token_filter): if token_filter: for char in token_filter: candidate = prefix + char if token.startswith(candidate): return True return False class SmoothedNgramPredictor(Predictor): #, pressagio.observer.Observer """ Calculates prediction from n-gram model in sqlite database. You have to create a database with the script `text2ngram` first. """ def __init__(self, config, context_tracker, predictor_name, short_desc = None, long_desc = None, dbconnection = None): Predictor.__init__(self, config, context_tracker, predictor_name, short_desc, long_desc) self.db = None self.dbconnection = dbconnection self.cardinality = None self.learn_mode_set = False self.dbclass = None self.dbuser = None self.dbpass = None self.dbhost = None self.dbport = None self._database = None self._deltas = None self._learn_mode = None self.config = config self.name = predictor_name self.context_tracker = context_tracker self._read_config() ################################################## Properties def deltas(): doc = "The deltas property." def fget(self): return self._deltas def fset(self, value): self._deltas = [] # make sure that values are floats for i, d in enumerate(value): self._deltas.append(float(d)) self.cardinality = len(value) self.init_database_connector_if_ready() def fdel(self): del self._deltas return locals() deltas = property(**deltas()) def learn_mode(): doc = "The learn_mode property." def fget(self): return self._learn_mode def fset(self, value): self._learn_mode = value self.learn_mode_set = True self.init_database_connector_if_ready() def fdel(self): del self._learn_mode return locals() learn_mode = property(**learn_mode()) def database(): doc = "The database property." def fget(self): return self._database def fset(self, value): self._database = value self.dbclass = self.config.get("Database", "class") if self.dbclass == "PostgresDatabaseConnector": self.dbuser = self.config.get("Database", "user") self.dbpass = self.config.get("Database", "password") self.dbhost = self.config.get("Database", "host") self.dbport = self.config.get("Database", "port") self.dblowercase = self.config.getboolean("Database", "lowercase_mode") self.dbnormalize = self.config.getboolean("Database", "normalize_mode") self.init_database_connector_if_ready() def fdel(self): del self._database return locals() database = property(**database()) #################################################### Methods def init_database_connector_if_ready(self): if self.database and len(self.database) > 0 and \ self.cardinality and self.cardinality > 0 and \ self.learn_mode_set: if self.dbclass == "SqliteDatabaseConnector": self.db = dbconnector.SqliteDatabaseConnector( self.database, self.cardinality) #, self.learn_mode elif self.dbclass == "PostgresDatabaseConnector": self.db = dbconnector.PostgresDatabaseConnector( self.database, self.cardinality, self.dbhost, self.dbport, self.dbuser, self.dbpass, self.dbconnection) self.db.lowercase = self.dblowercase self.db.normalize = self.dbnormalize self.db.open_database() def ngram_to_string(self, ngram): "|".join(ngram) def predict(self, max_partial_prediction_size, filter): print("SmoothedNgramPredictor Predicting") print(filter) tokens = [""] * self.cardinality prediction = Prediction() for i in range(self.cardinality): tokens[self.cardinality - 1 - i] = self.context_tracker.token(i) prefix_completion_candidates = [] for k in reversed(range(self.cardinality)): if len(prefix_completion_candidates) >= max_partial_prediction_size: break prefix_ngram = tokens[(len(tokens) - k - 1):] partial = None if not filter: partial = self.db.ngram_like_table(prefix_ngram, max_partial_prediction_size - \ len(prefix_completion_candidates)) else: partial = db.ngram_like_table_filtered(prefix_ngram, filter, max_partial_prediction_size - \ len(prefix_completion_candidates)) print((partial)) for p in partial: if len(prefix_completion_candidates) > \ max_partial_prediction_size: break candidate = p[-2] # ??? if candidate not in prefix_completion_candidates: prefix_completion_candidates.append(candidate) # smoothing unigram_counts_sum = self.db.unigram_counts_sum() for j, candidate in enumerate(prefix_completion_candidates): #if j >= max_partial_prediction_size: # break tokens[self.cardinality - 1] = candidate probability = 0 for k in range(self.cardinality): numerator = self._count(tokens, 0, k + 1) denominator = unigram_counts_sum if numerator > 0: denominator = self._count(tokens, -1, k) frequency = 0 if denominator > 0: frequency = float(numerator) / denominator probability += self.deltas[k] * frequency if probability > 0: prediction.add_suggestion(Suggestion(tokens[self.cardinality - 1], probability)) return(prediction) def close_database(self): self.db.close_database() ################################################ Private methods def _read_config(self): self.database = self.config.get("Database", "database") self.deltas = self.config.get(self.name, "deltas").split() self.learn_mode = self.config.get(self.name, "learn") def _count(self, tokens, offset, ngram_size): result = 0 if (ngram_size > 0): ngram = \ tokens[len(tokens) - ngram_size + offset:\ len(tokens) + offset] result = self.db.ngram_count(ngram) else: result = self.db.unigram_counts_sum() return result