mirror of
				https://github.com/qurator-spk/modstool.git
				synced 2025-11-04 03:14:14 +01:00 
			
		
		
		
	🤓 Add type annotations (and related changes)
This commit is contained in:
		
							parent
							
								
									44550ff926
								
							
						
					
					
						commit
						580442a4c9
					
				
					 2 changed files with 32 additions and 27 deletions
				
			
		| 
						 | 
				
			
			@ -1,3 +1,5 @@
 | 
			
		|||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from itertools import groupby
 | 
			
		||||
import re
 | 
			
		||||
import warnings
 | 
			
		||||
| 
						 | 
				
			
			@ -24,40 +26,40 @@ ns = {
 | 
			
		|||
class TagGroup:
 | 
			
		||||
    """Helper class to simplify the parsing and checking of MODS metadata"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tag, group: List[ET.Element]):
 | 
			
		||||
    def __init__(self, tag, group: List[ET._Element]):
 | 
			
		||||
        self.tag = tag
 | 
			
		||||
        self.group = group
 | 
			
		||||
 | 
			
		||||
    def to_xml(self):
 | 
			
		||||
    def to_xml(self) -> str:
 | 
			
		||||
        return '\n'.join(str(ET.tostring(e), 'utf-8').strip() for e in self.group)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"TagGroup with content:\n{self.to_xml()}"
 | 
			
		||||
 | 
			
		||||
    def is_singleton(self):
 | 
			
		||||
    def is_singleton(self) -> TagGroup:
 | 
			
		||||
        if len(self.group) != 1:
 | 
			
		||||
            raise ValueError('More than one instance: {}'.format(self))
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def has_no_attributes(self):
 | 
			
		||||
    def has_no_attributes(self) -> TagGroup:
 | 
			
		||||
        return self.has_attributes({})
 | 
			
		||||
 | 
			
		||||
    def has_attributes(self, attrib):
 | 
			
		||||
    def has_attributes(self, attrib) -> TagGroup:
 | 
			
		||||
        if not isinstance(attrib, Sequence):
 | 
			
		||||
            attrib = [attrib]
 | 
			
		||||
        if not all(e.attrib in attrib for e in self.group):
 | 
			
		||||
            raise ValueError('One or more element has unexpected attributes: {}'.format(self))
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def ignore_attributes(self):
 | 
			
		||||
    def ignore_attributes(self) -> TagGroup:
 | 
			
		||||
        # This serves as documentation for now.
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def sort(self, key=None, reverse=False):
 | 
			
		||||
    def sort(self, key=None, reverse=False) -> TagGroup:
 | 
			
		||||
        self.group = sorted(self.group, key=key, reverse=reverse)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def text(self, separator='\n'):
 | 
			
		||||
    def text(self, separator='\n') -> str:
 | 
			
		||||
        t = ''
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            if t != '':
 | 
			
		||||
| 
						 | 
				
			
			@ -66,13 +68,13 @@ class TagGroup:
 | 
			
		|||
                t += e.text
 | 
			
		||||
        return t
 | 
			
		||||
 | 
			
		||||
    def text_set(self):
 | 
			
		||||
    def text_set(self) -> set:
 | 
			
		||||
        return {e.text for e in self.group}
 | 
			
		||||
 | 
			
		||||
    def descend(self, raise_errors):
 | 
			
		||||
    def descend(self, raise_errors) -> dict:
 | 
			
		||||
        return _to_dict(self.is_singleton().group[0], raise_errors)
 | 
			
		||||
 | 
			
		||||
    def filter(self, cond, warn=None):
 | 
			
		||||
    def filter(self, cond, warn=None) -> TagGroup:
 | 
			
		||||
        new_group = []
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            if cond(e):
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +84,7 @@ class TagGroup:
 | 
			
		|||
                    warnings.warn('Filtered {} element ({})'.format(self.tag, warn))
 | 
			
		||||
        return TagGroup(self.tag, new_group)
 | 
			
		||||
 | 
			
		||||
    def force_singleton(self, warn=True):
 | 
			
		||||
    def force_singleton(self, warn=True) -> TagGroup:
 | 
			
		||||
        if len(self.group) == 1:
 | 
			
		||||
            return self
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -93,7 +95,7 @@ class TagGroup:
 | 
			
		|||
    RE_ISO8601_DATE = r'^\d{2}(\d{2}|XX)(-\d{2}-\d{2})?$'  # Note: Includes non-specific century dates like '18XX'
 | 
			
		||||
    RE_GERMAN_DATE = r'^(?P<dd>\d{2})\.(?P<mm>\d{2})\.(?P<yyyy>\d{4})$'
 | 
			
		||||
 | 
			
		||||
    def fix_date(self):
 | 
			
		||||
    def fix_date(self) -> TagGroup:
 | 
			
		||||
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            if e.attrib.get('encoding') == 'w3cdtf':
 | 
			
		||||
| 
						 | 
				
			
			@ -103,6 +105,9 @@ class TagGroup:
 | 
			
		|||
 | 
			
		||||
        new_group = []
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            if e.text is None:
 | 
			
		||||
                warnings.warn('Empty date')
 | 
			
		||||
                continue
 | 
			
		||||
            if e.attrib.get('encoding') == 'iso8601' and re.match(self.RE_ISO8601_DATE, e.text):
 | 
			
		||||
                new_group.append(e)
 | 
			
		||||
            elif re.match(self.RE_ISO8601_DATE, e.text):
 | 
			
		||||
| 
						 | 
				
			
			@ -131,7 +136,7 @@ class TagGroup:
 | 
			
		|||
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def fix_event_type(self):
 | 
			
		||||
    def fix_event_type(self) -> TagGroup:
 | 
			
		||||
        # According to MODS-AP 2.3.1, every originInfo should have its eventType set.
 | 
			
		||||
        # Fix this for special cases.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +166,7 @@ class TagGroup:
 | 
			
		|||
                    pass
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def fix_script_term(self):
 | 
			
		||||
    def fix_script_term(self) -> TagGroup:
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            # MODS-AP 2.3.1 is not clear about this, but it looks like that this should be lower case.
 | 
			
		||||
            if e.attrib['authority'] == 'ISO15924':
 | 
			
		||||
| 
						 | 
				
			
			@ -169,7 +174,7 @@ class TagGroup:
 | 
			
		|||
                warnings.warn('Changed scriptTerm authority to lower case')
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def merge_sub_tags_to_set(self):
 | 
			
		||||
    def merge_sub_tags_to_set(self) -> dict:
 | 
			
		||||
        from .mods4pandas import mods_to_dict
 | 
			
		||||
        value = {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -189,7 +194,7 @@ class TagGroup:
 | 
			
		|||
            value[sub_tag] = s
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def attributes(self):
 | 
			
		||||
    def attributes(self) -> dict[str, str]:
 | 
			
		||||
        """
 | 
			
		||||
        Return a merged dict of all attributes of the tag group.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -204,7 +209,7 @@ class TagGroup:
 | 
			
		|||
                attrib[a_localname] = v
 | 
			
		||||
        return attrib
 | 
			
		||||
 | 
			
		||||
    def subelement_counts(self):
 | 
			
		||||
    def subelement_counts(self) -> dict[str, int]:
 | 
			
		||||
        counts = {}
 | 
			
		||||
        for e in self.group:
 | 
			
		||||
            for x in e.iter():
 | 
			
		||||
| 
						 | 
				
			
			@ -213,7 +218,7 @@ class TagGroup:
 | 
			
		|||
                counts[key] = counts.get(key, 0) + 1
 | 
			
		||||
        return counts
 | 
			
		||||
 | 
			
		||||
    def xpath_statistics(self, xpath_expr, namespaces):
 | 
			
		||||
    def xpath_statistics(self, xpath_expr, namespaces) -> dict[str, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Extract values and calculate statistics
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -235,7 +240,7 @@ class TagGroup:
 | 
			
		|||
            statistics[f'{xpath_expr}-max'] = np.max(values)
 | 
			
		||||
        return statistics
 | 
			
		||||
 | 
			
		||||
    def xpath_count(self, xpath_expr, namespaces):
 | 
			
		||||
    def xpath_count(self, xpath_expr, namespaces) -> dict[str, int]:
 | 
			
		||||
        """
 | 
			
		||||
        Count all elements matching xpath_expr
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -279,7 +284,7 @@ def _to_dict(root, raise_errors):
 | 
			
		|||
        raise ValueError(f"Unknown namespace {root_name.namespace}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def flatten(d: MutableMapping, parent='', separator='_'):
 | 
			
		||||
def flatten(d: MutableMapping, parent='', separator='_') -> dict:
 | 
			
		||||
    """
 | 
			
		||||
    Flatten the given nested dict.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -301,13 +306,13 @@ def flatten(d: MutableMapping, parent='', separator='_'):
 | 
			
		|||
    return dict(items)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def valid_column_key(k):
 | 
			
		||||
    if re.match("^[a-zA-Z0-9 _@/:\[\]-]+$", k):
 | 
			
		||||
def valid_column_key(k) -> bool:
 | 
			
		||||
    if re.match(r'^[a-zA-Z0-9 _@/:\[\]-]+$', k):
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
def column_names_csv(columns):
 | 
			
		||||
def column_names_csv(columns) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Format Column names (identifiers) as a comma-separated list.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -376,7 +376,7 @@ def pages_to_dict(mets, raise_errors=True) -> List[Dict]:
 | 
			
		|||
@click.option('--output', '-o', 'output_file', type=click.Path(), help='Output Parquet file',
 | 
			
		||||
              default='mods_info_df.parquet', show_default=True)
 | 
			
		||||
@click.option('--output-page-info', type=click.Path(), help='Output page info Parquet file')
 | 
			
		||||
def process(mets_files: List[str], output_file: str, output_page_info: str):
 | 
			
		||||
def process(mets_files: list[str], output_file: str, output_page_info: str):
 | 
			
		||||
    """
 | 
			
		||||
    A tool to convert the MODS metadata in INPUT to a pandas DataFrame.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -389,7 +389,7 @@ def process(mets_files: List[str], output_file: str, output_page_info: str):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    # Extend file list if directories are given
 | 
			
		||||
    mets_files_real = []
 | 
			
		||||
    mets_files_real: list[str] = []
 | 
			
		||||
    for m in mets_files:
 | 
			
		||||
        if os.path.isdir(m):
 | 
			
		||||
            logger.info('Scanning directory {}'.format(m))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue