Skip to content

Commit 777c738

Browse files
authored
Implement embedding extraction for LargeRecordingAnalyzer (#122)
* Implement embedding extraction for LargeRecordingAnalyzer * Ensure the tests use the expected release of BirdNET for cmdline comparison
1 parent 505c99c commit 777c738

File tree

7 files changed

+79
-37
lines changed

7 files changed

+79
-37
lines changed

.github/workflows/publish.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ jobs:
3535
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3636
- name: Test with pytest
3737
run: |
38-
pytest -m "not omit_during_ghactions"
38+
git -C tests/BirdNET-Analyzer checkout 98945574c68102ccfac6c3504fcc63e64ed6f9e3
39+
pytest -m "not omit_during_ghactions" --maxfail=1
3940
deploy:
4041
runs-on: ubuntu-latest
4142
needs: [test]

.github/workflows/test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ jobs:
3737
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3838
- name: Test with pytest
3939
run: |
40-
pytest -m "not omit_during_ghactions"
40+
git -C tests/BirdNET-Analyzer checkout 98945574c68102ccfac6c3504fcc63e64ed6f9e3
41+
pytest -m "not omit_during_ghactions" --maxfail=1

src/birdnetlib/analyzer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,27 @@ def analyze_recording(self, recording):
568568

569569
self.results = results
570570
recording.detection_list = self.detections
571+
572+
def extract_embeddings_for_recording(self, recording):
573+
print("extract_embeddings_for_recording", recording.filename)
574+
start = 0
575+
end = recording.sample_secs
576+
results = []
577+
for segment in read_audio_segments(recording.path, sr=48000):
578+
c = segment["segment"]
579+
if len(c) < recording.sample_secs * 48000:
580+
# If below the minimum segment duration, continue.
581+
del c
582+
continue
583+
start = segment["start_sec"]
584+
end = segment["end_sec"]
585+
586+
data = np.array([c], dtype="float32")
587+
e = self._return_embeddings(data)[0].tolist()
588+
results.append({"start_time": start, "end_time": end, "embeddings": e})
589+
590+
# Increment start and end
591+
start += recording.sample_secs - recording.overlap
592+
end = start + recording.sample_secs
593+
594+
self.embeddings = results

src/birdnetlib/main.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,6 @@ def analyze(self):
7373
self.analyzed = True
7474

7575
def extract_embeddings(self):
76-
# Check that analyzer is not LargeRecordingAnalyzer
77-
if isinstance(self.analyzer, LargeRecordingAnalyzer):
78-
raise IncompatibleAnalyzerError(
79-
"LargeRecordingAnalyzer can only be used with the LargeRecording class"
80-
)
81-
8276
# Read and analyze.
8377
self.read_audio_data()
8478
self.analyzer.extract_embeddings_for_recording(self)
@@ -480,9 +474,9 @@ def analyze(self):
480474
self.analyzed = True
481475

482476
def extract_embeddings(self):
483-
raise NotImplementedError(
484-
"Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible."
485-
)
477+
self.analyzer.extract_embeddings_for_recording(self)
478+
self.embeddings_list = self.analyzer.embeddings
479+
self.embeddings_extracted = True
486480

487481
def get_extract_array(self, start_sec, end_sec):
488482
# Returns ndarray trimmed for start_sec:end_sec

tests/test_analyzer_versioning.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def run_before_and_after_tests():
2727
clean_up_temp_installed_versions()
2828
# Restore main branch for BirdNET-Analyzer to origin/main.
2929
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")
30-
os.system(f"cd {birdnet_analyzer_path}; git clean -fd; git switch main; git status")
30+
os.system(
31+
f"cd {birdnet_analyzer_path}; git clean -fd; git checkout 98945574c68102ccfac6c3504fcc63e64ed6f9e3; git status"
32+
)
3133

3234

3335
def clean_up_temp_installed_versions():

tests/test_buffer_analyzer.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from unittest.mock import patch
1111
import io
1212

13-
def test_without_species_list():
13+
TEST_BN_COMMIT = "98945574c68102ccfac6c3504fcc63e64ed6f9e3"
14+
1415

16+
def test_without_species_list():
1517
# Process file with command line utility, then process with python library and ensure equal commandline_results.
1618

1719
lon = -120.7463
@@ -24,6 +26,9 @@ def test_without_species_list():
2426

2527
# Process using python script as is.
2628
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")
29+
os.system(
30+
f"cd {birdnet_analyzer_path}; git clean -fd; git checkout {TEST_BN_COMMIT}; git status"
31+
)
2732

2833
cmd = f"python analyze.py --i '{input_path}' --o={output_path} --lat {lat} --lon {lon} --week {week_48} --min_conf {min_conf} --rtype=csv"
2934
print(cmd)
@@ -50,10 +55,10 @@ def test_without_species_list():
5055

5156
# pprint(commandline_results)
5257
assert len(commandline_results) > 0
53-
with open(input_path,'rb') as f:
58+
with open(input_path, "rb") as f:
5459
wav_buffer = f.read()
5560
bytes_buffer = io.BytesIO(wav_buffer)
56-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
61+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
5762
analyzer = Analyzer()
5863
recording = RecordingBuffer(
5964
analyzer,
@@ -78,7 +83,6 @@ def test_without_species_list():
7883

7984

8085
def test_with_species_list_path():
81-
8286
# Process file with command line utility, then process with python library and ensure equal commandline_results.
8387

8488
lon = -120.7463
@@ -95,6 +99,9 @@ def test_with_species_list_path():
9599

96100
# Process using python script as is.
97101
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")
102+
os.system(
103+
f"cd {birdnet_analyzer_path}; git clean -fd; git checkout {TEST_BN_COMMIT}; git status"
104+
)
98105

99106
cmd = f"python analyze.py --i '{input_path}' --o={output_path} --min_conf {min_conf} --slist {custom_list_path} --rtype=csv"
100107
os.system(f"cd {birdnet_analyzer_path}; {cmd}")
@@ -120,10 +127,10 @@ def test_with_species_list_path():
120127

121128
pprint(commandline_results)
122129
assert len(commandline_results) > 0
123-
with open(input_path,'rb') as f:
130+
with open(input_path, "rb") as f:
124131
wav_buffer = f.read()
125132
bytes_buffer = io.BytesIO(wav_buffer)
126-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
133+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
127134
analyzer = Analyzer(custom_species_list_path=custom_list_path)
128135
recording = RecordingBuffer(
129136
analyzer,
@@ -139,7 +146,8 @@ def test_with_species_list_path():
139146
pprint(recording.detections)
140147

141148
assert (
142-
commandline_results[0]["common_name"] == recording.detections[0]["common_name"]
149+
commandline_results[0]["common_name"]
150+
== recording.detections[0]["common_name"]
143151
)
144152

145153
commandline_birds = [i["common_name"] for i in commandline_results]
@@ -151,11 +159,11 @@ def test_with_species_list_path():
151159
len(analyzer.custom_species_list) == 41
152160
) # Check that this matches the number printed by the cli version.
153161

154-
with open(input_path,'rb') as f:
162+
with open(input_path, "rb") as f:
155163
wav_buffer = f.read()
156164
bytes_buffer = io.BytesIO(wav_buffer)
157165
# Run a recording without path and throw an error when used with custom species list.
158-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
166+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
159167
with pytest.raises(ValueError):
160168
recording = RecordingBuffer(
161169
analyzer,
@@ -170,7 +178,6 @@ def test_with_species_list_path():
170178

171179

172180
def test_with_species_list():
173-
174181
# Process file with command line utility, then process with python library and ensure equal commandline_results.
175182

176183
lon = -120.7463
@@ -187,6 +194,9 @@ def test_with_species_list():
187194

188195
# Process using python script as is.
189196
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")
197+
os.system(
198+
f"cd {birdnet_analyzer_path}; git clean -fd; git checkout {TEST_BN_COMMIT}; git status"
199+
)
190200

191201
cmd = f"python analyze.py --i '{input_path}' --o={output_path} --min_conf {min_conf} --slist {custom_list_path} --rtype=csv"
192202
os.system(f"cd {birdnet_analyzer_path}; {cmd}")
@@ -257,11 +267,11 @@ def test_with_species_list():
257267
"Zonotrichia albicollis_White-throated Sparrow",
258268
]
259269

260-
with open(input_path,'rb') as f:
270+
with open(input_path, "rb") as f:
261271
wav_buffer = f.read()
262272
bytes_buffer = io.BytesIO(wav_buffer)
263273
# Run a recording without path and throw an error when used with custom species list.
264-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
274+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
265275
analyzer = Analyzer(custom_species_list=custom_species_list)
266276
recording = RecordingBuffer(
267277
analyzer,
@@ -277,7 +287,8 @@ def test_with_species_list():
277287
pprint(recording.detections)
278288

279289
assert (
280-
commandline_results[0]["common_name"] == recording.detections[0]["common_name"]
290+
commandline_results[0]["common_name"]
291+
== recording.detections[0]["common_name"]
281292
)
282293

283294
commandline_birds = [i["common_name"] for i in commandline_results]
@@ -290,11 +301,11 @@ def test_with_species_list():
290301
) # Check that this matches the number printed by the cli version.
291302

292303
# Run a recording with lat/lon and throw an error when used with custom species list.
293-
with open(input_path,'rb') as f:
304+
with open(input_path, "rb") as f:
294305
wav_buffer = f.read()
295306
bytes_buffer = io.BytesIO(wav_buffer)
296307
# Run a recording without path and throw an error when used with custom species list.
297-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
308+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
298309
with pytest.raises(ValueError):
299310
recording = RecordingBuffer(
300311
analyzer,
@@ -309,7 +320,6 @@ def test_with_species_list():
309320

310321

311322
def test_species_list_calls():
312-
313323
lon = -120.7463
314324
lat = 35.4244
315325
week_48 = 18
@@ -324,11 +334,11 @@ def test_species_list_calls():
324334
"return_predicted_species_list",
325335
wraps=analyzer.return_predicted_species_list,
326336
) as wrapped_return_predicted_species_list:
327-
with open(input_path,'rb') as f:
337+
with open(input_path, "rb") as f:
328338
wav_buffer = f.read()
329339
bytes_buffer = io.BytesIO(wav_buffer)
330340
# Run a recording without path and throw an error when used with custom species list.
331-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
341+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
332342
recording = RecordingBuffer(
333343
analyzer,
334344
buffer,
@@ -342,10 +352,10 @@ def test_species_list_calls():
342352
assert wrapped_return_predicted_species_list.call_count == 1
343353

344354
# Second recording with the same position/time should not regerate the species list.
345-
with open(input_path,'rb') as f:
355+
with open(input_path, "rb") as f:
346356
wav_buffer = f.read()
347357
bytes_buffer = io.BytesIO(wav_buffer)
348-
for rate,buffer in wavutils.bufferwavs(bytes_buffer):
358+
for rate, buffer in wavutils.bufferwavs(bytes_buffer):
349359
recording = RecordingBuffer(
350360
analyzer,
351361
buffer,

tests/test_embeddings.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212

13+
@pytest.mark.omit_during_ghactions
1314
def test_embeddings():
1415
# Process file with command line utility, then process with python library and ensure equal commandline_results.
1516

@@ -68,6 +69,7 @@ def test_embeddings():
6869
)
6970

7071

72+
@pytest.mark.omit_during_ghactions
7173
def test_largefile_embeddings():
7274
# Process file with command line utility, then process with python library and ensure equal commandline_results.
7375

@@ -101,8 +103,6 @@ def test_largefile_embeddings():
101103
# pprint(commandline_results)
102104
assert len(commandline_results) == 40
103105

104-
# TODO: Implement for LargeRecording.
105-
# Confirm that LargeRecording return not implemented.
106106
large_analyzer = LargeRecordingAnalyzer()
107107
recording = LargeRecording(
108108
large_analyzer,
@@ -113,6 +113,16 @@ def test_largefile_embeddings():
113113
min_conf=min_conf,
114114
return_all_detections=True,
115115
)
116-
msg = "Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible."
117-
with pytest.raises(NotImplementedError, match=msg):
118-
recording.extract_embeddings()
116+
117+
recording.extract_embeddings()
118+
119+
# Check that birdnetlib results match command line results.
120+
assert len(recording.embeddings) == 40
121+
for idx, i in enumerate(commandline_results):
122+
# Specify the tolerance level
123+
tolerance = 1e-4 # 4 decimal points tolerance between BirdNET and birdnetlib.
124+
125+
# Assert that the arrays are almost equal within the tolerance
126+
assert np.allclose(
127+
i["embeddings"], recording.embeddings[idx]["embeddings"], atol=tolerance
128+
)

0 commit comments

Comments
 (0)