StRoot  1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
StPmdNeuNet.h
1 
4 //
6 // Neural Network classes :
7 // TNNFormula
8 // TNNTree
9 // TNNKernel
10 // TNNControlE
11 // TNNUtils
12 // J.P. Ernenwein (rnenwein@in2p3.fr)
14 
15 
17 /*
18 #ifndef ROOT_TNamed
19 #include "TNamed.h"
20 #endif
21 #ifndef ROOT_TROOT
22 #include "TROOT.h"
23 #endif
24 #ifndef ROOT_TTree
25 #include "TTree.h"
26 #endif
27 #ifndef ROOT_TString
28 #include "TString.h"
29 #endif
30 //#ifndef ROOT_TFormula
31 //#include "TFormula.h"
32 //#endif
33 //#ifndef ROOT_TTreeFormula
34 //#include "TTreeFormula.h"
35 //#endif
36 #ifndef ROOT_TCanvas
37 #include "TCanvas.h"
38 #endif
39 */
40 #include "TFrame.h"
41 #include "TStringLong.h"
42 #include "TFile.h"
43 #include "TText.h"
44 #include "TDatime.h"
45 #include "TRandom.h"
46 #include "TPad.h"
47 #include "math.h"
48 #include "stdlib.h"
49 #include "Stiostream.h"
50 #include "StPmdDiscriminatorMaker.h"
51 
53 
54 class StPmdNeuNet : public TNamed
55 {
56 
57 
58  private:
59  Int_t fNHiddL; // number of hidden layers
60  Float_t **fValues;
61  Double_t **fErrors;
62  Double_t **fBiases;
63  Int_t *fNUnits;
64  Double_t ***fW;
65 
66  Int_t fNTrainEvents; // number of events for training
67  Int_t fNValidEvents; // number of events for validation
68 // TNNTree *fValidTree; // validation tree
69  Double_t fLearnParam; // learning parameter
70  Float_t fLowerInitWeight; // minimum weight for initialisation
71  Float_t fUpperInitWeight; // maximum weight for initialisation
72  Float_t **fArrayOut;
73  Float_t *fTeach;
74  Float_t **fArrayIn;
75  Int_t *fEventsList;
76  Int_t fNTrainCycles; // Number of training cycles done
77  Double_t fUseBiases; // flag for use of biases or not (1=use, 0=no use)
78  TRandom fRandom; // Random object used in initialisation and mixing
79  Int_t fNWeights; // number of weights in neural network
80  Double_t fMu; // backpropagation momentum parameter
81  Double_t fFlatSE; // Flat Spot elimination paramater
82  Double_t ***fDW;
83  Double_t **fDB;
84 
85 
86  void GetArrayEvt(Int_t iEvent)
87  {
88  Int_t l;
89  for(l=0;l<fNUnits[0];l++)fValues[0][l]=fArrayIn[iEvent][l];
90  for(l=0;l<fNUnits[fNHiddL+1];l++)fTeach[l]=fArrayOut[iEvent][l];
91  };
92  void LearnBackward(); // gradient retropropagation (updates of biases and weights)
93  void Forward(); // do a simple forward propagation
94  Double_t Error();// compute the error between forward propagation and teaching
95  Double_t ErrorO();// compute the error between forward propagation and teaching
96  void Error(const char*, const char*, ...) const{}//WarnOff
97  void FreeVW();
98  void ZeroAll();
99  void AllocateVW(Int_t nInput, const Text_t *hidden, Int_t nOutput);
100  void SetHidden(const Text_t *ttext);
101  Float_t Alea();
102  void DeleteArray();
103 
104  protected:
105  virtual Double_t Sigmoide(Double_t x)
106  {
107  if(x> 10.) return 0.99999; // probability MUST be < 1
108  if(x<-10.) return 0.;
109  return (1./(1.+exp(-x)));
110  };
111  virtual Double_t SigPrim(Double_t x){return (x*(1.-x));};
112  StPmdDiscriminatorMaker * m_DiscMaker;
113 
114  public:
115  StPmdNeuNet();
116  StPmdNeuNet(const Text_t *name, Int_t nInput=5, const Text_t *hidden="6:7:8", Int_t nOutput=4);
117  void setDiscMaker(StPmdDiscriminatorMaker*);
118  virtual ~StPmdNeuNet(); // destructor
119  virtual void SetKernel(Int_t nInput, const Text_t *hidden, Int_t nOutput);
120  virtual void SetLearnParam(Double_t learnParam=0.2,Double_t fse=0.,Double_t mu=0.);
121  virtual void SetInitParam(Float_t lowerInitWeight=-1., Float_t upperInitWeight=1.);
122  virtual void Init(); // init biases and weights
123  virtual void PrintS(); // print structure of network
124  virtual void Mix(); // mix the events before learning
125  virtual Double_t TrainOneCycle(); // one loop on internal events = one cycle
126  virtual void ResetCycles(){fNTrainCycles=0;};
127  virtual void Export(const Text_t *fileName="exportNN.dat");
128  virtual void Import(const Text_t *fileName="exportNN.dat");
129  virtual void SetUseBiases(Bool_t trueForUse=1){fUseBiases=(Double_t)trueForUse;};
130  virtual void SetRandomSeed(UInt_t seed=0){fRandom.SetSeed(seed);};
131  virtual UInt_t GetRandomSeed(){return fRandom.GetSeed();};
132  virtual Bool_t IsTrained(){return fNTrainCycles;};
133  virtual Int_t GetNTrainCycles(){return fNTrainCycles;};
134  virtual Int_t GetNTrainEvents(){return fNTrainEvents;};
135  virtual void SetNTrainEvents(Int_t nevt){fNTrainEvents = nevt;};
136  virtual Int_t GetNValidEvents(){return fNValidEvents;};
137  virtual void SetArraySize(Int_t s=0);
138  virtual void FillArray(Int_t,Int_t,Float_t);
139  virtual void Fill(Int_t iev=0)
140  {
141  Int_t i;
142  for(i=0;i<fNUnits[0];i++)fArrayIn[iev][i]=fValues[0][i];
143  for(i=0;i<fNUnits[fNHiddL+1];i++)fArrayOut[iev][i]=fTeach[i];
144  }
145  virtual Float_t* GetInputAdr(){return fValues[0];};
146  virtual void SetInput(Float_t v,Int_t i){fValues[0][i]=v;};
147  virtual Int_t GetNInput(){return fNUnits[0];};
148  virtual Int_t GetNOutput(){return fNUnits[fNHiddL+1];};
149  virtual Float_t GetOutput(Int_t unit=0){return fValues[fNHiddL+1][unit];};
150  virtual Float_t* GetOutputAdr(){return fValues[fNHiddL+1];};
151  virtual Float_t* GetTeachAdr(){return fTeach;};
152  virtual void SetTeach(Float_t v,Int_t i){fTeach[i]=v;};
153  virtual void fillArrayOut(Float_t v,Int_t i,Int_t l){fArrayOut[i][l]=v;};
154  virtual Double_t GoThrough(){Forward();return ErrorO();};
155  virtual Float_t GetSumO()
156  {
157  Int_t i; Float_t s=0.;
158  for(i=0;i<fNUnits[fNHiddL+1];i++)s+=fValues[fNHiddL+1][i];
159  return s;
160  };
161 
162  void PrintTrain()
163  {
164  cout<<"Units** "<<fNUnits[fNHiddL+1]<<endl;
165 
166  Int_t l;
167  for(l=0;l<fNUnits[fNHiddL+1];l++){
168  cout<<"teach "<<fTeach[l]<<"Value "<<fValues[fNHiddL+1][l]<<endl;
169  }
170  }
171 
172  // virtual void SetTrainTree(TNNTree *t);
173 // virtual void SetValidTree(TNNTree *t);
174  virtual Double_t Valid();
175 // virtual void TrainNCycles(TNNControlE *conte, Int_t period=5, Int_t nCycles=10);
176  virtual void TrainNCycles(Int_t nCycles=10);
177  virtual Int_t GetNWeights()
178  {
179  if(!fNUnits)return 0;
180  Int_t n=0;
181  for(Int_t i=0;i<fNHiddL+1;i++)
182  {
183  n+=fNUnits[i]*fNUnits[i+1];
184  }
185  return n;
186  };
187 
188  virtual Double_t ApplyWeights(Float_t*,Float_t*); // one loop on internal events = one cycle
189 
190  ClassDef(StPmdNeuNet,1)
191 
192 };
193 
194 inline void StPmdNeuNet::setDiscMaker(StPmdDiscriminatorMaker* disc){m_DiscMaker=disc;}
virtual Double_t ApplyWeights(Float_t *, Float_t *)
virtual Double_t Valid()
Definition: FJcore.h:367
virtual void SetLearnParam(Double_t learnParam=0.2, Double_t fse=0., Double_t mu=0.)
virtual void Init()
StPmdNeuNet()
Constructor with no parameter . Purpose ??
Definition: StPmdNeuNet.cxx:42
virtual void SetInitParam(Float_t lowerInitWeight=-1., Float_t upperInitWeight=1.)
Sets the initialisation parameters : max and min weights.
virtual void TrainNCycles(Int_t nCycles=10)
virtual void Export(const Text_t *fileName="exportNN.dat")
virtual void Import(const Text_t *fileName="exportNN.dat")
virtual void PrintS()
prints structure of network on screen
virtual Double_t TrainOneCycle()
virtual void Mix()