diff --git a/README.md b/README.md index a94e34bbd..c0fc28026 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,9 @@ In this repository, the network implementation can be found in =0.4.3 +Pillow>=2.2.1 diff --git a/test/test_model.py b/test/test_model.py index a19de6d57..1cf517224 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -9,7 +9,7 @@ # import librosa from wavenet import (WaveNetModel, time_to_batch, batch_to_time, causal_conv, - optimizer_factory, mu_law_decode) + optimizer_factory, mu_law_decode, mu_law_encode) SAMPLE_RATE_HZ = 2000.0 # Hz TRAIN_ITERATIONS = 400 @@ -144,7 +144,8 @@ def testEndToEndTraining(self): # plt.show() audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32) - loss = self.net.loss(audio_tensor) + encode_output = mu_law_encode(audio_tensor, QUANTIZATION_CHANNELS) + loss = self.net.loss(encode_output) optimizer = optimizer_factory[self.optimizer_type]( learning_rate=self.learning_rate, momentum=self.momentum) trainable = tf.trainable_variables() diff --git a/train.py b/train.py index 7310903b7..dda512f26 100644 --- a/train.py +++ b/train.py @@ -17,10 +17,10 @@ import tensorflow as tf from tensorflow.python.client import timeline -from wavenet import WaveNetModel, AudioReader, optimizer_factory +from wavenet import WaveNetModel, FileReader, optimizer_factory BATCH_SIZE = 1 -DATA_DIRECTORY = './VCTK-Corpus' +DATA_DIRECTORY = './data' LOGDIR_ROOT = './logdir' CHECKPOINT_EVERY = 50 NUM_STEPS = int(1e5) @@ -45,9 +45,9 @@ def _str_to_bool(s): parser = argparse.ArgumentParser(description='WaveNet example network') parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, - help='How many wav files to process at once.') + help='How many raw files to process at once.') parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY, - help='The directory containing the VCTK corpus.') + help='The directory containing the training data.') parser.add_argument('--store_metadata', type=bool, default=False, help='Whether to store advanced debugging information ' '(execution time, memory consumption) for use with ' @@ -202,19 +202,19 @@ def main(): # Create coordinator. coord = tf.train.Coordinator() - # Load raw waveform from VCTK corpus. + # Load raw waveform files. with tf.name_scope('create_inputs'): - # Allow silence trimming to be skipped by specifying a threshold near - # zero. - silence_threshold = args.silence_threshold if args.silence_threshold > \ - EPSILON else None - reader = AudioReader( - args.data_dir, - coord, - sample_rate=wavenet_params['sample_rate'], - sample_size=args.sample_size, - silence_threshold=args.silence_threshold) - audio_batch = reader.dequeue(args.batch_size) + reader = FileReader( + args.data_dir, + coord, + sample_rate=wavenet_params['sample_rate'], + sample_size=args.sample_size, + silence_threshold=args.silence_threshold, + quantization_channels=wavenet_params['quantization_channels'], + pattern=wavenet_params['file_ext'], + EPSILON=EPSILON, + raw_type=wavenet_params['raw_type']) + input_batch = reader.dequeue(args.batch_size) # Create network. net = WaveNetModel( @@ -231,7 +231,7 @@ def main(): histograms=args.histograms) if args.l2_regularization_strength == 0: args.l2_regularization_strength = None - loss = net.loss(audio_batch, args.l2_regularization_strength) + loss = net.loss(input_batch, args.l2_regularization_strength) optimizer = optimizer_factory[args.optimizer]( learning_rate=args.learning_rate, momentum=args.momentum) diff --git a/wavenet/__init__.py b/wavenet/__init__.py index 33004507b..690e17688 100644 --- a/wavenet/__init__.py +++ b/wavenet/__init__.py @@ -1,4 +1,7 @@ from .model import WaveNetModel from .audio_reader import AudioReader -from .ops import (mu_law_encode, mu_law_decode, time_to_batch, - batch_to_time, causal_conv, optimizer_factory) +from .text_reader import TextReader +from .image_reader import ImageReader +from .ops import (FileReader, mu_law_encode, mu_law_decode, time_to_batch, + batch_to_time, causal_conv, optimizer_factory, write_output, + create_seed_audio) diff --git a/wavenet/audio_reader.py b/wavenet/audio_reader.py index a1e6f477f..49ccc70a9 100644 --- a/wavenet/audio_reader.py +++ b/wavenet/audio_reader.py @@ -2,13 +2,13 @@ import os import re import threading - import librosa import numpy as np import tensorflow as tf +from .ops import * -def find_files(directory, pattern='*.wav'): +def find_files(directory, pattern): '''Recursively finds all files matching the pattern.''' files = [] for root, dirnames, filenames in os.walk(directory): @@ -17,9 +17,9 @@ def find_files(directory, pattern='*.wav'): return files -def load_generic_audio(directory, sample_rate): +def load_generic_audio(directory, sample_rate, pattern): '''Generator that yields audio waveforms from the directory.''' - files = find_files(directory) + files = find_files(directory, pattern) for filename in files: audio, _ = librosa.load(filename, sr=sample_rate, mono=True) audio = audio.reshape(-1, 1) @@ -59,8 +59,12 @@ def __init__(self, sample_rate, sample_size=None, silence_threshold=None, - queue_size=256): + quantization_channels=256, + queue_size=256, + pattern='*.wav'): self.audio_dir = audio_dir + self.pattern = pattern + self.quantization_channels = quantization_channels self.sample_rate = sample_rate self.coord = coord self.sample_size = sample_size @@ -73,21 +77,26 @@ def __init__(self, self.enqueue = self.queue.enqueue([self.sample_placeholder]) # TODO Find a better way to check this. - # Checking inside the AudioReader's thread makes it hard to terminate - # the execution of the script, so we do it in the constructor for now. - if not find_files(audio_dir): + # Checking inside the AudioReader's thread makes it + # hard to terminate the execution of the script, so + # we do it in the constructor for now. + if not find_files(audio_dir, self.pattern): raise ValueError("No audio files found in '{}'.".format(audio_dir)) def dequeue(self, num_elements): output = self.queue.dequeue_many(num_elements) - return output + # We mu-law encode and quantize the input audioform. + encode_output = mu_law_encode(output, self.quantization_channels) + return encode_output def thread_main(self, sess): buffer_ = np.array([]) stop = False # Go through the dataset multiple times while not stop: - iterator = load_generic_audio(self.audio_dir, self.sample_rate) + iterator = load_generic_audio(self.audio_dir, + self.sample_rate, + self.pattern) for audio, filename in iterator: if self.coord.should_stop(): stop = True diff --git a/wavenet/image_reader.py b/wavenet/image_reader.py new file mode 100644 index 000000000..d279d4b47 --- /dev/null +++ b/wavenet/image_reader.py @@ -0,0 +1,90 @@ +import fnmatch +import os +import threading + +import numpy as np +import tensorflow as tf +from PIL import Image + + +def find_files(directory, pattern='*.jpg'): + '''Recursively finds all files matching the pattern.''' + files = [] + for root, dirnames, filenames in os.walk(directory): + for filename in fnmatch.filter(filenames, pattern): + files.append(os.path.join(root, filename)) + return files + + +def _read_image(filename): + return Image.open(filename).convert('L') + + +def load_generic_image(directory, pattern): + '''Generator that yields text raw from the directory.''' + files = find_files(directory, pattern) + for filename in files: + pic = _read_image(filename) + pic = pic.resize((64, 64), Image.ANTIALIAS) + img = np.array(pic) + img = np.array(img, dtype='float32') + img = img.reshape(-1, 1) + yield img, filename + + +class ImageReader(object): + '''Generic background text reader that preprocesses image files + and enqueues them into a TensorFlow queue.''' + + def __init__(self, + image_dir, + coord, + sample_size=None, + queue_size=256, + pattern='*.jpg'): + self.image_dir = image_dir + self.pattern = pattern + self.coord = coord + self.sample_size = sample_size + self.threads = [] + self.sample_placeholder = tf.placeholder(dtype=tf.float32, shape=None) + self.queue = tf.PaddingFIFOQueue(queue_size, + ['float32'], + shapes=[(None, 1)]) + self.enqueue = self.queue.enqueue([self.sample_placeholder]) + + def dequeue(self, num_elements): + output = self.queue.dequeue_many(num_elements) + encode_output = tf.cast(output, tf.int32) + return encode_output + + def thread_main(self, sess): + buffer_ = np.array([]) + stop = False + # Go through the dataset multiple times + while not stop: + iterator = load_generic_image(self.image_dir, self.pattern) + for image, filename in iterator: + if self.coord.should_stop(): + self.stop_threads() + stop = True + break + if self.sample_size: + # Cut samples into fixed size pieces + buffer_ = np.append(buffer_, image) + while len(buffer_) > self.sample_size: + piece = np.reshape(buffer_[:self.sample_size], [-1, 1]) + sess.run(self.enqueue, + feed_dict={self.sample_placeholder: piece}) + buffer_ = buffer_[self.sample_size:] + else: + sess.run(self.enqueue, + feed_dict={self.sample_placeholder: image}) + + def start_threads(self, sess, n_threads=1): + for _ in range(n_threads): + thread = threading.Thread(target=self.thread_main, args=(sess,)) + thread.daemon = True # Thread will close when parent quits. + thread.start() + self.threads.append(thread) + return self.threads diff --git a/wavenet/model.py b/wavenet/model.py index 071e4738f..078b23caf 100644 --- a/wavenet/model.py +++ b/wavenet/model.py @@ -1,6 +1,6 @@ import tensorflow as tf -from .ops import causal_conv, mu_law_encode +from .ops import causal_conv def create_variable(name, shape): @@ -476,10 +476,6 @@ def loss(self, The variables are all scoped to the given name. ''' with tf.name_scope(name): - # We mu-law encode and quantize the input audioform. - input_batch = mu_law_encode(input_batch, - self.quantization_channels) - encoded = self._one_hot(input_batch) if self.scalar_input: network_input = tf.reshape( diff --git a/wavenet/ops.py b/wavenet/ops.py index 682dee7a5..5c5bf81ac 100644 --- a/wavenet/ops.py +++ b/wavenet/ops.py @@ -1,6 +1,87 @@ from __future__ import division - +import librosa +import numpy as np import tensorflow as tf +from .audio_reader import (AudioReader, trim_silence) +from .text_reader import TextReader +from .image_reader import ImageReader +from PIL import Image + + +def FileReader(data_dir, coord, sample_rate, sample_size, + silence_threshold, quantization_channels, + pattern, EPSILON=0.001, raw_type="Audio"): + if raw_type == "Audio": + # Allow silence trimming to be skipped by specifying a + # threshold near zero. + silence_threshold = silence_threshold if silence_threshold > \ + EPSILON else None + reader = AudioReader(data_dir, coord, sample_rate=sample_rate, + sample_size=sample_size, + silence_threshold=silence_threshold, + quantization_channels=quantization_channels, + pattern=pattern) + elif raw_type == "Text": + reader = TextReader(data_dir, coord, + sample_size=sample_size, + pattern=pattern) + elif raw_type == "Image": + reader = ImageReader(data_dir, coord, + sample_size=sample_size, + pattern=pattern) + return reader + + +def write_output(waveform, filename, sample_rate, raw_type="Audio"): + if raw_type == "Image": + write_img(waveform, filename) + elif raw_type == "Text": + write_text(waveform, filename) + else: + write_wav(waveform, sample_rate, filename) + + +def write_img(waveform, filename): + img = waveform[:-1] + img = np.array(img) + img = img.reshape(-1, 1) + img = img.reshape(64, 64) + new_img = Image.fromarray(img) + new_img = new_img.convert('RGB') + new_img.save(filename) + print('Updated image file at {}'.format(filename)) + + +def write_text(waveform, filename): + text = waveform + y = [] + for index, item in enumerate(text): + y.append(chr(text[index])) + print('Prediction is: ', ''.join(str(e) for e in y)) + y = np.array(y) + np.savetxt(filename, y.reshape(1, y.shape[0]), + delimiter="", newline="\n", fmt="%s") + print('Updated text file at {}'.format(filename)) + + +def write_wav(waveform, sample_rate, filename): + y = np.array(waveform) + librosa.output.write_wav(filename, y, sample_rate) + print('Updated wav file at {}'.format(filename)) + + +def create_seed_audio(filename, + sample_rate, + quantization_channels, + window_size=8000, + silence_threshold=0.1): + audio, _ = librosa.load(filename, sr=sample_rate, mono=True) + audio = trim_silence(audio, silence_threshold) + quantized = mu_law_encode(audio, quantization_channels) + cut_index = tf.cond(tf.size(quantized) < tf.constant(window_size), + lambda: tf.size(quantized), + lambda: tf.constant(window_size)) + return quantized[:cut_index] def create_adam_optimizer(learning_rate, momentum): diff --git a/wavenet/text_reader.py b/wavenet/text_reader.py new file mode 100644 index 000000000..ce29dffb9 --- /dev/null +++ b/wavenet/text_reader.py @@ -0,0 +1,89 @@ +import fnmatch +import os +import threading +import numpy as np +import tensorflow as tf + + +def find_files(directory, pattern): + '''Recursively finds all files matching the pattern.''' + files = [] + for root, dirnames, filenames in os.walk(directory): + for filename in fnmatch.filter(filenames, pattern): + files.append(os.path.join(root, filename)) + return files + + +def _read_text(filename): + with tf.gfile.GFile(filename, "r") as f: + return list(f.read().decode("utf-8").replace("\n", "")) + + +def load_generic_text(directory, pattern): + '''Generator that yields text raw from the directory.''' + files = find_files(directory, pattern) + for filename in files: + text = _read_text(filename) + for index, item in enumerate(text): + text[index] = ord(text[index]) + text = np.array(text, dtype='float32') + text = text.reshape(-1, 1) + yield text, filename + + +class TextReader(object): + '''Generic background text reader that preprocesses text files + and enqueues them into a TensorFlow queue.''' + + def __init__(self, + text_dir, + coord, + sample_size=None, + queue_size=256, + pattern='*.txt'): + self.text_dir = text_dir + self.pattern = pattern + self.coord = coord + self.sample_size = sample_size + self.threads = [] + self.sample_placeholder = tf.placeholder(dtype=tf.float32, shape=None) + self.queue = tf.PaddingFIFOQueue(queue_size, + ['float32'], + shapes=[(None, 1)]) + self.enqueue = self.queue.enqueue([self.sample_placeholder]) + + def dequeue(self, num_elements): + output = self.queue.dequeue_many(num_elements) + encode_output = tf.cast(output, tf.int32) + return encode_output + + def thread_main(self, sess): + buffer_ = np.array([]) + stop = False + # Go through the dataset multiple times + while not stop: + iterator = load_generic_text(self.text_dir, self.pattern) + for text, filename in iterator: + if self.coord.should_stop(): + self.stop_threads() + stop = True + break + if self.sample_size: + # Cut samples into fixed size pieces + buffer_ = np.append(buffer_, text) + while len(buffer_) > self.sample_size: + piece = np.reshape(buffer_[:self.sample_size], [-1, 1]) + sess.run(self.enqueue, + feed_dict={self.sample_placeholder: piece}) + buffer_ = buffer_[self.sample_size:] + else: + sess.run(self.enqueue, + feed_dict={self.sample_placeholder: text}) + + def start_threads(self, sess, n_threads=1): + for _ in range(n_threads): + thread = threading.Thread(target=self.thread_main, args=(sess,)) + thread.daemon = True # Thread will close when parent quits. + thread.start() + self.threads.append(thread) + return self.threads diff --git a/wavenet_params.json b/wavenet_params.json index ab63f509f..90c63155d 100644 --- a/wavenet_params.json +++ b/wavenet_params.json @@ -12,5 +12,7 @@ "skip_channels": 512, "use_biases": true, "scalar_input": false, - "initial_filter_width": 32 + "initial_filter_width": 32, + "raw_type": "Audio", + "file_ext": "*.wav" }