"""
Module to generate and run unit tests
Author: Amir Zeldes
"""
from collections import defaultdict
import unittest
import re, os, sys
from .xrenner_xrenner import Xrenner
from .xrenner_coref import find_antecedent
if sys.version_info[0] < 3:
python_version = 2
else:
python_version = 3
def generate_test(conll_tokens, markables, parse, model="eng", name="test"):
tok_count = len(conll_tokens)
mark_count = 0
ids = []
marks_by_id = {}
# Collect markable groups, assign IDs by extension and count markables
marks_by_group = defaultdict(list)
for mark in markables:
mark_count += 1
# Assign predictable ID of the form start_end
id_ = str(mark.start) + "_" + str(mark.end)
if id_ in ids:
raise("xrenner generated two markables with same extension: tok" + str(mark.start) + ":tok" + str(mark.end))
else:
ids.append(id_)
mark.id = id_
marks_by_id[id_] = mark
marks_by_group[int(mark.group)].append(mark)
group_count = len(marks_by_group)
# Serialize group details
chains = []
for group in sorted(marks_by_group):
chain = sorted(marks_by_group[group],key=lambda x: int(x.id[:x.id.find("_")]))
gid = "g" + chain[0].id
chain_string = " "
for mark in chain:
chain_string += mark.id + " < "
chains.append(chain_string[:-3])
chains.sort(key=lambda x: int(x[2:x.find("_")]))
snippets = []
for chain in chains:
first = chain[2:chain.find("<")-1] if "<" in chain else chain.strip()
snippet = marks_by_id[first].text[:20] + "..." if len(marks_by_id[first].text) > 20 else marks_by_id[first].text
snippets.append(snippet)
zipped = zip(snippets,chains)
output = ""
output += "name:" + name + "\n"
output += "model:" + model + "\n"
output += "toks:" + str(tok_count) + " # " + " ".join(tok.text for tok in conll_tokens[1:4]) + "..." + "\n"
output += "marks:" + str(mark_count) + "\n"
output += "groups:" + str(group_count) + "\n"
output += "chains:" + "\n"
for chain in zipped:
output += " # " + str(chain[0]) + "\n"
output += chain[1] + "\n"
output += "input_data:" + "\n"
output += "\n".join(parse)
output += "\n" + "-"*5 + "\n"
return output
def setUpModule():
# Read test/tests.dat
print("\nxrenner unit tests\n" + "=" * 20 + "\nReading test cases from test/tests.dat")
file = os.path.dirname(os.path.realpath(__file__)) + os.sep + ".." + os.sep + "test" + os.sep + "tests.dat"
test_data = ""
with open(file, 'rb') as f:
test_data = f.read()
# Populate cases with Case objects
global cases
cases = {}
if python_version < 3:
case_list = test_data.split("-----")
else:
case_list = test_data.decode().split("-----")
for case in case_list:
case = case.strip()
if len(case) > 0:
case_to_add = Case(case)
cases[case_to_add.name] = case_to_add
# Initialize an Xrenner object with the language model and assign to module level xrenner variable for all suites
print("Initializing xrenner model 'eng'\n")
global xrenner
xrenner = Xrenner("eng",override="TEST")
xrenner.set_doc_name("test_test")
[docs]class Test1Model(unittest.TestCase):
@classmethod
[docs] def setUpClass(cls):
print("\nTesting model integrity\n" + "-"*30)
global xrenner
cls.xrenner = xrenner
global cases
cls.cases = cases
@classmethod
[docs] def tearDownClass(cls):
global xrenner
cls.xrenner = xrenner
[docs] def test_model_files(self):
print("Checking model files: ")
# Check that all model components were read and are filled as expected
self.assertTrue(len(self.xrenner.lex.coref_rules),"check that coref_rules is full")
self.assertTrue(len(self.xrenner.lex.entities),"check that entities is full")
self.assertTrue(len(self.xrenner.lex.entity_heads),"check that entity_heads is full")
self.assertTrue(len(self.xrenner.lex.pronouns),"check that pronouns is full")
self.assertTrue(len(self.xrenner.lex.filters),"check that filters is full")
# Optional components, should be included in default model
self.assertTrue(len(self.xrenner.lex.names),"check that names is full")
self.assertTrue(len(self.xrenner.lex.stop_list),"check that stop_list is full")
self.assertTrue(len(self.xrenner.lex.open_close_punct),"check that open_close_punct is full")
self.assertTrue(len(self.xrenner.lex.open_close_punct_rev),"check that open_close_punct_rev is full")
self.assertTrue(len(self.xrenner.lex.entity_mods),"check that entity_mods is full")
self.assertTrue(len(self.xrenner.lex.entity_deps),"check that entity_deps is full")
self.assertTrue(len(self.xrenner.lex.hasa),"check that hasa is full")
self.assertTrue(len(self.xrenner.lex.coref),"check that coref is full")
self.assertTrue(len(self.xrenner.lex.numbers),"check that numbers is full")
self.assertTrue(len(self.xrenner.lex.affix_tokens),"check that affix_tokens is full")
self.assertTrue(len(self.xrenner.lex.antonyms),"check that antonyms is full")
self.assertTrue(len(self.xrenner.lex.isa),"check that isa is full")
[docs]class Test2MarkableMethods(unittest.TestCase):
@classmethod
[docs] def setUpClass(cls):
print("\n\nRunning markable method tests\n" + "-"*30)
global xrenner
cls.xrenner = xrenner
global cases
cls.cases = cases
cls.xrenner.lex.filters["remove_singletons"] = False
@classmethod
[docs] def tearDownClass(cls):
global xrenner
cls.xrenner = xrenner
cls.xrenner.lex.filters["remove_singletons"] = True
[docs] def test_name(self):
# Jerry B. Clinton
print("\nRun markable name test: ")
target = self.cases["mark_name_test"]
self.xrenner.analyze(target.parse.split("\n"), "unittest")
markables = self.xrenner.markables
# Check that there are no nested markables in Jerry B. Clinton
self.assertEqual(len(markables),1)
# Check that the name is classified as a person
self.assertEqual(markables[0].entity, "person")
[docs] def test_atomic_mod(self):
# Israel Machines Corp.
# Note that Corp. is an organization-marking atomic flagged modifier of the head
print("\nRun atomic modifier test: ")
target = self.cases["mark_atomic_mod_test"]
self.xrenner.analyze(target.parse.split("\n"), "unittest")
markables = self.xrenner.markables
# Check that there are no nested markables in Israel Machines Corp.
self.assertEqual(len(markables), 1)
# Check that the name is classified as a person
self.assertEqual(markables[0].entity, "organization")
[docs]class Test3CorefMethods(unittest.TestCase):
@classmethod
[docs] def setUpClass(cls):
print("\n\nRunning coref method tests\n" + "-"*30)
global xrenner
cls.xrenner = xrenner
global cases
cls.cases = cases
[docs] def test_cardinality(self):
# I saw two birds . The three birds flew .
print("\nRun cardinality test: ")
target = self.cases["cardinality_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(0,result.mark_count,"cardinality test (two birds != the three birds)")
[docs] def test_appos_envelope(self):
# Meet [[Mark Smith] , [the Governor]]. [He] is the best.
print("\nRun apposition envelope test: ")
target = self.cases["appos_envelope"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"appos envelope test")
[docs] def test_isa(self):
# I read [the Wall Street Journal]. [That newspaper] is great.
print("\nRun isa test: ")
target = self.cases["isa_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"isa test (Wall Street Journal <- newspaper)")
[docs] def test_hasa(self):
# The [[CEO] and the taxi driver] ate . [[His] employees] joined them
print("\nRun hasa test: ")
target = self.cases["hasa_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"hasa test (CEO, taxi driver <- his employees)")
[docs] def test_dynamic_hasa(self):
# Beth was worried about [[Sinead 's] well-being] , and also about Jane . [[Her] well-being] was always a concern .
print("\nRun dynamic hasa test: ")
target = self.cases["dynamic_hasa_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"dynamic hasa test (Sinead 's <- her)")
[docs] def test_entity_dep(self):
# I have a book , [a dog] and a car. [It] barked.
print("\nRun entity dep test: ")
target = self.cases["entity_dep_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"entity dep test (a book, a dog <- It barked)")
[docs] def test_affix_morphology(self):
# [A blorker] had a mummelhound in a blargmobile. I saw [the person] .
print("\nRun affix morphology test: ")
target = self.cases["morph_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"affix morph test (A blorker <- the person)")
[docs] def test_verbal_event_stem(self):
# John [visited] Spain . [The visit] went well .
print("\nRun verbal event coreference test: ")
target = self.cases["verb_test"]
result = Case(self.xrenner.analyze(target.parse.split("\n"),"unittest"))
self.assertEqual(target.chains,result.chains,"verbal event stemming (visited <- the visit )")
[docs]class Case:
def __init__(self, case_string):
params, parse = case_string.split("input_data:")
self.parse = parse.strip()
params = params.replace("\r","")
self.chains = []
chain_mode = False
for line in params.split("\n"):
line = re.sub(r'#.*','',line).strip()
if len(line) > 0:
if chain_mode:
self.chains.append(line)
if ":" in line and not chain_mode and not "options" in line:
key, val = line.split(":")
if key == "name":
self.name = val
elif key == "toks":
self.tok_count = int(val)
elif key == "marks":
self.mark_count = int(val)
elif key == "groups":
self.group_count = int(val)
elif key == "model":
self.model = val
elif key == "chains":
chain_mode = True
def suite():
# Create test suite
test_suite = unittest.TestSuite()
# Add a test case
test_suite.addTest(unittest.makeSuite(Test1Model))
test_suite.addTest(unittest.makeSuite(Test2MarkableMethods))
test_suite.addTest(unittest.makeSuite(Test3CorefMethods))
return test_suite
def can_be_coreferent(mark1, mark2, lex):
"""
Utility function to check whether an xrenner model is capable of finding two markables coreferent
:param mark1: The :class:`.Markable` object to match to mark2
:param mark2: The :class:`.Markable` object to match to mark1
:param lex: the :class:`.LexData` object with gazetteer information and model settings
:return: bool
"""
lex.incompatible_isa_pairs = set([])
lex.incompatible_mod_pairs = set([])
prev_markables = [mark1, mark2]
mark2.sentence.sent_num = 2
antecedent, propagation = find_antecedent(mark2, prev_markables,lex)
return antecedent is not None
if __name__ == '__main__':
xrenner = None
cases = {}
test_runner = unittest.TextTestRunner()
test_runner.run(suite())