@@ -225,11 +225,32 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
225225 real (rkind), parameter :: keep = 0.01
226226 real (rkind), allocatable :: cost(:)
227227 real (rkind), allocatable :: harvest(:)
228- integer i, batch, lon, lat, level, time, network_unit, io_status, final_step, epoch
228+ integer i, batch, lon, lat, level, time, network_unit, io_status, epoch
229229 integer (int64) start_training, finish_training
230230
231231 open (newunit= network_unit, file= network_file, form= ' formatted' , status= ' old' , iostat= io_status, action= ' read' )
232232
233+ if (.not. allocated (end_step)) end_step = t_end
234+
235+ print * ," Defining tensors from time step" , start_step, " through" , end_step, " with strides of" , stride
236+
237+ ! The following temporary copies are required by gfortran bug 100650 and possibly 49324
238+ ! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
239+ inputs = [( [( [( [( &
240+ tensor_t( &
241+ [ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
242+ qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
243+ ] &
244+ ), lon = 1 , size (qv_in,1 ))], lat = 1 , size (qv_in,2 ))], level = 1 , size (qv_in,3 ))], time = start_step, end_step, stride)]
245+
246+ outputs = [( [( [( [( &
247+ tensor_t( &
248+ [dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
249+ dqs_dt(lon,lat,level,time) &
250+ ] &
251+ ), lon = 1 , size (qv_in,1 ))], lat = 1 , size (qv_in,2 ))], level = 1 , size (qv_in,3 ))], time = start_step, end_step, stride)]
252+
253+ read_or_initialize_engine: &
233254 if (io_status== 0 ) then
234255 print * ," Reading network from file " // network_file
235256 trainable_engine = trainable_engine_t(inference_engine_t(file_t(string_t(network_file))))
@@ -240,68 +261,48 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
240261 initialize_network: &
241262 block
242263 character (len= len (' YYYYMMDD' )) date
243- type (tensor_range_t) input_range, output_range
244264
245265 call date_and_time (date)
266+
246267 print * ," Calculating input tensor component ranges."
247- input_range = tensor_range_t( &
268+ associate( input_range = > tensor_range_t( &
248269 layer = " inputs" , &
249270 minima = [minval (pressure_in), minval (potential_temperature_in), minval (temperature_in), &
250271 minval (qv_in), minval (qc_in), minval (qr_in), minval (qs_in)], &
251272 maxima = [maxval (pressure_in), maxval (potential_temperature_in), maxval (temperature_in), &
252273 maxval (qv_in), maxval (qc_in), maxval (qr_in), maxval (qs_in)] &
253- )
254- print * ," Calculating output tensor component ranges."
255- output_range = tensor_range_t( &
256- layer = " outputs" , &
257- minima = [minval (dpt_dt), minval (dqv_dt), minval (dqc_dt), minval (dqr_dt), minval (dqs_dt)], &
258- maxima = [maxval (dpt_dt), maxval (dqv_dt), maxval (dqc_dt), maxval (dqr_dt), maxval (dqs_dt)] &
259- )
260- print * , " Initializing a new network "
261-
262- associate(activation = > training_configuration % differentiable_activation_strategy())
263- associate( &
264- model_name = > string_t(" Simple microphysics " ), &
265- author = > string_t( " Inference Engine " ), &
266- date_string = > string_t(date), &
267- activation_name = > activation % function_name(), &
268- residual_network = > string_t( trim ( merge ( " true " , " false " , training_configuration % skip_connections()))) &
269- )
270- trainable_engine = trainable_engine_t( &
271- training_configuration, perturbation_magnitude = 0.05 , &
272- metadata = [model_name, author, date_string, activation_name, residual_network], &
273- input_range = input_range, output_range = output_range &
274- )
274+ ))
275+ print * ," Calculating output tensor component ranges."
276+ associate( output_range = > tensor_range_t( &
277+ layer = " outputs" , &
278+ minima = [minval (dpt_dt), minval (dqv_dt), minval (dqc_dt), minval (dqr_dt), minval (dqs_dt)], &
279+ maxima = [maxval (dpt_dt), maxval (dqv_dt), maxval (dqc_dt), maxval (dqr_dt), maxval (dqs_dt)] &
280+ ) )
281+ associate(activation = > training_configuration % differentiable_activation_strategy())
282+ associate( &
283+ model_name = > string_t( " Simple microphysics " ), &
284+ author = > string_t( " Inference Engine " ), &
285+ date_string = > string_t(date ), &
286+ activation_name = > activation % function_name( ), &
287+ residual_network = > string_t(trim ( merge ( " true " , " false " , training_configuration % skip_connections()))) &
288+ )
289+ trainable_engine = trainable_engine_t( &
290+ training_configuration, perturbation_magnitude = 0.05 , &
291+ metadata = [model_name, author, date_string, activation_name, residual_network], &
292+ input_range = input_range, output_range = output_range &
293+ )
294+ end associate
295+ end associate
275296 end associate
276297 end associate
277298 end block initialize_network
278- end if
279-
280- if (.not. allocated (end_step)) end_step = t_end
281-
282- print * ," Defining tensors from time step" , start_step, " through" , end_step, " with strides of" , stride
283-
284- ! The following temporary copies are required by gfortran bug 100650 and possibly 49324
285- ! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
286- inputs = [( [( [( [( &
287- tensor_t( &
288- [ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
289- qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
290- ] &
291- ), lon = 1 , size (qv_in,1 ))], lat = 1 , size (qv_in,2 ))], level = 1 , size (qv_in,3 ))], time = start_step, end_step, stride)]
292-
293- outputs = [( [( [( [( &
294- tensor_t( &
295- [dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
296- dqs_dt(lon,lat,level,time) &
297- ] &
298- ), lon = 1 , size (qv_in,1 ))], lat = 1 , size (qv_in,2 ))], level = 1 , size (qv_in,3 ))], time = start_step, end_step, stride)]
299+ end if read_or_initialize_engine
299300
300301 print * ," Normalizing input tensors"
301- inputs = input_range % map_to_training_range (inputs)
302+ inputs = trainable_engine % map_to_input_training_range (inputs)
302303
303304 print * ," Normalizing output tensors"
304- outputs = output_range % map_to_training_range (outputs)
305+ outputs = trainable_engine % map_to_output_training_range (outputs)
305306
306307 print * , " Eliminating" ,int (100 * (1 .- keep))," % of the grid points that have all-zero time derivatives"
307308
0 commit comments