HMLP: High-performance Machine Learning Primitives
root.hpp
1 #ifndef ROOT_HPP
2 #define ROOT_HPP
3 
4 #include <assert.h>
5 #include <typeinfo>
6 #include <algorithm>
7 #include <random>
8 #include <limits>
9 #include <cstddef>
10 #include <math.h>
11 
12 namespace hmlp
13 {
14 namespace root
15 {
16 
17 template<typename T>
19 {
20  public:
21 
22  RootFinderBase() {};
23 
24  virtual std::pair<T, T> Initialize() = 0;
25 
26  virtual std::pair<T, T> Iterate() = 0;
27 
28  virtual std::pair<T, T> Solve( T user_x_lower, T user_x_upper )
29  {
31  x_lower = user_x_lower;
32  x_upper = user_x_upper;
33  x = ( x_lower + x_upper ) / 2.0;
34 
35  auto root = Initialize();
36  auto previous = root;
37 
38  size_t iter = 0;
39  do
40  {
41  auto previous = root;
42  root = Iterate();
43  iter ++;
44  if ( !( iter % 50 ) ) printf( "iter %4lu f(x) %E\n", iter, root.second );
45  } while ( !Terminate( iter, root, previous ) );
46 
47  return root;
48  };
49 
50  T x = 0.0;
51 
52  T x_lower = 0.0;
53 
54  T x_upper = 0.0;
55 
56  void SetupTerminationCriteria( size_t niter, T tol )
57  {
58  this->use_defined_termination_criteria = true;
59  this->niter = niter;
60  this->tol = tol;
61  };
62 
63  bool ReachMaximumIteration( size_t iter )
64  {
65  return ( iter < niter );
66  };
67 
68  bool Terminate( size_t iter, std::pair<T, T> &now, std::pair<T, T> &previous )
69  {
70  bool doterminate = false;
71  if ( std::fabs( now.second ) < std::numeric_limits<T>::epsilon() )
72  {
73  return true;
74  }
75 
76  if ( use_defined_termination_criteria )
77  {
78  if ( iter >= niter ) doterminate = true;
79  if ( std::fabs( ( now.second - previous.second ) / previous.second ) < tol )
80  doterminate = true;
81  }
82  return doterminate;
83  };
84 
85  private:
86 
87  bool use_defined_termination_criteria = false;
88 
90  size_t niter = 10;
91 
92  T tol = 1E-5;
93 };
94 
95 
99 template<typename FUNC, typename T>
100 class Bisection : public RootFinderBase<T>
101 {
102  public:
103 
104  Bisection( FUNC *func )
105  {
106  this->func = func;
107  };
108 
109  std::pair<T, T> Initialize()
110  {
111  T &x_lower = this->x_lower;
112  T &x_upper = this->x_upper;
113 
115  f_lower = func->F( x_lower );
116  f_upper = func->F( x_upper );
117 
119  if ( ( f_lower < 0.0 && f_upper < 0.0 ) || ( f_lower > 0.0 && f_upper > 0.0 ) )
120  {
121  printf( "endpoints do not straddle y = 0\n" );
122  exit( 1 );
123  }
124 
126  std::pair<T, T> root( ( x_lower + x_upper ) / 2.0, ( f_lower + f_upper ) / 2.0 );
127 
128  return root;
129  };
130 
131 
132  std::pair<T, T> Iterate()
133  {
134  T x_bisect, f_bisect;
135 
136  T &x_lower = this->x_lower;
137  T &x_upper = this->x_upper;
138 
139  if ( f_lower == 0.0 ) return std::pair<T, T>( x_lower, f_lower );
140  if ( f_upper == 0.0 ) return std::pair<T, T>( x_upper, f_upper );
141 
143  x_bisect = ( x_lower + x_upper ) / 2.0;
144  f_bisect = func->F( x_bisect );
145  if ( f_bisect == 0.0 )
146  {
147  x_lower = x_bisect;
148  x_upper = x_bisect;
149  return std::pair<T, T>( x_bisect, f_bisect );
150  }
151 
153  if ( ( f_lower > 0.0 && f_bisect < 0.0 ) || ( f_lower < 0.0 && f_bisect > 0.0 ) )
154  {
155  x_upper = x_bisect;
156  f_upper = f_bisect;
157  return std::pair<T, T>( 0.5 * ( x_lower + x_upper ), f_upper );
158  }
159  else
160  {
161  x_lower = x_bisect;
162  f_lower = f_bisect;
163  return std::pair<T, T>( 0.5 * ( x_lower + x_upper ), f_lower );
164  }
165  };
167  private:
168 
169  FUNC *func = NULL;
170 
172  T f_lower = 0.0;
173 
175  T f_upper = 0.0;
176 
177 };
180 template<typename FUNC, typename T>
181 class Newton : public RootFinderBase<T>
182 {
183  public:
184 
185  Newton( FUNC *func )
186  {
187  this->func = func;
188  };
189 
190  std::pair<T, T> Initialize()
191  {
192  T x_bisect, f_bisect, df_bisec;
193 
194  T &x_lower = this->x_lower;
195  T &x_upper = this->x_upper;
196 
197  if ( !func->HasdF() )
198  {
199  printf( "Newton(): no first order derivity provided\n" );
200  exit( 1 );
201  };
202 
203  x_bisect = ( x_lower + x_upper ) / 2.0;
204  f = func->F( x_bisect );
205  df = func->dF( x_bisect );
206 
208  std::pair<T, T> root( x_bisect, f );
209 
210  return root;
211  };
212 
213  std::pair<T, T> Iterate()
214  {
215  auto &x = this->x;
216 
217  if ( df == 0.0 )
218  {
219  printf( "Newton(): derivative is zero\n" );
220  return std::pair<T, T>( x, f );
221  }
222 
224  x -= ( f / df );
225 
226  if ( func->HasFdF() )
227  {
228  auto fdf = func->FdF( x );
229  f = fdf.first;
230  df = fdf.second;
231  }
232  else
233  {
234  f = func->F( x );
235  df = func->dF( x );
236  }
237 
238  if ( f != f )
239  {
240  printf( "Newton(): f( x ) is infinite\n" );
241  }
242 
243  if ( df != df )
244  {
245  printf( "Newton(): f'( x ) is infinite\n" );
246  }
247 
248  return std::pair<T, T>( x, f );
249  };
250 
251  private:
252 
253  FUNC *func = NULL;
254 
255  T f = 0.0;
256 
257  T df = 0.0;
258 
259 };
262 };
263 };
265 #endif
This is not thread safe.
Definition: root.hpp:100
std::pair< T, T > Initialize()
Definition: root.hpp:109
std::pair< T, T > Iterate()
Definition: root.hpp:132
Definition: root.hpp:181
std::pair< T, T > Initialize()
Definition: root.hpp:190
std::pair< T, T > Iterate()
Definition: root.hpp:213
Definition: root.hpp:18
virtual std::pair< T, T > Solve(T user_x_lower, T user_x_upper)
Definition: root.hpp:28
Definition: gofmm.hpp:83