From d8db405a4c5597314aceba7252715c2b11c8c5cf Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 6 Nov 2019 00:39:05 +0100 Subject: [PATCH] warn if passing raw images to single-channel models --- ocrd_calamari/recognize.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/ocrd_calamari/recognize.py b/ocrd_calamari/recognize.py index 92aa5a4..31a37e1 100644 --- a/ocrd_calamari/recognize.py +++ b/ocrd_calamari/recognize.py @@ -31,6 +31,14 @@ class CalamariRecognize(Processor): checkpoints = glob(self.parameter['checkpoint']) self.predictor = MultiPredictor(checkpoints=checkpoints) + self.input_channels = self.predictor.predictors[0].network.input_channels + #self.input_channels = self.predictor.predictors[0].network_params.channels # not used! + # binarization = self.predictor.predictors[0].model_params.data_preprocessor.binarization # not used! + # self.features = ('' if self.input_channels != 1 else + # 'binarized' if binarization != 'GRAY' else + # 'grayscale_normalized') + self.features = '' + voter_params = VoterParams() voter_params.type = VoterParams.Type.Value(self.parameter['voter'].upper()) self.voter = voter_from_proto(voter_params) @@ -54,17 +62,30 @@ class CalamariRecognize(Processor): pcgts = page_from_file(self.workspace.download_file(input_file)) page = pcgts.get_Page() - page_image, page_xywh, page_image_info = self.workspace.image_from_page(page, page_id) + page_image, page_coords, page_image_info = self.workspace.image_from_page( + page, page_id, feature_selector=self.features) - for region in pcgts.get_Page().get_TextRegion(): - region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh) + for region in page.get_TextRegion(): + region_image, region_coords = self.workspace.image_from_segment( + region, page_image, page_coords, feature_selector=self.features) textlines = region.get_TextLine() log.info("About to recognize %i lines of region '%s'", len(textlines), region.id) - for (line_no, line) in enumerate(textlines): - log.debug("Recognizing line '%s' in region '%s'", line_no, region.id) - - line_image, line_xywh = self.workspace.image_from_segment(line, region_image, region_xywh) + for line in textlines: + log.debug("Recognizing line '%s' in region '%s'", line.id, region.id) + + line_image, line_coords = self.workspace.image_from_segment( + line, region_image, region_coords, feature_selector=self.features) + if ('binarized' not in line_coords['features'] and + 'grayscale_normalized' not in line_coords['features'] and + self.input_channels == 1): + # We cannot use a feature selector for this since we don't + # know whether the model expects (has been trained on) + # binarized or grayscale images; but raw images are likely + # always inadequate: + log.warning("Using raw image for line '%s' in region '%s'", + line.id, region.id) + line_image_np = np.array(line_image, dtype=np.uint8) raw_results = list(self.predictor.predict_raw([line_image_np], progress_bar=False))[0]