Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PageXML: keep SubElement NS prefixes consistent when writing #342

Merged
merged 4 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions calamari_ocr/ocr/dataset/datareader/pagexml/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ def _bounding_rect_from_points(points: List[Tuple[int, int]]) -> Tuple[int, int,
def _coords_for_rectangle(x, y, width, height):
return f"{int(x)},{int(y)} {int(x+width)},{int(y)} {int(x+width)},{int(y+height)} {int(x)},{int(y+height)}"

@staticmethod
def _make_subelement(parent, tag, attrib=None):
tag = '{' + parent.nsmap.get(parent.prefix, '') + '}' + tag
return etree.SubElement(parent, tag, attrib=attrib, nsmap=parent.nsmap)

def _store_old_word(self, word_xml, ns):
word_xml.set("id", f"{word_xml.get('id')}_old")

Expand All @@ -496,11 +501,11 @@ def _store_glyph(self, glyph, word_id, word_xml, line_x, line_y, line_height, gl

glyph_xml = word_xml.find(f'./ns:Glyph[@id="{glyph_id}"]', namespaces=ns)
if glyph_xml is None:
glyph_xml = etree.SubElement(word_xml, "Glyph", attrib={"id": glyph_id})
glyph_xml = self._make_subelement(word_xml, "Glyph", attrib={"id": glyph_id})

coords_xml = glyph_xml.find("./ns:Coords", namespaces=ns)
if coords_xml is None:
coords_xml = etree.SubElement(glyph_xml, "Coords")
coords_xml = self._make_subelement(glyph_xml, "Coords")

glyph_x, glyph_y = glyph.global_start + line_x, line_y
glyph_width, glyph_height = glyph.global_end - glyph.global_start, line_height
Expand All @@ -513,15 +518,15 @@ def _store_glyph(self, glyph, word_id, word_xml, line_x, line_y, line_height, gl

textequiv_xml = glyph_xml.find(f'./ns:TextEquiv[@index="{glyph_index}"]', namespaces=ns)
if textequiv_xml is None:
textequiv_xml = etree.SubElement(glyph_xml, "TextEquiv")
textequiv_xml = self._make_subelement(glyph_xml, "TextEquiv")
textequiv_xml.set("index", str(glyph_index))

if self.params.output_confidences:
textequiv_xml.set("conf", str(confidence))

u_xml = textequiv_xml.find("./ns:Unicode", namespaces=ns)
if u_xml is None:
u_xml = etree.SubElement(textequiv_xml, "Unicode")
u_xml = self._make_subelement(textequiv_xml, "Unicode")
u_xml.text = char

def _store_words(self, words, line_xml, line_coords, ns) -> float:
Expand Down Expand Up @@ -554,11 +559,11 @@ def _store_words(self, words, line_xml, line_coords, ns) -> float:
word_xml = line_xml.find(f'./ns:Word[@id="{word_id}"]', namespaces=ns)
if word_xml is None:
# no word with this id, create a new word element
word_xml = etree.SubElement(line_xml, "Word", attrib={"id": word_id})
word_xml = self._make_subelement(line_xml, "Word", attrib={"id": word_id})

coords_xml = word_xml.find("./ns:Coords", namespaces=ns)
if coords_xml is None:
coords_xml = etree.SubElement(word_xml, "Coords")
coords_xml = self._make_subelement(word_xml, "Coords")

word_text = ""
word_confidence = 1
Expand All @@ -575,15 +580,15 @@ def _store_words(self, words, line_xml, line_coords, ns) -> float:

textequiv_xml = word_xml.find(f'./ns:TextEquiv[@index="{self.params.text_index}"]', namespaces=ns)
if textequiv_xml is None:
textequiv_xml = etree.SubElement(word_xml, "TextEquiv")
textequiv_xml = self._make_subelement(word_xml, "TextEquiv")
textequiv_xml.set("index", str(self.params.text_index))

if self.params.output_confidences:
textequiv_xml.set("conf", str(word_confidence))

u_xml = textequiv_xml.find("./ns:Unicode", namespaces=ns)
if u_xml is None:
u_xml = etree.SubElement(textequiv_xml, "Unicode")
u_xml = self._make_subelement(textequiv_xml, "Unicode")
u_xml.text = word_text

word_x, word_y = word[0].global_start + line_x, line_y
Expand Down
6 changes: 6 additions & 0 deletions calamari_ocr/ocr/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, params: ModelParams, name="CalamariGraph", **kwargs):
self.reshape = ToInputDimsLayerParams(dims=3).create()
self.logits = KL.Dense(params.classes, name="logits")
self.softmax = KL.Softmax(name="softmax")
self.temperature = (
tf.constant(params.temperature, dtype=tf.float32, name="temperature") if params.temperature > 0 else None
)

def build_graph(self, inputs, training=None):
params: ModelParams = self._params
Expand Down Expand Up @@ -90,6 +93,9 @@ def build_graph(self, inputs, training=None):
blank_last_softmax = self.softmax(blank_last_logits)

logits = tf.roll(blank_last_logits, shift=1, axis=-1)
if self.temperature != None:
logits = tf.divide(logits, self.temperature) ### TEST scale, seems to work...

softmax = tf.nn.softmax(logits)

greedy_decoded = ctc.ctc_greedy_decoder(
Expand Down
1 change: 1 addition & 0 deletions calamari_ocr/ocr/model/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ModelParams(ModelBaseParams):
classes: int = -1
ctc_merge_repeated: bool = True
ensemble: int = 0 # For usage with the ensemble-model graph
temperature: float = field(default=-1, metadata=pai_meta(help="Value to divide logits by (temperature scaling)."))
masking_mode: int = False # This parameter is for evaluation only and should not be used in production

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions calamari_ocr/ocr/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def from_checkpoint(params: PredictorParams, checkpoint: str, auto_update_checkp
predictor = Predictor(params, scenario.create_data())
predictor.set_model(
keras.models.load_model(
ckpt.ckpt_path + ".h5",
ckpt.ckpt_path,
custom_objects=CalamariScenario.model_cls().all_custom_objects(),
)
)
Expand All @@ -50,11 +50,12 @@ def from_paths(

DeviceConfig(predictor_params.device)
checkpoints = [SavedCalamariModel(ckpt, auto_update=auto_update_checkpoints) for ckpt in checkpoints]

multi_predictor = super(MultiPredictor, cls).from_paths(
[ckpt.json_path for ckpt in checkpoints],
predictor_params,
CalamariScenario,
model_paths=[ckpt.ckpt_path + ".h5" for ckpt in checkpoints],
model_paths=[ckpt.ckpt_path for ckpt in checkpoints],
predictor_args={"voter_params": voter_params},
)

Expand Down
34 changes: 34 additions & 0 deletions calamari_ocr/ocr/savedmodel/migrations/version5to6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
import os

from tensorflow import keras

from calamari_ocr.ocr.scenario import CalamariScenario
from calamari_ocr.ocr.training.params import TrainerParams

logger = logging.getLogger(__name__)


def update_model(params: dict, path: str):
logger.info(f"Updating model at {path}")

trainer_params = TrainerParams.from_dict(params)
scenario_params = trainer_params.scenario
scenario = CalamariScenario(scenario_params)
inputs = scenario.data.create_input_layers()
outputs = scenario.graph.predict(inputs)
pred_model = keras.models.Model(inputs, outputs)
pred_model.load_weights(path + ".h5")

logger.info(f"Writing converted model at {path}.tmp")
pred_model.save(path + ".tmp", include_optimizer=False)
logger.info(f"Attempting to load converted model at {path}.tmp")
keras.models.load_model(
path + ".tmp",
custom_objects=CalamariScenario.model_cls().all_custom_objects(),
)
logger.info(f"Replacing old model at {path}.h5")
os.remove(path + ".h5")
os.rename(path + ".tmp", path)
logger.info(f"New model successfully written")
keras.backend.clear_session()
8 changes: 7 additions & 1 deletion calamari_ocr/ocr/savedmodel/saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class SavedCalamariModel:
VERSION = 5
VERSION = 6

def __init__(self, json_path: str, auto_update=True, dry_run=False):
self.json_path = json_path if json_path.endswith(".json") else json_path + ".json"
Expand Down Expand Up @@ -98,6 +98,12 @@ def _single_upgrade(self):
update_model(self.dict, self.ckpt_path)
self.version = 5

elif self.version == 5:
from calamari_ocr.ocr.savedmodel.migrations.version5to6 import update_model

update_model(self.dict, self.ckpt_path)
self.version = 6

self._update_json_version()

def _update_json_version(self):
Expand Down
2 changes: 1 addition & 1 deletion calamari_ocr/ocr/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def default_params(cls):
scenario_params = super(CalamariScenarioBase, cls).default_params()
scenario_params.export_serve = True
scenario_params.export_net_config = False
scenario_params.default_serve_dir = "best.ckpt.h5"
scenario_params.default_serve_dir = "best.ckpt"
scenario_params.scenario_params_filename = "scenario_params.json" # should never be written!
scenario_params.trainer_params_filename = "best.ckpt.json"
return scenario_params
Expand Down
2 changes: 1 addition & 1 deletion calamari_ocr/ocr/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TrainerParams(AIPTrainerParams[CalamariScenarioParams, CalamariDefaultTrai
)

def __post_init__(self):
self.scenario.default_serve_dir = f"{self.best_model_prefix}.ckpt.h5"
self.scenario.default_serve_dir = f"{self.best_model_prefix}.ckpt"
self.scenario.trainer_params_filename = f"{self.best_model_prefix}.ckpt.json"
self.early_stopping.best_model_name = ""

Expand Down
2 changes: 1 addition & 1 deletion calamari_ocr/ocr/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, params: TrainerParams, scenario, restore=False):
self._params.warmstart.model,
auto_update=self._params.auto_upgrade_checkpoints,
)
self._params.warmstart.model = self.checkpoint.ckpt_path + ".h5"
self._params.warmstart.model = self.checkpoint.ckpt_path
self._params.warmstart.trim_graph_name = False
network = self.checkpoint.trainer_params.network
if self._params.network != network:
Expand Down
4 changes: 2 additions & 2 deletions calamari_ocr/scripts/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def split(args):
ckpt = SavedCalamariModel(args.model)
keras_model = keras.models.load_model(
ckpt.ckpt_path + ".h5",
ckpt.ckpt_path,
custom_objects={
"Graph": Graph,
"EnsembleGraph": EnsembleGraph,
Expand Down Expand Up @@ -62,7 +62,7 @@ def extract_keras_model(i):
path = os.path.join(ckpt.dirname, f"{ckpt.basename}_split_{i}.ckpt")
with open(path + ".json", "w") as f:
json.dump(ckpt_dict, f, indent=2)
split_model.save(path + ".h5")
split_model.save(path)
print(f"Saved {i + 1}/{len(split_models)}")


Expand Down
Loading
Loading