Skip to content

Commit 1a90529

Browse files
Merge pull request #57 from renan-siqueira/develop
Merge develop into main
2 parents 475eb64 + ba9ad9f commit 1a90529

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/app/training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def train_model(
8585
lambda_gp = 10 # Default value to WGAN-GP
8686

8787
fixed_noise = Variable(torch.randn(sample_size, z_dim, 1, 1)).to(device)
88-
88+
fid_score = 'Unavailable'
89+
8990
for epoch in range(last_epoch, num_epochs + 1):
9091
start_time = time.time()
9192

@@ -118,8 +119,9 @@ def train_model(
118119
fake_images = generator(z)
119120
real_images = images
120121

121-
# FID Test
122-
fid_score = calculate_fid(real_images, fake_images, inception_model)
122+
if inception_model:
123+
# FID Test
124+
fid_score = calculate_fid(real_images, fake_images, inception_model)
123125

124126
outputs = discriminator(fake_images).squeeze()
125127
g_loss = -torch.mean(outputs)

src/modules/run_training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ def main(params, path_data, path_dataset, path_train_params):
2323

2424
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2525

26-
# Frechet Inception Distance (FID)
27-
inception_model = models.inception_v3(weights='Inception_V3_Weights.DEFAULT', transform_input=False, init_weights=False).to(device)
28-
inception_model = inception_model.eval()
26+
inception_model = None
27+
if params['image_size'] == 128:
28+
# Frechet Inception Distance (FID)
29+
inception_model = models.inception_v3(weights='Inception_V3_Weights.DEFAULT', transform_input=False, init_weights=False).to(device)
30+
inception_model = inception_model.eval()
2931

3032
generator = Generator(params["z_dim"], params["channels_img"], params["features_g"], img_size=params['image_size']).to(device)
3133
generator.apply(weights_init)

0 commit comments

Comments
 (0)