diff --git a/.vscode/settings.json b/.vscode/settings.json index de288e1..74a2cbb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,8 @@ { - "python.formatting.provider": "black" + "python.formatting.provider": "black", + "python.testing.pytestArgs": [ + "." + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/README-DEV.md b/README-DEV.md index 134e784..33da234 100644 --- a/README-DEV.md +++ b/README-DEV.md @@ -1,5 +1,5 @@ ``` -pip install -r requirements-test.txt +pip install -r requirements-dev.txt ``` To run tests: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..e63c022 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +pytest +pytest-profiling + +ipython + +mypy +types-lxml +types-tqdm +pandas-stubs diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index 6f0f369..0000000 --- a/requirements-test.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytest -pytest-profiling diff --git a/src/mods4pandas/alto4pandas.py b/src/mods4pandas/alto4pandas.py index 1508150..359a26e 100755 --- a/src/mods4pandas/alto4pandas.py +++ b/src/mods4pandas/alto4pandas.py @@ -5,6 +5,8 @@ import os import re import warnings import sys +import contextlib +import sqlite3 from xml.dom.expatbuilder import Namespaces from lxml import etree as ET from itertools import groupby @@ -13,11 +15,16 @@ from typing import List from collections.abc import MutableMapping, Sequence import click -import pandas as pd import numpy as np from tqdm import tqdm -from .lib import TagGroup, sorted_groupby, flatten, ns +from .lib import TagGroup, convert_db_to_parquet, sorted_groupby, flatten, ns, insert_into_db + +with warnings.catch_warnings(): + # Filter warnings on WSL + if "Microsoft" in os.uname().release: + warnings.simplefilter("ignore") + import pandas as pd logger = logging.getLogger('alto4pandas') @@ -74,12 +81,20 @@ def alto_to_dict(alto, raise_errors=True): value[localname] = TagGroup(tag, group).is_singleton().has_no_attributes().descend(raise_errors) elif localname == 'fileName': value[localname] = TagGroup(tag, group).is_singleton().has_no_attributes().text() + elif localname == 'fileIdentifier': + value[localname] = TagGroup(tag, group).is_singleton().has_no_attributes().text() elif localname == 'Layout': value[localname] = TagGroup(tag, group).is_singleton().has_no_attributes().descend(raise_errors) elif localname == 'Page': value[localname] = {} value[localname].update(TagGroup(tag, group).is_singleton().attributes()) + for attr in ("WIDTH", "HEIGHT"): + if attr in value[localname]: + try: + value[localname][attr] = int(value[localname][attr]) + except ValueError: + del value[localname][attr] value[localname].update(TagGroup(tag, group).subelement_counts()) value[localname].update(TagGroup(tag, group).xpath_statistics("//alto:String/@WC", namespaces)) @@ -121,30 +136,43 @@ def walk(m): @click.command() @click.argument('alto_files', type=click.Path(exists=True), required=True, nargs=-1) -@click.option('--output', '-o', 'output_file', type=click.Path(), help='Output pickle file', - default='alto_info_df.pkl', show_default=True) -@click.option('--output-csv', type=click.Path(), help='Output CSV file') -@click.option('--output-xlsx', type=click.Path(), help='Output Excel .xlsx file') -def process(alto_files: List[str], output_file: str, output_csv: str, output_xlsx: str): +@click.option('--output', '-o', 'output_file', type=click.Path(), help='Output Parquet file', + default='alto_info_df.parquet', show_default=True) +def process_command(alto_files: List[str], output_file: str): """ A tool to convert the ALTO metadata in INPUT to a pandas DataFrame. INPUT is assumed to be a ALTO document. INPUT may optionally be a directory. The tool then reads all files in the directory. - alto4pandas writes two output files: A pickled pandas DataFrame and a CSV file with all conversion warnings. + alto4pandas writes multiple output files: + - A Parquet DataFrame + - A SQLite database + - and a CSV file with all conversion warnings. """ + process(alto_files, output_file) + +def process(alto_files: List[str], output_file: str): # Extend file list if directories are given alto_files_real = [] for m in alto_files: for x in walk(m): alto_files_real.append(x) + # Prepare output files + with contextlib.suppress(FileNotFoundError): + os.remove(output_file) + output_file_sqlite3 = output_file + ".sqlite3" + with contextlib.suppress(FileNotFoundError): + os.remove(output_file_sqlite3) + + logger.info('Writing SQLite DB to {}'.format(output_file_sqlite3)) + con = sqlite3.connect(output_file_sqlite3) + # Process ALTO files with open(output_file + '.warnings.csv', 'w') as csvfile: csvwriter = csv.writer(csvfile) - alto_info = [] logger.info('Processing ALTO files') for alto_file in tqdm(alto_files_real, leave=False): try: @@ -160,7 +188,9 @@ def process(alto_files: List[str], output_file: str, output_csv: str, output_xls d['alto_file'] = alto_file d['alto_xmlns'] = ET.QName(alto).namespace - alto_info.append(d) + # Save + insert_into_db(con, "alto_info", d) + con.commit() if caught_warnings: # PyCharm thinks caught_warnings is not Iterable: @@ -171,25 +201,9 @@ def process(alto_files: List[str], output_file: str, output_csv: str, output_xls logger.error('Exception in {}: {}'.format(alto_file, e)) import traceback; traceback.print_exc() - # Convert the alto_info List[Dict] to a pandas DataFrame - columns = [] - for m in alto_info: - for c in m.keys(): - if c not in columns: - columns.append(c) - data = [[m.get(c) for c in columns] for m in alto_info] - index = [m['alto_file'] for m in alto_info] # TODO use ppn + page? - alto_info_df = pd.DataFrame(data=data, index=index, columns=columns) - - # Pickle the DataFrame + # Convert the alto_info SQL to a pandas DataFrame logger.info('Writing DataFrame to {}'.format(output_file)) - alto_info_df.to_pickle(output_file) - if output_csv: - logger.info('Writing CSV to {}'.format(output_csv)) - alto_info_df.to_csv(output_csv) - if output_xlsx: - logger.info('Writing Excel .xlsx to {}'.format(output_xlsx)) - alto_info_df.to_excel(output_xlsx) + convert_db_to_parquet(con, "alto_info", "alto_file", output_file) def main(): diff --git a/src/mods4pandas/lib.py b/src/mods4pandas/lib.py index d2e1f8f..68050b1 100644 --- a/src/mods4pandas/lib.py +++ b/src/mods4pandas/lib.py @@ -1,12 +1,21 @@ +from __future__ import annotations + +import ast from itertools import groupby import re import warnings -from typing import List, Sequence, MutableMapping, Dict +import os +from typing import Any, List, Sequence, MutableMapping, Dict +from collections import defaultdict -import pandas as pd import numpy as np from lxml import etree as ET +with warnings.catch_warnings(): + # Filter warnings on WSL + if "Microsoft" in os.uname().release: + warnings.simplefilter("ignore") + import pandas as pd __all__ = ["ns"] @@ -23,40 +32,40 @@ ns = { class TagGroup: """Helper class to simplify the parsing and checking of MODS metadata""" - def __init__(self, tag, group: List[ET.Element]): + def __init__(self, tag, group: List[ET._Element]): self.tag = tag self.group = group - def to_xml(self): + def to_xml(self) -> str: return '\n'.join(str(ET.tostring(e), 'utf-8').strip() for e in self.group) - def __str__(self): + def __str__(self) -> str: return f"TagGroup with content:\n{self.to_xml()}" - def is_singleton(self): + def is_singleton(self) -> TagGroup: if len(self.group) != 1: raise ValueError('More than one instance: {}'.format(self)) return self - def has_no_attributes(self): + def has_no_attributes(self) -> TagGroup: return self.has_attributes({}) - def has_attributes(self, attrib): + def has_attributes(self, attrib) -> TagGroup: if not isinstance(attrib, Sequence): attrib = [attrib] if not all(e.attrib in attrib for e in self.group): raise ValueError('One or more element has unexpected attributes: {}'.format(self)) return self - def ignore_attributes(self): + def ignore_attributes(self) -> TagGroup: # This serves as documentation for now. return self - def sort(self, key=None, reverse=False): + def sort(self, key=None, reverse=False) -> TagGroup: self.group = sorted(self.group, key=key, reverse=reverse) return self - def text(self, separator='\n'): + def text(self, separator='\n') -> str: t = '' for e in self.group: if t != '': @@ -65,13 +74,13 @@ class TagGroup: t += e.text return t - def text_set(self): + def text_set(self) -> set: return {e.text for e in self.group} - def descend(self, raise_errors): + def descend(self, raise_errors) -> dict: return _to_dict(self.is_singleton().group[0], raise_errors) - def filter(self, cond, warn=None): + def filter(self, cond, warn=None) -> TagGroup: new_group = [] for e in self.group: if cond(e): @@ -81,7 +90,7 @@ class TagGroup: warnings.warn('Filtered {} element ({})'.format(self.tag, warn)) return TagGroup(self.tag, new_group) - def force_singleton(self, warn=True): + def force_singleton(self, warn=True) -> TagGroup: if len(self.group) == 1: return self else: @@ -92,7 +101,7 @@ class TagGroup: RE_ISO8601_DATE = r'^\d{2}(\d{2}|XX)(-\d{2}-\d{2})?$' # Note: Includes non-specific century dates like '18XX' RE_GERMAN_DATE = r'^(?P
\d{2})\.(?P\d{2})\.(?P\d{4})$' - def fix_date(self): + def fix_date(self) -> TagGroup: for e in self.group: if e.attrib.get('encoding') == 'w3cdtf': @@ -102,15 +111,17 @@ class TagGroup: new_group = [] for e in self.group: + if e.text is None: + warnings.warn('Empty date') + continue if e.attrib.get('encoding') == 'iso8601' and re.match(self.RE_ISO8601_DATE, e.text): new_group.append(e) elif re.match(self.RE_ISO8601_DATE, e.text): warnings.warn('Added iso8601 encoding to date {}'.format(e.text)) e.attrib['encoding'] = 'iso8601' new_group.append(e) - elif re.match(self.RE_GERMAN_DATE, e.text): + elif m := re.match(self.RE_GERMAN_DATE, e.text): warnings.warn('Converted date {} to iso8601 encoding'.format(e.text)) - m = re.match(self.RE_GERMAN_DATE, e.text) e.text = '{}-{}-{}'.format(m.group('yyyy'), m.group('mm'), m.group('dd')) e.attrib['encoding'] = 'iso8601' new_group.append(e) @@ -130,7 +141,7 @@ class TagGroup: return self - def fix_event_type(self): + def fix_event_type(self) -> TagGroup: # According to MODS-AP 2.3.1, every originInfo should have its eventType set. # Fix this for special cases. @@ -160,7 +171,7 @@ class TagGroup: pass return self - def fix_script_term(self): + def fix_script_term(self) -> TagGroup: for e in self.group: # MODS-AP 2.3.1 is not clear about this, but it looks like that this should be lower case. if e.attrib['authority'] == 'ISO15924': @@ -168,7 +179,7 @@ class TagGroup: warnings.warn('Changed scriptTerm authority to lower case') return self - def merge_sub_tags_to_set(self): + def merge_sub_tags_to_set(self) -> dict: from .mods4pandas import mods_to_dict value = {} @@ -188,7 +199,7 @@ class TagGroup: value[sub_tag] = s return value - def attributes(self): + def attributes(self) -> dict[str, str]: """ Return a merged dict of all attributes of the tag group. @@ -203,8 +214,8 @@ class TagGroup: attrib[a_localname] = v return attrib - def subelement_counts(self): - counts = {} + def subelement_counts(self) -> dict[str, int]: + counts: dict[str, int] = {} for e in self.group: for x in e.iter(): tag = ET.QName(x.tag).localname @@ -212,19 +223,21 @@ class TagGroup: counts[key] = counts.get(key, 0) + 1 return counts - def xpath_statistics(self, xpath_expr, namespaces): + def xpath_statistics(self, xpath_expr, namespaces) -> dict[str, float]: """ Extract values and calculate statistics Extract values using the given XPath expression, convert them to float and return descriptive statistics on the values. """ - values = [] - for e in self.group: - r = e.xpath(xpath_expr, namespaces=namespaces) - values += r - values = np.array([float(v) for v in values]) + def xpath_values(): + values = [] + for e in self.group: + r = e.xpath(xpath_expr, namespaces=namespaces) + values += r + return np.array([float(v) for v in values]) + values = xpath_values() statistics = {} if values.size > 0: statistics[f'{xpath_expr}-mean'] = np.mean(values) @@ -234,7 +247,7 @@ class TagGroup: statistics[f'{xpath_expr}-max'] = np.max(values) return statistics - def xpath_count(self, xpath_expr, namespaces): + def xpath_count(self, xpath_expr, namespaces) -> dict[str, int]: """ Count all elements matching xpath_expr """ @@ -278,13 +291,13 @@ def _to_dict(root, raise_errors): raise ValueError(f"Unknown namespace {root_name.namespace}") -def flatten(d: MutableMapping, parent='', separator='_'): +def flatten(d: MutableMapping, parent='', separator='_') -> dict: """ Flatten the given nested dict. It is assumed that d maps strings to either another dictionary (similarly structured) or some other value. """ - items = [] + items: list[Any] = [] for k, v in d.items(): if parent: @@ -300,31 +313,79 @@ def flatten(d: MutableMapping, parent='', separator='_'): return dict(items) -def dicts_to_df(data_list: List[Dict], *, index_column) -> pd.DataFrame: - """ - Convert the given list of dicts to a Pandas DataFrame. - - The keys of the dicts make the columns. - """ - - # Build columns from keys - columns = [] - for m in data_list: - for c in m.keys(): - if c not in columns: - columns.append(c) - - # Build data table - data = [[m.get(c) for c in columns] for m in data_list] - - # Build index - if isinstance(index_column, str): - index = [m[index_column] for m in data_list] - elif isinstance(index_column, tuple): - index = [[m[c] for m in data_list] for c in index_column] - index = pd.MultiIndex.from_arrays(index, names=index_column) +def valid_column_key(k) -> bool: + if re.match(r'^[a-zA-Z0-9 _@/:\[\]-]+$', k): + return True else: - raise ValueError(f"index_column must") + return False - df = pd.DataFrame(data=data, index=index, columns=columns) - return df +def column_names_csv(columns) -> str: + """ + Format Column names (identifiers) as a comma-separated list. + + This uses double quotes per SQL standard. + """ + return ",".join('"' + c + '"' for c in columns) + +current_columns: dict[str, list] = defaultdict(list) +current_columns_types: dict[str, dict] = defaultdict(dict) + +def insert_into_db(con, table, d: Dict): + """Insert the values from the dict into the table, creating columns if necessary""" + + # Create table if necessary + if not current_columns[table]: + for k in d.keys(): + assert valid_column_key(k), f'"{k}" is not a valid column name' + current_columns[table].append(k) + con.execute(f"CREATE TABLE {table} ({column_names_csv(current_columns[table])})") + + # Add columns if necessary + for k in d.keys(): + if not k in current_columns[table]: + assert valid_column_key(k), f'"{k}" is not a valid column name' + current_columns[table].append(k) + con.execute(f'ALTER TABLE {table} ADD COLUMN "{k}"') + + # Save types + for k in d.keys(): + if k not in current_columns_types[table]: + current_columns_types[table][k] = type(d[k]).__name__ + + # Insert + # Unfortunately, Python3's sqlite3 does not like named placeholders with spaces, so we + # have use qmark style here. + columns = d.keys() + con.execute( + f"INSERT INTO {table}" + f"( {column_names_csv(columns)} )" + "VALUES" + f"( {','.join('?' for c in columns)} )", + [str(d[c]) for c in columns] + ) + +def insert_into_db_multiple(con, table, ld: List[Dict]): + for d in ld: + insert_into_db(con, table, d) + +def convert_db_to_parquet(con, table, index_col, output_file): + df = pd.read_sql_query(f"SELECT * FROM {table}", con, index_col) + + # Convert Python column type into Pandas type + for c in df.columns: + column_type = current_columns_types[table][c] + + if column_type == "str": + continue + elif column_type == "int": + df[c] = df[c].astype("Int64") + elif column_type == "float64": + df[c] = df[c].astype("Float64") + elif column_type == "bool": + df[c] = df[c].map({"True": True, "False": False}).astype("boolean") + elif column_type == "set": + df[c] = df[c].apply(lambda s: list(ast.literal_eval(s)) if s else None) + else: + raise NotImplementedError(f"Column {c}: type {column_type} not implemented yet.") + + df.to_parquet(output_file) \ No newline at end of file diff --git a/src/mods4pandas/mods4pandas.py b/src/mods4pandas/mods4pandas.py index ef24d36..669c1e0 100755 --- a/src/mods4pandas/mods4pandas.py +++ b/src/mods4pandas/mods4pandas.py @@ -1,21 +1,29 @@ #!/usr/bin/env python3 +import contextlib import csv import logging import os import re +import sqlite3 import warnings +import sys from lxml import etree as ET from itertools import groupby from operator import attrgetter from typing import Dict, List +from collections import defaultdict from collections.abc import MutableMapping, Sequence import click -import pandas as pd from tqdm import tqdm -from .lib import sorted_groupby, TagGroup, ns, flatten, dicts_to_df +from .lib import convert_db_to_parquet, sorted_groupby, TagGroup, ns, flatten, insert_into_db, insert_into_db_multiple, current_columns_types +with warnings.catch_warnings(): + # Filter warnings on WSL + if "Microsoft" in os.uname().release: + warnings.simplefilter("ignore") + import pandas as pd logger = logging.getLogger('mods4pandas') @@ -273,7 +281,7 @@ def pages_to_dict(mets, raise_errors=True) -> List[Dict]: # This is expected in a multivolume work or periodical! if any( structMap_LOGICAL.find(f'./mets:div[@TYPE="{t}"]', ns) is not None - for t in ["multivolume_work", "MultivolumeWork", "periodical"] + for t in ["multivolume_work", "MultivolumeWork", "multivolume_manuscript", "periodical"] ): return [] else: @@ -319,6 +327,8 @@ def pages_to_dict(mets, raise_errors=True) -> List[Dict]: assert file_ is not None fileGrp_USE = file_.getparent().attrib.get("USE") file_FLocat_href = (file_.xpath('mets:FLocat/@xlink:href', namespaces=ns) or [None])[0] + if file_FLocat_href is not None: + file_FLocat_href = str(file_FLocat_href) page_dict[f"fileGrp_{fileGrp_USE}_file_FLocat_href"] = file_FLocat_href def get_struct_log(*, to_phys): @@ -358,9 +368,9 @@ def pages_to_dict(mets, raise_errors=True) -> List[Dict]: # Populate structure type indicator variables for struct_div in struct_divs: - type_ = struct_div.attrib.get("TYPE") + type_ = struct_div.attrib.get("TYPE").lower() assert type_ - page_dict[f"structMap-LOGICAL_TYPE_{type_}"] = 1 + page_dict[f"structMap-LOGICAL_TYPE_{type_}"] = True result.append(page_dict) @@ -372,7 +382,7 @@ def pages_to_dict(mets, raise_errors=True) -> List[Dict]: @click.option('--output', '-o', 'output_file', type=click.Path(), help='Output Parquet file', default='mods_info_df.parquet', show_default=True) @click.option('--output-page-info', type=click.Path(), help='Output page info Parquet file') -def process(mets_files: List[str], output_file: str, output_page_info: str): +def process_command(mets_files: list[str], output_file: str, output_page_info: str): """ A tool to convert the MODS metadata in INPUT to a pandas DataFrame. @@ -383,9 +393,11 @@ def process(mets_files: List[str], output_file: str, output_page_info: str): Per-page information (e.g. structure information) can be output to a separate Parquet file. """ + process(mets_files, output_file, output_page_info) +def process(mets_files: list[str], output_file: str, output_page_info: str): # Extend file list if directories are given - mets_files_real = [] + mets_files_real: list[str] = [] for m in mets_files: if os.path.isdir(m): logger.info('Scanning directory {}'.format(m)) @@ -394,13 +406,29 @@ def process(mets_files: List[str], output_file: str, output_page_info: str): else: mets_files_real.append(m) + + # Prepare output files + with contextlib.suppress(FileNotFoundError): + os.remove(output_file) + output_file_sqlite3 = output_file + ".sqlite3" + with contextlib.suppress(FileNotFoundError): + os.remove(output_file_sqlite3) + + logger.info('Writing SQLite DB to {}'.format(output_file_sqlite3)) + con = sqlite3.connect(output_file_sqlite3) + + if output_page_info: + output_page_info_sqlite3 = output_page_info + ".sqlite3" + logger.info('Writing SQLite DB to {}'.format(output_page_info_sqlite3)) + with contextlib.suppress(FileNotFoundError): + os.remove(output_page_info_sqlite3) + con_page_info = sqlite3.connect(output_page_info_sqlite3) + # Process METS files with open(output_file + '.warnings.csv', 'w') as csvfile: csvwriter = csv.writer(csvfile) - mods_info = [] - page_info = [] logger.info('Processing METS files') - for mets_file in tqdm(mets_files_real, leave=False): + for mets_file in tqdm(mets_files_real, leave=True): try: root = ET.parse(mets_file).getroot() mets = root # XXX .find('mets:mets', ns) does not work here @@ -419,13 +447,15 @@ def process(mets_files: List[str], output_file: str, output_page_info: str): # "meta" d['mets_file'] = mets_file + # Save + insert_into_db(con, "mods_info", d) + con.commit() + # METS - per-page if output_page_info: page_info_doc: list[dict] = pages_to_dict(mets, raise_errors=True) - - mods_info.append(d) - if output_page_info: - page_info.extend(page_info_doc) + insert_into_db_multiple(con_page_info, "page_info", page_info_doc) + con_page_info.commit() if caught_warnings: # PyCharm thinks caught_warnings is not Iterable: @@ -433,22 +463,13 @@ def process(mets_files: List[str], output_file: str, output_page_info: str): for caught_warning in caught_warnings: csvwriter.writerow([mets_file, caught_warning.message]) except Exception as e: - logger.error('Exception in {}: {}'.format(mets_file, e)) - #import traceback; traceback.print_exc() + logger.exception('Exception in {}'.format(mets_file)) - # Convert the mods_info List[Dict] to a pandas DataFrame - mods_info_df = dicts_to_df(mods_info, index_column="recordInfo_recordIdentifier") - - # Save the DataFrame logger.info('Writing DataFrame to {}'.format(output_file)) - mods_info_df.to_parquet(output_file) - - # Convert page_info + convert_db_to_parquet(con, "mods_info", "recordInfo_recordIdentifier", output_file) if output_page_info: - page_info_df = dicts_to_df(page_info, index_column=("ppn", "ID")) - # Save the DataFrame - logger.info('Writing DataFrame to {}'.format(output_page_info)) - page_info_df.to_parquet(output_page_info) + logger.info('Writing DataFrame to {}'.format(output_page_info)) + convert_db_to_parquet(con_page_info, "page_info", ["ppn", "ID"], output_page_info) def main(): @@ -457,7 +478,7 @@ def main(): for prefix, uri in ns.items(): ET.register_namespace(prefix, uri) - process() + process_command() if __name__ == '__main__': diff --git a/src/mods4pandas/tests/data/alto/PPN1844793923/00000017.xml b/src/mods4pandas/tests/data/alto/PPN1844793923/00000017.xml new file mode 100644 index 0000000..7f658fa --- /dev/null +++ b/src/mods4pandas/tests/data/alto/PPN1844793923/00000017.xml @@ -0,0 +1,663 @@ + + + + + + pixel + + 16_b079a_default.jpg + https://content.staatsbibliothek-berlin.de/dc/1844793923-0017/full/full/0/default.jpg + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/mods4pandas/tests/test_alto.py b/src/mods4pandas/tests/test_alto.py index 827bc7a..adf931f 100644 --- a/src/mods4pandas/tests/test_alto.py +++ b/src/mods4pandas/tests/test_alto.py @@ -1,9 +1,13 @@ +from pathlib import Path +import re from lxml import etree as ET +import pandas as pd -from mods4pandas.alto4pandas import alto_to_dict +from mods4pandas.alto4pandas import alto_to_dict, process from mods4pandas.lib import flatten +TESTS_DATA_DIR = Path(__file__).parent / "data" def dict_fromstring(x): return flatten(alto_to_dict(ET.fromstring(x))) @@ -79,3 +83,50 @@ def test_String_TAGREF_counts(): """) assert d['Layout_Page_//alto:String[@TAGREFS]-count'] == 3 assert d['Layout_Page_String-count'] == 4 + + +def test_dtypes(tmp_path): + alto_dir = (TESTS_DATA_DIR / "alto").absolute().as_posix() + alto_info_df_parquet = (tmp_path / "test_dtypes_alto_info.parquet").as_posix() + process([alto_dir], alto_info_df_parquet) + alto_info_df = pd.read_parquet(alto_info_df_parquet) + + EXPECTED_TYPES = { + r"Description_.*": ("object", ["str", "NoneType"]), + r"Layout_Page_ID": ("object", ["str", "NoneType"]), + r"Layout_Page_PHYSICAL_(IMG|IMAGE)_NR": ("object", ["str", "NoneType"]), + r"Layout_Page_PROCESSING": ("object", ["str", "NoneType"]), + r"Layout_Page_QUALITY": ("object", ["str", "NoneType"]), + r"Layout_Page_//alto:String/@WC-.*": ("Float64", None), + r".*-count": ("Int64", None), + r"alto_xmlns": ("object", ["str", "NoneType"]), + + r"Layout_Page_(WIDTH|HEIGHT)": ("Int64", None), + } + def expected_types(c): + """Return the expected types for column c.""" + for r, types in EXPECTED_TYPES.items(): + if re.fullmatch(r, c): + edt = types[0] + einner_types = types[1] + if einner_types: + einner_types = set(einner_types) + return edt, einner_types + return None, None + + def check_types(df): + """Check the types of the DataFrame df.""" + for c in df.columns: + dt = df.dtypes[c] + edt, einner_types = expected_types(c) + print(c, dt, edt) + + assert edt is not None, f"No expected dtype known for column {c} (got {dt})" + assert dt == edt, f"Unexpected dtype {dt} for column {c} (expected {edt})" + + if edt == "object": + inner_types = set(type(v).__name__ for v in df[c]) + assert all(it in einner_types for it in inner_types), \ + f"Unexpected inner types {inner_types} for column {c} (expected {einner_types})" + + check_types(alto_info_df) \ No newline at end of file diff --git a/src/mods4pandas/tests/test_mods4pandas.py b/src/mods4pandas/tests/test_mods4pandas.py index f9a98d7..0707a74 100644 --- a/src/mods4pandas/tests/test_mods4pandas.py +++ b/src/mods4pandas/tests/test_mods4pandas.py @@ -1,10 +1,14 @@ +from pathlib import Path +import re from lxml import etree as ET +import pandas as pd import pytest -from mods4pandas.mods4pandas import mods_to_dict +from mods4pandas.mods4pandas import mods_to_dict, process from mods4pandas.lib import flatten +TESTS_DATA_DIR = Path(__file__).parent / "data" def dict_fromstring(x): """Helper function to parse a MODS XML string to a flattened dict""" @@ -151,3 +155,68 @@ def test_relatedItem(): """) assert d['relatedItem-original_recordInfo_recordIdentifier-dnb-ppn'] == '1236513355' + +def test_dtypes(tmp_path): + mets_files = [p.absolute().as_posix() for p in (TESTS_DATA_DIR / "mets-mods").glob("*.xml")] + mods_info_df_parquet = (tmp_path / "test_dtypes_mods_info.parquet").as_posix() + page_info_df_parquet = (tmp_path / "test_dtypes_page_info.parquet").as_posix() + process(mets_files, mods_info_df_parquet, page_info_df_parquet) + mods_info_df = pd.read_parquet(mods_info_df_parquet) + page_info_df = pd.read_parquet(page_info_df_parquet) + + EXPECTED_TYPES = { + # mods_info + + r"mets_file": ("object", ["str"]), + r"titleInfo_title": ("object", ["str"]), + r"titleInfo_subTitle": ("object", ["str", "NoneType"]), + r"titleInfo_partName": ("object", ["str", "NoneType"]), + r"identifier-.*": ("object", ["str", "NoneType"]), + r"location_.*": ("object", ["str", "NoneType"]), + r"name\d+_.*roleTerm": ("object", ["ndarray", "NoneType"]), + r"name\d+_.*": ("object", ["str", "NoneType"]), + r"relatedItem-.*_recordInfo_recordIdentifier": ("object", ["str", "NoneType"]), + r"typeOfResource": ("object", ["str", "NoneType"]), + r"accessCondition-.*": ("object", ["str", "NoneType"]), + r"originInfo-.*": ("object", ["str", "NoneType"]), + + r".*-count": ("Int64", None), + + r"genre-.*": ("object", ["ndarray", "NoneType"]), + r"subject-.*": ("object", ["ndarray", "NoneType"]), + r"language_.*Term": ("object", ["ndarray", "NoneType"]), + r"classification-.*": ("object", ["ndarray", "NoneType"]), + + # page_info + + r"fileGrp_.*_file_FLocat_href": ("object", ["str", "NoneType"]), + r"structMap-LOGICAL_TYPE_.*": ("boolean", None), + } + def expected_types(c): + """Return the expected types for column c.""" + for r, types in EXPECTED_TYPES.items(): + if re.fullmatch(r, c): + edt = types[0] + einner_types = types[1] + if einner_types: + einner_types = set(einner_types) + return edt, einner_types + return None, None + + def check_types(df): + """Check the types of the DataFrame df.""" + for c in df.columns: + dt = df.dtypes[c] + edt, einner_types = expected_types(c) + print(c, dt, edt) + + assert edt is not None, f"No expected dtype known for column {c} (got {dt})" + assert dt == edt, f"Unexpected dtype {dt} for column {c} (expected {edt})" + + if edt == "object": + inner_types = set(type(v).__name__ for v in df[c]) + assert all(it in einner_types for it in inner_types), \ + f"Unexpected inner types {inner_types} for column {c} (expected {einner_types})" + + check_types(mods_info_df) + check_types(page_info_df) \ No newline at end of file