-
Notifications
You must be signed in to change notification settings - Fork 171
Description
Do you have an example script for training the genomics model? I am attempting to apply this approach more broadly, and am starting by trying to replicate your example, but my weights are not correct. Any help would be much appreciated!
I'm including the code here that I've written (repurposed from your code) to try to train the model in case that is helpful:
#####################################
Training Genomics Model
#####################################
from future import print_function
import tensorflow
print("Tensorflow version:", tensorflow.version)
import keras
print("Keras version:", keras.version)
import numpy as np
print("Numpy version:", np.version)
from tensorflow.keras.models import model_from_json
import simdna.synthetic as synthetic
#####################################
Import Model Architecture from the original DeepLift code
#####################################
keras_model_json = "keras2_conv1d_record_5_model_PQzyq_modelJson.json"
keras_model = model_from_json(open(keras_model_json).read())
keras_model_config = keras_model.get_config()
model_empty = tensorflow.keras.Sequential().from_config(keras_model_config)
#####################################
Convert Training Set to One Hot Encoding
####################################
def one_hot_encode_along_channel_axis(sequence):
to_return = np.zeros((len(sequence),4), dtype=np.int8)
seq_to_one_hot_fill_in_array(zeros_array=to_return,
sequence=sequence, one_hot_axis=1)
return to_return
def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
assert one_hot_axis==0 or one_hot_axis==1
if (one_hot_axis==0):
assert zeros_array.shape[1] == len(sequence)
elif (one_hot_axis==1):
assert zeros_array.shape[0] == len(sequence)
#will mutate zeros_array
for (i,char) in enumerate(sequence):
if (char=="A" or char=="a"):
char_idx = 0
elif (char=="C" or char=="c"):
char_idx = 1
elif (char=="G" or char=="g"):
char_idx = 2
elif (char=="T" or char=="t"):
char_idx = 3
elif (char=="N" or char=="n"):
continue #leave that pos as all 0's
else:
raise RuntimeError("Unsupported character: "+str(char))
if (one_hot_axis==0):
zeros_array[char_idx,i] = 1
elif (one_hot_axis==1):
zeros_array[i,char_idx] = 1
#read in the data in the training set
data_filename = "sequences.simdata"
train_ids_fh = open("test.txt","r")
ids_to_load = [x.rstrip("\n") for x in train_ids_fh]
#read_simdata_file adds three lists: ids, sequences, embeddings, and labels
data = synthetic.read_simdata_file(data_filename, ids_to_load=ids_to_load)
onehot_data = np.array([one_hot_encode_along_channel_axis(seq) for seq in data.sequences])
#####################################
Train Model
####################################
model_empty.compile(loss="mse", optimizer="sgd")
model_empty.fit(onehot_data, data.labels)
model_empty.save_weights("new_model.h5", save_format='h5')