From 2eeaee6be2403c4d0b87b5004af788ec6b5c841f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADt=20Ku=C4=8Dera?= Date: Thu, 22 Feb 2024 02:36:55 +0100 Subject: [PATCH] Improve compare script --- exec/compare.py | 177 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 175 insertions(+), 2 deletions(-) diff --git a/exec/compare.py b/exec/compare.py index 228e106e..651ae693 100755 --- a/exec/compare.py +++ b/exec/compare.py @@ -8,9 +8,15 @@ """ import argparse +from sys import exit from ROOT import ( # pylint: disable=import-error TH1, + TH2, + TH3, + THnSparse, + # RooUnfoldResponse, + TAxis, TCanvas, TColor, TFile, @@ -21,6 +27,159 @@ # import itertools +def msg_err(message : str): + """ Print error message """ + print(f"Error: {message}") + + +def msg_fatal(message : str): + """ Print error message and exit """ + print(f"Fatal: {message}") + exit(1) + + +def are_valid(*objects) -> bool: + """Check whether objects exist""" + result = True + for i, o in enumerate(objects): + if not o: + msg_err(f"Bad object {i}") + result = False + return result + + +def are_same_axes(axis1, axis2) -> bool: + """ Tell whether two axes are same. """ + if not are_valid(axis1, axis2): + msg_fatal("Bad input objects") + return False + # Check classes + for i, o in enumerate((axis1, axis2)): + if not isinstance(o, TAxis): + msg_fatal(f"Object {i} is not an axis") + return False + # Check number of bins + n_bins1, n_bins2 = axis1.GetNbins(), axis2.GetNbins() + if n_bins1 != n_bins2: + return False + # Check bin arrays + array1 = [axis1.GetBinLowEdge(i + 1) for i in range(n_bins1 + 1)] + array2 = [axis2.GetBinLowEdge(i + 1) for i in range(n_bins2 + 1)] + if array1 != array2: + return False + return True + + +def get_object_type(obj) -> int: + """Return histogram degree""" + # for num, tp in zip((5, 4, 3, 2, 1), (RooUnfoldResponse, THnSparse, TH3, TH2, TH1)): + for num, tp in zip((4, 3, 2, 1), (THnSparse, TH3, TH2, TH1)): + if isinstance(obj, tp): + return num + return 0 + + +def are_same_histograms(his1 : TH1, his2 : TH1) -> bool: + """ Tell whether two histograms are same. """ + if not are_valid(his1, his2): + msg_fatal("Bad input objects") + return False + # Compare number of entries + if his1.GetEntries() != his2.GetEntries(): + print(f"Different number of entries {his1.GetEntries()} vs {his2.GetEntries()}") + return False; + # Compare axes + for ax1, ax2 in zip((his1.GetXaxis(), his1.GetYaxis(), his1.GetZaxis()), (his2.GetXaxis(), his2.GetYaxis(), his2.GetZaxis())): + if not are_same_axes(ax1, ax2): + print("Different axes") + return False + # Compare bin counts and errors (include under/overflow bins) + for bin_z in range(his1.GetNbinsZ() + 2): + for bin_y in range(his1.GetNbinsY() + 2): + for bin_x in range(his1.GetNbinsX() + 2): + bin = his1.GetBin(bin_x, bin_y, bin_z) + if his1.GetBinContent(bin) != his2.GetBinContent(bin) or his1.GetBinError(bin) != his2.GetBinError(bin): + print(f"Different bin {bin} content: {his1.GetBinContent(bin)} ± {his1.GetBinError(bin)} vs {his2.GetBinContent(bin)} ± {his2.GetBinError(bin)}") + return False + return True + + +def are_same_thnspare(his1 : THnSparse, his2 : THnSparse) -> bool: + """ Tell whether two THnSparse objects are same. """ + if not are_valid(his1, his2): + msg_fatal("Bad input objects") + return False + # Compare number of dimensions + if his1.GetNdimensions() != his2.GetNdimensions(): + return False + # Compare number of entries + if his1.GetEntries() != his2.GetEntries(): + return False + # Compare number of filled bins + if his1.GetNbins() != his2.GetNbins(): + return False + # Compare axes + for iAx in range(his1.GetNdimensions()): + if not are_same_axes(his1.GetAxis(iAx), his2.GetAxis(iAx)): + return False + # Compare bin content + for iBin in range(his1.GetNbins()): + if his1.GetBinContent(iBin) != his2.GetBinContent(iBin) or his1.GetBinError(iBin) != his2.GetBinError(iBin): + return False + return True + + +# def are_same_response(his1 : RooUnfoldResponse, his2 : RooUnfoldResponse) -> bool: +# """ Tell whether two RooUnfoldResponse objects are same. """ +# if not are_valid(his1, his2): +# msg_fatal("Bad input objects") +# return False +# # Compare number of dimensions +# if his1.GetDimensionMeasured() != his2.GetDimensionMeasured() or his1.GetDimensionTruth() != his2.GetDimensionTruth(): +# return False +# # Compare number of bins +# if his1.GetNbinsMeasured() != his2.GetNbinsMeasured() or his1.GetNbinsTruth() != his2.GetNbinsTruth(): +# return False +# # Compare axes and bin content +# if not are_same_histograms(his1.Hfakes(), his2.Hfakes()): +# return False +# if not are_same_histograms(his1.Hmeasured(), his2.Hmeasured()): +# return False +# if not are_same_histograms(his1.Htruth(), his2.Htruth()): +# return False +# if not are_same_histograms(his1.Hresponse(), his2.Hresponse()): +# return False +# return True + + +def are_same_objects(obj1, obj2) -> bool: + """ Tell whether two histogram-like objects are same. """ + if not are_valid(obj1, obj2): + msg_fatal("Bad input objects") + return False + # Compare types + if type(obj1) != type(obj2): + print(f"Different types {type(obj1)} {type(obj2)}") + return False + # Get ROOT types + list_type = [-1, -2] + for i, o in enumerate((obj1, obj2)): + list_type[i] = get_object_type(o) + # Compare ROOT types (is it not covered by type(obj)?) + if list_type[0] != list_type[1]: + print(f"Different types {list_type[0]} {list_type[1]}") + return False + type_obj = list_type[0] + # Compare supported ROOT objects + if type_obj == 0: + msg_fatal(f"Objects have an unsupported type {type(obj1)}.") + return False + # elif type_obj == 5: + # return are_same_response(obj1, obj2) + elif type_obj == 4: + return are_same_thnspare(obj1, obj2) + return are_same_histograms(obj1, obj2) + def compare(dict_obj, add_leg_title=True, normalize=True): print("Comparing") @@ -29,6 +188,20 @@ def compare(dict_obj, add_leg_title=True, normalize=True): dict_colors = {} dict_markers = {} dict_list_canvas = {} + + # Explicit comparison + list_files = list(dict_obj.keys()) + name_file_0 = list_files[0] + name_file_1 = list_files[1] + for key_obj in dict_obj[name_file_0]: + obj_0 = dict_obj[name_file_0][key_obj] + obj_1 = dict_obj[name_file_1][key_obj] + name_his = obj_0.GetName() + if are_same_objects(obj_0, obj_1): + print(f"Objects {name_his} are same {obj_0.GetEntries()}") + else: + print(f"Objects {name_his} are different") + for key_file in dict_obj: print("Entry", len(dict_colors), key_file) dict_colors[key_file] = TColor.GetColor(list_colors[len(dict_colors)]) @@ -50,7 +223,7 @@ def compare(dict_obj, add_leg_title=True, normalize=True): else: opt += "same" dict_list_canvas[key_obj][0].cd() - print(f'Drawing {obj.GetName()} with opt "{opt}" on canvas {gPad.GetName()}') + # print(f'Drawing {obj.GetName()} with opt "{opt}" on canvas {gPad.GetName()}') obj.SetLineColor(dict_colors[key_file]) obj.SetMarkerStyle(dict_markers[key_file]) obj.SetMarkerColor(dict_colors[key_file]) @@ -64,7 +237,7 @@ def compare(dict_obj, add_leg_title=True, normalize=True): # Ratio if not is_first_file: dict_list_canvas[key_obj][1].cd() - print(f'Drawing {obj.GetName()} with opt "{opt}" on canvas {gPad.GetName()}') + # print(f'Drawing {obj.GetName()} with opt "{opt}" on canvas {gPad.GetName()}') # line_1 = TLine(obj.GetXaxis().GetXmin(), 1, obj.GetXaxis().GetXmax(), 1) obj_ratio = obj.Clone(f"{obj.GetName()}_ratio") obj_ratio.Divide(dict_obj[key_file_first][key_obj])