StRoot  1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
StPmdDiscriminatorNN.cxx
1 /****************************************************
2  *
3  * $Id: StPmdDiscriminatorNN.cxx,v 1.4 2007/04/26 04:13:46 perev Exp $
4  *
5  * Author: Subhasis Chattopadhyay
6  *
7  ******************************************************
8  *
9  * Description: Class for discrimination through energy
10  * cut is defined.
11  *
12  ******************************************************
13  *
14  * $Log: StPmdDiscriminatorNN.cxx,v $
15  * Revision 1.4 2007/04/26 04:13:46 perev
16  * Remove StBFChain dependency
17  *
18  * Revision 1.3 2003/09/02 17:58:48 perev
19  * gcc 3.2 updates + WarnOff
20  *
21  * Revision 1.2 2003/08/04 18:53:44 perev
22  * warnOff
23  *
24  * Revision 1.1 2003/05/29 13:21:05 subhasis
25  * NN discriminator
26  *
27  *
28  ******************************************************/
29 
30 #include<Stiostream.h>
31 #include"Stiostream.h"
32 #include<assert.h>
33 #include<math.h>
34 #include"TROOT.h"
35 #include<TRandom.h>
36 #include<TBrowser.h>
37 #include<TPad.h>
38 #include<StMessMgr.h>
39 #include<TFile.h>
40 
41 #include "StPmdUtil/StPmdGeom.h"
42 #include "StPmdUtil/StPmdDetector.h"
43 #include "StPmdDiscriminatorMaker.h"
44 #include "StPmdDiscriminatorNN.h"
45 #include "StPmdUtil/StPmdClusterCollection.h"
46 #include "StPmdUtil/StPmdCluster.h"
47 #include "StPmdNeuNet.h"
48 #include "StEventTypes.h"
49 #include "StNNCluster.h"
50 
51 ClassImp(StPmdDiscriminatorNN)
52  ofstream fileo("test_NN",ios::app);
53 
54 Int_t npmdvalue=0;
55 Int_t ncpvvalue=0;
56 Int_t Trained=0;
57 Int_t NTrain=0;
58 
60 {
61  mClContainer=cl_con;
62  mApplyFlagNN=0;
63  cout<<"inside discNN, size**"<<cl_con.size()<<" "<<mClContainer.size()<<endl;
64 }
65 
66 
67 StPmdDiscriminatorNN::StPmdDiscriminatorNN(StPmdDetector* pmd_det, StPmdDetector* cpv_det)
68 {
69  m_PmdDet=pmd_det;
70  m_CpvDet=cpv_det;
71 }
72 
73 StPmdDiscriminatorNN::~StPmdDiscriminatorNN()
74 {
75  //destructor
76 }
77 
78 void StPmdDiscriminatorNN::Discriminate()
79 {
80  StPmdNeuNet *sneu=new StPmdNeuNet("for PMD",4,"4",1);
81  sneu->setDiscMaker(m_DiscMaker);
82  sneu->SetLearnParam(0.2); // the learning parameter (<1)
83  sneu->SetInitParam(-2,2); // bounds for the initialisation of weights
84  sneu ->SetUseBiases();
85  sneu->Init(); // initialisation of the kernel
86  sneu->PrintS(); // printing of network structure
87  Int_t NNSize=mClContainer.size();
88  sneu->SetNTrainEvents(NNSize);
89  sneu->SetArraySize(NNSize);
90  cout<<"nTrainevts "<<sneu->GetNTrainEvents()<<endl;
91  Input(sneu);
92  cout<<"No of INputs **"<<NTrain<<"NNFlag "<<mApplyFlagNN<<endl;
93 
94  if(mApplyFlagNN!=1){
95  if(Trained!=1){
96  sneu->TrainNCycles(100);
97  Trained=1;
98  sneu->Export("NNTrain.out"); // printing of network structure
99  }
100  }
101 
102  cout<<" NNTrain to be imported "<<endl;
103  sneu->Import("NNTrain.out"); // printing of network structure
104  if(mApplyFlagNN==1)
105  {
106  Float_t Teach[20000];
107  Float_t Value[20000];
108  for(Int_t i=0;i<20000;i++){Teach[i]=999; Value[i]=0;}
109 
110  cout<<"NN Valid() Called "<<endl;
111  sneu->ApplyWeights(Teach,Value);
112  for(Int_t i=0;i<20000;i++){
113  if(Teach[i]!=999)fileo<<Teach[i]<<" "<<Value[i]<<endl;
114  }
115  }
116 }
117 
118 void StPmdDiscriminatorNN::Input(StPmdNeuNet* sneu)
119 {
120  TFile * file=new TFile("nninput.root","RECREATE");
121  m_NNedep_ph=new TH1F("nnedp_ph","(ph) PMD edep",100,0.,1000.);
122  m_NNncell_ph=new TH1F("nn_ncell_ph","(ph) PMD ncell",100,0.,20.);
123  m_NNsigma_ph=new TH1F("nn_sigma_ph","(ph) PMD sigma",100,0.,20.);
124  m_NNedep_cpv_ph=new TH1F("nnedp_cpv_ph","(ph) CPV edep",100,0.,100.);
125  m_NNedep_had=new TH1F("nnedp_had","(had) PMD edep",100,0.,1000.);
126  m_NNncell_had=new TH1F("nn_ncell_had","(had) PMD ncell",100,0.,20.);
127  m_NNsigma_had=new TH1F("nn_sigma_had","(had) PMD sigma",100,0.,20.);
128  m_NNedep_cpv_had=new TH1F("nnedp_cpv_had","(had) CPV edep",100,0.,100.);
129 
130 
131  Int_t totno = 0,totcpvno = 0;
132  Float_t aveEnergy = 0., aveNcell = 0., aveSigma = 0.,aveCpvEnergy = 0;
133  Float_t totEnergy = 0., totNcell = 0., totSigma = 0.,totCpvEnergy = 0;
134  // cout<<"CONTAINER SIZE ********* "<<mClContainer.size()<<endl;
135  for(UInt_t i=0;i<mClContainer.size();i++)
136  {
137  StPhmdCluster *cl1=(StPhmdCluster*)(mClContainer[i]->PmdCluster());
138  StPhmdCluster *cl2=(StPhmdCluster*)(mClContainer[i]->CpvCluster());
140 
141  totEnergy = totEnergy + cl1->energy();
142  totNcell = totNcell + cl1->numberOfCells();
143  totSigma = totSigma + cl1->sigma();
144  totno++;
145  if(cl2){
146  totCpvEnergy = totCpvEnergy + cl2->energy();
147  totcpvno++;
148  }
149  }
150  aveEnergy = totEnergy/totno;
151  aveNcell = totNcell/totno;
152  aveSigma = totSigma/totno;
153  aveCpvEnergy = totCpvEnergy/totcpvno;
154 
155  if(mApplyFlagNN==1){
156  // Put avergaes by hand
157 
158  /*Single particle
159  aveEnergy =0.000253643 ;
160  aveNcell =2.54512;
161  aveSigma =0.467905;
162  aveCpvEnergy =1.8775e-05;
163  */
164  // AuAu
165  aveEnergy =0.000157085 ;
166  aveNcell =2.20981;
167  aveSigma =0.476074;
168  aveCpvEnergy =1.26994e-05;
169 
170  }
171 
172  fileo<<" AveEnergy "<<aveEnergy<<" AveNcell "<<aveNcell<<" aveSigma "<<aveSigma<<" aveCpvEne "<<aveCpvEnergy<<" TotNo "<<totno<<" TotCPvNo "<<totcpvno<<endl;
173 
174  Float_t outEnergy,outNcell,outSigma,outCpvEnergy;
175 
176  for(UInt_t i=0;i<mClContainer.size();i++)
177  {
178  StPhmdCluster *cl1=(StPhmdCluster*)(mClContainer[i]->PmdCluster());
179  StPhmdCluster *cl2=(StPhmdCluster*)(mClContainer[i]->CpvCluster());
180  Float_t target;
181  if(cl1->mcPid()==1)target=1.;
182  if(cl1->mcPid()==8)target=0.;
183  sneu->fillArrayOut(target,i,0);
184 
185 //VP Int_t sm=cl1->module();
186  Float_t energy=cl1->energy();
187  InputRange(energy,aveEnergy,outEnergy);
188  sneu->FillArray(i,0,outEnergy);
189 
190  Int_t ncell=cl1->numberOfCells();
191  InputRange(ncell,aveNcell,outNcell);
192 
193  sneu->FillArray(i,1,outNcell);
194 
195  Float_t sigma=cl1->sigma();
196  InputRange(sigma,aveSigma,outSigma);
197 
198  sneu->FillArray(i,2,outSigma);
199 
200  Float_t cpv_energy=0.;
201  if(cl2){
202  cpv_energy=cl2->energy();
203  InputRange(cpv_energy,aveCpvEnergy,outCpvEnergy);
204 
205  sneu->FillArray(i,3,outCpvEnergy);
206  }
207 
208  if(target==1)m_NNncell_ph->Fill(Float_t(ncell));
209  if(target==0)m_NNncell_had->Fill(Float_t(ncell));
210  if(target==1)m_NNsigma_ph->Fill(sigma);
211  if(target==0)m_NNsigma_had->Fill(sigma);
212  if(target==1&&cpv_energy>0)m_NNedep_cpv_ph->Fill(cpv_energy);
213  if(target==0&&cpv_energy>0)m_NNedep_cpv_had->Fill(cpv_energy);
214  if(target==1)m_NNedep_ph->Fill(energy);
215  if(target==0)m_NNedep_had->Fill(energy);
216  NTrain++;
217  }
218  cout<<"In Input **"<<NTrain<<endl;
219  m_NNedep_ph->Write();
220  m_NNncell_ph->Write();
221  m_NNsigma_ph->Write();
222  m_NNedep_cpv_ph->Write();
223 
224  m_NNedep_had->Write();
225  m_NNncell_had->Write();
226  m_NNsigma_had->Write();
227  m_NNedep_cpv_had->Write();
228  file->Close();
229 }
230 
231 void StPmdDiscriminatorNN::setFormula()
232 {
233  cout<<"In the Neural Network"<<npmdvalue<<ncpvvalue<<endl;
234 }
235 
236 //function used for scaling the input variables to (-1,1)
237 
238 Float_t StPmdDiscriminatorNN::InputRange(Float_t Input,Float_t aveInput, Float_t& Output)
239 {
240  Float_t fx;
241  Float_t ax = 1.;
242  if(aveInput !=0.){
243  if((Input/aveInput)<10.){
244  fx = (2./(1. + exp(-ax*Input/aveInput))) - 1;
245  }
246  else
247  {
248  fx = (2./(1. + exp(-10.))) - 1.;
249  }
250  Output = 1 - 2. * fx;
251  }
252  return Output;
253 }
254 
255 
virtual Double_t ApplyWeights(Float_t *, Float_t *)
virtual void SetLearnParam(Double_t learnParam=0.2, Double_t fse=0., Double_t mu=0.)
virtual void Init()
virtual void SetInitParam(Float_t lowerInitWeight=-1., Float_t upperInitWeight=1.)
Sets the initialisation parameters : max and min weights.
virtual void TrainNCycles(Int_t nCycles=10)
virtual void Export(const Text_t *fileName="exportNN.dat")
virtual void Import(const Text_t *fileName="exportNN.dat")
virtual void PrintS()
prints structure of network on screen
StPhmdCluster * PmdCluster()
destructor
Definition: StNNCluster.h:65