From 5e0a70e6ab41cd9967e981e7c2cf0fbb0a9a853a Mon Sep 17 00:00:00 2001 From: Andrei Popescu Date: Wed, 23 Aug 2023 22:46:53 +0300 Subject: [PATCH] Added explicit parameters args for model init --- example.py | 33 ++++++++ pydiar/models/binary_key/__init__.py | 77 +++++++++++++------ .../models/binary_key/diarizationFunctions.py | 15 +--- pydiar/util/sad.py | 5 +- 4 files changed, 90 insertions(+), 40 deletions(-) create mode 100644 example.py diff --git a/example.py b/example.py new file mode 100644 index 0000000..bcea993 --- /dev/null +++ b/example.py @@ -0,0 +1,33 @@ +from pydiar.models import BinaryKeyDiarizationModel, Segment +from pydiar.util.misc import optimize_segments +from pydub import AudioSegment +import numpy as np + +if __name__ == "__main__": + INPUT_FILE = "./" + OUTLIER_FILE = "./" + + sample_rate = 32000 + audio = AudioSegment.from_wav(INPUT_FILE) + audio = audio.set_frame_rate(sample_rate) + audio = audio.set_channels(1) + + outlier = AudioSegment.from_wav(OUTLIER_FILE) + outlier = outlier.set_frame_rate(sample_rate) + outlier = outlier.set_channels(1) + + # combine audio and outlier + audio = audio + outlier + + diarization_model = BinaryKeyDiarizationModel(clustering_selection_max_speakers=2) + + segments = diarization_model.diarize( + sample_rate, np.array(audio.get_array_of_samples()) + ) + optimized_segments = optimize_segments(segments) + + unique_speakers = set() + for segment in optimized_segments: + unique_speakers.add(segment.speaker_id) + + print(f"Number of speakers: {len(unique_speakers)}") diff --git a/pydiar/models/binary_key/__init__.py b/pydiar/models/binary_key/__init__.py index afaf258..6470bbd 100644 --- a/pydiar/models/binary_key/__init__.py +++ b/pydiar/models/binary_key/__init__.py @@ -27,37 +27,67 @@ class BinaryKeyDiarizationModel(DiarizationModel): This implementation is heavily based on https://github.com/josepatino/pyBK """ - def __init__(self): - self.__init_parameters() + def __init__( + self, + framelength=0.025, + frameshift=0.01, + nfilters=30, + ncoeff=30, + segment_length=100, + segment_shift=100, + segment_overlap=100, + kbm_max_window_shift=50, + kbm_window_length=200, + kbm_min_gaussians=1024, + kbm_size_rel=0.1, + top_gaussians_per_frame=5, + initial_clusters=16, + bk_one_percentage=0.2, + clustering_metric="cosine", + clustering_selection_metric="cosine", + clustering_selection_max_speakers=16, + resegmentation_model_size=6, + resegmentation_nb_iter=10, + resegmentation_smooth_win=100, + ): + self.FRAMELENGTH = framelength + self.FRAMESHIFT = frameshift + self.NFILTERS = nfilters + self.NCOEFF = ncoeff - def __init_parameters(self): - self.FRAMELENGTH = 0.025 - self.FRAMESHIFT = 0.01 - self.NFILTERS = 30 - self.NCOEFF = 30 + self.SEGMENT_LENGTH = segment_length + self.SEGMENT_SHIFT = segment_shift + self.SEGMENT_OVERLAP = segment_overlap - self.SEGMENT_LENGTH = 100 - self.SEGMENT_SHIFT = 100 - self.SEGMENT_OVERLAP = 100 + self.KBM_MAX_WINDOW_SHIFT = kbm_max_window_shift + self.KBM_WINDOW_LENGTH = kbm_window_length + self.KBM_MIN_GAUSSIANS = kbm_min_gaussians - self.KBM_MAX_WINDOW_SHIFT = 50 - self.KBM_WINDOW_LENGTH = 200 - self.KBM_MIN_GAUSSIANS = 1024 + self.KBM_SIZE_REL = kbm_size_rel - self.KBM_SIZE_REL = 0.1 + self.TOP_GAUSSIANS_PER_FRAME = top_gaussians_per_frame - self.TOP_GAUSSIANS_PER_FRAME = 5 + self.INITIAL_CLUSTERS = initial_clusters + self.BK_ONE_PERCENTAGE = bk_one_percentage - self.INITIAL_CLUSTERS = 16 - self.BK_ONE_PERCENTAGE = 0.2 + if clustering_metric != "cosine": + raise ValueError( + "Only `cosine` distance is supported for clustering metric." + ) + + self.CLUSTERING_METRIC = clustering_metric + + if clustering_selection_metric != "cosine": + raise ValueError( + "Only `cosine` distance is supported for clustering selection metric." + ) - self.CLUSTERING_METRIC = "cosine" self.CLUSTERING_SELECTION_METRIC = "cosine" - self.CLUSTERING_SELECTION_MAX_SPEAKERS = 16 + self.CLUSTERING_SELECTION_MAX_SPEAKERS = clustering_selection_max_speakers - self.RESEGMENTATION_MODEL_SIZE = 6 - self.RESEGMENTATION_NB_ITER = 10 - self.RESEGMENTATION_SMOOTH_WIN = 100 + self.RESEGMENTATION_MODEL_SIZE = resegmentation_model_size + self.RESEGMENTATION_NB_ITER = resegmentation_nb_iter + self.RESEGMENTATION_SMOOTH_WIN = resegmentation_smooth_win def _extract_features(self, sample_rate, signal): """ @@ -167,7 +197,6 @@ def _binary_processing( initialClustering, self.CLUSTERING_METRIC, ) - logging.info("Finding best clustering") bestClusteringID = getBestClustering( @@ -178,7 +207,9 @@ def _binary_processing( k, self.CLUSTERING_SELECTION_MAX_SPEAKERS, ).astype(int) + best_clustering = finalClusteringTable[:, bestClusteringID] + logging.info( f"Best: {bestClusteringID} with " f"{np.size(np.unique(best_clustering), 0)} clusters" diff --git a/pydiar/models/binary_key/diarizationFunctions.py b/pydiar/models/binary_key/diarizationFunctions.py index 234f76b..e170355 100644 --- a/pydiar/models/binary_key/diarizationFunctions.py +++ b/pydiar/models/binary_key/diarizationFunctions.py @@ -251,6 +251,7 @@ def performClustering( ####### highest similarity, creating a new signature for the resulting ####### cluster ####### 4. Back to 1 if #clusters > 1 + for k in range(N_init): ####### 1. Data reassignment. Calculate the similarity between the current segment with all clusters and assign it to the one which maximizes ####### the similarity. Finally re-calculate binaryKeys for all cluster @@ -299,15 +300,6 @@ def performClustering( location = np.nanargmax(clusterSimilarityMatrix) R, C = np.unravel_index(location, (N_init, N_init)) ### Then we merge clusters R and C - # logging.info('Merging clusters',R+1,'and',C+1,'with a similarity score of',np.around(value,decimals=4)) - logging.info( - "Merging clusters", - "%3s" % str(R + 1), - "and", - "%3s" % str(C + 1), - "with a similarity score of", - np.around(value, decimals=4), - ) activeClusters[0, C] = 0 ### 3. Save the resulting clustering and go back to 1 if the number of clusters >1 mergingClusteringIndices = np.where(clusteringTable[:, k] == C + 1) @@ -386,7 +378,6 @@ def binaryKeySimilarity_cdist(clusteringMetric, bkT1, cvT1, bkT2, cvT2): def getBestClustering( bestClusteringMetric, bkT, cvT, clusteringTable, n, maxNrSpeakers ): - wss = np.zeros([1, n]) overallMean = np.mean(cvT, 0) if bestClusteringMetric == "cosine": @@ -443,15 +434,12 @@ def getBestClustering( vecToLine = vecFromFirst - vecFromFirstParallel distToLine = np.sqrt(np.sum(np.square(vecToLine), axis=1)) bestClusteringID = allCoord[np.argmax(distToLine)][0] - print(allCoord, distToLine, bestClusteringID) # Select best clustering that matches max speaker limit nrSpeakersPerSolution = np.zeros((clusteringTable.shape[1])) for k in np.arange(clusteringTable.shape[1]): nrSpeakersPerSolution[k] = np.size(np.unique(clusteringTable[:, k])) firstAllowedClustering = np.min(np.where(nrSpeakersPerSolution <= maxNrSpeakers)) - print(f"{nrSpeakersPerSolution=}") - print(f"{firstAllowedClustering=}") # Note: clusters are ordered from most clusters to least, so this selects the bestClusteringID # unless it has more than maxNrSpeakers nodes, in which case it selects firstAllowedClustering bestClusteringID = np.maximum( @@ -472,7 +460,6 @@ def performResegmentation( smoothWin, numberOfSpeechFeatures, ): - np.random.seed(0) changePoints, segBeg, segEnd, nSegs = unravelMask(mask) diff --git a/pydiar/util/sad.py b/pydiar/util/sad.py index f01272b..c212a42 100644 --- a/pydiar/util/sad.py +++ b/pydiar/util/sad.py @@ -8,7 +8,6 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): - # from librosa.core import resample # from librosa.util import frame """Voice activity detection. @@ -55,7 +54,7 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): # check data if data.dtype.kind == "i": - if data.max() > 2 ** 15 - 1 or data.min() < -(2 ** 15): + if data.max() > 2**15 - 1 or data.min() < -(2**15): raise ValueError( "when data type is int, data must be -32768 < data < 32767." ) @@ -65,7 +64,7 @@ def py_webrtcvad(data, fs, fs_vad, hoplength=30, vad_mode=0): if np.abs(data).max() >= 1: data = data / np.abs(data).max() * 0.9 warnings.warn("input data was rescaled.") - data = (data * 2 ** 15).astype("f") + data = (data * 2**15).astype("f") else: raise ValueError("data dtype must be int or float.")