Skip to content

Commit 799628b

Browse files
authored
Add tests for Spherical Quantizer Type. (Azure#44502)
* Add tests for `spherical` quantizerType. * add cleanup container logic in each tests
1 parent de0f681 commit 799628b

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

sdk/cosmos/azure-cosmos/tests/test_vector_policy.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,47 @@
1010
import test_config
1111
from azure.cosmos import CosmosClient, PartitionKey
1212

13+
VectorPolicyTestData = {
14+
"valid_vector_indexing_policy" : {
15+
"indexing_policy": {
16+
"vectorIndexes": [
17+
{"path": "/vector1", "type": "flat"},
18+
{"path": "/vector2", "type": "quantizedFlat", "quantizerType": "product", "quantizationByteSize": 8},
19+
{"path": "/vector3", "type": "diskANN", "quantizerType": "product", "quantizationByteSize": 8,
20+
"vectorIndexShardKey": ["/city"], "indexingSearchListSize": 50},
21+
{"path": "/vector4", "type": "diskANN", "quantizerType": "spherical", "indexingSearchListSize": 50},
22+
]
23+
},
24+
"vector_embedding_policy": {
25+
"vectorEmbeddings": [
26+
{
27+
"path": "/vector1",
28+
"dataType": "float32",
29+
"dimensions": 256,
30+
"distanceFunction": "euclidean"
31+
},
32+
{
33+
"path": "/vector2",
34+
"dataType": "int8",
35+
"dimensions": 200,
36+
"distanceFunction": "dotproduct"
37+
},
38+
{
39+
"path": "/vector3",
40+
"dataType": "uint8",
41+
"dimensions": 400,
42+
"distanceFunction": "cosine"
43+
},
44+
{
45+
"path": "/vector4",
46+
"dataType": "uint8",
47+
"dimensions": 400,
48+
"distanceFunction": "euclidean"
49+
},
50+
]
51+
}
52+
}
53+
}
1354

1455
@pytest.mark.cosmosSearchQuery
1556
class TestVectorPolicy(unittest.TestCase):
@@ -55,6 +96,21 @@ def test_create_valid_vector_embedding_policy(self):
5596
assert properties["vectorEmbeddingPolicy"]["vectorEmbeddings"][0]["dataType"] == data_type
5697
self.test_db.delete_container('vector_container_' + data_type)
5798

99+
@unittest.skip
100+
def test_create_valid_vector_indexing_policy(self):
101+
test_data = VectorPolicyTestData["valid_vector_indexing_policy"]
102+
indexing_policy = test_data["indexing_policy"]
103+
vector_embedding_policy = test_data["vector_embedding_policy"]
104+
105+
created_container = self.test_db.create_container(
106+
id="container_" + str(uuid.uuid4()),
107+
partition_key=PartitionKey(path="/id"),
108+
vector_embedding_policy=vector_embedding_policy,
109+
indexing_policy=indexing_policy)
110+
properties = created_container.read()
111+
assert properties['indexingPolicy']['vectorIndexes'] == indexing_policy['vectorIndexes']
112+
self.test_db.delete_container(created_container.id)
113+
58114
def test_create_vector_embedding_container(self):
59115
indexing_policy = {
60116
"vectorIndexes": [

sdk/cosmos/azure-cosmos/tests/test_vector_policy_async.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from azure.cosmos import CosmosClient as CosmosSyncClient
1212
from azure.cosmos import PartitionKey
1313
from azure.cosmos.aio import CosmosClient
14-
14+
from test_vector_policy import VectorPolicyTestData
1515

1616
@pytest.mark.cosmosSearchQuery
1717
class TestVectorPolicyAsync(unittest.IsolatedAsyncioTestCase):
@@ -46,6 +46,21 @@ async def asyncSetUp(self):
4646
async def asyncTearDown(self):
4747
await self.client.close()
4848

49+
@unittest.skip
50+
async def test_create_valid_vector_indexing_policy_async(self):
51+
test_data = VectorPolicyTestData["valid_vector_indexing_policy"]
52+
indexing_policy = test_data["indexing_policy"]
53+
vector_embedding_policy = test_data["vector_embedding_policy"]
54+
55+
created_container = await self.test_db.create_container(
56+
id="container_" + str(uuid.uuid4()),
57+
partition_key=PartitionKey(path="/id"),
58+
vector_embedding_policy=vector_embedding_policy,
59+
indexing_policy=indexing_policy)
60+
properties = await created_container.read()
61+
assert properties['indexingPolicy']['vectorIndexes'] == indexing_policy['vectorIndexes']
62+
await self.test_db.delete_container(created_container.id)
63+
4964
async def test_create_valid_vector_embedding_policy_async(self):
5065
# Using valid data types
5166
data_types = ["float32", "float16", "int8", "uint8"]

0 commit comments

Comments
 (0)