Skip to content

Commit b9568c4

Browse files
Support where field arg
1 parent 0de34b1 commit b9568c4

File tree

8 files changed

+225
-10
lines changed

8 files changed

+225
-10
lines changed

.github/workflows/pr_agent.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
on:
2+
pull_request:
3+
types: [opened, reopened, ready_for_review]
4+
issue_comment:
5+
jobs:
6+
pr_agent_job:
7+
if: ${{ github.event.sender.type != 'Bot' }}
8+
runs-on: ubuntu-latest
9+
permissions:
10+
issues: write
11+
pull-requests: write
12+
contents: write
13+
name: Run pr agent on every pull request, respond to user comments
14+
steps:
15+
- name: PR Agent action step
16+
id: pragent
17+
uses: qodo-ai/[email protected]
18+
with:
19+
args: '/improve --pr_code_suggestions.commitable_code_suggestions=true'
20+
env:
21+
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
22+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ venv.bak/
106106
.mypy_cache/
107107

108108
.idea/
109+
.qodo

agent-coding-standards.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Agent Coding Standards
2+
3+
# ALL CODE MUSE BE UP TO THIS STANDARD
4+
5+
## Follow Clean Code principles (Robert C. Martin) when writing code
6+
7+
- Specifically the correct order of functions
8+
9+
## No comments!
10+
11+
- Most likely there is absolutely no reason to add a comment
12+
- If there is, it's probably a sign that the code is not clear
13+
14+
## Don't use dicts.
15+
16+
- str keys are red
17+
- Use pydantic models instead
18+
19+
## Do not hallucinate
20+
21+
- If you are not sure, ask the user
22+
- If you don't know the library, do the research
23+
24+
## Do not regress when moving code
25+
26+
- Make sure quality of moved code is at least as good as the original
27+
28+
## Don't make me ask you twice
29+
30+
- Follow all these rules
31+
- Every time, all the time
32+
- If it's hard, ask me how to solve it
33+
34+
## Only essential complexity, not accidental
35+
36+
- Use the simplest approach that works
37+
- Question whether each line adds value or just complexity

mockfirestore/collection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from mockfirestore import AlreadyExists
55
from mockfirestore._helpers import generate_random_string, Store, get_by_path, set_by_path, Timestamp
6-
from mockfirestore.query import Query
6+
from mockfirestore.query import Query, FieldFilter
77
from mockfirestore.document import DocumentReference, DocumentSnapshot
88

99

@@ -41,9 +41,10 @@ def add(self, document_data: Dict, document_id: str = None) \
4141
timestamp = Timestamp.from_now()
4242
return timestamp, doc_ref
4343

44-
def where(self, field: str, op: str, value: Any) -> Query:
45-
query = Query(self, field_filters=[(field, op, value)])
46-
return query
44+
def where(self, field: str = None, op: str = None, value: Any = None, *, filter: FieldFilter = None) -> Query:
45+
if filter is not None:
46+
return Query(self, field_filters=[(filter.field, filter.op, filter.value)])
47+
return Query(self, field_filters=[(field, op, value)])
4748

4849
def order_by(self, key: str, direction: Optional[str] = None) -> Query:
4950
query = Query(self, orders=[(key, direction)])
@@ -82,4 +83,4 @@ def list_documents(self, page_size: Optional[int] = None) -> Sequence[DocumentRe
8283
def stream(self, transaction=None) -> Iterable[DocumentSnapshot]:
8384
for key in sorted(get_by_path(self._data, self._path)):
8485
doc_snapshot = self.document(key).get()
85-
yield doc_snapshot
86+
yield doc_snapshot

mockfirestore/query.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
import warnings
22
from itertools import islice, tee
33
from typing import Iterator, Any, Optional, List, Callable, Union
4+
from dataclasses import dataclass
45

56
from mockfirestore.document import DocumentSnapshot
67
from mockfirestore._helpers import T
78

9+
@dataclass
10+
class FieldFilter:
11+
field: str
12+
op: str
13+
value: Any
14+
15+
def __init__(self, field: str, op: str, value: Any):
16+
self.field = field
17+
self.op = op
18+
self.value = value
19+
820

921
class Query:
1022
def __init__(self, parent: 'CollectionReference', projection=None,
@@ -61,8 +73,11 @@ def _add_field_filter(self, field: str, op: str, value: Any):
6173
compare = self._compare_func(op)
6274
self._field_filters.append((field, compare, value))
6375

64-
def where(self, field: str, op: str, value: Any) -> 'Query':
65-
self._add_field_filter(field, op, value)
76+
def where(self, field: str = None, op: str = None, value: Any = None, *, filter: FieldFilter = None) -> 'Query':
77+
if filter is not None:
78+
self._add_field_filter(filter.field, filter.op, filter.value)
79+
else:
80+
self._add_field_filter(field, op, value)
6681
return self
6782

6883
def order_by(self, key: str, direction: Optional[str] = 'ASCENDING') -> 'Query':
@@ -136,4 +151,4 @@ def _compare_func(self, op: str) -> Callable[[T, T], bool]:
136151
elif op == 'array_contains':
137152
return lambda x, y: y in x
138153
elif op == 'array_contains_any':
139-
return lambda x, y: any([val in y for val in x])
154+
return lambda x, y: any([val in y for val in x])

requirements-dev-minimal.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-cloud-firestore
1+
google-cloud-firestore

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="mock-firestore",
8-
version="0.11.0",
8+
version="0.12.0",
99
author="Matt Dowds",
1010
description="In-memory implementation of Google Cloud Firestore for use in tests",
1111
long_description=long_description,
@@ -19,6 +19,7 @@
1919
'Programming Language :: Python :: 3.8',
2020
'Programming Language :: Python :: 3.9',
2121
'Programming Language :: Python :: 3.10',
22+
'Programming Language :: Python :: 3.11',
2223
"License :: OSI Approved :: MIT License",
2324
],
2425
)

tests/test_where_field.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from unittest import TestCase
2+
3+
from mockfirestore import MockFirestore
4+
from mockfirestore.query import FieldFilter
5+
6+
7+
class TestWhereField(TestCase):
8+
def test_collection_whereEquals(self):
9+
fs = MockFirestore()
10+
fs._data = {'foo': {
11+
'first': {'valid': True},
12+
'second': {'gumby': False}
13+
}}
14+
15+
docs = list(fs.collection('foo').where(field='valid', op='==', value=True).stream())
16+
self.assertEqual({'valid': True}, docs[0].to_dict())
17+
18+
def test_collection_whereEquals_with_filter(self):
19+
fs = MockFirestore()
20+
fs._data = {'foo': {
21+
'first': {'valid': True},
22+
'second': {'gumby': False}
23+
}}
24+
25+
docs = list(fs.collection('foo').where(filter=FieldFilter('valid', '==', True)).stream())
26+
self.assertEqual({'valid': True}, docs[0].to_dict())
27+
28+
def test_collection_whereNotEquals(self):
29+
fs = MockFirestore()
30+
fs._data = {'foo': {
31+
'first': {'count': 1},
32+
'second': {'count': 5}
33+
}}
34+
35+
docs = list(fs.collection('foo').where('count', '!=', 1).stream())
36+
self.assertEqual({'count': 5}, docs[0].to_dict())
37+
38+
def test_collection_whereLessThan(self):
39+
fs = MockFirestore()
40+
fs._data = {'foo': {
41+
'first': {'count': 1},
42+
'second': {'count': 5}
43+
}}
44+
45+
docs = list(fs.collection('foo').where('count', '<', 5).stream())
46+
self.assertEqual({'count': 1}, docs[0].to_dict())
47+
48+
def test_collection_whereLessThanOrEqual(self):
49+
fs = MockFirestore()
50+
fs._data = {'foo': {
51+
'first': {'count': 1},
52+
'second': {'count': 5}
53+
}}
54+
55+
docs = list(fs.collection('foo').where('count', '<=', 5).stream())
56+
self.assertEqual({'count': 1}, docs[0].to_dict())
57+
self.assertEqual({'count': 5}, docs[1].to_dict())
58+
59+
def test_collection_whereGreaterThan(self):
60+
fs = MockFirestore()
61+
fs._data = {'foo': {
62+
'first': {'count': 1},
63+
'second': {'count': 5}
64+
}}
65+
66+
docs = list(fs.collection('foo').where('count', '>', 1).stream())
67+
self.assertEqual({'count': 5}, docs[0].to_dict())
68+
69+
def test_collection_whereGreaterThanOrEqual(self):
70+
fs = MockFirestore()
71+
fs._data = {'foo': {
72+
'first': {'count': 1},
73+
'second': {'count': 5}
74+
}}
75+
76+
docs = list(fs.collection('foo').where('count', '>=', 1).stream())
77+
self.assertEqual({'count': 1}, docs[0].to_dict())
78+
self.assertEqual({'count': 5}, docs[1].to_dict())
79+
80+
def test_collection_whereMissingField(self):
81+
fs = MockFirestore()
82+
fs._data = {'foo': {
83+
'first': {'count': 1},
84+
'second': {'count': 5}
85+
}}
86+
87+
docs = list(fs.collection('foo').where('no_field', '==', 1).stream())
88+
self.assertEqual(len(docs), 0)
89+
90+
def test_collection_whereNestedField(self):
91+
fs = MockFirestore()
92+
fs._data = {'foo': {
93+
'first': {'nested': {'a': 1}},
94+
'second': {'nested': {'a': 2}}
95+
}}
96+
97+
docs = list(fs.collection('foo').where('nested.a', '==', 1).stream())
98+
self.assertEqual(len(docs), 1)
99+
self.assertEqual({'nested': {'a': 1}}, docs[0].to_dict())
100+
101+
def test_collection_whereIn(self):
102+
fs = MockFirestore()
103+
fs._data = {'foo': {
104+
'first': {'field': 'a1'},
105+
'second': {'field': 'a2'},
106+
'third': {'field': 'a3'},
107+
'fourth': {'field': 'a4'},
108+
}}
109+
110+
docs = list(fs.collection('foo').where('field', 'in', ['a1', 'a3']).stream())
111+
self.assertEqual(len(docs), 2)
112+
self.assertEqual({'field': 'a1'}, docs[0].to_dict())
113+
self.assertEqual({'field': 'a3'}, docs[1].to_dict())
114+
115+
def test_collection_whereArrayContains(self):
116+
fs = MockFirestore()
117+
fs._data = {'foo': {
118+
'first': {'field': ['val4']},
119+
'second': {'field': ['val3', 'val2']},
120+
'third': {'field': ['val3', 'val2', 'val1']}
121+
}}
122+
123+
docs = list(fs.collection('foo').where('field', 'array_contains', 'val1').stream())
124+
self.assertEqual(len(docs), 1)
125+
self.assertEqual(docs[0].to_dict(), {'field': ['val3', 'val2', 'val1']})
126+
127+
def test_collection_whereArrayContainsAny(self):
128+
fs = MockFirestore()
129+
fs._data = {'foo': {
130+
'first': {'field': ['val4']},
131+
'second': {'field': ['val3', 'val2']},
132+
'third': {'field': ['val3', 'val2', 'val1']}
133+
}}
134+
135+
contains_any_docs = list(fs.collection('foo').where('field', 'array_contains_any', ['val1', 'val4']).stream())
136+
self.assertEqual(len(contains_any_docs), 2)
137+
self.assertEqual({'field': ['val4']}, contains_any_docs[0].to_dict())
138+
self.assertEqual({'field': ['val3', 'val2', 'val1']}, contains_any_docs[1].to_dict())

0 commit comments

Comments
 (0)