Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions gapic/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gapic import generator
from gapic.schema import api
from gapic.utils import Options
from gapic.utils.cache import generation_cache_context


@click.command()
Expand Down Expand Up @@ -56,15 +57,23 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None:
[p.package for p in req.proto_file if p.name in req.file_to_generate]
).rstrip(".")

# Build the API model object.
# This object is a frozen representation of the whole API, and is sent
# to each template in the rendering step.
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
# Create the generation cache context.
# This provides the shared storage for the @cached_proto_context decorator.
# 1. Performance: Memoizes `with_context` calls, speeding up generation significantly.
# 2. Safety: The decorator uses this storage to "pin" Proto objects in memory.
# This prevents Python's Garbage Collector from deleting objects created during
# `API.build` while `Generator.get_response` is still using their IDs.
# (See `gapic.utils.cache.cached_proto_context` for the specific pinning logic).
with generation_cache_context():
# Build the API model object.
# This object is a frozen representation of the whole API, and is sent
# to each template in the rendering step.
api_schema = api.API.build(req.proto_file, opts=opts, package=package)

# Translate into a protobuf CodeGeneratorResponse; this reads the
# individual templates and renders them.
# If there are issues, error out appropriately.
res = generator.Generator(opts).get_response(api_schema, opts)
# Translate into a protobuf CodeGeneratorResponse; this reads the
# individual templates and renders them.
# If there are issues, error out appropriately.
res = generator.Generator(opts).get_response(api_schema, opts)

# Output the serialized response.
output.write(res.SerializeToString())
Expand Down
3 changes: 3 additions & 0 deletions gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from gapic.schema import imp
from gapic.schema import naming
from gapic.utils import cached_property
from gapic.utils import cached_proto_context
from gapic.utils import RESERVED_NAMES

# This class is a minor hack to optimize Address's __eq__ method.
Expand Down Expand Up @@ -359,6 +360,7 @@ def resolve(self, selector: str) -> str:
return f'{".".join(self.package)}.{selector}'
return selector

@cached_proto_context
def with_context(self, *, collisions: Set[str]) -> "Address":
"""Return a derivative of this address with the provided context.

Expand Down Expand Up @@ -398,6 +400,7 @@ def doc(self):
return "\n\n".join(self.documentation.leading_detached_comments)
return ""

@cached_proto_context
def with_context(self, *, collisions: Set[str]) -> "Metadata":
"""Return a derivative of this metadata with the provided context.

Expand Down
8 changes: 8 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

from gapic import utils
from gapic.schema import metadata
from gapic.utils import cached_proto_context
from gapic.utils import uri_sample
from gapic.utils import make_private

Expand Down Expand Up @@ -410,6 +411,7 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]:
"This code should not be reachable; please file a bug."
)

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -805,6 +807,7 @@ def get_field(
# message.
return cursor.message.get_field(*field_path[1:], collisions=collisions)

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -937,6 +940,7 @@ def ident(self) -> metadata.Address:
"""Return the identifier data to be used in templates."""
return self.meta.address

@cached_proto_context
def with_context(self, *, collisions: Set[str]) -> "EnumType":
"""Return a derivative of this enum with the provided context.

Expand Down Expand Up @@ -1058,6 +1062,7 @@ class ExtendedOperationInfo:
request_type: MessageType
operation_type: MessageType

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -1127,6 +1132,7 @@ class OperationInfo:
response_type: MessageType
metadata_type: MessageType

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -1937,6 +1943,7 @@ def void(self) -> bool:
"""Return True if this method has no return value, False otherwise."""
return self.output.ident.proto == "google.protobuf.Empty"

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -2357,6 +2364,7 @@ def operation_polling_method(self) -> Optional[Method]:
def is_internal(self) -> bool:
return any(m.is_internal for m in self.methods.values())

@cached_proto_context
def with_context(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions gapic/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from gapic.utils.cache import cached_property
from gapic.utils.cache import cached_proto_context
from gapic.utils.case import to_snake_case
from gapic.utils.case import to_camel_case
from gapic.utils.checks import is_msg_field_pb
Expand All @@ -34,6 +35,7 @@

__all__ = (
"cached_property",
"cached_proto_context",
"convert_uri_fieldnames",
"doc",
"empty",
Expand Down
90 changes: 90 additions & 0 deletions gapic/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import functools
import contextlib
import threading


def cached_property(fx):
Expand Down Expand Up @@ -43,3 +45,91 @@ def inner(self):
return self._cached_values[fx.__name__]

return property(inner)


# Thread-local storage for the simple cache dictionary.
# This ensures that parallel generation tasks (if any) do not corrupt each other's cache.
_thread_local = threading.local()


@contextlib.contextmanager
def generation_cache_context():
"""Context manager to explicitly manage the lifecycle of the generation cache.

This manager initializes a fresh dictionary in thread-local storage when entering
the context and strictly deletes it when exiting.

**Memory Management:**
The cache stores strong references to Proto objects to "pin" them in memory
(see `cached_proto_context`). It is critical that this context manager deletes
the dictionary in the `finally` block. Deleting the dictionary breaks the
reference chain, allowing Python's Garbage Collector to finally free all the
large Proto objects that were pinned during generation.
"""
# Initialize the cache as a standard dictionary.
_thread_local.cache = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be better to use a more unique name here, to avoid collisions

try:
yield
finally:
# Delete the dictionary to free all memory and pinned objects.
# This is essential to prevent memory leaks in long-running processes.
del _thread_local.cache


def cached_proto_context(func):
"""Decorator to memoize `with_context` calls based on object identity and collisions.

This mechanism provides a significant performance boost by preventing
redundant recalculations of naming collisions during template rendering.

Since the Proto wrapper objects are unhashable (mutable), we use `id(self)` as
the primary cache key. Normally, this is dangerous: if the object is garbage
collected, Python might reuse its memory address for a *new* object, leading to
a cache collision (the "Zombie ID" bug).

To prevent this, this decorator stores the value as a tuple: `(result, self)`.
By keeping a reference to `self` in the cache value, we "pin" the object in
memory. This forces the Garbage Collector to keep the object alive, guaranteeing
that `id(self)` remains unique for the entire lifespan of the `generation_cache_context`.

Args:
func (Callable): The function to decorate (usually `with_context`).

Returns:
Callable: The wrapped function with caching and pinning logic.
"""

@functools.wraps(func)
def wrapper(self, *, collisions, **kwargs):

# 1. Check for active cache (returns None if context is not active)
context_cache = getattr(_thread_local, "cache", None)

# If we are not inside a generation_cache_context (e.g. unit tests),
# bypass the cache entirely.
if context_cache is None:
return func(self, collisions=collisions, **kwargs)

# 2. Create the cache key
# We use frozenset for collisions to make it hashable.
# We use id(self) because 'self' is not hashable.
collisions_key = frozenset(collisions) if collisions else None
key = (id(self), collisions_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this, my first thought was that this cached state should probably be associated with each instance instead of being global, since it's keyed by the instance anyway.

But looking at the code more, it looks like each self refers to a different dataclass, so it wouldn't be easy to add this state to each of them. So this looks like the cleaner solution to me


# 3. Check Cache
if key in context_cache:
# The cache stores (result, pinned_object). We return just the result.
return context_cache[key][0]

# 4. Execute the actual function
# We ensure context_cache is passed down to the recursive calls
result = func(self, collisions=collisions, **kwargs)

# 5. Update Cache & Pin Object
# We store (result, self). The reference to 'self' prevents garbage collection,
# ensuring that 'id(self)' cannot be reused for a new object while this
# cache entry exists.
context_cache[key] = (result, self)
return result

return wrapper
40 changes: 40 additions & 0 deletions tests/unit/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,43 @@ def bar(self):
assert foo.call_count == 1
assert foo.bar == 42
assert foo.call_count == 1


def test_cached_proto_context():
class Foo:
def __init__(self):
self.call_count = 0

# We define a signature that matches the real Proto.with_context
# to ensure arguments are propagated correctly.
@cache.cached_proto_context
def with_context(self, collisions, *, skip_fields=False, visited_messages=None):
self.call_count += 1
return f"val-{self.call_count}"

foo = Foo()

# 1. Test Bypass (No Context)
# The cache is not active, so every call increments the counter.
assert foo.with_context(collisions={"a"}) == "val-1"
assert foo.with_context(collisions={"a"}) == "val-2"

# 2. Test Context Activation
with cache.generation_cache_context():
# Reset counter to make tracking easier
foo.call_count = 0

# A. Basic Cache Hit
assert foo.with_context(collisions={"a"}) == "val-1", "a"
assert foo.with_context(collisions={"a"}) == "val-1" # Hit
assert foo.call_count == 1

# B. Collision Difference
# Changing collisions creates a new key
assert foo.with_context(collisions={"b"}) == "val-2"
assert foo.call_count == 2

# 3. Context Cleared
# Everything should be forgotten now.
assert getattr(cache._thread_local, "cache", None) is None
assert foo.with_context(collisions={"a"}) == "val-3"
Loading