Skip to content

Commit a0aa21a

Browse files
committed
fix: change temperature type from int to *float32
1 parent b1e93d1 commit a0aa21a

File tree

4 files changed

+59
-5
lines changed

4 files changed

+59
-5
lines changed

request/request.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type ChatCompletionsRequest struct {
2323
Stop []string `json:"stop,omitempty"`
2424
Stream bool `json:"stream,omitempty"`
2525
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
26-
Temperature int `json:"temperature,omitempty"`
26+
Temperature *float32 `json:"temperature,omitempty"`
2727
TopP *float32 `json:"top_p,omitempty"`
2828
Tools *[]Tool `json:"tools,omitempty"`
2929
ToolChoice any `json:"tool_choice,omitempty"`
@@ -92,3 +92,7 @@ type ToolChoiceNamed struct {
9292
type ToolChoiceFunction struct {
9393
Name string `json:"name"`
9494
}
95+
96+
func ToPtr[T any](v T) *T {
97+
return &v
98+
}

request/request_test/validator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestValidateChatCompletionsRequest(t *testing.T) {
2929
StreamOptions: &request.StreamOptions{
3030
IncludeUsage: true,
3131
},
32-
Temperature: 2,
32+
Temperature: request.ToPtr(float32(0.2)),
3333
// TopP: nil, // TODO: VN -- pass non nil
3434
}
3535
err := request.ValidateChatCompletionsRequest(req)

request/validator.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,15 @@ func validateMultipleFields(req *ChatCompletionsRequest) error {
110110
return fmt.Errorf("err: presence_penalty is invalid; it should be number between -2 and 2")
111111
}
112112

113-
if !(req.Temperature >= 0 && req.Temperature <= 2) {
114-
return fmt.Errorf("err: temperature is invalid; it should be number between 0 and 2")
113+
if req.Temperature != nil {
114+
if !(*req.Temperature >= 0 && *req.Temperature <= 2) {
115+
return fmt.Errorf("err: temperature is invalid; the valid range of temperature is [0, 2.0]")
116+
}
115117
}
116118

117119
if req.TopP != nil {
118120
if !(*req.TopP > 0 && *req.TopP <= 1) {
119-
return fmt.Errorf("err: invalid top_p value; the valid range of top_p is (0, 1.0]")
121+
return fmt.Errorf("err: top_p is invalid; the valid range of top_p is (0, 1.0]")
120122
}
121123
}
122124

request/validator_unit_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,51 @@ func TestValidateStreamOptions(t *testing.T) {
178178
assert.NoError(t, err)
179179
})
180180
}
181+
182+
func TestValidateMultipleFields(t *testing.T) {
183+
t.Run("no err for valid temprature", func(t *testing.T) {
184+
req := &ChatCompletionsRequest{
185+
Temperature: nil,
186+
}
187+
err := validateMultipleFields(req)
188+
assert.NoError(t, err)
189+
190+
req = &ChatCompletionsRequest{
191+
Temperature: ToPtr(float32(0)),
192+
}
193+
err = validateMultipleFields(req)
194+
assert.NoError(t, err)
195+
196+
req = &ChatCompletionsRequest{
197+
Temperature: ToPtr(float32(0.1)),
198+
}
199+
err = validateMultipleFields(req)
200+
assert.NoError(t, err)
201+
202+
req = &ChatCompletionsRequest{
203+
Temperature: ToPtr(float32(1.9)),
204+
}
205+
err = validateMultipleFields(req)
206+
assert.NoError(t, err)
207+
208+
req = &ChatCompletionsRequest{
209+
Temperature: ToPtr(float32(2.0)),
210+
}
211+
err = validateMultipleFields(req)
212+
assert.NoError(t, err)
213+
})
214+
215+
t.Run("err for invalid temprature", func(t *testing.T) {
216+
req := &ChatCompletionsRequest{
217+
Temperature: ToPtr(float32(-0.1)),
218+
}
219+
err := validateMultipleFields(req)
220+
assert.Error(t, err)
221+
222+
req = &ChatCompletionsRequest{
223+
Temperature: ToPtr(float32(2.1)),
224+
}
225+
err = validateMultipleFields(req)
226+
assert.Error(t, err)
227+
})
228+
}

0 commit comments

Comments
 (0)