Skip to content

Commit f572e1d

Browse files
authored
Merge pull request #141 from BerkeleyLab/fix-restart
Fix training restarts
2 parents 194a893 + 6de177c commit f572e1d

File tree

7 files changed

+89
-65
lines changed

7 files changed

+89
-65
lines changed

cloud-microphysics/app/train-cloud-microphysics.f90

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,32 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
225225
real(rkind), parameter :: keep = 0.01
226226
real(rkind), allocatable :: cost(:)
227227
real(rkind), allocatable :: harvest(:)
228-
integer i, batch, lon, lat, level, time, network_unit, io_status, final_step, epoch
228+
integer i, batch, lon, lat, level, time, network_unit, io_status, epoch
229229
integer(int64) start_training, finish_training
230230

231231
open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')
232232

233+
if (.not. allocated(end_step)) end_step = t_end
234+
235+
print *,"Defining tensors from time step", start_step, "through", end_step, "with strides of", stride
236+
237+
! The following temporary copies are required by gfortran bug 100650 and possibly 49324
238+
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
239+
inputs = [( [( [( [( &
240+
tensor_t( &
241+
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
242+
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
243+
] &
244+
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]
245+
246+
outputs = [( [( [( [( &
247+
tensor_t( &
248+
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
249+
dqs_dt(lon,lat,level,time) &
250+
] &
251+
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]
252+
253+
read_or_initialize_engine: &
233254
if (io_status==0) then
234255
print *,"Reading network from file " // network_file
235256
trainable_engine = trainable_engine_t(inference_engine_t(file_t(string_t(network_file))))
@@ -240,68 +261,48 @@ subroutine read_train_write(training_configuration, base_name, plot_unit, previo
240261
initialize_network: &
241262
block
242263
character(len=len('YYYYMMDD')) date
243-
type(tensor_range_t) input_range, output_range
244264

245265
call date_and_time(date)
266+
246267
print *,"Calculating input tensor component ranges."
247-
input_range = tensor_range_t( &
268+
associate(input_range => tensor_range_t( &
248269
layer = "inputs", &
249270
minima = [minval(pressure_in), minval(potential_temperature_in), minval(temperature_in), &
250271
minval(qv_in), minval(qc_in), minval(qr_in), minval(qs_in)], &
251272
maxima = [maxval(pressure_in), maxval(potential_temperature_in), maxval(temperature_in), &
252273
maxval(qv_in), maxval(qc_in), maxval(qr_in), maxval(qs_in)] &
253-
)
254-
print *,"Calculating output tensor component ranges."
255-
output_range = tensor_range_t( &
256-
layer = "outputs", &
257-
minima = [minval(dpt_dt), minval(dqv_dt), minval(dqc_dt), minval(dqr_dt), minval(dqs_dt)], &
258-
maxima = [maxval(dpt_dt), maxval(dqv_dt), maxval(dqc_dt), maxval(dqr_dt), maxval(dqs_dt)] &
259-
)
260-
print *,"Initializing a new network"
261-
262-
associate(activation => training_configuration%differentiable_activation_strategy())
263-
associate( &
264-
model_name => string_t("Simple microphysics"), &
265-
author => string_t("Inference Engine"), &
266-
date_string => string_t(date), &
267-
activation_name => activation%function_name(), &
268-
residual_network => string_t(trim(merge("true ", "false", training_configuration%skip_connections()))) &
269-
)
270-
trainable_engine = trainable_engine_t( &
271-
training_configuration, perturbation_magnitude=0.05, &
272-
metadata = [model_name, author, date_string, activation_name, residual_network], &
273-
input_range = input_range, output_range = output_range &
274-
)
274+
))
275+
print *,"Calculating output tensor component ranges."
276+
associate(output_range => tensor_range_t( &
277+
layer = "outputs", &
278+
minima = [minval(dpt_dt), minval(dqv_dt), minval(dqc_dt), minval(dqr_dt), minval(dqs_dt)], &
279+
maxima = [maxval(dpt_dt), maxval(dqv_dt), maxval(dqc_dt), maxval(dqr_dt), maxval(dqs_dt)] &
280+
))
281+
associate(activation => training_configuration%differentiable_activation_strategy())
282+
associate( &
283+
model_name => string_t("Simple microphysics"), &
284+
author => string_t("Inference Engine"), &
285+
date_string => string_t(date), &
286+
activation_name => activation%function_name(), &
287+
residual_network => string_t(trim(merge("true ", "false", training_configuration%skip_connections()))) &
288+
)
289+
trainable_engine = trainable_engine_t( &
290+
training_configuration, perturbation_magnitude=0.05, &
291+
metadata = [model_name, author, date_string, activation_name, residual_network], &
292+
input_range = input_range, output_range = output_range &
293+
)
294+
end associate
295+
end associate
275296
end associate
276297
end associate
277298
end block initialize_network
278-
end if
279-
280-
if (.not. allocated(end_step)) end_step = t_end
281-
282-
print *,"Defining tensors from time step", start_step, "through", end_step, "with strides of", stride
283-
284-
! The following temporary copies are required by gfortran bug 100650 and possibly 49324
285-
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
286-
inputs = [( [( [( [( &
287-
tensor_t( &
288-
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time), &
289-
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qr_in(lon,lat,level,time), qs_in(lon,lat,level,time) &
290-
] &
291-
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]
292-
293-
outputs = [( [( [( [( &
294-
tensor_t( &
295-
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), &
296-
dqs_dt(lon,lat,level,time) &
297-
] &
298-
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = start_step, end_step, stride)]
299+
end if read_or_initialize_engine
299300

300301
print *,"Normalizing input tensors"
301-
inputs = input_range%map_to_training_range(inputs)
302+
inputs = trainable_engine%map_to_input_training_range(inputs)
302303

303304
print *,"Normalizing output tensors"
304-
outputs = output_range%map_to_training_range(outputs)
305+
outputs = trainable_engine%map_to_output_training_range(outputs)
305306

306307
print *, "Eliminating",int(100*(1.-keep)),"% of the grid points that have all-zero time derivatives"
307308

src/inference_engine/inference_engine_m_.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ module inference_engine_m_
4343
end type
4444

4545
type exchange_t
46+
type(tensor_range_t) input_range_, output_range_
4647
type(string_t) metadata_(size(key))
4748
real(rkind), allocatable :: weights_(:,:,:), biases_(:,:)
4849
integer, allocatable :: nodes_(:)

src/inference_engine/inference_engine_s.F90 renamed to src/inference_engine/inference_engine_s.f90

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
use sourcery_formats_m, only : separated_values
1212
implicit none
1313

14-
#ifndef NO_EXTRAPOLATION
15-
#define NO_EXTRAPOLATION .false.
16-
#endif
17-
1814
interface assert_consistency
1915
procedure inference_engine_consistency
2016
procedure difference_consistency
@@ -23,6 +19,8 @@
2319
contains
2420

2521
module procedure to_exchange
22+
exchange%input_range_ = self%input_range_
23+
exchange%output_range_ = self%output_range_
2624
exchange%metadata_ = self%metadata_
2725
exchange%weights_ = self%weights_
2826
exchange%biases_ = self%biases_
@@ -38,8 +36,6 @@
3836

3937
call assert_consistency(self)
4038

41-
if (NO_EXTRAPOLATION) call assert(self%input_range_%in_range(inputs), "inference_engine_s(infer): inputs in range")
42-
4339
associate(w => self%weights_, b => self%biases_, n => self%nodes_, output_layer => ubound(self%nodes_,1))
4440

4541
allocate(a(maxval(n), input_layer:output_layer))
@@ -61,8 +57,6 @@
6157

6258
end associate
6359

64-
if (NO_EXTRAPOLATION) call assert(self%output_range_%in_range(outputs), "inference_engine_s(infer): outputs in range")
65-
6660
end procedure
6761

6862
pure subroutine inference_engine_consistency(self)
@@ -228,7 +222,7 @@ pure subroutine set_activation_strategy(inference_engine)
228222
end associate
229223
end block
230224

231-
inference_engine = hidden_layers%inference_engine(metadata, output_layer)
225+
inference_engine = hidden_layers%inference_engine(metadata, output_layer, input_range, output_range)
232226

233227
call set_activation_strategy(inference_engine)
234228
call assert_consistency(inference_engine)

src/inference_engine/layer_m.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module layer_m
55
use sourcery_string_m, only : string_t
66
use kind_parameters_m, only : rkind
77
use inference_engine_m_, only : inference_engine_t
8+
use tensor_range_m, only : tensor_range_t
89
implicit none
910

1011
private
@@ -39,11 +40,12 @@ recursive module function construct_layer(layer_lines, start) result(layer)
3940

4041
interface
4142

42-
module function inference_engine(hidden_layers, metadata, output_layer) result(inference_engine_)
43+
module function inference_engine(hidden_layers, metadata, output_layer, input_range, output_range) result(inference_engine_)
4344
implicit none
4445
class(layer_t), intent(in), target :: hidden_layers
4546
type(layer_t), intent(in), target :: output_layer
4647
type(string_t), intent(in) :: metadata(:)
48+
type(tensor_range_t), intent(in) :: input_range, output_range
4749
type(inference_engine_t) inference_engine_
4850
end function
4951

src/inference_engine/layer_s.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102

103103
end do loop_over_output_neurons
104104

105-
inference_engine_ = inference_engine_t(metadata, weights, biases, nodes)
105+
inference_engine_ = inference_engine_t(metadata, weights, biases, nodes, input_range, output_range)
106106
end block
107107
end associate
108108
end associate

src/inference_engine/trainable_engine_m.F90

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ module trainable_engine_m
3232
procedure :: num_inputs
3333
procedure :: num_outputs
3434
procedure :: to_inference_engine
35-
procedure :: map_to_training_range
36-
procedure :: map_from_training_range
35+
procedure :: map_to_input_training_range
36+
procedure :: map_from_input_training_range
37+
procedure :: map_to_output_training_range
38+
procedure :: map_from_output_training_range
3739
end type
3840

3941
integer, parameter :: input_layer = 0
@@ -123,14 +125,28 @@ module function to_inference_engine(self) result(inference_engine)
123125
type(inference_engine_t) :: inference_engine
124126
end function
125127

126-
elemental module function map_to_training_range(self, tensor) result(normalized_tensor)
128+
elemental module function map_to_input_training_range(self, tensor) result(normalized_tensor)
127129
implicit none
128130
class(trainable_engine_t), intent(in) :: self
129131
type(tensor_t), intent(in) :: tensor
130132
type(tensor_t) normalized_tensor
131133
end function
132134

133-
elemental module function map_from_training_range(self, tensor) result(unnormalized_tensor)
135+
elemental module function map_from_input_training_range(self, tensor) result(unnormalized_tensor)
136+
implicit none
137+
class(trainable_engine_t), intent(in) :: self
138+
type(tensor_t), intent(in) :: tensor
139+
type(tensor_t) unnormalized_tensor
140+
end function
141+
142+
elemental module function map_to_output_training_range(self, tensor) result(normalized_tensor)
143+
implicit none
144+
class(trainable_engine_t), intent(in) :: self
145+
type(tensor_t), intent(in) :: tensor
146+
type(tensor_t) normalized_tensor
147+
end function
148+
149+
elemental module function map_from_output_training_range(self, tensor) result(unnormalized_tensor)
134150
implicit none
135151
class(trainable_engine_t), intent(in) :: self
136152
type(tensor_t), intent(in) :: tensor

src/inference_engine/trainable_engine_s.F90

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
module procedure construct_from_inference_engine
2626

2727
associate(exchange => inference_engine%to_exchange())
28+
trainable_engine%input_range_ = exchange%input_range_
29+
trainable_engine%output_range_ = exchange%output_range_
2830
trainable_engine%metadata_ = exchange%metadata_
2931
trainable_engine%w = exchange%weights_
3032
trainable_engine%b = exchange%biases_
@@ -304,11 +306,19 @@ pure function e(j,n) result(unit_vector)
304306

305307
end procedure
306308

307-
module procedure map_to_training_range
309+
module procedure map_to_input_training_range
308310
normalized_tensor = self%input_range_%map_to_training_range(tensor)
309311
end procedure
310312

311-
module procedure map_from_training_range
313+
module procedure map_from_input_training_range
314+
unnormalized_tensor = self%input_range_%map_from_training_range(tensor)
315+
end procedure
316+
317+
module procedure map_to_output_training_range
318+
normalized_tensor = self%output_range_%map_to_training_range(tensor)
319+
end procedure
320+
321+
module procedure map_from_output_training_range
312322
unnormalized_tensor = self%output_range_%map_from_training_range(tensor)
313323
end procedure
314324

0 commit comments

Comments
 (0)