-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Hi, I was trying to inject faults with TensorFI2 in different pre-trained models in Keras with the ImageNet dataset. I was able to successfully inject faults with VGG16, VGG19, and MobileNet. However, when injecting faults in Resnet50, Resnet101, Resnet152, NasnetMobile, etc, I got the following error.
Traceback (most recent call last):
File "resnet-imagenet.py", line 42, in <module>
res = tfi.inject(model=model, x_test=image, confFile=conf)
File "/home/sabuj/TensorFI2/src/tensorfi2.py", line 42, in __init__
fiFunc(model, fiConf, **kwargs)
File "/home/sabuj/TensorFI2/src/tensorfi2.py", line 223, in layer_outputs
pred = get_pred([fiLayerOutputs])
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/keras/backend.py", line 3792, in __call__
outputs = self._graph_fn(*converted_inputs)
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1605, in __call__
return self._call_impl(args, kwargs)
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1645, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 598, in call
ctx=ctx)
File "/home/sabuj/anaconda3/envs/tfi2/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]
[[node input_1 (defined at /home/sabuj/TensorFI2/src/tensorfi2.py:222) ]] [Op:__inference_keras_scratch_graph_10471]
The problem is, this error is not generating every time. When the layer number for fault injection is low, there is no error. For example, when I injected faults in the first 6 layers for Resnet50, there was no error(I changed random layer number to fixed layer number in tensorfi2 implementation for finding the reason). When injecting faults in the later layers, most of the time, it throws errors.
I ran the experiments/layer_outputs/resnet-imagenet.py file for my testing. However, I simplified it for ease of my testing. I am adding the code snippet here.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import random
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.resnet import preprocess_input, decode_predictions
import time, sys, math
from src import tensorfi2 as tfi
model = tf.keras.applications.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)
model.compile(optimizer='sgd', loss='categorical_crossentropy')
#model.save_weights('h5/vgg16-trained.h5')
path = 'dog.jpg'
image = load_img(path, target_size=(224, 224))
image = img_to_array(image)
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
image = preprocess_input(image)
out = model.predict(image)
label = decode_predictions(out)
label = label[0][0]
print(label)
conf = 'confFiles/sample.yaml'
numFaults = 10
for i in range(numFaults):
# model.load_weights('h5/resnet-trained.h5')
image = load_img(path, target_size=(224, 224))
image = img_to_array(image)
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
image = preprocess_input(image)
res = tfi.inject(model=model, x_test=image, confFile=conf)
Here 'dog.jpg' is a random file, that I used for testing. The confFiles/sample.yml is added below.
Target: layer_outputs
Mode: single
Type: bitflips
Amount: 1
Bit: N
I tried this both in windows and ubuntu 20.04. I used both conda and pip environments for my experiment. I found the same issue in all the cases. I am adding my env file for reference.
# packages in environment at /home/sabuj/anaconda3/envs/tfi2:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 4.5 1_gnu
_tflow_select 2.1.0 gpu
absl-py 0.15.0 pyhd3eb1b0_0
aiohttp 3.8.1 py37h7f8727e_0
aiosignal 1.2.0 pyhd3eb1b0_0
astor 0.8.1 py37h06a4308_0
astunparse 1.6.3 py_0
async-timeout 4.0.1 pyhd3eb1b0_0
asynctest 0.13.0 py_0
attrs 21.4.0 pyhd3eb1b0_0
blas 1.0 mkl
blinker 1.4 py37h06a4308_0
brotlipy 0.7.0 py37h27cfd23_1003
c-ares 1.18.1 h7f8727e_0
ca-certificates 2021.10.26 h06a4308_2
cachetools 4.2.2 pyhd3eb1b0_0
certifi 2021.10.8 py37h06a4308_2
cffi 1.15.0 py37hd667e15_1
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.0.3 pyhd3eb1b0_0
cryptography 3.4.8 py37hd23ed53_0
cudatoolkit 10.1.243 h6bb024c_0
cudnn 7.6.5 cuda10.1_0
cupti 10.1.168 0
dataclasses 0.8 pyh6d0b6a4_7
frozenlist 1.2.0 py37h7f8727e_0
gast 0.3.3 py_0
google-auth 1.33.0 pyhd3eb1b0_0
google-auth-oauthlib 0.4.4 pyhd3eb1b0_0
google-pasta 0.2.0 pyhd3eb1b0_0
grpcio 1.42.0 py37hce63b2e_0
h5py 2.10.0 py37hd6299e0_1
hdf5 1.10.6 hb1b8bf9_0
idna 3.3 pyhd3eb1b0_0
importlib-metadata 4.8.2 py37h06a4308_0
intel-openmp 2021.4.0 h06a4308_3561
keras-preprocessing 1.1.2 pyhd3eb1b0_0
ld_impl_linux-64 2.35.1 h7274673_9
libffi 3.3 he6710b0_2
libgcc-ng 9.3.0 h5101ec6_17
libgfortran-ng 7.5.0 ha8ba4b0_17
libgfortran4 7.5.0 ha8ba4b0_17
libgomp 9.3.0 h5101ec6_17
libprotobuf 3.19.1 h4ff587b_0
libstdcxx-ng 9.3.0 hd4cf53a_17
markdown 3.3.4 py37h06a4308_0
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py37h7f8727e_0
mkl_fft 1.3.1 py37hd3c417c_0
mkl_random 1.2.2 py37h51133e4_0
multidict 5.2.0 py37h7f8727e_2
ncurses 6.3 h7f8727e_2
numpy 1.21.2 py37h20f2e39_0
numpy-base 1.21.2 py37h79a1101_0
oauthlib 3.1.1 pyhd3eb1b0_0
openssl 1.1.1m h7f8727e_0
opt_einsum 3.3.0 pyhd3eb1b0_1
pillow 9.0.1 pypi_0 pypi
pip 21.2.2 py37h06a4308_0
protobuf 3.19.1 py37h295c915_0
pyasn1 0.4.8 pyhd3eb1b0_0
pyasn1-modules 0.2.8 py_0
pycparser 2.21 pyhd3eb1b0_0
pyjwt 2.1.0 py37h06a4308_0
pyopenssl 21.0.0 pyhd3eb1b0_1
pysocks 1.7.1 py37_1
python 3.7.9 h7579374_0
pyyaml 5.3.1 pypi_0 pypi
readline 8.1.2 h7f8727e_1
requests 2.27.1 pyhd3eb1b0_0
requests-oauthlib 1.3.0 py_0
rsa 4.7.2 pyhd3eb1b0_1
scipy 1.7.3 py37hc147768_0
setuptools 58.0.4 py37h06a4308_0
six 1.16.0 pyhd3eb1b0_0
sqlite 3.37.0 hc218d9a_0
tensorboard 2.4.0 pyhc547734_0
tensorboard-plugin-wit 1.6.0 py_0
tensorflow 2.2.0 gpu_py37h1a511ff_0
tensorflow-base 2.2.0 gpu_py37h8a81be8_0
tensorflow-estimator 2.6.0 pyh7b7c402_0
tensorflow-gpu 2.2.0 h0d30ee6_0
termcolor 1.1.0 py37h06a4308_1
tk 8.6.11 h1ccaba5_0
typing-extensions 3.10.0.2 hd3eb1b0_0
typing_extensions 3.10.0.2 pyh06a4308_0
urllib3 1.26.8 pyhd3eb1b0_0
werkzeug 2.0.2 pyhd3eb1b0_0
wheel 0.37.1 pyhd3eb1b0_0
wrapt 1.13.3 py37h7f8727e_2
xz 5.2.5 h7b6447c_0
yarl 1.6.3 py37h27cfd23_0
zipp 3.7.0 pyhd3eb1b0_0
zlib 1.2.11 h7f8727e_4
Any solution of this error will be very helpful for me. Thank you.