diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 3c918e5..1b49077 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -65,14 +65,14 @@ class Eynollah_ocr: self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size if tr_ocr: - self.model_zoo.load_model('trocr_processor') - self.model_zoo.load_model('ocr', 'tr') + self.model_zoo.load_models('trocr_processor') + self.model_zoo.load_models(['ocr', 'tr']) self.model_zoo.get('ocr').to(self.device) else: - self.model_zoo.load_model('ocr', '') - self.model_zoo.load_model('num_to_char') - self.model_zoo.load_model('characters') - self.end_character = len(self.model_zoo.get('characters', list)) + 2 + self.model_zoo.load_models('ocr') + self.model_zoo.load_models('num_to_char') + self.model_zoo.load_models('characters') + self.end_character = len(self.model_zoo.get('characters')) + 2 @property def device(self): diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index 22fe97b..b0b5910 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -19,7 +19,6 @@ import statistics os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf -from tensorflow.keras.models import Model from .model_zoo import EynollahModelZoo from .utils.resize import resize_image @@ -50,7 +49,7 @@ class machine_based_reading_order_on_layout: except: self.logger.warning("no GPU device available") - self.model_zoo.load_model('reading_order') + self.model_zoo.load_models('reading_order') def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) @@ -676,7 +675,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0') + y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0') for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) diff --git a/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 b/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 new file mode 100644 index 0000000..c7dd87d Binary files /dev/null and b/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 differ diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index fffd389..9611388 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -94,8 +94,12 @@ class EynollahModelZoo: elif model_category.endswith('_patched'): load_args[0] = model_category[:-8] load_kwargs["patched"] = True - ret[model_category] = Predictor(self.logger, self) - ret[model_category].load_model(*load_args, **load_kwargs, device=device) + spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '') + if spec.type in ['Keras'] and spec.category != 'ocr': + ret[model_category] = Predictor(self.logger, self) + ret[model_category].load_model(*load_args, **load_kwargs, device=device) + else: + ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device) self._loaded.update(ret) return self._loaded