diff --git a/charon/pkgs/maven.py b/charon/pkgs/maven.py index 02692e8..070a310 100644 --- a/charon/pkgs/maven.py +++ b/charon/pkgs/maven.py @@ -32,7 +32,7 @@ META_FILE_FAILED, MAVEN_METADATA_TEMPLATE, ARCHETYPE_CATALOG_TEMPLATE, ARCHETYPE_CATALOG_FILENAME, PACKAGE_TYPE_MAVEN) -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union from jinja2 import Template from datetime import datetime from zipfile import ZipFile, BadZipFile @@ -217,7 +217,8 @@ def parse_gavs(pom_paths: List[str], root="/") -> Dict[str, Dict[str, List[str]] return gavs -def gen_meta_file(group_id, artifact_id: str, versions: list, root="/", digest=True) -> List[str]: +def gen_meta_file(group_id, artifact_id: str, + versions: list, root="/", do_digest=True) -> List[str]: content = MavenMetadata( group_id, artifact_id, versions ).generate_meta_file_content() @@ -229,7 +230,7 @@ def gen_meta_file(group_id, artifact_id: str, versions: list, root="/", digest=T meta_files.append(final_meta_path) except FileNotFoundError as e: raise e - if digest: + if do_digest: meta_files.extend(__gen_all_digest_files(final_meta_path)) return meta_files @@ -782,7 +783,7 @@ def _merge_directories_with_rename(src_dir: str, dest_dir: str, root: str): _handle_archetype_catalog_merge(src_file, dest_file) merged_count += 1 logger.debug("Merged archetype catalog: %s -> %s", src_file, dest_file) - if os.path.exists(dest_file): + elif os.path.exists(dest_file): duplicated_count += 1 logger.debug("Duplicated: %s, skipped", dest_file) else: @@ -1303,8 +1304,8 @@ def __wildcard_metadata_paths(paths: List[str]) -> List[str]: new_paths.append(path[:-len(".xml")] + ".*") elif path.endswith(".md5")\ or path.endswith(".sha1")\ - or path.endswith(".sha128")\ - or path.endswith(".sha256"): + or path.endswith(".sha256")\ + or path.endswith(".sha512"): continue else: new_paths.append(path) @@ -1313,7 +1314,7 @@ def __wildcard_metadata_paths(paths: List[str]) -> List[str]: class VersionCompareKey: 'Used as key function for version sorting' - def __init__(self, obj): + def __init__(self, obj: str): self.obj = obj def __lt__(self, other): @@ -1344,36 +1345,61 @@ def __compare(self, other) -> int: big = max(len(xitems), len(yitems)) for i in range(big): try: - xitem = xitems[i] + xitem: Union[str, int] = xitems[i] except IndexError: return -1 try: - yitem = yitems[i] + yitem: Union[str, int] = yitems[i] except IndexError: return 1 - if xitem.isnumeric() and yitem.isnumeric(): + if (isinstance(xitem, str) and isinstance(yitem, str) and + xitem.isnumeric() and yitem.isnumeric()): xitem = int(xitem) yitem = int(yitem) - elif xitem.isnumeric() and not yitem.isnumeric(): + elif (isinstance(xitem, str) and xitem.isnumeric() and + (not isinstance(yitem, str) or not yitem.isnumeric())): return 1 - elif not xitem.isnumeric() and yitem.isnumeric(): - return -1 - if xitem > yitem: - return 1 - elif xitem < yitem: + elif (isinstance(yitem, str) and yitem.isnumeric() and + (not isinstance(xitem, str) or not xitem.isnumeric())): return -1 + # At this point, both are the same type (both int or both str) + if isinstance(xitem, int) and isinstance(yitem, int): + if xitem > yitem: + return 1 + elif xitem < yitem: + return -1 + elif isinstance(xitem, str) and isinstance(yitem, str): + if xitem > yitem: + return 1 + elif xitem < yitem: + return -1 else: continue return 0 -class ArchetypeCompareKey(VersionCompareKey): - 'Used as key function for GAV sorting' - def __init__(self, gav): - super().__init__(gav.version) +class ArchetypeCompareKey: + def __init__(self, gav: ArchetypeRef): self.gav = gav - # pylint: disable=unused-private-member + def __lt__(self, other): + return self.__compare(other) < 0 + + def __gt__(self, other): + return self.__compare(other) > 0 + + def __le__(self, other): + return self.__compare(other) <= 0 + + def __ge__(self, other): + return self.__compare(other) >= 0 + + def __eq__(self, other): + return self.__compare(other) == 0 + + def __hash__(self): + return self.gav.__hash__() + def __compare(self, other) -> int: x = self.gav.group_id + ":" + self.gav.artifact_id y = other.gav.group_id + ":" + other.gav.artifact_id diff --git a/charon/utils/files.py b/charon/utils/files.py index ccad3e2..dca7144 100644 --- a/charon/utils/files.py +++ b/charon/utils/files.py @@ -17,7 +17,9 @@ import os import hashlib import errno -from typing import List, Tuple +import tempfile +import shutil +from typing import List, Tuple, Optional from charon.constants import MANIFEST_SUFFIX @@ -32,24 +34,37 @@ class HashType(Enum): def get_hash_type(type_str: str) -> HashType: """Get hash type from string""" - if type_str.lower() == "md5": + type_str_low = type_str.lower() + if type_str_low == "md5": return HashType.MD5 - elif type_str.lower() == "sha1": + elif type_str_low == "sha1": return HashType.SHA1 - elif type_str.lower() == "sha256": + elif type_str_low == "sha256": return HashType.SHA256 - elif type_str.lower() == "sha512": + elif type_str_low == "sha512": return HashType.SHA512 else: raise ValueError("Unsupported hash type: {}".format(type_str)) -def overwrite_file(file_path: str, content: str): - if not os.path.isfile(file_path): - with open(file_path, mode="a", encoding="utf-8"): - pass - with open(file_path, mode="w", encoding="utf-8") as f: - f.write(content) +def overwrite_file(file_path: str, content: str) -> None: + parent_dir: Optional[str] = os.path.dirname(file_path) + if parent_dir: + if not os.path.exists(parent_dir): + os.makedirs(parent_dir, exist_ok=True) + else: + parent_dir = None # None explicitly means current directory for tempfile + + # Write to temporary file first, then atomically rename + fd, temp_path = tempfile.mkstemp(dir=parent_dir, text=True) + try: + with os.fdopen(fd, 'w', encoding='utf-8') as f: + f.write(content) + shutil.move(temp_path, file_path) + except Exception: + if os.path.exists(temp_path): + os.unlink(temp_path) + raise def read_sha1(file: str) -> str: @@ -97,7 +112,6 @@ def digest_content(content: str, hash_type=HashType.SHA1) -> str: def _hash_object(hash_type: HashType): - hash_obj = None if hash_type == HashType.SHA1: hash_obj = hashlib.sha1() elif hash_type == HashType.SHA256: @@ -107,7 +121,7 @@ def _hash_object(hash_type: HashType): elif hash_type == HashType.SHA512: hash_obj = hashlib.sha512() else: - raise Exception("Error: Unknown hash type for digesting.") + raise ValueError("Error: Unknown hash type for digesting.") return hash_obj @@ -116,14 +130,8 @@ def write_manifest(paths: List[str], root: str, product_key: str) -> Tuple[str, manifest_path = os.path.join(root, manifest_name) artifacts = [] for path in paths: - if path.startswith(root): - path = path[len(root):] - if path.startswith("/"): - path = path[1:] - artifacts.append(path) - - if not os.path.isfile(manifest_path): - with open(manifest_path, mode="a", encoding="utf-8"): - pass + rel_path = os.path.relpath(path, root) + artifacts.append(rel_path) + overwrite_file(manifest_path, '\n'.join(artifacts)) return manifest_name, manifest_path