You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.7 KiB

10 months ago
  1. from __future__ import division
  2. import ROOT
  3. from ROOT import gStyle
  4. # Some convenience function to easily iterate over the parts of the collections
  5. # Needed if importing this script from another script in case TMultiGraphs are used
  6. # ROOT.SetMemoryPolicy(ROOT.kMemoryStrict)
  7. # Start a bit right of the Yaxis and above the Xaxis to not overlap with the ticks
  8. start, stop = 0.28, 0.52
  9. x_width, y_width = 0.4, 0.2
  10. PLACES = [
  11. (start, stop - y_width, start + x_width, stop), # top left opt
  12. (start, start, start + x_width, start + y_width), # bottom left opt
  13. (stop - x_width, stop - y_width, stop, stop), # top right opt
  14. (stop - x_width, start, stop, start + y_width), # bottom right opt
  15. (stop - x_width, 0.5 - y_width / 2, stop, 0.5 + y_width / 2), # right
  16. (start, 0.5 - y_width / 2, start + x_width, 0.5 + y_width / 2),
  17. ] # left
  18. # Needed if importing this script from another script in case TMultiGraphs are used
  19. # ROOT.SetMemoryPolicy(ROOT.kMemoryStrict)
  20. def transform_to_user(canvas, x1, y1, x2, y2):
  21. """
  22. Transforms from Pad coordinates to User coordinates.
  23. This can probably be replaced by using the built-in conversion commands.
  24. """
  25. xstart = canvas.GetX1()
  26. xlength = canvas.GetX2() - xstart
  27. xlow = xlength * x1 + xstart
  28. xhigh = xlength * x2 + xstart
  29. if canvas.GetLogx():
  30. xlow = 10**xlow
  31. xhigh = 10**xhigh
  32. ystart = canvas.GetY1()
  33. ylength = canvas.GetY2() - ystart
  34. ylow = ylength * y1 + ystart
  35. yhigh = ylength * y2 + ystart
  36. if canvas.GetLogy():
  37. ylow = 10**ylow
  38. yhigh = 10**yhigh
  39. return xlow, ylow, xhigh, yhigh
  40. def overlap_h(hist, x1, y1, x2, y2):
  41. xlow = hist.FindFixBin(x1)
  42. xhigh = hist.FindFixBin(x2)
  43. for i in range(xlow, xhigh + 1):
  44. val = hist.GetBinContent(i)
  45. # Values
  46. if y1 <= val <= y2:
  47. return True
  48. # Errors
  49. if val + hist.GetBinErrorUp(i) > y1 and val - hist.GetBinErrorLow(i) < y2:
  50. return True
  51. return False
  52. def overlap_rect(rect1, rect2):
  53. """Do the two rectangles overlap?"""
  54. if rect1[0] > rect2[2] or rect1[2] < rect2[0]:
  55. return False
  56. if rect1[1] > rect2[3] or rect1[3] < rect2[1]:
  57. return False
  58. return True
  59. def to_list(pointer):
  60. """turns pointer to array into list after checking if pointer is nullptr"""
  61. if len(pointer) == 0:
  62. return []
  63. return list(pointer)
  64. def overlap_g(graph, x1, y1, x2, y2):
  65. x_values = to_list(graph.GetX())
  66. y_values = to_list(graph.GetY())
  67. x_err = to_list(graph.GetEX()) or [0] * len(x_values)
  68. y_err = to_list(graph.GetEY()) or [0] * len(y_values)
  69. for x, ex, y, ey in zip(x_values, x_err, y_values, y_err):
  70. # Could maybe be less conservative
  71. if overlap_rect((x1, y1, x2, y2), (x - ex, y - ey, x + ex, y + ey)):
  72. # print "Overlap with graph", graph.GetName(), "at point", (x, y)
  73. return True
  74. return False
  75. def place_legend(
  76. canvas,
  77. x1=None,
  78. y1=None,
  79. x2=None,
  80. y2=None,
  81. header="LHCb Simulation",
  82. option="lpe",
  83. ):
  84. gStyle.SetFillStyle(0)
  85. gStyle.SetTextSize(0.06)
  86. # If position is specified, use that
  87. if all(x is not None for x in (x1, x2, y1, y2)):
  88. return canvas.BuildLegend(x1, y1, x2, y2, header, option)
  89. # Make sure all objects are correctly registered
  90. canvas.Update()
  91. # Build a list of objects to check for overlaps
  92. objects = []
  93. for x in canvas.GetListOfPrimitives():
  94. if isinstance(x, ROOT.TH1) or isinstance(x, ROOT.TGraph):
  95. objects.append(x)
  96. elif isinstance(x, ROOT.THStack) or isinstance(x, ROOT.TMultiGraph):
  97. objects.extend(x)
  98. for place in PLACES:
  99. place_user = canvas.PadtoU(*place)
  100. # Make sure there are no overlaps
  101. if any(obj.Overlap(*place_user) for obj in objects):
  102. continue
  103. return canvas.BuildLegend(
  104. place[0],
  105. place[1],
  106. place[2],
  107. place[3],
  108. header,
  109. option,
  110. )
  111. # As a fallback, use the default values, taken from TCanvas::BuildLegend
  112. return canvas.BuildLegend(0.4, 0.37, 0.88, 0.68, header, option)
  113. # Monkey patch ROOT objects to make it all work
  114. ROOT.THStack.__iter__ = lambda self: iter(self.GetHists())
  115. ROOT.TMultiGraph.__iter__ = lambda self: iter(self.GetListOfGraphs())
  116. ROOT.TH1.Overlap = overlap_h
  117. ROOT.TGraph.Overlap = overlap_g
  118. ROOT.TPad.PadtoU = transform_to_user
  119. ROOT.TPad.PlaceLegend = place_legend
  120. def set_legend(legend, gr, title, colors, label):
  121. legend.SetTextSize(0.05)
  122. legend.SetFillColor(0)
  123. legend.SetShadowColor(0)
  124. legend.SetBorderSize(0)
  125. legend.SetTextFont(132)
  126. for idx, lab in enumerate(label):
  127. legend.AddEntry(gr[lab], title[lab], "lep").SetTextColor(colors[idx])
  128. return legend