Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/whichprovides/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import dataclasses
import functools
import pathlib
import re
import shutil
Expand Down Expand Up @@ -290,6 +291,7 @@ def whichprovides(cls, filepaths: typing.Collection[str]) -> dict[str, ProvidedB
return results


@functools.cache
def _package_providers() -> list[type[PackageProvider]]:
"""Returns a list of package providers sorted in
the order that they should be attempted.
Expand All @@ -305,6 +307,20 @@ def all_subclasses(cls):
return sorted(all_subclasses(PackageProvider), key=lambda p: p._resolve_order)


def _available_package_providers(
_is_available_cache: dict[type[PackageProvider], bool] = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand how caching works here.
Should we expect the default argument to be extended in this case ? (i.e. mutable default)
If that's the case, it should work although a bit hard to read/understand (and is that ok in the context of a no-gil build if called in parallel ?)

) -> typing.Generator[type[PackageProvider], None, None]:
"""We use a generator here because PackageProviders might not
all need to be queried for 'is_available()' if 'whichprovides()'
is able to find matches for all file paths.
"""
for package_provider in _package_providers():
if package_provider not in _is_available_cache:
_is_available_cache[package_provider] = package_provider.is_available()
if _is_available_cache[package_provider]:
yield package_provider


def whichprovides(filepath: typing.Union[str, list[str]]) -> dict[str, ProvidedBy]:
"""Return a package URL (PURL) for the package that provides a file"""
if isinstance(filepath, str):
Expand All @@ -318,12 +334,10 @@ def whichprovides(filepath: typing.Union[str, list[str]]) -> dict[str, ProvidedB
str(pathlib.Path(filepath).resolve()): filepath for filepath in filepaths
}
filepath_provided_by: dict[str, ProvidedBy] = {}
for package_provider in _package_providers():
for package_provider in _available_package_providers():
remaining = set(resolved_filepaths) - set(filepath_provided_by)
if not remaining:
break
if not package_provider.is_available():
continue
results = package_provider.whichprovides(remaining)
filepath_provided_by.update(results)

Expand Down