-
Notifications
You must be signed in to change notification settings - Fork 0
/
annotate_helper.py
189 lines (153 loc) · 6.78 KB
/
annotate_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import re
import json
import spacy
from pathlib import Path
import uuid
from colorama import Fore
class AnnotateHelper:
def __init__(self, raw_text):
"""
Assists in helping to annotate the text of a document.
:param raw_text: str. A string representation of the text for a document.
"""
self.raw_text = raw_text
self.entities = []
def drop_entity(self, entity_id):
"""
Removes an entity from the entity list.
:param entity_id: Address and label of entity to remove.
:return: Edits the self.entities list.
"""
for entity in self.entities:
if entity["id"] == entity_id:
self.entities.remove(entity)
print("The following entity was removed: ")
print(entity)
def suggest_entities(self, ner_model):
"""
Applies an existing NER model to the document to suggest what entities are in the document.
:param ner_model: Path to the NER model.
:return: Returns a list of detected entities in the text.
"""
nlp = spacy.load(ner_model)
doc = nlp(self.raw_text)
identified_entities = []
for ent in doc.ents:
entity_text = self.raw_text[ent.start_char:ent.end_char]
entity = {
"id": str(uuid.uuid4().hex),
"text": entity_text,
"start_idx": ent.start_char,
"end_idx": ent.end_char,
"label": ent.label_
}
print(entity["id"], entity["text"], entity["label"])
self.entities.append(entity)
return identified_entities
def get_entity_span(self, phrase, match_case=True, label=None, context_len=30):
"""
Searches the text for an occurrence of the specified entity. For each occurrence of the text the context is
printed and you are prompted if you want to add the text as an entity.
:param phrase: str. The phrase to search for.
:param match_case: bool. (optional) Whether the matching should be case-sensitive.
:param label: str. (optional) Name of the label to assign to the entity. If set then the same label will be
assigned to each occurrence of the phrase in the text. Otherwise user will have the option to set the
label manually.
:param context_len: int. Default=30. Number of begin/end characters surrounding the phrase to include to
determine context of the phrase.
:return:
"""
# Find all the matches for the phrase.
if match_case:
matches = list(re.finditer(phrase, self.raw_text))
else:
matches = list(re.finditer(phrase.lower(), self.raw_text.lower()))
# Check if there are any matches in the text.
if len(matches) == 0:
print("There were no matches for the text provided. Check spelling or try not matching on case.")
return
print("There were {} matching phrases in the text.\n".format(str(len(matches))))
# Add the matching phrase locations to the entities list.
for m in matches:
start_idx = m.span()[0]
end_idx = m.span()[1]
entity = {
"id": str(uuid.uuid4()),
"text": phrase,
"start_idx": start_idx,
"end_idx": end_idx,
"label": label
}
subset_start = start_idx - context_len
if subset_start < 0:
subset_start = 0
subset_end = end_idx + context_len
highlight_text = Fore.RED + self.raw_text[start_idx:end_idx] + Fore.RESET
print("".join([self.raw_text[subset_start:start_idx], highlight_text, self.raw_text[end_idx:subset_end]]))
add_entity = input("Do you want to add this to your entity list? (y/n) ")
if add_entity == "y":
response = self.add_entity(entity)
print(response)
else:
print("Okay let's try the next one...\n")
def add_entity(self, entity):
"""
Gathers information about an entity tag before adding it to the document entities list.
:param entity: An entity object.
"""
if entity["label"] is None:
label = input("What should the label be for this entity? ")
while label == '':
print('Label cannot be an empty string')
label = input("What should the label be for this entity? ")
entity["label"] = label.upper()
if entity in self.entities:
return "Entity already exists.\n"
else:
self.entities.append(entity)
return "ENTITY ADDED!\n"
def view_annotated_document(self, focus_entity=None):
"""
Print the document with entities marked by their colors.
:param focus_entity: List of entities to focus on in the document.
:return: Prints the document with entities color coded.
"""
colors = [
'RED',
'GREEN',
'YELLOW',
'BLUE',
'MAGENTA'
]
if focus_entity is not None:
entity_labels = list(set([e["label"] for e in self.entities if e["label"] in focus_entity]))
else:
entity_labels = list(set([e["label"] for e in self.entities]))
color_map = {e: colors[i] for i, e in enumerate(entity_labels)}
print("COLOR-ENTITY KEY: ")
for entity, color in color_map.items():
print(Fore.__getattribute__(color) + entity + Fore.RESET)
print("\n\n")
annotated_text = ''
track_idx = 0
for i, e in enumerate(sorted(self.entities, key=lambda d: d["start_idx"])):
cmap = color_map[e["label"]]
annotated_text += self.raw_text[track_idx:e["start_idx"]]
#annotated_text += cmap + self.raw_text[e["start_idx"]:e["end_idx"]] + '\033[0m'
annotated_text += Fore.__getattribute__(cmap) + self.raw_text[e["start_idx"]:e["end_idx"]] + Fore.RESET
track_idx = e["end_idx"]
if i == len(self.entities) - 1:
annotated_text += self.raw_text[e["end_idx"]:]
print(annotated_text)
def save(self, save_path):
"""
Saves the annotated object to the specified location where the object will be saved where the final location is
the base_path provided and the unique obj_id.
:return:
"""
if Path(save_path).exists():
overwrite = input("File already exists. Would you like to overwrite existing file? (y/n)")
if overwrite.lower() == "y":
json.dump(self.__dict__, open(save_path, "w"))
else:
json.dump(self.__dict__, open(save_path, "w"))