Skip to content

Commit 7b0e945

Browse files
authored
Merge pull request #54 from tomsch420/main
Fixed bug in getting relationships from parent alternatively mapped objects.
2 parents 54c9907 + f56a49e commit 7b0e945

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

src/krrood/ormatic/dao.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sqlalchemy.orm
1313
from sqlalchemy import Column
1414
from sqlalchemy.orm import MANYTOONE, MANYTOMANY, ONETOMANY, RelationshipProperty
15+
from sqlalchemy.util import ReadOnlyProperties
1516
from typing_extensions import (
1617
Type,
1718
get_args,
@@ -23,6 +24,7 @@
2324
Optional,
2425
List,
2526
Iterable,
27+
Tuple,
2628
)
2729

2830
from ..utils import recursive_subclasses
@@ -527,17 +529,42 @@ def to_dao_if_subclass_of_alternative_mapping(
527529
setattr(self, prop.key, getattr(obj, prop.key))
528530

529531
# split relationships in relationships by parent and relationships by child
530-
all_relationships = mapper.relationships
531-
relationships_of_parent = parent_mapper.relationships
532-
relationships_of_this_table = [
533-
r for r in all_relationships if r not in relationships_of_parent
534-
]
532+
relationships_of_parent, relationships_of_this_table = (
533+
self.partition_parent_child_relationships(parent_mapper, mapper)
534+
)
535535

536-
for relationship in relationships_of_parent:
537-
setattr(self, relationship.key, getattr(parent_dao, relationship.key))
536+
# get relationships from parent dao
537+
self.get_relationships_from(parent_dao, relationships_of_parent, state)
538538

539+
# get relationships from the current table
539540
self.get_relationships_from(obj, relationships_of_this_table, state)
540541

542+
def partition_parent_child_relationships(
543+
self, parent: sqlalchemy.orm.Mapper, child: sqlalchemy.orm.Mapper
544+
) -> Tuple[
545+
List[RelationshipProperty[Any]],
546+
List[RelationshipProperty[Any]],
547+
]:
548+
"""
549+
Partition the relationships by parent-only and child-only relationships.
550+
551+
:param parent: The parent mapper to extract relationships from
552+
:param child: The child mapper to extract relationships from
553+
:return: A tuple of the relationships that are only in the parent and the relationships that are only in the child
554+
"""
555+
all_relationships = child.relationships
556+
relationships_of_parent = parent.relationships
557+
relationship_names_of_parent = list(
558+
map(lambda x: x.key, relationships_of_parent)
559+
)
560+
561+
relationships_of_child = list(
562+
filter(
563+
lambda x: x.key not in relationship_names_of_parent, all_relationships
564+
)
565+
)
566+
return relationships_of_parent, relationships_of_child
567+
541568
def get_columns_from(self, obj: T, columns: Iterable[Column]) -> None:
542569
"""
543570
Retrieves and assigns values from specified columns of a given object.

test/dataset/example_classes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ def create_instance(cls, obj: T):
510510
class ParentAlternativelyMapped:
511511
base_attribute: float = 0
512512

513+
entities: List[Entity] = field(default_factory=list)
514+
513515

514516
@dataclass
515517
class ChildLevel1NormallyMapped(ParentAlternativelyMapped):
@@ -524,15 +526,19 @@ class ChildLevel2NormallyMapped(ChildLevel1NormallyMapped):
524526
@dataclass
525527
class ParentAlternativelyMappedMapping(AlternativeMapping[ParentAlternativelyMapped]):
526528
derived_attribute: str
529+
entities: List[Entity]
527530

528531
@classmethod
529532
def create_instance(cls, obj: T) -> Self:
530-
return cls(str(obj.base_attribute))
533+
return cls(str(obj.base_attribute), obj.entities)
531534

532535
def create_from_dao(self) -> T:
533536
raise NotImplementedError
534537

535538

539+
# %% Function like classes for testing
540+
541+
536542
@dataclass
537543
class CallableWrapper:
538544
func: FunctionType
@@ -560,6 +566,9 @@ class UUIDWrapper:
560566
other_identifications: List[uuid.UUID] = field(default_factory=list)
561567

562568

569+
# %% Test JSON serialization in ORM classes
570+
571+
563572
@dataclass
564573
class JSONSerializableClass(SubclassJSONSerializer):
565574
a: float = 0.0

test/dataset/ormatic_interface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ class Base(DeclarativeBase):
4444

4545

4646
# Association tables for many-to-many relationships
47+
parentalternativelymappedmappingdao_entities_association = Table(
48+
"parentalternativelymappedmappingdao_entities_association",
49+
Base.metadata,
50+
Column(
51+
"parentalternativelymappedmappingdao_id",
52+
ForeignKey("ParentAlternativelyMappedMappingDAO.database_id"),
53+
),
54+
Column("customentitydao_id", ForeignKey("CustomEntityDAO.database_id")),
55+
)
4756
alternativemappingaggregatordao_entities1_association = Table(
4857
"alternativemappingaggregatordao_entities1_association",
4958
Base.metadata,
@@ -333,6 +342,12 @@ class ParentAlternativelyMappedMappingDAO(
333342
String(255), nullable=False, use_existing_column=True
334343
)
335344

345+
entities: Mapped[typing.List[CustomEntityDAO]] = relationship(
346+
"CustomEntityDAO",
347+
secondary="parentalternativelymappedmappingdao_entities_association",
348+
cascade="save-update, merge",
349+
)
350+
336351
__mapper_args__ = {
337352
"polymorphic_on": "polymorphic_type",
338353
"polymorphic_identity": "ParentAlternativelyMappedMappingDAO",

test/test_ormatic/test_interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,18 @@ def test_inheritance_mapper_args(session, database):
475475

476476

477477
def test_to_dao_alternatively_mapped_parent(session, database):
478-
ch2 = ChildLevel2NormallyMapped(1, 2, 3)
478+
ch2 = ChildLevel2NormallyMapped(1, [Entity("a")], 2, 3)
479479
ch2_dao = to_dao(ch2)
480480

481-
assert ch2_dao == ChildLevel2NormallyMappedDAO("1", 2, 3)
481+
assert isinstance(ch2_dao.entities[0], CustomEntityDAO)
482+
assert ch2_dao.entities == [CustomEntityDAO(overwritten_name="a")]
483+
484+
assert ch2_dao == ChildLevel2NormallyMappedDAO(
485+
derived_attribute="1",
486+
entities=[CustomEntityDAO(overwritten_name="a")],
487+
level_one_attribute=2,
488+
level_two_attribute=3,
489+
)
482490

483491

484492
def test_callable_alternative_mapping():

0 commit comments

Comments
 (0)