00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef _adevs_rk45_improved_h_
00021 #define _adevs_rk45_improved_h_
00022 #include "adevs_dess.h"
00023 #include <cmath>
00024 #include <algorithm>
00025 #include <iostream>
00026 #include <cassert>
00027
00028 namespace adevs
00029 {
00030
00051 template <class X> class rk45_improved: public DESS<X>
00052 {
00053 public:
00061 rk45_improved(int num_state_vars, double h_max, double err_tol,
00062 int zero_crossing_funcs, double event_tol = 1E-12);
00066 void init(int i, double q0) { q[i] = q0; }
00070 const double* getStateVars() const { return q; }
00074 int getNumStateVars() const { return num_state_vars; }
00079 virtual void der_func(const double* q, double* dq) = 0;
00086 virtual void state_event_func(const double* q, double* z) = 0;
00091 virtual double time_event_func(const double* q) = 0;
00106 virtual void discrete_action(double* q, const Bag<X>& xb, const bool* event_flags) = 0;
00111 virtual void discrete_output(const double* q, Bag<X>& yb, const bool* event_flags) = 0;
00120 virtual void state_changed(const double* q){}
00121
00122 void evolve_func(double h);
00123
00124 double next_event_func(bool& is_event);
00125
00126 void discrete_action_func(const Bag<X>& xb);
00127
00128 void discrete_output_func(Bag<X>& yb);
00129
00130 void state_changed() { state_changed(q); }
00132 ~rk45_improved();
00133
00134 private:
00135
00136
00137 const double h_max, err_tol, event_tol;
00138
00139
00140 double *q, *dq, *t, *k[6], *q_tmp, *es, *en;
00141
00142 bool* event_indicator;
00143
00144 double h_cur;
00145
00146 bool keep_q_tmp;
00147
00148 const int num_state_vars, zero_funcs;
00149
00150
00151
00152
00153
00154 double ode_step(double *qq, double step);
00155 };
00156
00157 template <class X>
00158 rk45_improved<X>::rk45_improved(int num_state_vars, double h_max, double err_tol, int zero_funcs, double event_tol):
00159 DESS<X>(),
00160 h_max(h_max),
00161 err_tol(err_tol),
00162 event_tol(event_tol),
00163 h_cur(h_max),
00164 keep_q_tmp(false),
00165 num_state_vars(num_state_vars),
00166 zero_funcs(zero_funcs)
00167 {
00168 q = new double[num_state_vars];
00169 dq = new double[num_state_vars];
00170 t = new double[num_state_vars];
00171 q_tmp = new double[num_state_vars];
00172 for (int i = 0; i < 6; i++)
00173 k[i] = new double[num_state_vars];
00174 en = new double[zero_funcs];
00175 es = new double[zero_funcs];
00176 event_indicator = new bool[zero_funcs+1];
00177 }
00178
00179 template <class X>
00180 rk45_improved<X>::~rk45_improved()
00181 {
00182 delete [] q;
00183 delete [] dq;
00184 delete [] t;
00185 delete [] q_tmp;
00186 for (int i = 0; i < 6; i++)
00187 delete [] k[i];
00188 delete [] es;
00189 delete [] en;
00190 delete [] event_indicator;
00191 }
00192
00193 template <class X>
00194 void rk45_improved<X>::evolve_func(double h)
00195 {
00196
00197 if (h == h_cur)
00198 {
00199
00200 if (keep_q_tmp)
00201 {
00202 for (int i = 0; i < num_state_vars; i++)
00203 q[i] = q_tmp[i];
00204 }
00205
00206 else ode_step(q,h);
00207 }
00208
00209
00210 else
00211 {
00212
00213 event_indicator[zero_funcs] = false;
00214
00215 state_event_func(q,es);
00216
00217 ode_step(q,h);
00218
00219 state_event_func(q,en);
00220
00221 for (int i = 0; i < zero_funcs; i++)
00222 {
00223 event_indicator[i] = (es[i]*en[i] < 0.0) || (fabs(en[i]) <= event_tol);
00224 }
00225 }
00226
00227 keep_q_tmp = false;
00228 }
00229
00230 template <class X>
00231 double rk45_improved<X>::next_event_func(bool& is_event)
00232 {
00233
00234 keep_q_tmp = true;
00235
00236 double time_event = time_event_func(q);
00237
00238 h_cur *= 1.2;
00239 if (h_cur > h_max) h_cur = h_max;
00240 if (h_cur > time_event) h_cur = time_event;
00241 for ( ; ; )
00242 {
00243
00244 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i];
00245 double err = ode_step(q_tmp,h_cur);
00246 if (err <= err_tol) break;
00247
00248 double h_next = 0.8*pow(err_tol*h_cur*h_cur*h_cur*h_cur*h_cur,0.2)/fabs(err);
00249 if (h_next > h_cur) h_next = 0.8*h_cur;
00250 h_cur = h_next;
00251 }
00252
00253
00254
00255
00256 state_event_func(q,es);
00257
00258 while(true)
00259 {
00260
00261 state_event_func(q_tmp,en);
00262
00263 bool found_state_event = false;
00264 double h_next = h_cur;
00265 for (int i = 0; i < zero_funcs; i++)
00266 {
00267 bool sign_change = (es[i]*en[i] < 0.0);
00268 bool tolerance_met = event_indicator[i] = (fabs(en[i]) <= event_tol);
00269 if (tolerance_met) found_state_event = true;
00270
00271
00272
00273 if (sign_change && !tolerance_met)
00274 {
00275 double t_cross = h_cur*es[i]/(es[i]-en[i]);
00276 assert(t_cross < h_cur);
00277 assert(t_cross > 0.0);
00278 if (t_cross < h_next) h_next = t_cross;
00279 }
00280 }
00281
00282
00283 if (h_next == h_cur)
00284 {
00285
00286 event_indicator[zero_funcs] = (h_next >= time_event);
00287
00288 is_event = found_state_event || event_indicator[zero_funcs];
00289
00290 return h_cur;
00291 }
00292
00293 assert(h_next < h_cur);
00294 h_cur = h_next;
00295 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i];
00296 ode_step(q_tmp,h_cur);
00297 }
00298 }
00299
00300 template <class X>
00301 void rk45_improved<X>::discrete_action_func(const Bag<X>& xb)
00302 {
00303
00304 h_cur = h_max;
00305 keep_q_tmp = false;
00306
00307 discrete_action(q,xb,event_indicator);
00308 }
00309
00310 template <class X>
00311 void rk45_improved<X>::discrete_output_func(Bag<X>& yb)
00312 {
00313 discrete_output(q,yb,event_indicator);
00314 }
00315
00316 template <class X>
00317 double rk45_improved<X>::ode_step(double*qq, double step)
00318 {
00319 if (step == 0.0)
00320 {
00321 return 0.0;
00322 }
00323
00324 der_func(qq,dq);
00325 for (int j = 0; j < num_state_vars; j++)
00326 k[0][j] = step*dq[j];
00327
00328 for (int j = 0; j < num_state_vars; j++)
00329 t[j] = qq[j] + 0.5*k[0][j];
00330 der_func(t,dq);
00331 for (int j = 0; j < num_state_vars; j++)
00332 k[1][j] = step*dq[j];
00333
00334 for (int j = 0; j < num_state_vars; j++)
00335 t[j] = qq[j] + 0.25*(k[0][j]+k[1][j]);
00336 der_func(t,dq);
00337 for (int j = 0; j < num_state_vars; j++)
00338 k[2][j] = step*dq[j];
00339
00340 for (int j = 0; j < num_state_vars; j++)
00341 t[j] = qq[j] - k[1][j] + 2.0*k[2][j];
00342 der_func(t,dq);
00343 for (int j = 0; j < num_state_vars; j++)
00344 k[3][j] = step*dq[j];
00345
00346 for (int j = 0; j < num_state_vars; j++)
00347 t[j] = qq[j] + (7.0/27.0)*k[0][j] + (10.0/27.0)*k[1][j] + (1.0/27.0)*k[3][j];
00348 der_func(t,dq);
00349 for (int j = 0; j < num_state_vars; j++)
00350 k[4][j] = step*dq[j];
00351
00352 for (int j = 0; j < num_state_vars; j++)
00353 t[j] = qq[j] + (28.0/625.0)*k[0][j] - 0.2*k[1][j] + (546.0/625.0)*k[2][j]
00354 + (54.0/625.0)*k[3][j] - (378.0/625.0)*k[4][j];
00355 der_func(t,dq);
00356 for (int j = 0 ; j < num_state_vars; j++)
00357 k[5][j] = step*dq[j];
00358
00359 double err = 0.0;
00360 for (int j = 0; j < num_state_vars; j++)
00361 {
00362 qq[j] += (1.0/24.0)*k[0][j] + (5.0/48.0)*k[3][j] +
00363 (27.0/56.0)*k[4][j] + (125.0/336.0)*k[5][j];
00364 err = std::max(err,
00365 fabs(k[0][j]/8.0+2.0*k[2][j]/3.0+k[3][j]/16.0-27.0*k[4][j]/56.0
00366 -125.0*k[5][j]/336.0));
00367 }
00368 return err;
00369 }
00370
00371 }
00372
00373 #endif