From ebe988cfff18b7849551aa7e1acc6ac26fd2d21f Mon Sep 17 00:00:00 2001 From: Mike Gerber Date: Wed, 4 Jun 2025 21:10:10 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20Restore=20types=20before=20savin?= =?UTF-8?q?g=20as=20Parquet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- check_dtypes.py | 10 ---------- src/mods4pandas/alto4pandas.py | 5 ++--- src/mods4pandas/lib.py | 21 +++++++++++++++++++++ src/mods4pandas/mods4pandas.py | 11 +++-------- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/check_dtypes.py b/check_dtypes.py index 3024d9b..cbdfd70 100644 --- a/check_dtypes.py +++ b/check_dtypes.py @@ -2,17 +2,7 @@ import pandas as pd import re -# Fix mods_info = pd.read_parquet("mods_info_df.parquet") -for c in mods_info.columns: - if c.endswith("-count"): - mods_info[c] = mods_info[c].astype('Int64') - - -# Tmp to parquet -mods_info.to_parquet("tmp.parquet") -mods_info = pd.read_parquet("tmp.parquet") - # Check EXPECTED_TYPES = { diff --git a/src/mods4pandas/alto4pandas.py b/src/mods4pandas/alto4pandas.py index 77e23e2..0739f35 100755 --- a/src/mods4pandas/alto4pandas.py +++ b/src/mods4pandas/alto4pandas.py @@ -19,7 +19,7 @@ import pandas as pd import numpy as np from tqdm import tqdm -from .lib import TagGroup, sorted_groupby, flatten, ns, insert_into_db +from .lib import TagGroup, convert_db_to_parquet, sorted_groupby, flatten, ns, insert_into_db logger = logging.getLogger('alto4pandas') @@ -188,9 +188,8 @@ def process(alto_files: List[str], output_file: str): import traceback; traceback.print_exc() # Convert the alto_info SQL to a pandas DataFrame - alto_info_df = pd.read_sql_query("SELECT * FROM alto_info", con, index_col="alto_file") logger.info('Writing DataFrame to {}'.format(output_file)) - alto_info_df.to_parquet(output_file) + convert_db_to_parquet(con, "alto_info", "alto_file", output_file) def main(): diff --git a/src/mods4pandas/lib.py b/src/mods4pandas/lib.py index 082ed9a..32f717a 100644 --- a/src/mods4pandas/lib.py +++ b/src/mods4pandas/lib.py @@ -355,3 +355,24 @@ def insert_into_db(con, table, d: Dict): def insert_into_db_multiple(con, table, ld: List[Dict]): for d in ld: insert_into_db(con, table, d) + +def convert_db_to_parquet(con, table, index_col, output_file): + df = pd.read_sql_query(f"SELECT * FROM {table}", con, index_col) + + # Convert Python column type into Pandas type + for c in df.columns: + column_type = current_columns_types[table][c] + + if column_type == "str": + continue + elif column_type == "int": + df[c] = df[c].astype("Int64") + elif column_type == "float64": + df[c] = df[c].astype("Float64") + elif column_type == "set": + # TODO WIP + continue + else: + raise NotImplementedError(f"Column type {column_type} not implemented yet.") + + df.to_parquet(output_file) \ No newline at end of file diff --git a/src/mods4pandas/mods4pandas.py b/src/mods4pandas/mods4pandas.py index 30d7c22..2da7c80 100755 --- a/src/mods4pandas/mods4pandas.py +++ b/src/mods4pandas/mods4pandas.py @@ -18,7 +18,7 @@ import click import pandas as pd from tqdm import tqdm -from .lib import sorted_groupby, TagGroup, ns, flatten, insert_into_db, insert_into_db_multiple +from .lib import convert_db_to_parquet, sorted_groupby, TagGroup, ns, flatten, insert_into_db, insert_into_db_multiple, current_columns_types @@ -457,16 +457,11 @@ def process(mets_files: List[str], output_file: str, output_page_info: str): except Exception as e: logger.exception('Exception in {}'.format(mets_file)) - # Convert the mods_info SQL to a pandas DataFrame - mods_info_df = pd.read_sql_query("SELECT * FROM mods_info", con, index_col="recordInfo_recordIdentifier") logger.info('Writing DataFrame to {}'.format(output_file)) - mods_info_df.to_parquet(output_file) - + convert_db_to_parquet(con, "mods_info", "recordInfo_recordIdentifier", output_file) if output_page_info: - # Convert page_info SQL to a pandas DataFrama - page_info_df = pd.read_sql_query("SELECT * FROM page_info", con_page_info, index_col=["ppn", "ID"]) logger.info('Writing DataFrame to {}'.format(output_page_info)) - page_info_df.to_parquet(output_page_info) + convert_db_to_parquet(con_page_info, "page_info", ["ppn", "ID"], output_page_info) def main():