mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-13 01:13:54 +02:00
fix model loading in mb_ro and ocr
This commit is contained in:
parent
2035b07b55
commit
218a95e6a0
4 changed files with 14 additions and 11 deletions
|
|
@ -65,14 +65,14 @@ class Eynollah_ocr:
|
||||||
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size
|
||||||
|
|
||||||
if tr_ocr:
|
if tr_ocr:
|
||||||
self.model_zoo.load_model('trocr_processor')
|
self.model_zoo.load_models('trocr_processor')
|
||||||
self.model_zoo.load_model('ocr', 'tr')
|
self.model_zoo.load_models(['ocr', 'tr'])
|
||||||
self.model_zoo.get('ocr').to(self.device)
|
self.model_zoo.get('ocr').to(self.device)
|
||||||
else:
|
else:
|
||||||
self.model_zoo.load_model('ocr', '')
|
self.model_zoo.load_models('ocr')
|
||||||
self.model_zoo.load_model('num_to_char')
|
self.model_zoo.load_models('num_to_char')
|
||||||
self.model_zoo.load_model('characters')
|
self.model_zoo.load_models('characters')
|
||||||
self.end_character = len(self.model_zoo.get('characters', list)) + 2
|
self.end_character = len(self.model_zoo.get('characters')) + 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import statistics
|
||||||
|
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import Model
|
|
||||||
|
|
||||||
from .model_zoo import EynollahModelZoo
|
from .model_zoo import EynollahModelZoo
|
||||||
from .utils.resize import resize_image
|
from .utils.resize import resize_image
|
||||||
|
|
@ -50,7 +49,7 @@ class machine_based_reading_order_on_layout:
|
||||||
except:
|
except:
|
||||||
self.logger.warning("no GPU device available")
|
self.logger.warning("no GPU device available")
|
||||||
|
|
||||||
self.model_zoo.load_model('reading_order')
|
self.model_zoo.load_models('reading_order')
|
||||||
|
|
||||||
def read_xml(self, xml_file):
|
def read_xml(self, xml_file):
|
||||||
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8'))
|
||||||
|
|
@ -676,7 +675,7 @@ class machine_based_reading_order_on_layout:
|
||||||
tot_counter += 1
|
tot_counter += 1
|
||||||
batch.append(j)
|
batch.append(j)
|
||||||
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
if tot_counter % inference_bs == 0 or tot_counter == len(ij_list):
|
||||||
y_pr = self.model_zoo.get('reading_order', Model).predict(input_1 , verbose='0')
|
y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0')
|
||||||
for jb, j in enumerate(batch):
|
for jb, j in enumerate(batch):
|
||||||
if y_pr[jb][0]>=0.5:
|
if y_pr[jb][0]>=0.5:
|
||||||
post_list.append(j)
|
post_list.append(j)
|
||||||
|
|
|
||||||
BIN
src/eynollah/model_zoo/.nfs00000002feddea7d00000031
Normal file
BIN
src/eynollah/model_zoo/.nfs00000002feddea7d00000031
Normal file
Binary file not shown.
|
|
@ -94,8 +94,12 @@ class EynollahModelZoo:
|
||||||
elif model_category.endswith('_patched'):
|
elif model_category.endswith('_patched'):
|
||||||
load_args[0] = model_category[:-8]
|
load_args[0] = model_category[:-8]
|
||||||
load_kwargs["patched"] = True
|
load_kwargs["patched"] = True
|
||||||
ret[model_category] = Predictor(self.logger, self)
|
spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '')
|
||||||
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
if spec.type in ['Keras'] and spec.category != 'ocr':
|
||||||
|
ret[model_category] = Predictor(self.logger, self)
|
||||||
|
ret[model_category].load_model(*load_args, **load_kwargs, device=device)
|
||||||
|
else:
|
||||||
|
ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device)
|
||||||
self._loaded.update(ret)
|
self._loaded.update(ret)
|
||||||
return self._loaded
|
return self._loaded
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue