diff --git a/ocrd_calamari/recognize.py b/ocrd_calamari/recognize.py index a0eca4c..817d6d5 100644 --- a/ocrd_calamari/recognize.py +++ b/ocrd_calamari/recognize.py @@ -48,6 +48,14 @@ class CalamariRecognize(Processor): checkpoints = glob(self.parameter['checkpoint']) self.predictor = MultiPredictor(checkpoints=checkpoints) + self.network_input_channels = self.predictor.predictors[0].network.input_channels + #self.network_input_channels = self.predictor.predictors[0].network_params.channels # not used! + # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used! + # self.features = ('' if self.network_input_channels != 1 else + # 'binarized' if binarization != 'GRAY' else + # 'grayscale_normalized') + self.features = '' + voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper()) self.voter = voter_from_proto(voter_params) @@ -69,17 +77,27 @@ class CalamariRecognize(Processor): pcgts = page_from_file(self.workspace.download_file(input_file)) page = pcgts.get_Page() - page_image, page_xywh, page_image_info = self.workspace.image_from_page(page, page_id) + page_image, page_coords, page_image_info = self.workspace.image_from_page( + page, page_id, feature_selector=self.features) - for region in pcgts.get_Page().get_TextRegion(): - region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh) + for region in page.get_TextRegion(): + region_image, region_coords = self.workspace.image_from_segment( + region, page_image, page_coords, feature_selector=self.features) textlines = region.get_TextLine() log.info("About to recognize %i lines of region '%s'", len(textlines), region.id) - line_images_np = [] - for (line_no, line) in enumerate(textlines): - line_image, line_coords = self.workspace.image_from_segment(line, region_image, region_xywh) + for line in textlines: + log.debug("Recognizing line '%s' in region '%s'", line.id, region.id) + + line_image, line_coords = self.workspace.image_from_segment(line, region_image, region_coords, feature_selector=self.features) + if ('binarized' not in line_coords['features'] and 'grayscale_normalized' not in line_coords['features'] and self.network_input_channels == 1): + # We cannot use a feature selector for this since we don't + # know whether the model expects (has been trained on) + # binarized or grayscale images; but raw images are likely + # always inadequate: + log.warning("Using raw image for line '%s' in region '%s'", line.id, region.id) + line_image = line_image if all(line_image.size) else [[0]] line_image_np = np.array(line_image, dtype=np.uint8) line_images_np.append(line_image_np) diff --git a/test/test_recognize.py b/test/test_recognize.py index 0a1e558..fdb2679 100644 --- a/test/test_recognize.py +++ b/test/test_recognize.py @@ -6,12 +6,12 @@ from lxml import etree from glob import glob import pytest +import logging from ocrd.resolver import Resolver from ocrd_calamari import CalamariRecognize from .base import assets - METS_KANT = assets.url_of('kant_aufklaerung_1784-page-region-line-word_glyph/data/mets.xml') WORKSPACE_DIR = '/tmp/test-ocrd-calamari' CHECKPOINT_DIR = os.path.join(os.getcwd(), 'gt4histocr-calamari1') @@ -99,6 +99,19 @@ def test_recognize_with_checkpoint_dir(workspace): assert "verſchuldeten" in f.read() +def test_recognize_should_warn_if_given_rgb_image_and_single_channel_model(workspace, caplog): + caplog.set_level(logging.WARNING) + CalamariRecognize( + workspace, + input_file_grp="OCR-D-GT-SEG-LINE", + output_file_grp="OCR-D-OCR-CALAMARI-BROKEN", + parameter={'checkpoint': CHECKPOINT} + ).process() + + interesting_log_messages = [t[2] for t in caplog.record_tuples if "Using raw image" in t[2]] + assert len(interesting_log_messages) > 10 # For every line! + + def test_word_segmentation(workspace): CalamariRecognize( workspace,