00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef _adevs_rk45_improved_h_
00021 #define _adevs_rk45_improved_h_
00022 #include "adevs_dess.h"
00023 #include <cmath>
00024 #include <algorithm>
00025 #include <iostream>
00026 #include <cassert>
00027
00028 namespace adevs
00029 {
00030
00052 template <class X> class rk45_improved: public DESS<X>
00053 {
00054 public:
00062 rk45_improved(int num_state_vars, double h_max, double err_tol,
00063 int zero_crossing_funcs, double event_tol = 1E-12);
00067 void init(int i, double q0) { q[i] = q0; }
00071 const double* getStateVars() const { return q; }
00075 int getNumStateVars() const { return num_state_vars; }
00080 virtual void der_func(const double* q, double* dq) = 0;
00087 virtual void state_event_func(const double* q, double* z) = 0;
00092 virtual double time_event_func(const double* q) = 0;
00107 virtual void discrete_action(double* q, const Bag<X>& xb, const bool* event_flags) = 0;
00112 virtual void discrete_output(const double* q, Bag<X>& yb, const bool* event_flags) = 0;
00121 virtual void state_changed(const double* q){}
00122
00123 void evolve_func(double h);
00124
00125 double next_event_func(bool& is_event);
00126
00127 void discrete_action_func(const Bag<X>& xb);
00128
00129 void discrete_output_func(Bag<X>& yb);
00130
00131 void state_changed() { state_changed(q); }
00133 ~rk45_improved();
00134
00135 private:
00136
00137
00138 const double h_max, err_tol, event_tol;
00139
00140
00141 double *q, *dq, *t, *k[6], *q_tmp, *es, *en;
00142
00143 bool* event_indicator;
00144
00145 double h_cur;
00146
00147 bool keep_q_tmp;
00148
00149 const int num_state_vars, zero_funcs;
00150
00151
00152
00153
00154
00155 double ode_step(double *qq, double step);
00156 };
00157
00158 template <class X>
00159 rk45_improved<X>::rk45_improved(int num_state_vars, double h_max, double err_tol, int zero_funcs, double event_tol):
00160 DESS<X>(),
00161 h_max(h_max),
00162 err_tol(err_tol),
00163 event_tol(event_tol),
00164 h_cur(h_max),
00165 keep_q_tmp(false),
00166 num_state_vars(num_state_vars),
00167 zero_funcs(zero_funcs)
00168 {
00169 q = new double[num_state_vars];
00170 dq = new double[num_state_vars];
00171 t = new double[num_state_vars];
00172 q_tmp = new double[num_state_vars];
00173 for (int i = 0; i < 6; i++)
00174 k[i] = new double[num_state_vars];
00175 en = new double[zero_funcs];
00176 es = new double[zero_funcs];
00177 event_indicator = new bool[zero_funcs+1];
00178 }
00179
00180 template <class X>
00181 rk45_improved<X>::~rk45_improved()
00182 {
00183 delete [] q;
00184 delete [] dq;
00185 delete [] t;
00186 delete [] q_tmp;
00187 for (int i = 0; i < 6; i++)
00188 delete [] k[i];
00189 delete [] es;
00190 delete [] en;
00191 delete [] event_indicator;
00192 }
00193
00194 template <class X>
00195 void rk45_improved<X>::evolve_func(double h)
00196 {
00197
00198 if (h == h_cur)
00199 {
00200
00201 if (keep_q_tmp)
00202 {
00203 for (int i = 0; i < num_state_vars; i++)
00204 q[i] = q_tmp[i];
00205 }
00206
00207 else ode_step(q,h);
00208 }
00209
00210
00211 else
00212 {
00213
00214 event_indicator[zero_funcs] = false;
00215
00216 state_event_func(q,es);
00217
00218 ode_step(q,h);
00219
00220 state_event_func(q,en);
00221
00222 for (int i = 0; i < zero_funcs; i++)
00223 {
00224 event_indicator[i] = (es[i]*en[i] < 0.0) || (fabs(en[i]) <= event_tol);
00225 }
00226 }
00227
00228 keep_q_tmp = false;
00229 }
00230
00231 template <class X>
00232 double rk45_improved<X>::next_event_func(bool& is_event)
00233 {
00234
00235 keep_q_tmp = true;
00236
00237 double time_event = time_event_func(q);
00238
00239 h_cur *= 1.2;
00240 if (h_cur > h_max) h_cur = h_max;
00241 if (h_cur > time_event) h_cur = time_event;
00242 for ( ; ; )
00243 {
00244
00245 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i];
00246 double err = ode_step(q_tmp,h_cur);
00247 if (err <= err_tol) break;
00248
00249 double h_next = 0.8*pow(err_tol*h_cur*h_cur*h_cur*h_cur*h_cur,0.2)/fabs(err);
00250 if (h_next >= h_cur) h_next = 0.8*h_cur;
00251 h_cur = h_next;
00252 }
00253
00254
00255
00256
00257 state_event_func(q,es);
00258
00259 while(true)
00260 {
00261
00262 state_event_func(q_tmp,en);
00263
00264 bool found_state_event = false;
00265 double h_next = h_cur;
00266 for (int i = 0; i < zero_funcs; i++)
00267 {
00268 bool sign_change = (es[i]*en[i] < 0.0);
00269 bool tolerance_met = event_indicator[i] = (fabs(en[i]) <= event_tol);
00270 if (tolerance_met) found_state_event = true;
00271
00272
00273
00274 if (sign_change && !tolerance_met)
00275 {
00276 double t_cross = h_cur*es[i]/(es[i]-en[i]);
00277 assert(t_cross >= 0.0);
00278 if (t_cross < h_next) h_next = t_cross;
00279 }
00280 }
00281
00282
00283 if (h_next == h_cur)
00284 {
00285
00286 event_indicator[zero_funcs] = (h_next >= time_event);
00287
00288 is_event = found_state_event || event_indicator[zero_funcs];
00289
00290 return h_cur;
00291 }
00292
00293 assert(h_next < h_cur);
00294 h_cur = h_next;
00295 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i];
00296 ode_step(q_tmp,h_cur);
00297 }
00298 }
00299
00300 template <class X>
00301 void rk45_improved<X>::discrete_action_func(const Bag<X>& xb)
00302 {
00303
00304 h_cur = h_max;
00305 keep_q_tmp = false;
00306
00307 discrete_action(q,xb,event_indicator);
00308 }
00309
00310 template <class X>
00311 void rk45_improved<X>::discrete_output_func(Bag<X>& yb)
00312 {
00313 discrete_output(q,yb,event_indicator);
00314 }
00315
00316 template <class X>
00317 double rk45_improved<X>::ode_step(double*qq, double step)
00318 {
00319 if (step == 0.0)
00320 {
00321 return 0.0;
00322 }
00323
00324 der_func(qq,dq);
00325 for (int j = 0; j < num_state_vars; j++)
00326 k[0][j] = step*dq[j];
00327
00328 for (int j = 0; j < num_state_vars; j++)
00329 t[j] = qq[j] + 0.5*k[0][j];
00330 der_func(t,dq);
00331 for (int j = 0; j < num_state_vars; j++)
00332 k[1][j] = step*dq[j];
00333
00334 for (int j = 0; j < num_state_vars; j++)
00335 t[j] = qq[j] + 0.25*(k[0][j]+k[1][j]);
00336 der_func(t,dq);
00337 for (int j = 0; j < num_state_vars; j++)
00338 k[2][j] = step*dq[j];
00339
00340 for (int j = 0; j < num_state_vars; j++)
00341 t[j] = qq[j] - k[1][j] + 2.0*k[2][j];
00342 der_func(t,dq);
00343 for (int j = 0; j < num_state_vars; j++)
00344 k[3][j] = step*dq[j];
00345
00346 for (int j = 0; j < num_state_vars; j++)
00347 t[j] = qq[j] + (7.0/27.0)*k[0][j] + (10.0/27.0)*k[1][j] + (1.0/27.0)*k[3][j];
00348 der_func(t,dq);
00349 for (int j = 0; j < num_state_vars; j++)
00350 k[4][j] = step*dq[j];
00351
00352 for (int j = 0; j < num_state_vars; j++)
00353 t[j] = qq[j] + (28.0/625.0)*k[0][j] - 0.2*k[1][j] + (546.0/625.0)*k[2][j]
00354 + (54.0/625.0)*k[3][j] - (378.0/625.0)*k[4][j];
00355 der_func(t,dq);
00356 for (int j = 0 ; j < num_state_vars; j++)
00357 k[5][j] = step*dq[j];
00358
00359 double err = 0.0;
00360 for (int j = 0; j < num_state_vars; j++)
00361 {
00362 qq[j] += (1.0/24.0)*k[0][j] + (5.0/48.0)*k[3][j] +
00363 (27.0/56.0)*k[4][j] + (125.0/336.0)*k[5][j];
00364 err = std::max(err,
00365 fabs(k[0][j]/8.0+2.0*k[2][j]/3.0+k[3][j]/16.0-27.0*k[4][j]/56.0
00366 -125.0*k[5][j]/336.0));
00367 }
00368 return err;
00369 }
00370
00371 }
00372
00373 #endif