StRoot  1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
BDTCriteria.h
1 #ifndef BDT_Criteria_h
2 #define BDT_Criteria_h
3 
4 #include "TMVA/Reader.h"
5 #include "Criteria/ICriterion.h"
6 
7 
10 class BDTCrit2 : public KiTrack::ICriterion{
11 
12 public:
13 
14 
15 
16  BDTCrit2 ( float scoreMin , float scoreMax ){
17  _scoreMax = scoreMax;
18  _scoreMin = scoreMin;
19 
20  _name = "Crit2_BDT";
21  _type = "2Hit";
22 
23  _saveValues = false;
24 
25 
26  }
27 
28 
29  float EvalDeltaPhi( KiTrack::IHit*a, KiTrack::IHit*b ){
30  // TODO: work on branchless version?
31  float ax = a->getX();
32  float ay = a->getY();
33  float bx = b->getX();
34  float by = b->getY();
35 
36  float phia = atan2( ay, ax );
37  float phib = atan2( by, bx );
38  float deltaPhi = phia - phib;
39 
40  if (deltaPhi > M_PI) deltaPhi -= 2*M_PI; //to the range from -pi to pi
41  if (deltaPhi < -M_PI) deltaPhi += 2*M_PI; //to the range from -pi to pi
42 
43  if (( by*by + bx*bx < 0.0001 )||( ay*ay + ax*ax < 0.0001 )) deltaPhi = 0.; // In case one of the hits is too close to the origin
44 
45  deltaPhi = 180.*fabs( deltaPhi ) / M_PI;
46  return deltaPhi;
47  }
48  float EvalDeltaRho( KiTrack::IHit*a, KiTrack::IHit*b ){
49  float ax = a->getX();
50  float ay = a->getY();
51  float bx = b->getX();
52  float by = b->getY();
53 
54  float rhoA = sqrt( ax*ax + ay*ay );
55  float rhoB = sqrt( bx*bx + by*by );
56 
57  float deltaRho = rhoA - rhoB;
58  return deltaRho;
59  }
60 
61  float EvalRZRatio( KiTrack::IHit*a, KiTrack::IHit*b ){
62  float ax = a->getX();
63  float ay = a->getY();
64  float az = a->getZ();
65 
66  float bx = b->getX();
67  float by = b->getY();
68  float bz = b->getZ();
69 
70  // the square is used, because it is faster to calculate with the squares than with sqrt, which takes some time!
71  double ratioSquared = 0.;
72  if ( az-bz != 0. )
73  ratioSquared = ( (ax-bx)*(ax-bx) + (ay-by)*(ay-by) + (az-bz)*(az-bz) ) / ( (az-bz) * ( az-bz ) );
74 
75  return sqrt(ratioSquared);
76  }
77 
78  float EvalStraightTrackRatio( KiTrack::IHit*a, KiTrack::IHit*b ){
79  float ax = a->getX();
80  float ay = a->getY();
81  float az = a->getZ();
82 
83  float bx = b->getX();
84  float by = b->getY();
85  float bz = b->getZ();
86 
87  //the distance to (0,0) in the xy plane
88  double rhoASquared = ax*ax + ay*ay;
89  double rhoBSquared = bx*bx + by*by;
90 
91  double ratioSquared = 0;
92  if( (rhoBSquared >0.) && ( az != 0. ) ){ //prevent division by 0
93  // the square is used, because it is faster to calculate with the squares than with sqrt, which takes some time!
94  ratioSquared = ( ( rhoASquared * ( bz*bz ) ) / ( rhoBSquared * ( az*az ) ) );
95  }
96 
97  return sqrt( ratioSquared );
98  }
99 
100  virtual bool areCompatible( KiTrack::Segment* parent , KiTrack::Segment* child ){
101 
102 
103  if ( reader == nullptr ){
104  BDTCrit2::reader = new TMVA::Reader("!Color:!Silent");
105 
106  // setup the inputs
107  BDTCrit2::reader->AddVariable("Crit2_RZRatio", &BDTCrit2::Crit2_RZRatio);
108  BDTCrit2::reader->AddVariable("Crit2_DeltaRho", &BDTCrit2::Crit2_DeltaRho);
109  BDTCrit2::reader->AddVariable("Crit2_DeltaPhi", &BDTCrit2::Crit2_DeltaPhi);
110  BDTCrit2::reader->AddVariable("Crit2_StraightTrackRatio", &BDTCrit2::Crit2_StraightTrackRatio);
111 
112  BDTCrit2::reader->BookMVA("BDT method", "bdt2-Copy1.xml");
113  }
114 
115  if (( parent->getHits().size() == 1 )&&( child->getHits().size() == 1 )){
116  } //a criterion for 1-segments
117  else {
118  std::stringstream s;
119  s << "Crit2_BDT::This criterion needs 2 segments with 1 hit each, passed was a "
120  << parent->getHits().size() << " hit segment (parent) and a "
121  << child->getHits().size() << " hit segment (child).";
122 
123  throw KiTrack::BadSegmentLength( s.str() );
124  }
125 
126  KiTrack::IHit* a = parent->getHits()[0];
127  KiTrack::IHit* b = child-> getHits()[0];
128 
129 
130  //first check, if the distance to (0,0) rises --> such a combo could not reach the IP
131 
132 
133  // compute input values
134  BDTCrit2::Crit2_DeltaPhi = EvalDeltaPhi( a, b );
135  BDTCrit2::Crit2_DeltaRho = EvalDeltaRho( a, b );
136  BDTCrit2::Crit2_RZRatio = EvalRZRatio( a, b );
137  BDTCrit2::Crit2_StraightTrackRatio = EvalStraightTrackRatio( a, b );
138 
139  float score = BDTCrit2::reader->EvaluateMVA("BDT method");
140 
141  if (_saveValues){
142  _map_name_value["Crit2_BDT"] = score;
143  _map_name_value["Crit2_BDT_DeltaPhi"] = BDTCrit2::Crit2_DeltaPhi;
144  _map_name_value["Crit2_BDT_DeltaRho"] = BDTCrit2::Crit2_DeltaRho;
145  _map_name_value["Crit2_BDT_RZRatio"] = BDTCrit2::Crit2_RZRatio;
146  _map_name_value["Crit2_BDT_StraightTrackRatio"] = BDTCrit2::Crit2_StraightTrackRatio;
147  }
148 
149  if ( score < _scoreMin || score > _scoreMax ) return false;
150  return true;
151  }
152 
153  virtual ~BDTCrit2(){};
154 
155 
156 private:
157 
158  float _scoreMin{};
159  float _scoreMax{};
160  static TMVA::Reader *reader;
161  // values input to BDT
162  static float Crit2_RZRatio, Crit2_DeltaRho, Crit2_DeltaPhi, Crit2_StraightTrackRatio;
163 
164 
165 
166 };
167 
168 
169 
170 
171 class BDTCrit3 : public KiTrack::ICriterion{
172 
173 public:
174 
175  BDTCrit3 ( float scoreMin , float scoreMax ){
176  _scoreMax = scoreMax;
177  _scoreMin = scoreMin;
178 
179  _name = "Crit3_BDT";
180  _type = "3Hit";
181 
182  _saveValues = false;
183 
184 
185  }
186 
187 
188  float Eval3DAngle(KiTrack::IHit*a, KiTrack::IHit*b, KiTrack::IHit*c ){
189  return 0;
190  }
191 
192  float Eval2DAngle(KiTrack::IHit*a, KiTrack::IHit*b, KiTrack::IHit*c ){
193  return 0;
194  }
195 
196  float EvalChangeRZRatio(KiTrack::IHit*a, KiTrack::IHit*b, KiTrack::IHit*c ){
197  return 0;
198  }
199 
200 
201  virtual bool areCompatible( KiTrack::Segment* parent , KiTrack::Segment* child ){
202 
203  if ( reader == nullptr ){
204  BDTCrit3::reader = new TMVA::Reader("!Color:!Silent");
205 
206  // setup the inputs
207  BDTCrit3::reader->AddVariable("Crit3_ChangeRZRatio", &BDTCrit3::Crit3_ChangeRZRatio);
208  BDTCrit3::reader->AddVariable("Crit3_3DAngle", &BDTCrit3::Crit3_3DAngle);
209  BDTCrit3::reader->AddVariable("Crit3_2DAngle", &BDTCrit3::Crit3_2DAngle);
210 
211  BDTCrit3::reader->BookMVA("BDT method", "bdt2-Copy1.xml");
212  }
213 
214  if (( parent->getHits().size() == 2 )&&( child->getHits().size() == 2 )){
215  } //a criterion for 1-segments
216  else {
217  std::stringstream s;
218  s << "Crit3_BDT::This criterion needs 2 segments with 1 hit each, passed was a "
219  << parent->getHits().size() << " hit segment (parent) and a "
220  << child->getHits().size() << " hit segment (child).";
221 
222  throw KiTrack::BadSegmentLength( s.str() );
223  }
224 
225  KiTrack::IHit* a = child->getHits()[0];
226  KiTrack::IHit* b = child->getHits()[1];
227  KiTrack::IHit* c = parent-> getHits()[1];
228 
229  // compute input values
230  BDTCrit3::Crit3_2DAngle = Eval2DAngle( a, b, c );
231  BDTCrit3::Crit3_3DAngle = Eval3DAngle( a, b, c );
232  BDTCrit3::Crit3_ChangeRZRatio = EvalChangeRZRatio( a, b, c );
233 
234  float score = BDTCrit3::reader->EvaluateMVA("BDT3 method");
235 
236  if (_saveValues){
237  _map_name_value["Crit3_BDT"] = score;
238  _map_name_value["Crit3_BDT_2DAngle"] = BDTCrit3::Crit3_2DAngle;
239  _map_name_value["Crit3_BDT_3DAngle"] = BDTCrit3::Crit3_3DAngle;
240  _map_name_value["Crit3_BDT_ChangeRZRatio"] = BDTCrit3::Crit3_ChangeRZRatio;
241 
242  }
243 
244  if ( score < _scoreMin || score > _scoreMax ) return false;
245  return true;
246  }
247 
248  virtual ~BDTCrit3(){};
249 
250 private:
251 
252  float _scoreMin{};
253  float _scoreMax{};
254  static TMVA::Reader *reader;
255  // values input to BDT
256  static float Crit3_ChangeRZRatio, Crit3_3DAngle, Crit3_2DAngle;
257 
258 };
259 
260 #endif