diff --git a/Readme.md b/Readme.md index 79177e2..e10de90 100644 --- a/Readme.md +++ b/Readme.md @@ -12,8 +12,18 @@ A CNN based pytorch implementation on facial expression recognition (FER2013 and - sklearn (plot confusion matrix) ## Visualize for a test image by a pre-trained model ## -- Firstly, download the pre-trained model from https://drive.google.com/open?id=1Oy_9YmpkSKX1Q8jkOhJbz3Mc7qjyISzU (or https://pan.baidu.com/s/1w2TAWzaqh8YvT-1I6rojAg, key:skg2) and then put it in the "FER2013_VGG19" folder; Next, Put the test image (rename as 1.jpg) into the "images" folder, then -- python visualize.py +- Firstly, download the pre-trained model from https://drive.google.com/open?id=1Oy_9YmpkSKX1Q8jkOhJbz3Mc7qjyISzU (or https://pan.baidu.com/s/1w2TAWzaqh8YvT-1I6rojAg, key:skg2) and then put it in the "FER2013_VGG19" folder; Next, Put the test image (rename as 1.jpg) into the "images" folder, then if you use GPU + +``` +python visualize.py +``` + +or if you use CPU, + +``` +python visualize.py --cpu +``` + ## FER2013 Dataset ## - Dataset from https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data diff --git a/images/results/.gitkeep b/images/results/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/visualize.py b/visualize.py index 5133971..bd43aeb 100644 --- a/visualize.py +++ b/visualize.py @@ -10,12 +10,18 @@ import torch.nn.functional as F import os from torch.autograd import Variable +import argparse import transforms as transforms from skimage import io from skimage.transform import resize from models import * +parser = argparse.ArgumentParser(description='demo script') +parser.add_argument('--cpu', help="use this option if you run this script on cpu", action='store_true') +args = parser.parse_args() +print(args) + cut_size = 44 transform_test = transforms.Compose([ @@ -39,15 +45,23 @@ def rgb2gray(rgb): class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] net = VGG('VGG19') -checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7')) + +if args.cpu: + checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7'), map_location='cpu') +else: + checkpoint = torch.load(os.path.join('FER2013_VGG19', 'PrivateTest_model.t7')) net.load_state_dict(checkpoint['net']) -net.cuda() + net.eval() ncrops, c, h, w = np.shape(inputs) inputs = inputs.view(-1, c, h, w) -inputs = inputs.cuda() + +if not args.cpu: + net.cuda() + inputs = inputs.cuda() + inputs = Variable(inputs, volatile=True) outputs = net(inputs)