Skip to content

Commit dc80b6e

Browse files
fix(patch): cherry-pick 37f128a to release/v0.28.0-pr-18478 (#18821)
Co-authored-by: matt korwel <matt.korwel@gmail.com>
1 parent 4fdc047 commit dc80b6e

File tree

6 files changed

+117
-22
lines changed

6 files changed

+117
-22
lines changed

packages/core/src/config/models.test.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest';
88
import {
99
resolveModel,
1010
resolveClassifierModel,
11+
isGemini3Model,
1112
isGemini2Model,
1213
isAutoModel,
1314
getDisplayString,
@@ -25,6 +26,29 @@ import {
2526
DEFAULT_GEMINI_MODEL_AUTO,
2627
} from './models.js';
2728

29+
describe('isGemini3Model', () => {
30+
it('should return true for gemini-3 models', () => {
31+
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
32+
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
33+
});
34+
35+
it('should return true for aliases that resolve to Gemini 3 when preview is enabled', () => {
36+
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO, true)).toBe(true);
37+
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO, true)).toBe(true);
38+
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
39+
});
40+
41+
it('should return false for Gemini 2 models', () => {
42+
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
43+
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
44+
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
45+
});
46+
47+
it('should return false for arbitrary strings', () => {
48+
expect(isGemini3Model('gpt-4')).toBe(false);
49+
});
50+
});
51+
2852
describe('getDisplayString', () => {
2953
it('should return Auto (Gemini 3) for preview auto model', () => {
3054
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');

packages/core/src/config/models.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ export function isPreviewModel(model: string): boolean {
137137
);
138138
}
139139

140+
/**
141+
* Checks if the model is a Gemini 3 model.
142+
*
143+
* @param model The model name to check.
144+
* @returns True if the model is a Gemini 3 model.
145+
*/
146+
export function isGemini3Model(
147+
model: string,
148+
previewFeaturesEnabled = false,
149+
): boolean {
150+
const resolved = resolveModel(model, previewFeaturesEnabled);
151+
return /^gemini-3(\.|-|$)/.test(resolved);
152+
}
153+
140154
/**
141155
* Checks if the model is a Gemini 2.x model.
142156
*

packages/core/src/routing/strategies/classifierStrategy.test.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
DEFAULT_GEMINI_FLASH_MODEL,
1818
DEFAULT_GEMINI_MODEL,
1919
DEFAULT_GEMINI_MODEL_AUTO,
20+
PREVIEW_GEMINI_MODEL_AUTO,
2021
} from '../../config/models.js';
2122
import { promptIdContext } from '../../utils/promptIdContext.js';
2223
import type { Content } from '@google/genai';
@@ -50,8 +51,8 @@ describe('ClassifierStrategy', () => {
5051
modelConfigService: {
5152
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
5253
},
53-
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
54-
getPreviewFeatures: () => false,
54+
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
55+
getPreviewFeatures: vi.fn().mockReturnValue(false),
5556
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
5657
} as unknown as Config;
5758
mockBaseLlmClient = {
@@ -61,8 +62,9 @@ describe('ClassifierStrategy', () => {
6162
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
6263
});
6364

64-
it('should return null if numerical routing is enabled', async () => {
65+
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
6566
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
67+
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
6668

6769
const decision = await strategy.route(
6870
mockContext,
@@ -74,6 +76,24 @@ describe('ClassifierStrategy', () => {
7476
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
7577
});
7678

79+
it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
80+
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
81+
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
82+
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
83+
reasoning: 'test',
84+
model_choice: 'flash',
85+
});
86+
87+
const decision = await strategy.route(
88+
mockContext,
89+
mockConfig,
90+
mockBaseLlmClient,
91+
);
92+
93+
expect(decision).not.toBeNull();
94+
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
95+
});
96+
7797
it('should call generateJson with the correct parameters', async () => {
7898
const mockApiResponse = {
7999
reasoning: 'Simple task',

packages/core/src/routing/strategies/classifierStrategy.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import type {
1212
RoutingDecision,
1313
RoutingStrategy,
1414
} from '../routingStrategy.js';
15-
import { resolveClassifierModel } from '../../config/models.js';
15+
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
1616
import { createUserContent, Type } from '@google/genai';
1717
import type { Config } from '../../config/config.js';
1818
import {
@@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy {
133133
): Promise<RoutingDecision | null> {
134134
const startTime = Date.now();
135135
try {
136-
if (await config.getNumericalRoutingEnabled()) {
136+
const model = context.requestedModel ?? config.getModel();
137+
if (
138+
(await config.getNumericalRoutingEnabled()) &&
139+
isGemini3Model(model, config.getPreviewFeatures())
140+
) {
137141
return null;
138142
}
139143

@@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
164168
const reasoning = routerResponse.reasoning;
165169
const latencyMs = Date.now() - startTime;
166170
const selectedModel = resolveClassifierModel(
167-
context.requestedModel ?? config.getModel(),
171+
model,
168172
routerResponse.model_choice,
169173
config.getPreviewFeatures(),
170174
);

packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js';
1010
import type { Config } from '../../config/config.js';
1111
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
1212
import {
13-
DEFAULT_GEMINI_FLASH_MODEL,
14-
DEFAULT_GEMINI_MODEL,
13+
PREVIEW_GEMINI_FLASH_MODEL,
14+
PREVIEW_GEMINI_MODEL,
1515
DEFAULT_GEMINI_MODEL_AUTO,
16+
DEFAULT_GEMINI_MODEL,
17+
PREVIEW_GEMINI_MODEL_AUTO,
1618
} from '../../config/models.js';
1719
import { promptIdContext } from '../../utils/promptIdContext.js';
1820
import type { Content } from '@google/genai';
@@ -46,8 +48,8 @@ describe('NumericalClassifierStrategy', () => {
4648
modelConfigService: {
4749
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
4850
},
49-
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
50-
getPreviewFeatures: () => false,
51+
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
52+
getPreviewFeatures: vi.fn().mockReturnValue(false),
5153
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
5254
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
5355
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
@@ -76,6 +78,32 @@ describe('NumericalClassifierStrategy', () => {
7678
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
7779
});
7880

81+
it('should return null if the model is not a Gemini 3 model', async () => {
82+
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
83+
84+
const decision = await strategy.route(
85+
mockContext,
86+
mockConfig,
87+
mockBaseLlmClient,
88+
);
89+
90+
expect(decision).toBeNull();
91+
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
92+
});
93+
94+
it('should return null if the model is explicitly a Gemini 2 model', async () => {
95+
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
96+
97+
const decision = await strategy.route(
98+
mockContext,
99+
mockConfig,
100+
mockBaseLlmClient,
101+
);
102+
103+
expect(decision).toBeNull();
104+
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
105+
});
106+
79107
it('should call generateJson with the correct parameters and wrapped user content', async () => {
80108
const mockApiResponse = {
81109
complexity_reasoning: 'Simple task',
@@ -120,7 +148,7 @@ describe('NumericalClassifierStrategy', () => {
120148
);
121149

122150
expect(decision).toEqual({
123-
model: DEFAULT_GEMINI_FLASH_MODEL,
151+
model: PREVIEW_GEMINI_FLASH_MODEL,
124152
metadata: {
125153
source: 'NumericalClassifier (Control)',
126154
latencyMs: expect.any(Number),
@@ -146,7 +174,7 @@ describe('NumericalClassifierStrategy', () => {
146174
);
147175

148176
expect(decision).toEqual({
149-
model: DEFAULT_GEMINI_MODEL,
177+
model: PREVIEW_GEMINI_MODEL,
150178
metadata: {
151179
source: 'NumericalClassifier (Control)',
152180
latencyMs: expect.any(Number),
@@ -172,7 +200,7 @@ describe('NumericalClassifierStrategy', () => {
172200
);
173201

174202
expect(decision).toEqual({
175-
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
203+
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
176204
metadata: {
177205
source: 'NumericalClassifier (Strict)',
178206
latencyMs: expect.any(Number),
@@ -198,7 +226,7 @@ describe('NumericalClassifierStrategy', () => {
198226
);
199227

200228
expect(decision).toEqual({
201-
model: DEFAULT_GEMINI_MODEL,
229+
model: PREVIEW_GEMINI_MODEL,
202230
metadata: {
203231
source: 'NumericalClassifier (Strict)',
204232
latencyMs: expect.any(Number),
@@ -226,7 +254,7 @@ describe('NumericalClassifierStrategy', () => {
226254
);
227255

228256
expect(decision).toEqual({
229-
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
257+
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
230258
metadata: {
231259
source: 'NumericalClassifier (Remote)',
232260
latencyMs: expect.any(Number),
@@ -252,7 +280,7 @@ describe('NumericalClassifierStrategy', () => {
252280
);
253281

254282
expect(decision).toEqual({
255-
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
283+
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
256284
metadata: {
257285
source: 'NumericalClassifier (Remote)',
258286
latencyMs: expect.any(Number),
@@ -278,7 +306,7 @@ describe('NumericalClassifierStrategy', () => {
278306
);
279307

280308
expect(decision).toEqual({
281-
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
309+
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
282310
metadata: {
283311
source: 'NumericalClassifier (Remote)',
284312
latencyMs: expect.any(Number),
@@ -306,7 +334,7 @@ describe('NumericalClassifierStrategy', () => {
306334
);
307335

308336
expect(decision).toEqual({
309-
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
337+
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
310338
metadata: {
311339
source: 'NumericalClassifier (Control)',
312340
latencyMs: expect.any(Number),
@@ -333,7 +361,7 @@ describe('NumericalClassifierStrategy', () => {
333361
);
334362

335363
expect(decision).toEqual({
336-
model: DEFAULT_GEMINI_FLASH_MODEL,
364+
model: PREVIEW_GEMINI_FLASH_MODEL,
337365
metadata: {
338366
source: 'NumericalClassifier (Control)',
339367
latencyMs: expect.any(Number),
@@ -360,7 +388,7 @@ describe('NumericalClassifierStrategy', () => {
360388
);
361389

362390
expect(decision).toEqual({
363-
model: DEFAULT_GEMINI_MODEL,
391+
model: PREVIEW_GEMINI_MODEL,
364392
metadata: {
365393
source: 'NumericalClassifier (Control)',
366394
latencyMs: expect.any(Number),

packages/core/src/routing/strategies/numericalClassifierStrategy.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import type {
1212
RoutingDecision,
1313
RoutingStrategy,
1414
} from '../routingStrategy.js';
15-
import { resolveClassifierModel } from '../../config/models.js';
15+
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
1616
import { createUserContent, Type } from '@google/genai';
1717
import type { Config } from '../../config/config.js';
1818
import { debugLogger } from '../../utils/debugLogger.js';
@@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
134134
): Promise<RoutingDecision | null> {
135135
const startTime = Date.now();
136136
try {
137+
const model = context.requestedModel ?? config.getModel();
137138
if (!(await config.getNumericalRoutingEnabled())) {
138139
return null;
139140
}
140141

142+
if (!isGemini3Model(model, config.getPreviewFeatures())) {
143+
return null;
144+
}
145+
141146
const promptId = getPromptIdWithFallback('classifier-router');
142147

143148
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
@@ -177,7 +182,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
177182
);
178183

179184
const selectedModel = resolveClassifierModel(
180-
config.getModel(),
185+
model,
181186
modelAlias,
182187
config.getPreviewFeatures(),
183188
);

0 commit comments

Comments
 (0)