Skip to content

Commit 0d17f1d

Browse files
committed
PolicyCache
1 parent d6d790e commit 0d17f1d

File tree

2 files changed

+294
-1
lines changed

2 files changed

+294
-1
lines changed

stytch/src/main/kotlin/com/stytch/java/common/PolicyCache.kt

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package com.stytch.java.common
22

33
import com.stytch.java.b2b.api.rbac.RBAC
4+
import com.stytch.java.b2b.api.rbacorganizations.Organizations
5+
import com.stytch.java.b2b.models.rbac.OrgPolicy
46
import com.stytch.java.b2b.models.rbac.Policy
57
import com.stytch.java.b2b.models.rbac.PolicyRequest
8+
import com.stytch.java.b2b.models.rbac.PolicyRole
9+
import com.stytch.java.b2b.models.rbacorganizations.GetOrgPolicyRequest
10+
import com.stytch.java.b2b.models.rbacorganizations.GetOrgPolicyResponse
611
import com.stytch.java.b2b.models.sessions.AuthorizationCheck
712
import kotlinx.coroutines.CoroutineScope
813
import kotlinx.coroutines.Job
@@ -22,13 +27,20 @@ public class PermissionException(
2227
authorizationCheck: AuthorizationCheck,
2328
) : RuntimeException("Permission denied for request $authorizationCheck")
2429

30+
private data class CachedOrgPolicy(
31+
val orgPolicy: OrgPolicy,
32+
val lastUpdate: Instant,
33+
)
34+
2535
internal class PolicyCache(
2636
private val client: RBAC,
2737
coroutineScope: CoroutineScope,
38+
private val organizations: Organizations? = null,
2839
) {
2940
private val job = SupervisorJob(coroutineScope.coroutineContext[Job])
3041
private val scope = CoroutineScope(coroutineScope.coroutineContext + job)
3142
private var cachedPolicy: Policy? = null
43+
private val cachedOrgPolicies: MutableMap<String, CachedOrgPolicy> = mutableMapOf()
3244
private var policyLastUpdate: Instant? = null
3345
private var backgroundRefreshStarted = false
3446

@@ -53,6 +65,34 @@ internal class PolicyCache(
5365
return cachedPolicy ?: throw Exception("Error fetching the policy")
5466
}
5567

68+
private fun getOrgPolicy(
69+
orgId: String,
70+
invalidate: Boolean = false,
71+
): OrgPolicy? {
72+
val cached = cachedOrgPolicies[orgId]
73+
val isMissing = cached == null
74+
val isStale = cached != null && Duration.between(cached.lastUpdate, Instant.now()).seconds > CACHE_TTL_SECONDS
75+
76+
if (invalidate || isMissing || isStale) {
77+
refreshOrgPolicy(orgId)
78+
}
79+
80+
return cachedOrgPolicies[orgId]?.orgPolicy
81+
}
82+
83+
private fun refreshOrgPolicy(orgId: String) {
84+
val orgs = organizations ?: client.organizations
85+
when (val result = orgs.getOrgPolicyCompletable(GetOrgPolicyRequest(orgId)).get()) {
86+
is StytchResult.Success<GetOrgPolicyResponse> -> {
87+
result.value.orgPolicy?.let { orgPolicy ->
88+
cachedOrgPolicies[orgId] = CachedOrgPolicy(orgPolicy, Instant.now())
89+
}
90+
}
91+
92+
else -> {}
93+
}
94+
}
95+
5696
private fun refreshPolicy() {
5797
when (val result = client.policyCompletable(PolicyRequest()).get()) {
5898
is StytchResult.Success -> {
@@ -69,6 +109,10 @@ internal class PolicyCache(
69109
while (isActive) {
70110
delay(REFRESH_INTERVAL_MS)
71111
refreshPolicy()
112+
// Refresh all cached org policies
113+
cachedOrgPolicies.keys.toList().forEach { orgId ->
114+
refreshOrgPolicy(orgId)
115+
}
72116
}
73117
}
74118
}
@@ -90,8 +134,17 @@ internal class PolicyCache(
90134
throw TenancyException(subjectOrgId, authorizationCheck.organizationId)
91135
}
92136
val policy = getPolicy()
137+
val orgPolicy = getOrgPolicy(subjectOrgId)
138+
139+
// Combine roles from both global policy and org-specific policy
140+
val allRoles: List<PolicyRole> =
141+
buildList {
142+
addAll(policy.roles)
143+
orgPolicy?.roles?.let { addAll(it) }
144+
}
145+
93146
val hasMatchingActionAndResource =
94-
policy.roles
147+
allRoles
95148
.filter { it.roleId in subjectRoles }
96149
.flatMap { it.permissions }
97150
.filter {

stytch/src/test/kotlin/com/stytch/java/common/PolicyCacheTest.kt

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package com.stytch.java.common
22

33
import com.stytch.java.b2b.api.rbac.RBAC
4+
import com.stytch.java.b2b.api.rbacorganizations.Organizations
5+
import com.stytch.java.b2b.models.rbac.OrgPolicy
46
import com.stytch.java.b2b.models.rbac.Policy
57
import com.stytch.java.b2b.models.rbac.PolicyResource
68
import com.stytch.java.b2b.models.rbac.PolicyResponse
79
import com.stytch.java.b2b.models.rbac.PolicyRole
810
import com.stytch.java.b2b.models.rbac.PolicyRolePermission
911
import com.stytch.java.b2b.models.rbac.PolicyScope
1012
import com.stytch.java.b2b.models.rbac.PolicyScopePermission
13+
import com.stytch.java.b2b.models.rbacorganizations.GetOrgPolicyResponse
1114
import com.stytch.java.b2b.models.sessions.AuthorizationCheck
1215
import io.mockk.every
1316
import io.mockk.mockk
@@ -109,6 +112,35 @@ private val policy =
109112
),
110113
)
111114

115+
private val orgPolicy =
116+
OrgPolicy(
117+
roles =
118+
listOf(
119+
PolicyRole(
120+
roleId = "org_admin",
121+
description = "Organization-specific admin",
122+
permissions =
123+
listOf(
124+
PolicyRolePermission(
125+
resourceId = "baz",
126+
actions = listOf("*"),
127+
),
128+
),
129+
),
130+
PolicyRole(
131+
roleId = "org_reader",
132+
description = "Organization-specific reader",
133+
permissions =
134+
listOf(
135+
PolicyRolePermission(
136+
resourceId = "baz",
137+
actions = listOf("read"),
138+
),
139+
),
140+
),
141+
),
142+
)
143+
112144
internal class PolicyCacheTest {
113145
private lateinit var rbac: RBAC
114146
private val testScope = CoroutineScope(Dispatchers.Unconfined)
@@ -125,6 +157,14 @@ internal class PolicyCacheTest {
125157
policy = policy,
126158
),
127159
)
160+
every { organizations.getOrgPolicyCompletable(any()).get() } returns
161+
StytchResult.Success(
162+
GetOrgPolicyResponse(
163+
statusCode = 200,
164+
requestId = "",
165+
orgPolicy = orgPolicy,
166+
),
167+
)
128168
}
129169
}
130170

@@ -312,4 +352,204 @@ internal class PolicyCacheTest {
312352
// Cancel the background refresh job
313353
policyCache.cancelBackgroundRefresh()
314354
}
355+
356+
@Test
357+
fun `succeeds when subject has matching org policy role`() {
358+
val policyCache = PolicyCache(rbac, testScope)
359+
policyCache.performAuthorizationCheck(
360+
subjectRoles = listOf("org_admin"),
361+
subjectOrgId = "my-org",
362+
authorizationCheck =
363+
AuthorizationCheck(
364+
organizationId = "my-org",
365+
resourceId = "baz",
366+
action = "write",
367+
),
368+
)
369+
}
370+
371+
@Test
372+
fun `succeeds when subject has org-specific role with read permission`() {
373+
val policyCache = PolicyCache(rbac, testScope)
374+
policyCache.performAuthorizationCheck(
375+
subjectRoles = listOf("org_reader"),
376+
subjectOrgId = "my-org",
377+
authorizationCheck =
378+
AuthorizationCheck(
379+
organizationId = "my-org",
380+
resourceId = "baz",
381+
action = "read",
382+
),
383+
)
384+
}
385+
386+
@Test(expected = PermissionException::class)
387+
fun `throws PermissionException when org role does not have matching action`() {
388+
val policyCache = PolicyCache(rbac, testScope)
389+
policyCache.performAuthorizationCheck(
390+
subjectRoles = listOf("org_reader"),
391+
subjectOrgId = "my-org",
392+
authorizationCheck =
393+
AuthorizationCheck(
394+
organizationId = "my-org",
395+
resourceId = "baz",
396+
action = "write",
397+
),
398+
)
399+
}
400+
401+
@Test
402+
fun `fetches org policy on first authorization check for an org`() {
403+
val rbacOrgMock =
404+
mockk<Organizations>(relaxed = true, relaxUnitFun = true) {
405+
every { getOrgPolicyCompletable(any()).get() } returns
406+
StytchResult.Success(
407+
GetOrgPolicyResponse(
408+
statusCode = 200,
409+
requestId = "",
410+
orgPolicy = orgPolicy,
411+
),
412+
)
413+
}
414+
415+
val policyCache = PolicyCache(rbac, testScope, rbacOrgMock)
416+
417+
// First call should fetch the org policy
418+
policyCache.performAuthorizationCheck(
419+
subjectRoles = listOf("org_admin"),
420+
subjectOrgId = "my-org",
421+
authorizationCheck =
422+
AuthorizationCheck(
423+
organizationId = "my-org",
424+
resourceId = "baz",
425+
action = "read",
426+
),
427+
)
428+
429+
verify(exactly = 1) { rbacOrgMock.getOrgPolicyCompletable(any()) }
430+
}
431+
432+
@Test
433+
fun `uses cached org policy on subsequent authorization checks for same org`() {
434+
val callCount = AtomicInteger(0)
435+
val rbacOrgMock =
436+
mockk<Organizations>(relaxed = true, relaxUnitFun = true) {
437+
every { getOrgPolicyCompletable(any()).get() } answers {
438+
callCount.incrementAndGet()
439+
StytchResult.Success(
440+
GetOrgPolicyResponse(
441+
statusCode = 200,
442+
requestId = "",
443+
orgPolicy = orgPolicy,
444+
),
445+
)
446+
}
447+
}
448+
449+
val policyCache = PolicyCache(rbac, testScope, rbacOrgMock)
450+
451+
// First call fetches
452+
policyCache.performAuthorizationCheck(
453+
subjectRoles = listOf("org_admin"),
454+
subjectOrgId = "my-org",
455+
authorizationCheck =
456+
AuthorizationCheck(
457+
organizationId = "my-org",
458+
resourceId = "baz",
459+
action = "read",
460+
),
461+
)
462+
463+
// Second call should use cache
464+
policyCache.performAuthorizationCheck(
465+
subjectRoles = listOf("org_admin"),
466+
subjectOrgId = "my-org",
467+
authorizationCheck =
468+
AuthorizationCheck(
469+
organizationId = "my-org",
470+
resourceId = "baz",
471+
action = "write",
472+
),
473+
)
474+
475+
// Should only have called the API once (second call used cache)
476+
assertEquals(1, callCount.get())
477+
}
478+
479+
@Test
480+
fun `fetches separate org policy for different organizations`() {
481+
val callCount = AtomicInteger(0)
482+
val rbacOrgMock =
483+
mockk<Organizations>(relaxed = true, relaxUnitFun = true) {
484+
every { getOrgPolicyCompletable(any()).get() } answers {
485+
callCount.incrementAndGet()
486+
StytchResult.Success(
487+
GetOrgPolicyResponse(
488+
statusCode = 200,
489+
requestId = "",
490+
orgPolicy = orgPolicy,
491+
),
492+
)
493+
}
494+
}
495+
496+
val policyCache = PolicyCache(rbac, testScope, rbacOrgMock)
497+
498+
// First call for org1
499+
policyCache.performAuthorizationCheck(
500+
subjectRoles = listOf("org_admin"),
501+
subjectOrgId = "org1",
502+
authorizationCheck =
503+
AuthorizationCheck(
504+
organizationId = "org1",
505+
resourceId = "baz",
506+
action = "read",
507+
),
508+
)
509+
510+
// Second call for org2 - should fetch again
511+
policyCache.performAuthorizationCheck(
512+
subjectRoles = listOf("org_admin"),
513+
subjectOrgId = "org2",
514+
authorizationCheck =
515+
AuthorizationCheck(
516+
organizationId = "org2",
517+
resourceId = "baz",
518+
action = "read",
519+
),
520+
)
521+
522+
// Should have called the API twice (once per org)
523+
assertEquals(2, callCount.get())
524+
}
525+
526+
@Test
527+
fun `succeeds when combining global and org policy permissions`() {
528+
// Subject has global_reader (can read foo) and org_admin (can do anything on baz)
529+
val policyCache = PolicyCache(rbac, testScope)
530+
531+
// Should succeed using global policy
532+
policyCache.performAuthorizationCheck(
533+
subjectRoles = listOf("global_reader", "org_admin"),
534+
subjectOrgId = "my-org",
535+
authorizationCheck =
536+
AuthorizationCheck(
537+
organizationId = "my-org",
538+
resourceId = "foo",
539+
action = "read",
540+
),
541+
)
542+
543+
// Should succeed using org policy
544+
policyCache.performAuthorizationCheck(
545+
subjectRoles = listOf("global_reader", "org_admin"),
546+
subjectOrgId = "my-org",
547+
authorizationCheck =
548+
AuthorizationCheck(
549+
organizationId = "my-org",
550+
resourceId = "baz",
551+
action = "delete",
552+
),
553+
)
554+
}
315555
}

0 commit comments

Comments
 (0)