99from equation .Term import Vertex , Term
1010from equation .closing import can_be_closed , replace_with_closures
1111from model_params .cmodel import CModel
12- from model_params .helpers import dynamically_relevant , Coupling , coupling_types
12+ from model_params .helpers import dynamically_relevant , Coupling , coupling_types , parse_coupling_descriptor
1313
1414t = sym .symbols ('t' )
1515
@@ -136,6 +136,8 @@ def add_terms(v: Vertex, term: Term, transition: tuple, model: CModel, neighbour
136136 neighbours_of_v_not_in_tuple = list (set (neighbours_of_v ) - set (copy .deepcopy (term ).node_list ()))
137137 neighbours_of_v_in_tuple = list (set (copy .deepcopy (term ).vertices ) - {v })
138138 terms = []
139+ _ , _ , _ , target_state = parse_coupling_descriptor (transition [1 ])
140+ target_state = target_state or v .state
139141 # Examples relate to vanilla SIR model
140142 if transition [0 ] == Coupling .NEIGHBOUR_ENTER :
141143 # E.G. I state requires something in S coming into contact with something in I
@@ -144,7 +146,7 @@ def add_terms(v: Vertex, term: Term, transition: tuple, model: CModel, neighbour
144146 elif transition [0 ] == Coupling .NEIGHBOUR_EXIT :
145147 # Converse of neighbour enter - S contacts I, transitions to I
146148 neighbour_exit (model , neighbours_of_v_in_tuple , neighbours_of_v_not_in_tuple , transition [2 ],
147- sym .symbols (f'{ transition [3 ]} ' ), copy .deepcopy (term ), terms , transition , v , transition [ 1 ][ - 1 ] )
149+ sym .symbols (f'{ transition [3 ]} ' ), copy .deepcopy (term ), terms , transition , v , target_state )
148150 elif transition [0 ] == Coupling .ISOLATED_ENTER :
149151 # E.G. I -> R without input of neighbours
150152 isolated_enter (transition [2 ], sym .symbols (f'{ transition [3 ]} ' ), copy .deepcopy (term ), terms , transition , v )
@@ -161,9 +163,10 @@ def find_neighbour_entries(model, neighbours_of_v_in_tuple, neighbours_of_v_not_
161163 # e.g. v is in state I, so change v to S and each neighbour in turn to I
162164 rate_of_transition = transition [2 ]
163165 symbol_for_rate_of_transition = sym .symbols (f'{ transition [3 ]} ' )
164- v_transitions_to = transition [1 ].split (':' )[1 ][- 1 ]
165- v_starts_as = transition [1 ].split (':' )[1 ][0 ] # the state v would be in before transitioning to current state
166- other_state_for_neighbours = transition [1 ].split ('*' )[1 ][0 ]
166+ _ , neighbour_state , source_state , target_state = parse_coupling_descriptor (transition [1 ])
167+ other_state_for_neighbours = neighbour_state or v .state
168+ v_starts_as = source_state or v .state # the state v would be in before transitioning to current state
169+ v_transitions_to = target_state or v .state
167170 # First, look at all external vertices that could lead to entering this state
168171 for n in neighbours_of_v_not_in_tuple :
169172 # Make sure new term contains all same terms as original
@@ -182,7 +185,8 @@ def find_neighbour_entries(model, neighbours_of_v_in_tuple, neighbours_of_v_not_
182185def neighbour_exit (model , neighbours_of_v_in_tuple , neighbours_of_v_not_in_tuple , rate_of_transition ,
183186 symbol_for_rate_of_transition , term_clone , terms , transition , v , v_transitions_to ):
184187 # e.g. v in state S so can exit if any neighbour in I
185- other_state_for_neighbours = transition [1 ].split ('*' )[1 ][0 ]
188+ _ , neighbour_state , _ , _ = parse_coupling_descriptor (transition [1 ])
189+ other_state_for_neighbours = neighbour_state or v .state
186190 for n in neighbours_of_v_not_in_tuple :
187191 vertices = set (term_clone .vertices ).union ({v , Vertex (other_state_for_neighbours , n )})
188192 term = Term (list (vertices )).function ()(sym .symbols ('t' ))
@@ -203,7 +207,8 @@ def isolated_exit(rate_of_transition, symbol_for_rate_of_transition, term_clone,
203207
204208def isolated_enter (rate_of_transition , symbol_for_rate_of_transition , term_clone , terms , transition , v ):
205209 # e.g. v in state R, gets there through recovery after being in I
206- other_state_for_v = transition [1 ].split (':' )[1 ][0 ]
210+ _ , _ , source_state , _ = parse_coupling_descriptor (transition [1 ])
211+ other_state_for_v = source_state or v .state
207212 vertices = set (term_clone .vertices ).union ({Vertex (other_state_for_v , v .node )})
208213 term = Term (list (vertices )).function ()(sym .symbols ('t' ))
209214 return append_term_to_terms (rate_of_transition , symbol_for_rate_of_transition , term , terms )
0 commit comments