Skip to content

Commit 0c252f4

Browse files
authored
Merge branch 'main' into feat/storybook-ci-tests
2 parents 3387520 + 064da0f commit 0c252f4

File tree

14 files changed

+346
-19
lines changed

14 files changed

+346
-19
lines changed

.github/workflows/general.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,8 @@ jobs:
734734
# the main 'validate' job)
735735
runs-on: ${{ matrix.replicated && 'namespace-profile-tensorzero-8x16;ephemeral-storage.size-multiplier=2' || 'namespace-profile-tensorzero-8x16' }}
736736
continue-on-error: ${{ matrix.clickhouse_version.allow_failure }}
737-
# This needs to pull from Docker Hub, so skip for external PR CI
738-
if: ${{ (github.event.pull_request.head.repo.full_name == github.repository) || inputs.is_merge_group }}
737+
# This needs to pull from Docker Hub, so only run in the merge queue, or when running on a PR from the main repository
738+
if: ${{ github.event_name == 'merge_group' || (github.event.pull_request.head.repo.full_name == github.repository) }}
739739
strategy:
740740
matrix:
741741
# Only include replicated: true when running in merge queue
@@ -1185,7 +1185,11 @@ jobs:
11851185
]
11861186
runs-on: ubuntu-latest
11871187
steps:
1188+
- name: Print all job results
1189+
run: |
1190+
echo "'needs': ${{ toJson(needs) }}"
1191+
echo "github.event_name: ${{ github.event_name }}"
11881192
# When running in the merge queue, jobs should never be skipped.
11891193
# In PR CI, some jobs may be intentionally skipped (e.g. due to running from a fork, or to save money)
1190-
- if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled' || (github.event_name == 'merge_group' && contains(needs.*.result, 'skipped'))) }}
1194+
- if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || (github.event_name == 'merge_group' && contains(needs.*.result, 'skipped')) }}
11911195
run: exit 1

clients/rust/src/lib.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub use tensorzero_core::endpoints::inference::{
7171
ChatCompletionInferenceParams, InferenceOutput, InferenceParams, InferenceResponse,
7272
InferenceResponseChunk, InferenceStream,
7373
};
74+
pub use tensorzero_core::endpoints::internal::config::GetConfigResponse;
7475
pub use tensorzero_core::endpoints::object_storage::ObjectResponse;
7576
pub use tensorzero_core::endpoints::stored_inferences::v1::types::{
7677
GetInferencesRequest, GetInferencesResponse, ListInferencesRequest,
@@ -467,6 +468,24 @@ pub trait ClientExt {
467468

468469
fn get_config(&self) -> Result<Arc<Config>, TensorZeroError>;
469470

471+
/// Gets a config snapshot by hash, or the live config if no hash is provided.
472+
///
473+
/// # Arguments
474+
///
475+
/// * `hash` - Optional hash of the config snapshot to retrieve. If `None`, returns the live config.
476+
///
477+
/// # Returns
478+
///
479+
/// A `GetConfigResponse` containing the config snapshot.
480+
///
481+
/// # Errors
482+
///
483+
/// Returns a `TensorZeroError` if the request fails or the config snapshot is not found.
484+
async fn get_config_snapshot(
485+
&self,
486+
hash: Option<&str>,
487+
) -> Result<GetConfigResponse, TensorZeroError>;
488+
470489
#[cfg(any(feature = "e2e_tests", feature = "pyo3"))]
471490
fn get_app_state_data(&self) -> Option<&tensorzero_core::utils::gateway::AppStateData>;
472491
}
@@ -1323,6 +1342,60 @@ impl ClientExt for Client {
13231342
}
13241343
}
13251344

1345+
async fn get_config_snapshot(
1346+
&self,
1347+
hash: Option<&str>,
1348+
) -> Result<GetConfigResponse, TensorZeroError> {
1349+
match self.mode() {
1350+
ClientMode::HTTPGateway(client) => {
1351+
let endpoint = match hash {
1352+
Some(h) => format!("internal/config/{h}"),
1353+
None => "internal/config".to_string(),
1354+
};
1355+
let url = client
1356+
.base_url
1357+
.join(&endpoint)
1358+
.map_err(|e| TensorZeroError::Other {
1359+
source: Error::new(ErrorDetails::InvalidBaseUrl {
1360+
message: format!(
1361+
"Failed to join base URL with /{endpoint} endpoint: {e}"
1362+
),
1363+
})
1364+
.into(),
1365+
})?;
1366+
let builder = client.http_client.get(url);
1367+
Ok(client.send_and_parse_http_response(builder).await?.0)
1368+
}
1369+
ClientMode::EmbeddedGateway { gateway, timeout } => {
1370+
with_embedded_timeout(*timeout, async {
1371+
use tensorzero_core::db::ConfigQueries;
1372+
let snapshot_hash = match hash {
1373+
Some(h) => h.parse().map_err(|_| {
1374+
err_to_http(Error::new(ErrorDetails::ConfigSnapshotNotFound {
1375+
snapshot_hash: h.to_string(),
1376+
}))
1377+
})?,
1378+
None => gateway.handle.app_state.config.hash.clone(),
1379+
};
1380+
let snapshot = gateway
1381+
.handle
1382+
.app_state
1383+
.clickhouse_connection_info
1384+
.get_config_snapshot(snapshot_hash)
1385+
.await
1386+
.map_err(err_to_http)?;
1387+
Ok(GetConfigResponse {
1388+
hash: snapshot.hash.to_string(),
1389+
config: snapshot.config,
1390+
extra_templates: snapshot.extra_templates,
1391+
tags: snapshot.tags,
1392+
})
1393+
})
1394+
.await
1395+
}
1396+
}
1397+
}
1398+
13261399
async fn get_variant_sampling_probabilities(
13271400
&self,
13281401
function_name: &str,

gateway/src/routes/internal.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,13 @@ pub fn build_internal_non_otel_enabled_routes() -> Router<AppStateData> {
9999
"/internal/models/latency",
100100
get(endpoints::internal::models::get_model_latency_handler),
101101
)
102+
// Config snapshot endpoints
103+
.route(
104+
"/internal/config",
105+
get(endpoints::internal::config::get_live_config_handler),
106+
)
107+
.route(
108+
"/internal/config/{hash}",
109+
get(endpoints::internal::config::get_config_by_hash_handler),
110+
)
102111
}

tensorzero-core/src/config/snapshot.rs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ pub struct ConfigSnapshot {
7676
/// uses the compiled templates in `Config.templates`, not these raw strings.
7777
pub extra_templates: HashMap<String, String>,
7878

79+
/// User-defined tags for categorizing or labeling this config snapshot.
80+
/// Tags are metadata and do not affect the config hash.
81+
pub tags: HashMap<String, String>,
82+
7983
__private: (),
8084
}
8185

@@ -105,7 +109,6 @@ impl SnapshotHash {
105109
}
106110
}
107111

108-
#[cfg(any(test, feature = "e2e_tests"))]
109112
impl std::str::FromStr for SnapshotHash {
110113
type Err = std::convert::Infallible;
111114

@@ -137,11 +140,13 @@ impl ConfigSnapshot {
137140
message: format!("Failed to serialize stored config: {e}"),
138141
})
139142
})?);
143+
140144
let hash = ConfigSnapshot::hash(&stored_config_toml, &extra_templates)?;
141145
Ok(Self {
142146
config: stored_config,
143147
hash,
144148
extra_templates,
149+
tags: HashMap::new(),
145150
__private: (),
146151
})
147152
}
@@ -169,6 +174,7 @@ impl ConfigSnapshot {
169174
config: StoredConfig::default(),
170175
hash: SnapshotHash::new_test(),
171176
extra_templates: HashMap::new(),
177+
tags: HashMap::new(),
172178
__private: (),
173179
}
174180
}
@@ -177,9 +183,14 @@ impl ConfigSnapshot {
177183
///
178184
/// This is used when loading a previously stored config snapshot from ClickHouse.
179185
/// The hash is recomputed from the config and templates to ensure consistency.
186+
///
187+
/// Note: We deserialize as `StoredConfig` (not `UninitializedConfig`) to support
188+
/// backward compatibility with historical snapshots that may contain deprecated
189+
/// fields like `timeouts`.
180190
pub fn from_stored(
181191
config_toml: &str,
182192
extra_templates: HashMap<String, String>,
193+
tags: HashMap<String, String>,
183194
original_hash: &SnapshotHash,
184195
) -> Result<Self, Error> {
185196
let table: toml::Table = config_toml.parse().map_err(|e| {
@@ -189,19 +200,24 @@ impl ConfigSnapshot {
189200
})?;
190201

191202
let sorted_table = prepare_table_for_snapshot(table);
192-
let config = UninitializedConfig::try_from(sorted_table.clone())?;
193-
let hash = ConfigSnapshot::hash(&sorted_table, &extra_templates)?;
194-
if hash != *original_hash {
195-
return Err(Error::new(ErrorDetails::ConfigSnapshotHashMismatch {
196-
expected: original_hash.clone(),
197-
actual: hash.clone(),
198-
}));
199-
}
200203

204+
// Deserialize as StoredConfig to accept deprecated fields (e.g., `timeouts`)
205+
let stored_config: StoredConfig =
206+
serde_path_to_error::deserialize(sorted_table).map_err(|e| {
207+
let path = e.path().clone();
208+
Error::new(ErrorDetails::Config {
209+
message: format!("{}: {}", path, e.into_inner().message()),
210+
})
211+
})?;
212+
213+
// Use the original hash from the database rather than recomputing it.
214+
// Recomputing can produce different hashes due to floating-point serialization
215+
// differences (e.g., 0.2 vs 0.20000000298023224) even when the config is identical.
201216
Ok(Self {
202-
config: config.into(),
203-
hash,
217+
config: stored_config,
218+
hash: original_hash.clone(),
204219
extra_templates,
220+
tags,
205221
__private: (),
206222
})
207223
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
use super::check_column_exists;
2+
use crate::db::clickhouse::ClickHouseConnectionInfo;
3+
use crate::db::clickhouse::migration_manager::migration_trait::Migration;
4+
use crate::error::Error;
5+
use async_trait::async_trait;
6+
7+
const MIGRATION_ID: &str = "0045";
8+
9+
/// This migration adds a `tags` column to the `ConfigSnapshot` table.
10+
pub struct Migration0045<'a> {
11+
pub clickhouse: &'a ClickHouseConnectionInfo,
12+
}
13+
14+
#[async_trait]
15+
impl Migration for Migration0045<'_> {
16+
async fn can_apply(&self) -> Result<(), Error> {
17+
Ok(())
18+
}
19+
20+
async fn should_apply(&self) -> Result<bool, Error> {
21+
Ok(!check_column_exists(self.clickhouse, "ConfigSnapshot", "tags", MIGRATION_ID).await?)
22+
}
23+
24+
async fn apply(&self, _clean_start: bool) -> Result<(), Error> {
25+
let on_cluster_name = self.clickhouse.get_on_cluster_name();
26+
27+
self.clickhouse
28+
.run_query_synchronous_no_params(format!(
29+
"ALTER TABLE ConfigSnapshot{on_cluster_name} ADD COLUMN IF NOT EXISTS tags Map(String, String) DEFAULT map()"
30+
))
31+
.await?;
32+
33+
Ok(())
34+
}
35+
36+
fn rollback_instructions(&self) -> String {
37+
let on_cluster_name = self.clickhouse.get_on_cluster_name();
38+
format!("ALTER TABLE ConfigSnapshot{on_cluster_name} DROP COLUMN tags;")
39+
}
40+
41+
async fn has_succeeded(&self) -> Result<bool, Error> {
42+
Ok(check_column_exists(self.clickhouse, "ConfigSnapshot", "tags", MIGRATION_ID).await?)
43+
}
44+
}

tensorzero-core/src/db/clickhouse/migration_manager/migrations/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub mod migration_0041;
4242
pub mod migration_0042;
4343
pub mod migration_0043;
4444
pub mod migration_0044;
45+
pub mod migration_0045;
4546

4647
/// Returns true if the table exists, false if it does not
4748
/// Errors if the query fails

tensorzero-core/src/db/clickhouse/migration_manager/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ use migrations::migration_0041::Migration0041;
5252
use migrations::migration_0042::Migration0042;
5353
use migrations::migration_0043::Migration0043;
5454
use migrations::migration_0044::Migration0044;
55+
use migrations::migration_0045::Migration0045;
5556
use serde::{Deserialize, Serialize};
5657

5758
/// This must match the number of migrations returned by `make_all_migrations` - the tests
5859
/// will panic if they don't match.
59-
pub const NUM_MIGRATIONS: usize = 38;
60+
pub const NUM_MIGRATIONS: usize = 39;
6061
pub fn get_run_migrations_command() -> String {
6162
let version = env!("CARGO_PKG_VERSION");
6263
format!(
@@ -126,6 +127,7 @@ pub fn make_all_migrations<'a>(
126127
Box::new(Migration0042 { clickhouse }),
127128
Box::new(Migration0043 { clickhouse }),
128129
Box::new(Migration0044 { clickhouse }),
130+
Box::new(Migration0045 { clickhouse }),
129131
];
130132
assert_eq!(
131133
migrations.len(),

tensorzero-core/src/db/clickhouse/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,13 @@ impl ConfigQueries for ClickHouseConnectionInfo {
340340
struct ConfigSnapshotRow {
341341
config: String,
342342
extra_templates: HashMap<String, String>,
343+
#[serde(default)]
344+
tags: HashMap<String, String>,
343345
}
344346

345347
let hash_str = snapshot_hash.to_string();
346348
let query = format!(
347-
"SELECT config, extra_templates \
349+
"SELECT config, extra_templates, tags \
348350
FROM ConfigSnapshot FINAL \
349351
WHERE hash = toUInt256('{hash_str}') \
350352
LIMIT 1 \
@@ -365,7 +367,7 @@ impl ConfigQueries for ClickHouseConnectionInfo {
365367
})
366368
})?;
367369

368-
ConfigSnapshot::from_stored(&row.config, row.extra_templates, &snapshot_hash)
370+
ConfigSnapshot::from_stored(&row.config, row.extra_templates, row.tags, &snapshot_hash)
369371
}
370372
}
371373

tensorzero-core/src/db/postgres/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,6 @@ async fn get_applied_migrations(pool: &PgPool) -> Result<HashSet<i64>, sqlx::Err
259259
Ok(applied_migrations)
260260
}
261261

262-
fn make_migrator() -> sqlx::migrate::Migrator {
262+
pub fn make_migrator() -> sqlx::migrate::Migrator {
263263
migrate!("src/db/postgres/migrations")
264264
}

0 commit comments

Comments
 (0)