From 532357507e7d377297cd639d66f12a5fe698c4c5 Mon Sep 17 00:00:00 2001 From: bplunkett-stripe Date: Tue, 6 Jan 2026 23:22:40 -0800 Subject: [PATCH] Table priv support --- README.md | 1 - cmd/pg-schema-diff/apply_cmd.go | 2 +- .../privilege_cases_test.go | 230 ++++++++++++++++++ internal/queries/queries.sql | 42 ++++ internal/queries/queries.sql.go | 80 ++++++ internal/schema/schema.go | 77 +++++- internal/schema/schema_test.go | 12 +- pkg/diff/plan.go | 4 + pkg/diff/plan_generator.go | 24 +- pkg/diff/privilege_sql_generator.go | 92 +++++++ pkg/diff/sql_generator.go | 44 ++++ scripts/lint/multiline_sql_strings_lint.go | 6 +- 12 files changed, 607 insertions(+), 7 deletions(-) create mode 100644 internal/migration_acceptance_tests/privilege_cases_test.go create mode 100644 pkg/diff/privilege_sql_generator.go diff --git a/README.md b/README.md index 8f136d8..67a8aba 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,6 @@ Unsupported: <= 13 are not supported. Use at your own risk. # Unsupported migrations An abridged list of unsupported migrations: -- Privileges (Planned) - Types (Only enums are currently supported) - Renaming. The diffing library relies on names to identify the old and new versions of a table, index, etc. If you rename an object, it will be treated as a drop and an add diff --git a/cmd/pg-schema-diff/apply_cmd.go b/cmd/pg-schema-diff/apply_cmd.go index 4d8876b..c7cbbc8 100644 --- a/cmd/pg-schema-diff/apply_cmd.go +++ b/cmd/pg-schema-diff/apply_cmd.go @@ -150,7 +150,7 @@ func runPlan(ctx context.Context, cmd *cobra.Command, connConfig *pgx.ConnConfig return fmt.Errorf("setting lock timeout: %w", err) } if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil { - return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt, err) + return fmt.Errorf("executing migration statement. the database maybe be in a dirty state: %s: %w", stmt.DDL, err) } cmdPrintf(cmd, "Finished executing statement. Duration: %s\n", time.Since(start)) } diff --git a/internal/migration_acceptance_tests/privilege_cases_test.go b/internal/migration_acceptance_tests/privilege_cases_test.go new file mode 100644 index 0000000..27bc3f8 --- /dev/null +++ b/internal/migration_acceptance_tests/privilege_cases_test.go @@ -0,0 +1,230 @@ +package migration_acceptance_tests + +import ( + "testing" + + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var privilegeAcceptanceTestCases = []acceptanceTestCase{ + { + name: "no-op", + roles: []string{ + "app_user", + }, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + expectEmptyPlan: true, + }, + { + name: "Grant multiple privileges to role", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + `CREATE TABLE foobar(id INT);`, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT, INSERT, UPDATE, DELETE ON foobar TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Revoke privilege from role", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + newSchemaDDL: []string{ + `CREATE TABLE foobar(id INT);`, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant WITH GRANT OPTION", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + `CREATE TABLE foobar(id INT);`, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user WITH GRANT OPTION; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Change GRANT OPTION (recreates privilege)", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user WITH GRANT OPTION; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Remove GRANT OPTION (recreates privilege)", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user WITH GRANT OPTION; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant on new table (no hazards since table is new)", + roles: []string{"app_user"}, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + // No hazards expected since table is brand new + }, + { + name: "Drop table with privileges (only DeletesData hazard)", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + GRANT SELECT ON foobar TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Grant on non-public schema table", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE SCHEMA app_schema; + CREATE TABLE app_schema.foobar(id INT); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA app_schema; + CREATE TABLE app_schema.foobar(id INT); + GRANT SELECT ON app_schema.foobar TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant on partitioned parent table", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category_1'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category_1'); + GRANT SELECT ON foobar TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Privilege on new partition (not implemented)", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + GRANT SELECT ON foobar_1 TO app_user; + `, + }, + expectedPlanErrorIs: diff.ErrNotImplemented, + }, + { + name: "Add privilege on existing partition (not implemented)", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + GRANT SELECT ON foobar_1 TO app_user; + `, + }, + expectedPlanErrorIs: diff.ErrNotImplemented, + }, +} + +func TestPrivilegeCases(t *testing.T) { + runTestCases(t, privilegeAcceptanceTestCases) +} diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 9316e07..d8397c8 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -624,3 +624,45 @@ WHERE AND depend.objid = c.oid AND depend.deptype = 'e' ); + +-- name: GetTablePrivileges :many +WITH parsed_acl AS ( + SELECT + c.oid AS table_oid, + c.relname AS table_name, + n.nspname AS table_schema_name, + c.relowner AS owner_oid, + (ACLEXPLODE(c.relacl)).grantee AS grantee_oid, + (ACLEXPLODE(c.relacl)).privilege_type AS privilege_type, + (ACLEXPLODE(c.relacl)).is_grantable AS is_grantable + FROM pg_catalog.pg_class AS c + INNER JOIN pg_catalog.pg_namespace AS n ON c.relnamespace = n.oid + WHERE + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname !~ '^pg_toast' + AND n.nspname !~ '^pg_temp' + AND (c.relkind = 'r' OR c.relkind = 'p') + AND c.relacl IS NOT null + -- Exclude tables owned by extensions + AND NOT EXISTS ( + SELECT depend.objid + FROM pg_catalog.pg_depend AS depend + WHERE + depend.classid = 'pg_class'::REGCLASS + AND depend.objid = c.oid + AND depend.deptype = 'e' + ) +) + +SELECT + pa.table_name::TEXT, + pa.table_schema_name::TEXT, + COALESCE(grantee_role.rolname, '')::TEXT AS grantee, + pa.privilege_type::TEXT AS privilege, + pa.is_grantable +FROM parsed_acl AS pa +LEFT JOIN pg_catalog.pg_roles AS grantee_role + ON pa.grantee_oid = grantee_role.oid +-- Exclude privileges granted to the table owner (these are implicit) +WHERE pa.grantee_oid != pa.owner_oid OR pa.grantee_oid = 0 +ORDER BY pa.table_schema_name, pa.table_name, grantee, pa.privilege_type; diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 09ff092..4315102 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -1004,6 +1004,86 @@ func (q *Queries) GetSequences(ctx context.Context) ([]GetSequencesRow, error) { return items, nil } +const getTablePrivileges = `-- name: GetTablePrivileges :many +WITH parsed_acl AS ( + SELECT + c.oid AS table_oid, + c.relname AS table_name, + n.nspname AS table_schema_name, + c.relowner AS owner_oid, + (ACLEXPLODE(c.relacl)).grantee AS grantee_oid, + (ACLEXPLODE(c.relacl)).privilege_type AS privilege_type, + (ACLEXPLODE(c.relacl)).is_grantable AS is_grantable + FROM pg_catalog.pg_class AS c + INNER JOIN pg_catalog.pg_namespace AS n ON c.relnamespace = n.oid + WHERE + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname !~ '^pg_toast' + AND n.nspname !~ '^pg_temp' + AND (c.relkind = 'r' OR c.relkind = 'p') + AND c.relacl IS NOT null + -- Exclude tables owned by extensions + AND NOT EXISTS ( + SELECT depend.objid + FROM pg_catalog.pg_depend AS depend + WHERE + depend.classid = 'pg_class'::REGCLASS + AND depend.objid = c.oid + AND depend.deptype = 'e' + ) +) + +SELECT + pa.table_name::TEXT, + pa.table_schema_name::TEXT, + COALESCE(grantee_role.rolname, '')::TEXT AS grantee, + pa.privilege_type::TEXT AS privilege, + pa.is_grantable +FROM parsed_acl AS pa +LEFT JOIN pg_catalog.pg_roles AS grantee_role + ON pa.grantee_oid = grantee_role.oid +WHERE pa.grantee_oid != pa.owner_oid OR pa.grantee_oid = 0 +ORDER BY pa.table_schema_name, pa.table_name, grantee, pa.privilege_type +` + +type GetTablePrivilegesRow struct { + PaTableName string + PaTableSchemaName string + Grantee string + Privilege string + IsGrantable interface{} +} + +// Exclude privileges granted to the table owner (these are implicit) +func (q *Queries) GetTablePrivileges(ctx context.Context) ([]GetTablePrivilegesRow, error) { + rows, err := q.db.QueryContext(ctx, getTablePrivileges) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetTablePrivilegesRow + for rows.Next() { + var i GetTablePrivilegesRow + if err := rows.Scan( + &i.PaTableName, + &i.PaTableSchemaName, + &i.Grantee, + &i.Privilege, + &i.IsGrantable, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getTables = `-- name: GetTables :many SELECT c.oid, diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 3fc16e9..b3945d5 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -132,6 +132,9 @@ func normalizeTable(t Table) Table { normPolicies = append(normPolicies, p) } t.Policies = normPolicies + + t.Privileges = sortSchemaObjectsByName(t.Privileges) + return t } @@ -214,6 +217,7 @@ type Table struct { Columns []Column CheckConstraints []CheckConstraint Policies []Policy + Privileges []TablePrivilege ReplicaIdentity ReplicaIdentity RLSEnabled bool RLSForced bool @@ -238,6 +242,24 @@ func (t Table) IsPartition() bool { return t.ParentTable != nil } +// TablePrivilege represents a privilege granted on a table +type TablePrivilege struct { + // Grantee is the role that has the privilege. Empty string means PUBLIC. + Grantee string + // Privilege is the type of privilege (SELECT, INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER) + Privilege string + // IsGrantable indicates if the grantee can grant this privilege to others (WITH GRANT OPTION) + IsGrantable bool +} + +func (p TablePrivilege) GetName() string { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } + return fmt.Sprintf("%s:%s", grantee, p.Privilege) +} + type ColumnIdentityType string const ( @@ -926,12 +948,21 @@ func (s *schemaFetcher) fetchTables(ctx context.Context) ([]Table, error) { policiesByTable[p.table.GetFQEscapedName()] = append(policiesByTable[p.table.GetFQEscapedName()], p.policy) } + privileges, err := s.fetchPrivileges(ctx) + if err != nil { + return nil, fmt.Errorf("fetchPrivileges(): %w", err) + } + privilegesByTable := make(map[string][]TablePrivilege) + for _, p := range privileges { + privilegesByTable[p.table.GetFQEscapedName()] = append(privilegesByTable[p.table.GetFQEscapedName()], p.privilege) + } + goroutineRunner := s.goroutineRunnerFactory() var tableFutures []concurrent.Future[Table] for _, _rawTable := range rawTables { rawTable := _rawTable // Capture loop variables for go routine tableFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() (Table, error) { - return s.buildTable(ctx, rawTable, checkConsByTable, policiesByTable) + return s.buildTable(ctx, rawTable, checkConsByTable, policiesByTable, privilegesByTable) }) if err != nil { return nil, fmt.Errorf("starting table future: %w", err) @@ -959,6 +990,7 @@ func (s *schemaFetcher) buildTable( table queries.GetTablesRow, checkConsByTable map[string][]CheckConstraint, policiesByTable map[string][]Policy, + privilegesByTable map[string][]TablePrivilege, ) (Table, error) { rawColumns, err := s.q.GetColumnsForTable(ctx, table.Oid) if err != nil { @@ -1023,6 +1055,7 @@ func (s *schemaFetcher) buildTable( Columns: columns, CheckConstraints: checkConsByTable[schemaQualifiedName.GetFQEscapedName()], Policies: policiesByTable[schemaQualifiedName.GetFQEscapedName()], + Privileges: privilegesByTable[schemaQualifiedName.GetFQEscapedName()], ReplicaIdentity: ReplicaIdentity(table.ReplicaIdentity), RLSEnabled: table.RlsEnabled, RLSForced: table.RlsForced, @@ -1343,6 +1376,11 @@ type policyAndTable struct { table SchemaQualifiedName } +type privilegeAndTable struct { + privilege TablePrivilege + table SchemaQualifiedName +} + func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, error) { rawPolicies, err := s.q.GetPolicies(ctx) if err != nil { @@ -1379,6 +1417,43 @@ func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, er return policies, nil } +func (s *schemaFetcher) fetchPrivileges(ctx context.Context) ([]privilegeAndTable, error) { + rawPrivileges, err := s.q.GetTablePrivileges(ctx) + if err != nil { + return nil, fmt.Errorf("GetTablePrivileges: %w", err) + } + + var privileges []privilegeAndTable + for _, rp := range rawPrivileges { + // Handle the is_grantable field which may be returned as interface{} + isGrantable := false + if rp.IsGrantable != nil { + if b, ok := rp.IsGrantable.(bool); ok { + isGrantable = b + } + } + + privileges = append(privileges, privilegeAndTable{ + privilege: TablePrivilege{ + Grantee: rp.Grantee, + Privilege: rp.Privilege, + IsGrantable: isGrantable, + }, + table: buildNameFromUnescaped(rp.PaTableName, rp.PaTableSchemaName), + }) + } + + privileges = filterSliceByName( + privileges, + func(p privilegeAndTable) SchemaQualifiedName { + return p.table + }, + s.nameFilter, + ) + + return privileges, nil +} + func (s *schemaFetcher) fetchTriggers(ctx context.Context) ([]Trigger, error) { rawTriggers, err := s.q.GetTriggers(ctx) if err != nil { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 7f98e3b..c73bb17 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -234,8 +234,12 @@ var ( -- Add a column with a default to test HasMissingValOptimization ALTER TABLE schema_2.foo ADD COLUMN added_col TEXT DEFAULT 'some_default'; + + -- Add table privileges to test they are fetched correctly + GRANT SELECT ON schema_2.foo TO some_role_1; + GRANT INSERT ON schema_2.foo TO some_role_2 WITH GRANT OPTION; `}, - expectedHash: "fdff644bbabb9fc", + expectedHash: "4c2174e2cac3956b", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -311,6 +315,10 @@ var ( Columns: []string{"version"}, }, }, + Privileges: []TablePrivilege{ + {Grantee: "some_role_2", Privilege: "INSERT", IsGrantable: true}, + {Grantee: "some_role_1", Privilege: "SELECT", IsGrantable: false}, + }, ReplicaIdentity: ReplicaIdentityIndex, RLSEnabled: true, }, @@ -583,7 +591,7 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "301808413c59ab76", + expectedHash: "32c5a9c52dcfb15e", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, diff --git a/pkg/diff/plan.go b/pkg/diff/plan.go index cb7c5ef..78db0c4 100644 --- a/pkg/diff/plan.go +++ b/pkg/diff/plan.go @@ -46,6 +46,10 @@ type Statement struct { LockTimeout time.Duration // The hazards this statement poses Hazards []MigrationHazard + // SkipValidation indicates that this statement should be skipped during plan validation against a temporary + // database instance. This is useful for statements that depend on entities (like roles) that won't exist + // in the temp DB. + SkipValidation bool } func (s Statement) MarshalJSON() ([]byte, error) { diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index f5d8484..2c2c92b 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -279,7 +279,24 @@ func schemaFromTempDb(ctx context.Context, db *tempdb.Database, plan *planOption return schema.GetSchema(ctx, db.ConnPool, append(plan.getSchemaOpts, db.ExcludeMetadataOptions...)...) } +// clearTablePrivileges returns a copy of the schema with all table privileges cleared. +// This is used during plan validation because privilege statements are skipped (roles don't exist in temp DB). +func clearTablePrivileges(s schema.Schema) schema.Schema { + tables := make([]schema.Table, len(s.Tables)) + for i, t := range s.Tables { + t.Privileges = nil + tables[i] = t + } + s.Tables = tables + return s +} + func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schema, planOptions *planOptions) error { + // Clear privileges from both schemas since privilege statements are skipped during validation + // (roles don't exist in temp DB). We make copies to avoid modifying the original schemas. + migratedSchema = clearTablePrivileges(migratedSchema) + targetSchema = clearTablePrivileges(targetSchema) + toTargetSchemaStmts, err := generateMigrationStatements(migratedSchema, targetSchema, planOptions) if err != nil { return fmt.Errorf("building schema diff between migrated database and new schema: %w", err) @@ -316,8 +333,13 @@ func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, stat // must be executed within its own transaction block. Postgres will error if you try to set a TRANSACTION-level // timeout for it. SESSION-level statement_timeouts are respected by `ADD INDEX CONCURRENTLY` for _, stmt := range statements { + if stmt.SkipValidation { + // Skip statements that cannot be validated in temp DB (e.g., GRANT/REVOKE which reference roles + // that don't exist in the temp DB) + continue + } if _, err := conn.ExecContext(ctx, stmt.ToSQL()); err != nil { - return fmt.Errorf("executing migration statement: %s: %w", stmt, err) + return fmt.Errorf("executing migration statement: %s: %w", stmt.DDL, err) } } return nil diff --git a/pkg/diff/privilege_sql_generator.go b/pkg/diff/privilege_sql_generator.go new file mode 100644 index 0000000..d851ca7 --- /dev/null +++ b/pkg/diff/privilege_sql_generator.go @@ -0,0 +1,92 @@ +package diff + +import ( + "fmt" + + "github.com/stripe/pg-schema-diff/internal/schema" +) + +var ( + migrationHazardPrivilegeGranted = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Granting privileges could allow unauthorized access to data.", + } + migrationHazardPrivilegeRevoked = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Revoking privileges could cause queries to fail if not correctly configured.", + } +) + +type privilegeSQLVertexGenerator struct { + tableName schema.SchemaQualifiedName +} + +func newPrivilegeSQLVertexGenerator(tableName schema.SchemaQualifiedName) sqlVertexGenerator[schema.TablePrivilege, privilegeDiff] { + return legacyToNewSqlVertexGenerator[schema.TablePrivilege, privilegeDiff](&privilegeSQLVertexGenerator{ + tableName: tableName, + }) +} + +func (psg *privilegeSQLVertexGenerator) Add(p schema.TablePrivilege) ([]Statement, error) { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } else { + grantee = schema.EscapeIdentifier(grantee) + } + + ddl := fmt.Sprintf("GRANT %s ON %s TO %s", p.Privilege, psg.tableName.GetFQEscapedName(), grantee) + if p.IsGrantable { + ddl += " WITH GRANT OPTION" + } + + return []Statement{{ + DDL: ddl, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardPrivilegeGranted}, + SkipValidation: true, + }}, nil +} + +func (psg *privilegeSQLVertexGenerator) Delete(p schema.TablePrivilege) ([]Statement, error) { + grantee := p.Grantee + if grantee == "" { + grantee = "PUBLIC" + } else { + grantee = schema.EscapeIdentifier(grantee) + } + + ddl := fmt.Sprintf("REVOKE %s ON %s FROM %s", p.Privilege, psg.tableName.GetFQEscapedName(), grantee) + + return []Statement{{ + DDL: ddl, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardPrivilegeRevoked}, + SkipValidation: true, + }}, nil +} + +func (psg *privilegeSQLVertexGenerator) Alter(diff privilegeDiff) ([]Statement, error) { + // Privileges don't support ALTER - if IsGrantable changes, we need to recreate + // (handled via requiresRecreation in buildTableDiff) + // This should not normally be called since only IsGrantable can change and that + // triggers recreation. + return nil, nil +} + +func (psg *privilegeSQLVertexGenerator) GetSQLVertexId(p schema.TablePrivilege, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("privilege", fmt.Sprintf("%s.%s", psg.tableName.GetFQEscapedName(), p.GetName()), diffType) +} + +func (psg *privilegeSQLVertexGenerator) GetAddAlterDependencies(newPriv, _ schema.TablePrivilege) ([]dependency, error) { + // Ensure delete runs before add/alter (for recreate scenarios) + return []dependency{ + mustRun(psg.GetSQLVertexId(newPriv, diffTypeDelete)).before(psg.GetSQLVertexId(newPriv, diffTypeAddAlter)), + }, nil +} + +func (psg *privilegeSQLVertexGenerator) GetDeleteDependencies(_ schema.TablePrivilege) ([]dependency, error) { + return nil, nil +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index e28d948..6c31e22 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -108,11 +108,16 @@ type ( oldAndNew[schema.CheckConstraint] } + privilegeDiff struct { + oldAndNew[schema.TablePrivilege] + } + tableDiff struct { oldAndNew[schema.Table] columnsDiff listDiff[schema.Column, columnDiff] checkConstraintDiff listDiff[schema.CheckConstraint, checkConstraintDiff] policiesDiff listDiff[schema.Policy, policyDiff] + privilegesDiff listDiff[schema.TablePrivilege, privilegeDiff] } indexDiff struct { @@ -424,6 +429,19 @@ func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, } + privilegesDiff, err := diffLists( + oldTable.Privileges, + newTable.Privileges, + func(old, new schema.TablePrivilege, _, _ int) (privilegeDiff, bool, error) { + // Recreate the privilege if IsGrantable changes + recreate := old.IsGrantable != new.IsGrantable + return privilegeDiff{oldAndNew[schema.TablePrivilege]{old: old, new: new}}, recreate, nil + }, + ) + if err != nil { + return tableDiff{}, false, fmt.Errorf("diffing privileges: %w", err) + } + return tableDiff{ oldAndNew: oldAndNew[schema.Table]{ old: oldTable, @@ -432,6 +450,7 @@ func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, columnsDiff: columnsDiff, checkConstraintDiff: checkConsDiff, policiesDiff: policiesDiff, + privilegesDiff: privilegesDiff, }, false, nil } @@ -790,6 +809,9 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { if len(table.Policies) > 0 { return nil, fmt.Errorf("policies on partitions: %w", ErrNotImplemented) } + if len(table.Privileges) > 0 { + return nil, fmt.Errorf("privileges on partitions: %w", ErrNotImplemented) + } // We attach the partitions separately. So the partition must have all the same check constraints // as the original table table.CheckConstraints = append(table.CheckConstraints, t.tablesInNewSchemaByName[table.ParentTable.GetName()].CheckConstraints...) @@ -863,6 +885,16 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { stmts = append(stmts, stripMigrationHazards(forceRLSForTable(table))...) } + privilegeGenerator := &privilegeSQLVertexGenerator{tableName: table.SchemaQualifiedName} + for _, privilege := range table.Privileges { + addPrivilegeStmts, err := privilegeGenerator.Add(privilege) + if err != nil { + return nil, fmt.Errorf("generating add privilege statements for privilege %s: %w", privilege.GetName(), err) + } + // Remove hazards from statements since the table is brand new + stmts = append(stmts, stripMigrationHazards(addPrivilegeStmts...)...) + } + return stmts, nil } @@ -1003,6 +1035,13 @@ func (t *tableSQLVertexGenerator) alterBaseTable(diff tableDiff) ([]Statement, e } partialGraph = concatPartialGraphs(partialGraph, policiesPartialGraph) + privilegeGenerator := newPrivilegeSQLVertexGenerator(diff.new.SchemaQualifiedName) + privilegesPartialGraph, err := generatePartialGraph(privilegeGenerator, diff.privilegesDiff) + if err != nil { + return nil, fmt.Errorf("resolving privilege sql: %w", err) + } + partialGraph = concatPartialGraphs(partialGraph, privilegesPartialGraph) + graph, err := graphFromPartials(partialGraph) if err != nil { return nil, fmt.Errorf("converting to graph") @@ -1033,6 +1072,11 @@ func (t *tableSQLVertexGenerator) alterPartition(diff tableDiff) ([]Statement, e // _independent_ of how it is ordered. return nil, fmt.Errorf("policies on partitions: %w", ErrNotImplemented) } + if !diff.privilegesDiff.isEmpty() { + // Privilege diffing on individual partitions cannot be supported until where a SQL statement is generated is + // _independent_ of how it is ordered. + return nil, fmt.Errorf("privileges on partitions: %w", ErrNotImplemented) + } var alteredParentColumnsByName map[string]columnDiff if parentDiff, ok := t.tableDiffsByName[diff.new.ParentTable.GetName()]; ok { diff --git a/scripts/lint/multiline_sql_strings_lint.go b/scripts/lint/multiline_sql_strings_lint.go index 169c02c..821e569 100644 --- a/scripts/lint/multiline_sql_strings_lint.go +++ b/scripts/lint/multiline_sql_strings_lint.go @@ -119,7 +119,11 @@ func processFile(filePath string, fix bool) (bool, error) { } if err := writer.Flush(); err != nil { - return false, fmt.Errorf("lush buffer: %w", err) + return false, fmt.Errorf("flush buffer: %w", err) + } + + if !fix { + return changesRequired, nil } return changesRequired, os.Rename(tempFile.Name(), filePath)