diff --git a/xdyna/da.py b/xdyna/da.py index 82ae566..a47166f 100644 --- a/xdyna/da.py +++ b/xdyna/da.py @@ -413,12 +413,12 @@ def add_random_initial(self, *, num_part=5000, min_turns=None): # How to change this? # Allowed on parallel process - def track_job(self, *, npart, tracker, turns=None, resubmit_unfinished=False): + def track_job(self, *, npart, tracker, turns=None, resubmit_unfinished=False, logging=False): if tracker is None: raise ValueError() # Create a job: get job ID and start logging - part_ids = self._create_job(npart, turns, resubmit_unfinished) + part_ids = self._create_job(npart, turns, resubmit_unfinished, logging) job_id = str(self._active_job) # Create initial particles @@ -436,10 +436,12 @@ def track_job(self, *, npart, tracker, turns=None, resubmit_unfinished=False): scale_with_transverse_norm_emitt=self.emittance ) # Track - self._append_job_log('output', datetime.datetime.now().isoformat() + ' Start tracking job ' + str(job_id) + '.') + if logging: + self._append_job_log('output', datetime.datetime.now().isoformat() + ' Start tracking job ' + str(job_id) + '.') tracker.track(particles=part, num_turns=self.this_turns) context.synchronize() - self._append_job_log('output', datetime.datetime.now().isoformat() + ' Done tracking job ' + str(job_id) + '.') + if logging: + self._append_job_log('output', datetime.datetime.now().isoformat() + ' Done tracking job ' + str(job_id) + '.') # Store results part_id = context.nparray_from_context_array(part.particle_id) @@ -471,10 +473,11 @@ def track_job(self, *, npart, tracker, turns=None, resubmit_unfinished=False): full_surv.to_parquet(pf, index=True) self._surv = full_surv - self._update_job_log({ - 'finished_time': datetime.datetime.now().isoformat(), - 'status': 'Finished' - }) + if logging: + self._update_job_log({ + 'finished_time': datetime.datetime.now().isoformat(), + 'status': 'Finished' + }) @@ -483,7 +486,7 @@ def track_job(self, *, npart, tracker, turns=None, resubmit_unfinished=False): # ================================================================= # Allowed on parallel process - def _create_job(self, npart, turns, resubmit_unfinished=False): + def _create_job(self, npart, turns, resubmit_unfinished=False, logging=False): if turns is not None: if self.meta.max_turns is None: self.meta.max_turns = turns @@ -498,7 +501,7 @@ def _create_job(self, npart, turns, resubmit_unfinished=False): if self.this_turns == _DAMetaData._max_turns_default: raise ValueError("Number of tracking turns not set! Cannot track.") # Get job ID - self._active_job = self.meta.new_submission_id() + self._active_job = self.meta.new_submission_id() if logging else 0 with ProtectFile(self.meta.surv_file, 'r+b', wait=0.02) as pf: # Get the first npart particle IDs that are not yet submitted # TODO: this can probably be optimised by only reading last column @@ -523,7 +526,8 @@ def _create_job(self, npart, turns, resubmit_unfinished=False): if len(this_part_ids) == 0: print("No more particles to submit! Exiting...") # TODO: this doesn't work! - self.meta.update_submissions(self._active_job, {'status': 'No submission needed.'}) + if logging: + self.meta.update_submissions(self._active_job, {'status': 'No submission needed.'}) exit() # Otherwise, flag the particles as submitted, before releasing the file again self._surv.loc[this_part_ids, 'submitted'] = True @@ -533,16 +537,17 @@ def _create_job(self, npart, turns, resubmit_unfinished=False): # Reduce dataframe to only those particles in this job self._surv = self._surv.loc[this_part_ids] # Submission info - self._active_job_log = { - 'submission_time': datetime.datetime.now().isoformat(), - 'finished_time': 0, - 'status': 'Running', - 'tracking_turns': self.this_turns, - 'particle_ids': '[' + ', '.join([str(pid) for pid in this_part_ids]) + ']', - 'seed': int(seed), - 'warnings': [], - 'output': [], - } + if logging: + self._active_job_log = { + 'submission_time': datetime.datetime.now().isoformat(), + 'finished_time': 0, + 'status': 'Running', + 'tracking_turns': self.this_turns, + 'particle_ids': '[' + ', '.join([str(pid) for pid in this_part_ids]) + ']', + 'seed': int(seed), + 'warnings': [], + 'output': [], + } return this_part_ids