Source code for reco.gen_job

__all__  = [
    "Parallel",
    "merge_args",
    "update_args",
]


import random
import joblib

from math               import ceil
from typing             import List, Callable, Union
from pprint             import pprint

from reco import chunks, merge, merge_args_from_file, update_args_from_file, append_index_to_file, check_file_exists




[docs] def merge_args( parser ): parser.add_argument('-e', '--event-numbers', action='store', dest='event_numbers', required=False, type=str, default=None, help="The event number list separated " "by ','. e.g. --event-numbers '0,1,2,3'") parser.add_argument('-o', '--output-file', action='store', dest='output_file', required=True, help="The output file.") parser.add_argument('-s', '--seed', action='store', dest='seed', required=False, type=int, default=0, help="The pythia seed (zero is the clock system)") parser.add_argument('--nov', '--number-of-events', action='store', dest='number_of_events', required=False, type=int, default=-1, help="The total number of events to run.") parser.add_argument('--events-per-job', action='store', dest='events_per_job', required=False, type=int, default=-1, help="The number of events per job") parser.add_argument('-nt', '--number-of-threads', action='store', dest='number_of_threads', required=False, type=int, default=1, help="The number of threads") parser.add_argument('-m','--merge', action='store_true', dest='merge', required=False, help='Merge all files.') parser.add_argument('--overwrite', action='store_true', dest='overwrite', required=False, help='Rerun all jobs.') parser.add_argument('--dry-run', action='store_true', dest='dry_run', required=False, help='Perform a dry run without executing jobs.') return merge_args_from_file(parser)
[docs] def update_args( args ): return update_args_from_file(args)
[docs] class Parallel: def __init__(self, output_file : str, seed : int=512, number_of_threads : int=1, number_of_events : int=-1, events_per_job : int=-1, event_numbers : Union[str,List[int]]=None, merge : bool=False, ntuple_name : str="particles", overwrite : bool = False, dry_run : bool = False, ): self.event_numbers = event_numbers self.number_of_events = number_of_events self.number_of_threads = number_of_threads self.events_per_job = events_per_job self.merge_files = merge self.ntuple_name = ntuple_name self.output_file = output_file self.overwrite=True#overwrite self.seed=seed self.dry_run=dry_run def __call__(self, function : Callable, **args ): plan = self.build_plan() pprint(plan) jobs = [] for output_file, config in plan.items(): if not check_file_exists( output_file, self.ntuple_name ) or self.overwrite: jobs.append( (output_file, config['evt'], config['seed']) ) pprint(args) if not self.dry_run: pool = joblib.Parallel(n_jobs=self.number_of_threads, backend='multiprocessing') pool(joblib.delayed(function)( events=events, output_file=output_file, seed=seed, **args ) for output_file, events, seed in jobs) files = list(plan.keys()) if self.merge_files or len(files)==1: merge(self.output_file, files)
[docs] def get_events_per_job(self): if self.events_per_job < 0: return ceil(self.number_of_events/self.number_of_threads) else: return self.events_per_job
[docs] def build_plan(self): def get_random_seed()->int: return random.randint(1, 900000000 ) if self.event_numbers: event_numbers = [int(event_number) for event_number in self.event_numbers.split(",")] if type(self.event_numbers) is str else self.event_numbers self.number_of_events = len(event_numbers_list) event_numbers = chunks( event_numbers, self.get_events_per_job()) else: event_numbers = chunks( list(range(self.number_of_events)) ,self.get_events_per_job() ) plan = {} random.seed(self.seed) for idx, events in enumerate(event_numbers): output_file = append_index_to_file(self.output_file, idx) plan[output_file] = {"evt":events, "seed": get_random_seed()} return plan
def create_parallel_job( args ): return Parallel( event_numbers = args.event_numbers, output_file = args.output_file, number_of_threads = args.number_of_threads , number_of_events = args.number_of_events, events_per_job = args.events_per_job, merge = args.merge, overwrite = args.overwrite, seed = args.seed, dry_run = args.dry_run )