Skip to content

Commit 744f773

Browse files
committed
rest: apply columnCast for Prefer: return=representation header
Signed-off-by: hmoazzem <[email protected]>
1 parent df645a4 commit 744f773

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

pkg/rest/query.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,14 @@ func buildInsertQuery(table schema.Table, data map[string]any, headers *Headers)
417417

418418
// Add RETURNING clause only if requested
419419
if headers != nil && headers.Prefer != nil && headers.Prefer.WantsRepresentation() {
420-
query.WriteString(" RETURNING *")
420+
query.WriteString(" RETURNING ")
421+
422+
returningCols := make([]string, 0, len(table.Columns))
423+
for _, col := range table.Columns {
424+
returningCols = append(returningCols, columnCastExpression(col.Name, col.DataType))
425+
}
426+
427+
query.WriteString(strings.Join(returningCols, ", "))
421428
}
422429

423430
return query.String(), args, nil
@@ -514,7 +521,14 @@ func buildUpdateQuery(table schema.Table, data map[string]any, params QueryParam
514521

515522
// Add RETURNING clause only if requested (to return the updated row/s)
516523
if headers != nil && headers.Prefer != nil && headers.Prefer.WantsRepresentation() {
517-
query.WriteString(" RETURNING *")
524+
query.WriteString(" RETURNING ")
525+
526+
returningCols := make([]string, 0, len(table.Columns))
527+
for _, col := range table.Columns {
528+
returningCols = append(returningCols, columnCastExpression(col.Name, col.DataType))
529+
}
530+
531+
query.WriteString(strings.Join(returningCols, ", "))
518532
}
519533

520534
return query.String(), args, nil
@@ -585,7 +599,14 @@ func buildDeleteQuery(table schema.Table, params QueryParams, headers *Headers)
585599
}
586600
// Add RETURNING clause only if requested (to return the deleted row/s)
587601
if headers != nil && headers.Prefer != nil && headers.Prefer.WantsRepresentation() {
588-
query.WriteString(" RETURNING *")
602+
query.WriteString(" RETURNING ")
603+
604+
returningCols := make([]string, 0, len(table.Columns))
605+
for _, col := range table.Columns {
606+
returningCols = append(returningCols, columnCastExpression(col.Name, col.DataType))
607+
}
608+
609+
query.WriteString(strings.Join(returningCols, ", "))
589610
}
590611

591612
return query.String(), args, nil

0 commit comments

Comments
 (0)