Revert "fix model loading in mb_ro and ocr"

This reverts commit 218a95e6a0.
This commit is contained in:
Robert Sachunsky 2026-05-19 03:32:19 +02:00
parent 1df32eba87
commit a1449da1d1
4 changed files with 11 additions and 14 deletions

View file

@ -65,14 +65,14 @@ class Eynollah_ocr:
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
if tr_ocr: if tr_ocr:
self.model_zoo.load_models('trocr_processor') self.model_zoo.load_model('trocr_processor')
self.model_zoo.load_models(['ocr', 'tr']) self.model_zoo.load_model('ocr', 'tr')
self.model_zoo.get('ocr').to(self.device) self.model_zoo.get('ocr').to(self.device)
else: else:
self.model_zoo.load_models('ocr') self.model_zoo.load_model('ocr', '')
self.model_zoo.load_models('num_to_char') self.model_zoo.load_model('num_to_char')
self.model_zoo.load_models('characters') self.model_zoo.load_model('characters')
self.end_character = len(self.model_zoo.get('characters')) + 2 self.end_character = len(self.model_zoo.get('characters', list)) + 2
@property @property
def device(self): def device(self):

View file

@ -19,6 +19,7 @@ import statistics
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import Model
from .model_zoo import EynollahModelZoo from .model_zoo import EynollahModelZoo
from .utils.resize import resize_image from .utils.resize import resize_image
@ -49,7 +50,7 @@ class machine_based_reading_order_on_layout:
except: except:
self.logger.warning("no GPU device available") self.logger.warning("no GPU device available")
self.model_zoo.load_models('reading_order') self.model_zoo.load_model('reading_order')
def read_xml(self, xml_file): def read_xml(self, xml_file):
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
@ -675,7 +676,7 @@ class machine_based_reading_order_on_layout:
tot_counter += 1 tot_counter += 1
batch.append(j) batch.append(j)
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0') y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
for jb, j in enumerate(batch): for jb, j in enumerate(batch):
if y_pr[jb][0]>=0.5: if y_pr[jb][0]>=0.5:
post_list.append(j) post_list.append(j)

View file

@ -94,12 +94,8 @@ class EynollahModelZoo:
elif model_category.endswith('_patched'): elif model_category.endswith('_patched'):
load_args[0] = model_category[:-8] load_args[0] = model_category[:-8]
load_kwargs["patched"] = True load_kwargs["patched"] = True
spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '') ret[model_category] = Predictor(self.logger, self)
if spec.type in ['Keras'] and spec.category != 'ocr': ret[model_category].load_model(*load_args, **load_kwargs, device=device)
ret[model_category] = Predictor(self.logger, self)
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
else:
ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device)
self._loaded.update(ret) self._loaded.update(ret)
return self._loaded return self._loaded