adevs
|
00001 /*************** 00002 Copyright (C) 2000-2006 by James Nutaro 00003 00004 This library is free software; you can redistribute it and/or 00005 modify it under the terms of the GNU Lesser General Public 00006 License as published by the Free Software Foundation; either 00007 version 2 of the License, or (at your option) any later version. 00008 00009 This library is distributed in the hope that it will be useful, 00010 but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00012 Lesser General Public License for more details. 00013 00014 You should have received a copy of the GNU Lesser General Public 00015 License along with this library; if not, write to the Free Software 00016 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 00017 00018 Bugs, comments, and questions can be sent to nutaro@gmail.com 00019 ***************/ 00020 #ifndef _adevs_rk45_improved_h_ 00021 #define _adevs_rk45_improved_h_ 00022 #include "adevs_dess.h" 00023 #include <cmath> 00024 #include <algorithm> 00025 #include <iostream> 00026 #include <cassert> 00027 00028 namespace adevs 00029 { 00030 00052 template <class X> class rk45_improved: public DESS<X> 00053 { 00054 public: 00062 rk45_improved(int num_state_vars, double h_max, double err_tol, 00063 int zero_crossing_funcs, double event_tol = 1E-12); 00067 void init(int i, double q0) { q[i] = q0; } 00071 const double* getStateVars() const { return q; } 00075 int getNumStateVars() const { return num_state_vars; } 00080 virtual void der_func(const double* q, double* dq) = 0; 00087 virtual void state_event_func(const double* q, double* z) = 0; 00092 virtual double time_event_func(const double* q) = 0; 00107 virtual void discrete_action(double* q, const Bag<X>& xb, const bool* event_flags) = 0; 00112 virtual void discrete_output(const double* q, Bag<X>& yb, const bool* event_flags) = 0; 00121 virtual void state_changed(const double* q){} 00122 // Implementation of the DESS evolve_func method 00123 void evolve_func(double h); 00124 // Implementation of the DESS next_event_func method 00125 double next_event_func(bool& is_event); 00126 // Implementation of the DESS discrete_action_func method 00127 void discrete_action_func(const Bag<X>& xb); 00128 // Implementation of the DESS dscrete_output_func method 00129 void discrete_output_func(Bag<X>& yb); 00130 // Implementation of the DESS state changed method 00131 void state_changed() { state_changed(q); } 00133 ~rk45_improved(); 00134 00135 private: 00136 // Maximum step size, trunc. err. tolerance, and 00137 // state event detection time tolerance 00138 const double h_max, err_tol, event_tol; 00139 // Scratch variables to implement the integration 00140 // and event detection scheme 00141 double *q, *dq, *t, *k[6], *q_tmp, *es, *en; 00142 // Event indicator flags 00143 bool* event_indicator; 00144 // Current time step selection 00145 double h_cur; 00146 // Does q_tmp hold the ODE solution at h_cur? 00147 bool keep_q_tmp; 00148 // Number of state variables and level crossing functions 00149 const int num_state_vars, zero_funcs; 00150 /* 00151 * This makes an appropriate sized step and stores the 00152 * result in qq. The method returns truncation error 00153 * estimate. 00154 */ 00155 double ode_step(double *qq, double step); 00156 }; 00157 00158 template <class X> 00159 rk45_improved<X>::rk45_improved(int num_state_vars, double h_max, double err_tol, int zero_funcs, double event_tol): 00160 DESS<X>(), 00161 h_max(h_max), 00162 err_tol(err_tol), 00163 event_tol(event_tol), 00164 h_cur(h_max), 00165 keep_q_tmp(false), 00166 num_state_vars(num_state_vars), 00167 zero_funcs(zero_funcs) 00168 { 00169 q = new double[num_state_vars]; 00170 dq = new double[num_state_vars]; 00171 t = new double[num_state_vars]; 00172 q_tmp = new double[num_state_vars]; 00173 for (int i = 0; i < 6; i++) 00174 k[i] = new double[num_state_vars]; 00175 en = new double[zero_funcs]; 00176 es = new double[zero_funcs]; 00177 event_indicator = new bool[zero_funcs+1]; 00178 } 00179 00180 template <class X> 00181 rk45_improved<X>::~rk45_improved() 00182 { 00183 delete [] q; 00184 delete [] dq; 00185 delete [] t; 00186 delete [] q_tmp; 00187 for (int i = 0; i < 6; i++) 00188 delete [] k[i]; 00189 delete [] es; 00190 delete [] en; 00191 delete [] event_indicator; 00192 } 00193 00194 template <class X> 00195 void rk45_improved<X>::evolve_func(double h) 00196 { 00197 // If this is an internal event 00198 if (h == h_cur) 00199 { 00200 // if q_tmp is ok, then just copy q_tmp to q 00201 if (keep_q_tmp) 00202 { 00203 for (int i = 0; i < num_state_vars; i++) 00204 q[i] = q_tmp[i]; 00205 } 00206 // otherwise advance the solution 00207 else ode_step(q,h); 00208 } 00209 // If this is an external event, then check for state events and 00210 // advance the solution 00211 else 00212 { 00213 // This is not a time event 00214 event_indicator[zero_funcs] = false; 00215 // Calculate the state event function at q 00216 state_event_func(q,es); 00217 // Advance the solution 00218 ode_step(q,h); 00219 // Calculate the state event function now 00220 state_event_func(q,en); 00221 // Look for zero crossings in the internval 00222 for (int i = 0; i < zero_funcs; i++) 00223 { 00224 event_indicator[i] = (es[i]*en[i] < 0.0) || (fabs(en[i]) <= event_tol); 00225 } 00226 } 00227 // Invalidate q_tmp 00228 keep_q_tmp = false; 00229 } 00230 00231 template <class X> 00232 double rk45_improved<X>::next_event_func(bool& is_event) 00233 { 00234 // q_tmp will hold the solution at h_cur when this method returns 00235 keep_q_tmp = true; 00236 // Get the next time event 00237 double time_event = time_event_func(q); 00238 /* Look for the largest allowable integration step */ 00239 h_cur *= 1.2; // Try a larger step size this time through 00240 if (h_cur > h_max) h_cur = h_max; // Limit to h_max 00241 if (h_cur > time_event) h_cur = time_event; // Limit to the next event time 00242 for ( ; ; ) 00243 { 00244 // Advance the solution by h_cur 00245 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i]; 00246 double err = ode_step(q_tmp,h_cur); 00247 if (err <= err_tol) break; // Check the error, exit if ok 00248 // Reduce the step size otherwise 00249 double h_next = 0.8*pow(err_tol*h_cur*h_cur*h_cur*h_cur*h_cur,0.2)/fabs(err); 00250 if (h_next >= h_cur) h_next = 0.8*h_cur; 00251 h_cur = h_next; 00252 } 00253 // q_tmp now stores the state variables at the end of the time step 00254 // and q has the variables at the start of the time step. Now we 00255 // look for state events. The array es contains the state event 00256 // functions at the start of the interval. 00257 state_event_func(q,es); 00258 // Look for the next zero crossing 00259 while(true) 00260 { 00261 // Compute the state event values at the end of the interval 00262 state_event_func(q_tmp,en); 00263 // Look for a zero-crossing 00264 bool found_state_event = false; 00265 double h_next = h_cur; 00266 for (int i = 0; i < zero_funcs; i++) 00267 { 00268 bool sign_change = (es[i]*en[i] < 0.0); 00269 bool tolerance_met = event_indicator[i] = (fabs(en[i]) <= event_tol); 00270 if (tolerance_met) found_state_event = true; 00271 // Estimate the time to cross zero and remember the 00272 // smallest such time (if we actual found an event, then h_cur 00273 // is the crossing time). 00274 if (sign_change && !tolerance_met) 00275 { 00276 double t_cross = h_cur*es[i]/(es[i]-en[i]); 00277 assert(t_cross >= 0.0); 00278 if (t_cross < h_next) h_next = t_cross; 00279 } 00280 } 00281 // If the next event time is equal to the current event time, then 00282 // we found the next event crossing or there wasn't one. 00283 if (h_next == h_cur) 00284 { 00285 // Is this a time event? 00286 event_indicator[zero_funcs] = (h_next >= time_event); 00287 // Are there any state or time events? 00288 is_event = found_state_event || event_indicator[zero_funcs]; 00289 // Done, return the next event time 00290 return h_cur; 00291 } 00292 // If a crossing was found adjust the time step and try again 00293 assert(h_next < h_cur); 00294 h_cur = h_next; 00295 for (int i = 0; i < num_state_vars; i++) q_tmp[i] = q[i]; 00296 ode_step(q_tmp,h_cur); 00297 } 00298 } 00299 00300 template <class X> 00301 void rk45_improved<X>::discrete_action_func(const Bag<X>& xb) 00302 { 00303 // Reset the integrator step size and invalidate q_tmp 00304 h_cur = h_max; 00305 keep_q_tmp = false; 00306 // Compute the discrete action 00307 discrete_action(q,xb,event_indicator); 00308 } 00309 00310 template <class X> 00311 void rk45_improved<X>::discrete_output_func(Bag<X>& yb) 00312 { 00313 discrete_output(q,yb,event_indicator); 00314 } 00315 00316 template <class X> 00317 double rk45_improved<X>::ode_step(double*qq, double step) 00318 { 00319 if (step == 0.0) 00320 { 00321 return 0.0; 00322 } 00323 // Compute k1 00324 der_func(qq,dq); 00325 for (int j = 0; j < num_state_vars; j++) 00326 k[0][j] = step*dq[j]; 00327 // Compute k2 00328 for (int j = 0; j < num_state_vars; j++) 00329 t[j] = qq[j] + 0.5*k[0][j]; 00330 der_func(t,dq); 00331 for (int j = 0; j < num_state_vars; j++) 00332 k[1][j] = step*dq[j]; 00333 // Compute k3 00334 for (int j = 0; j < num_state_vars; j++) 00335 t[j] = qq[j] + 0.25*(k[0][j]+k[1][j]); 00336 der_func(t,dq); 00337 for (int j = 0; j < num_state_vars; j++) 00338 k[2][j] = step*dq[j]; 00339 // Compute k4 00340 for (int j = 0; j < num_state_vars; j++) 00341 t[j] = qq[j] - k[1][j] + 2.0*k[2][j]; 00342 der_func(t,dq); 00343 for (int j = 0; j < num_state_vars; j++) 00344 k[3][j] = step*dq[j]; 00345 // Compute k5 00346 for (int j = 0; j < num_state_vars; j++) 00347 t[j] = qq[j] + (7.0/27.0)*k[0][j] + (10.0/27.0)*k[1][j] + (1.0/27.0)*k[3][j]; 00348 der_func(t,dq); 00349 for (int j = 0; j < num_state_vars; j++) 00350 k[4][j] = step*dq[j]; 00351 // Compute k6 00352 for (int j = 0; j < num_state_vars; j++) 00353 t[j] = qq[j] + (28.0/625.0)*k[0][j] - 0.2*k[1][j] + (546.0/625.0)*k[2][j] 00354 + (54.0/625.0)*k[3][j] - (378.0/625.0)*k[4][j]; 00355 der_func(t,dq); 00356 for (int j = 0 ; j < num_state_vars; j++) 00357 k[5][j] = step*dq[j]; 00358 // Compute next state and maximum approx. error 00359 double err = 0.0; 00360 for (int j = 0; j < num_state_vars; j++) 00361 { 00362 qq[j] += (1.0/24.0)*k[0][j] + (5.0/48.0)*k[3][j] + 00363 (27.0/56.0)*k[4][j] + (125.0/336.0)*k[5][j]; 00364 err = std::max(err, 00365 fabs(k[0][j]/8.0+2.0*k[2][j]/3.0+k[3][j]/16.0-27.0*k[4][j]/56.0 00366 -125.0*k[5][j]/336.0)); 00367 } 00368 return err; 00369 } 00370 00371 } // end of namespace 00372 00373 #endif