55Note: This file must be formatted using the Black Python formatter.
66"""
77
8- import os . path
8+ import pathlib
99import subprocess
1010import sys
1111from typing import Required , TypedDict , List , Callable , Optional
@@ -41,7 +41,7 @@ def missing_module(module_name: str) -> None:
4141 .decode ("utf-8" )
4242 .strip ()
4343)
44- build_dir = os . path . join (gitroot , "mad-generation-build" )
44+ build_dir = pathlib . Path (gitroot , "mad-generation-build" )
4545
4646
4747# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
8686 git_tag = project .get ("git-tag" )
8787
8888 # Determine target directory
89- target_dir = os . path . join ( build_dir , name )
89+ target_dir = build_dir / name
9090
9191 # Clone only if directory doesn't already exist
92- if not os . path . exists (target_dir ):
92+ if not target_dir . exists ():
9393 if git_tag :
9494 print (f"Cloning { name } from { repo_url } at tag { git_tag } " )
9595 else :
@@ -191,10 +191,10 @@ def build_database(
191191 name = project ["name" ]
192192
193193 # Create database directory path
194- database_dir = os . path . join ( build_dir , f"{ name } -db" )
194+ database_dir = build_dir / f"{ name } -db"
195195
196196 # Only build the database if it doesn't already exist
197- if not os . path . exists (database_dir ):
197+ if not database_dir . exists ():
198198 print (f"Building CodeQL database for { name } ..." )
199199 extractor_options = [option for x in extractor_options for option in ("-O" , x )]
200200 try :
@@ -241,7 +241,11 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
241241 generator .with_summaries = should_generate_summaries (project )
242242 generator .threads = args .codeql_threads
243243 generator .ram = args .codeql_ram
244- generator .setenvironment (database = database_dir , folder = name )
244+ if config .get ("single-file" , False ):
245+ generator .single_file = name
246+ else :
247+ generator .folder = name
248+ generator .setenvironment (database = database_dir )
245249 generator .run ()
246250
247251
@@ -312,20 +316,14 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
312316 if response .status_code != 200 :
313317 print (f"Failed to download file. Status code: { response .status_code } " )
314318 sys .exit (1 )
315- target_zip = os . path . join ( build_dir , zipName )
319+ target_zip = build_dir / zipName
316320 with open (target_zip , "wb" ) as file :
317321 for chunk in response .iter_content (chunk_size = 8192 ):
318322 file .write (chunk )
319323 print (f"Download complete: { target_zip } " )
320324 return target_zip
321325
322326
323- def remove_extension (filename : str ) -> str :
324- while "." in filename :
325- filename , _ = os .path .splitext (filename )
326- return filename
327-
328-
329327def pretty_name_from_artifact_name (artifact_name : str ) -> str :
330328 return artifact_name .split ("___" )[1 ]
331329
@@ -399,19 +397,17 @@ def download_and_decompress(analyzed_database: dict) -> str:
399397 # The database is in a zip file, which contains a tar.gz file with the DB
400398 # First we open the zip file
401399 with zipfile .ZipFile (artifact_zip_location , "r" ) as zip_ref :
402- artifact_unzipped_location = os . path . join ( build_dir , artifact_name )
400+ artifact_unzipped_location = build_dir / artifact_name
403401 # clean up any remnants of previous runs
404402 shutil .rmtree (artifact_unzipped_location , ignore_errors = True )
405403 # And then we extract it to build_dir/artifact_name
406404 zip_ref .extractall (artifact_unzipped_location )
407405 # And then we extract the language tar.gz file inside it
408- artifact_tar_location = os .path .join (
409- artifact_unzipped_location , f"{ language } .tar.gz"
410- )
406+ artifact_tar_location = artifact_unzipped_location / f"{ language } .tar.gz"
411407 with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
412408 # And we just untar it to the same directory as the zip file
413409 tar_ref .extractall (artifact_unzipped_location )
414- ret = os . path . join ( artifact_unzipped_location , language )
410+ ret = artifact_unzipped_location / language
415411 print (f"Decompression complete: { ret } " )
416412 return ret
417413
@@ -431,8 +427,16 @@ def download_and_decompress(analyzed_database: dict) -> str:
431427 return [(project_map [n ], r ) for n , r in zip (analyzed_databases , results )]
432428
433429
434- def get_mad_destination_for_project (config , name : str ) -> str :
435- return os .path .join (config ["destination" ], name )
430+ def clean_up_mad_destination_for_project (config , name : str ):
431+ target = pathlib .Path (config ["destination" ], name )
432+ if config .get ("single-file" , False ):
433+ target = target .with_suffix (".model.yml" )
434+ if target .exists ():
435+ print (f"Deleting existing MaD file at { target } " )
436+ target .unlink ()
437+ elif target .exists ():
438+ print (f"Deleting existing MaD directory at { target } " )
439+ shutil .rmtree (target , ignore_errors = True )
436440
437441
438442def get_strategy (config ) -> str :
@@ -454,8 +458,7 @@ def main(config, args) -> None:
454458 language = config ["language" ]
455459
456460 # Create build directory if it doesn't exist
457- if not os .path .exists (build_dir ):
458- os .makedirs (build_dir )
461+ build_dir .mkdir (parents = True , exist_ok = True )
459462
460463 database_results = []
461464 match get_strategy (config ):
@@ -475,7 +478,7 @@ def main(config, args) -> None:
475478 if args .pat is None :
476479 print ("ERROR: --pat argument is required for DCA strategy" )
477480 sys .exit (1 )
478- if not os . path .exists (args . pat ):
481+ if not args . pat .exists ():
479482 print (f"ERROR: Personal Access Token file '{ pat } ' does not exist." )
480483 sys .exit (1 )
481484 with open (args .pat , "r" ) as f :
@@ -499,12 +502,9 @@ def main(config, args) -> None:
499502 )
500503 sys .exit (1 )
501504
502- # Delete the MaD directory for each project
503- for project , database_dir in database_results :
504- mad_dir = get_mad_destination_for_project (config , project ["name" ])
505- if os .path .exists (mad_dir ):
506- print (f"Deleting existing MaD directory at { mad_dir } " )
507- subprocess .check_call (["rm" , "-rf" , mad_dir ])
505+ # clean up existing MaD data for the projects
506+ for project , _ in database_results :
507+ clean_up_mad_destination_for_project (config , project ["name" ])
508508
509509 for project , database_dir in database_results :
510510 if database_dir is not None :
@@ -514,7 +514,10 @@ def main(config, args) -> None:
514514if __name__ == "__main__" :
515515 parser = argparse .ArgumentParser ()
516516 parser .add_argument (
517- "--config" , type = str , help = "Path to the configuration file." , required = True
517+ "--config" ,
518+ type = pathlib .Path ,
519+ help = "Path to the configuration file." ,
520+ required = True ,
518521 )
519522 parser .add_argument (
520523 "--dca" ,
@@ -525,7 +528,7 @@ def main(config, args) -> None:
525528 )
526529 parser .add_argument (
527530 "--pat" ,
528- type = str ,
531+ type = pathlib . Path ,
529532 help = "Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)" ,
530533 )
531534 parser .add_argument (
@@ -544,7 +547,7 @@ def main(config, args) -> None:
544547
545548 # Load config file
546549 config = {}
547- if not os . path .exists (args . config ):
550+ if not args . config .exists ():
548551 print (f"ERROR: Config file '{ args .config } ' does not exist." )
549552 sys .exit (1 )
550553 try :
0 commit comments