Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8e1176c
BCDA-9287: create struct to wrap db connections
michaeljvaldes Jul 22, 2025
bd4a313
Pass db connections explicitly from cli to service
michaeljvaldes Jul 23, 2025
ca1b0f1
Remove unintended change to middleware
michaeljvaldes Jul 23, 2025
bb1fd60
Separate connection and pool
michaeljvaldes Jul 23, 2025
6aa2fa2
Pass connection through routers and services instead of combined struct
michaeljvaldes Jul 23, 2025
cab0a71
Add connection as dependency of data router
michaeljvaldes Jul 23, 2025
0d37ba1
Pass connection as argument to token job middleware
michaeljvaldes Jul 23, 2025
ed1783c
Remove connection global from cli
michaeljvaldes Jul 24, 2025
8272973
Define db pool as dependency for apis and workers
michaeljvaldes Jul 24, 2025
e30cd6e
Remove connection globals from several tests
michaeljvaldes Jul 30, 2025
82c98f0
Pass connection to health.go
michaeljvaldes Jul 30, 2025
134d8fb
Refactor connection globals in cclf
michaeljvaldes Jul 30, 2025
607ce18
Refactor connection global in admin create group lambda
michaeljvaldes Jul 30, 2025
2749820
Refactor db globals in optout lambda
michaeljvaldes Jul 30, 2025
db96700
Refactor globals in ratelimit middleware
michaeljvaldes Jul 30, 2025
3256641
Refactor db connection in worker and river
michaeljvaldes Jul 30, 2025
ae8dc98
Refactor db connections in cleanup worker
michaeljvaldes Jul 30, 2025
486f0c4
Inject provider as dependency
michaeljvaldes Jul 31, 2025
c896f87
Refactor provider-related tests
michaeljvaldes Aug 4, 2025
7376ff3
Merge branch 'main' into mvaldes/BCDA-9287-Connection
michaeljvaldes Aug 4, 2025
29ba3cc
Remove database connection globals
michaeljvaldes Aug 5, 2025
b00818b
Rename connection to db
michaeljvaldes Aug 5, 2025
1237279
Remove comments
michaeljvaldes Aug 5, 2025
beb66ca
Pass repo as argument in cli function
michaeljvaldes Aug 5, 2025
a54157f
Pass globals as arguments in cli functions
michaeljvaldes Aug 5, 2025
49b753b
Run test scenarios as separate tests
michaeljvaldes Aug 6, 2025
16dd65f
Rename blacklist to denylist in a few places
michaeljvaldes Aug 6, 2025
cfb33ef
Add comment for refactoring todo
michaeljvaldes Aug 6, 2025
01bd874
Merge branch 'main' into mvaldes/BCDA-9287-Connection
michaeljvaldes Aug 6, 2025
bd29e55
Add unit test for pool connection
michaeljvaldes Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions bcda/api/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

"github.com/CMSgov/bcda-app/bcda/auth"
"github.com/CMSgov/bcda-app/bcda/constants"
"github.com/CMSgov/bcda-app/bcda/database"
"github.com/CMSgov/bcda-app/bcda/models"
"github.com/CMSgov/bcda-app/bcda/models/postgres"
responseutils "github.com/CMSgov/bcda-app/bcda/responseutils"
Expand All @@ -35,6 +34,7 @@ import (
"github.com/CMSgov/bcda-app/conf"
"github.com/CMSgov/bcda-app/log"
m "github.com/CMSgov/bcda-app/middleware"
pgxv5Pool "github.com/jackc/pgx/v5/pgxpool"
)

type Handler struct {
Expand All @@ -53,8 +53,7 @@ type Handler struct {
// Needed to have access to the repository/db for lookup needed in the bulkRequest.
// TODO (BCDA-3412): Remove this reference once we've captured all of the necessary
// logic into a service method.
r models.Repository
db *sql.DB
r models.Repository
}

type fhirResponseWriter interface {
Expand All @@ -63,14 +62,14 @@ type fhirResponseWriter interface {
JobsBundle(context.Context, http.ResponseWriter, []*models.Job, string)
}

func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string) *Handler {
return newHandler(dataTypes, basePath, apiVersion, database.Connection)
func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB, pool *pgxv5Pool.Pool) *Handler {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see all handlers and APIvX passing around pool, but I dont see it being used anywhere? Is this preparation for swapping out db? Or did I miss it somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the handler passes it to its enqueuer, which then passes it to its river client. As far as I can tell, I only replaced existing instances of the global pool with the pool passed as a dependency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a PR comment on enqueue.go to make it a little clearer.

return newHandler(dataTypes, basePath, apiVersion, db, pool)
}

func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB) *Handler {
func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB, pool *pgxv5Pool.Pool) *Handler {
h := &Handler{JobTimeout: time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24))}

h.Enq = queueing.NewEnqueuer()
h.Enq = queueing.NewEnqueuer(db, pool)

cfg, err := service.LoadConfig()
if err != nil {
Expand All @@ -81,7 +80,7 @@ func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersi
}

repository := postgres.NewRepository(db)
h.db, h.r = db, repository
h.r = repository
h.Svc = service.NewService(repository, cfg, basePath)

h.supportedDataTypes = dataTypes
Expand Down
55 changes: 29 additions & 26 deletions bcda/api/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
fhirmodelv2CR "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/bundle_and_contained_resource_go_proto"
fhircodesv1 "github.com/google/fhir/go/proto/google/fhir/proto/stu3/codes_go_proto"
fhirmodelsv1 "github.com/google/fhir/go/proto/google/fhir/proto/stu3/resources_go_proto"
pgxv5Pool "github.com/jackc/pgx/v5/pgxpool"
)

const apiVersionOne = "v1"
Expand All @@ -67,6 +68,8 @@ type RequestsTestSuite struct {

db *sql.DB

pool *pgxv5Pool.Pool

acoID uuid.UUID

resourceType map[string]service.DataType
Expand All @@ -79,9 +82,11 @@ func TestRequestsTestSuite(t *testing.T) {
func (s *RequestsTestSuite) SetupSuite() {
// See testdata/acos.yml
s.acoID = uuid.Parse("ba21d24d-cd96-4d7d-a691-b0e8c88e67a5")
s.db, _ = databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true)
db, _ := databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true)
s.db = db
s.pool = database.ConnectPool()
tf, err := testfixtures.New(
testfixtures.Database(s.db),
testfixtures.Database(db),
testfixtures.Dialect("postgres"),
testfixtures.Directory("testdata/"),
)
Expand Down Expand Up @@ -137,7 +142,7 @@ func (s *RequestsTestSuite) TestRunoutEnabled() {
mockSvc := &service.MockService{}
mockAco := service.ACOConfig{Data: []string{"adjudicated"}}
mockSvc.On("GetACOConfigForID", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAco, true)
h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.db)
h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.db, s.pool)
h.Svc = mockSvc
enqueuer := queueing.NewMockEnqueuer(s.T())
h.Enq = enqueuer
Expand Down Expand Up @@ -239,7 +244,7 @@ func (s *RequestsTestSuite) TestJobsStatusV1() {
"Patient": {},
"Coverage": {},
"ExplanationOfBenefit": {},
}, fhirPath, apiVersion, s.db)
}, fhirPath, apiVersion, s.db, s.pool)
h.Svc = mockSvc

rr := httptest.NewRecorder()
Expand Down Expand Up @@ -353,7 +358,7 @@ func (s *RequestsTestSuite) TestJobsStatusV2() {
"Patient": {},
"Coverage": {},
"ExplanationOfBenefit": {},
}, v2BasePath, apiVersionTwo, s.db)
}, v2BasePath, apiVersionTwo, s.db, s.pool)
if tt.useMock {
h.Svc = mockSvc
}
Expand Down Expand Up @@ -472,7 +477,7 @@ func (s *RequestsTestSuite) TestAttributionStatus() {
fhirPath := "/" + apiVersion + "/fhir"

resourceMap := s.resourceType
h := newHandler(resourceMap, fhirPath, apiVersion, s.db)
h := newHandler(resourceMap, fhirPath, apiVersion, s.db, s.pool)
h.Svc = mockSvc

rr := httptest.NewRecorder()
Expand Down Expand Up @@ -563,7 +568,11 @@ func (s *RequestsTestSuite) TestDataTypeAuthorization() {
"ClaimResponse": {Adjudicated: false, PartiallyAdjudicated: true},
}

h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo)
h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.db, s.pool)
r := models.NewMockRepository(s.T())
r.On("CreateJob", mock.Anything, mock.Anything).Return(uint(4), nil)
h.r = r

h.supportedDataTypes = dataTypeMap
client.SetLogger(log.API) // Set logger so we don't get errors later
jsonBytes, _ := json.Marshal("{}")
Expand Down Expand Up @@ -647,7 +656,7 @@ func (s *RequestsTestSuite) TestRequests() {
fhirPath := "/" + apiVersion + "/fhir"
resourceMap := s.resourceType

h := newHandler(resourceMap, fhirPath, apiVersion, s.db)
h := newHandler(resourceMap, fhirPath, apiVersion, s.db, s.pool)

// Test Group and Patient
// Patient, Coverage, and ExplanationOfBenefit
Expand Down Expand Up @@ -777,7 +786,7 @@ func (s *RequestsTestSuite) TestJobStatusErrorHandling() {

for _, tt := range tests {
s.T().Run(tt.testName, func(t *testing.T) {
h := newHandler(resourceMap, basePath, apiVersion, s.db)
h := newHandler(resourceMap, basePath, apiVersion, s.db, s.pool)
if tt.useMockService {
mockSrv := service.MockService{}
timestp := time.Now()
Expand Down Expand Up @@ -851,7 +860,7 @@ func (s *RequestsTestSuite) TestJobStatusProgress() {
apiVersion := apiVersionTwo
requestUrl := v2JobRequestUrl
resourceMap := s.resourceType
h := newHandler(resourceMap, basePath, apiVersion, s.db)
h := newHandler(resourceMap, basePath, apiVersion, s.db, s.pool)

req := httptest.NewRequest("GET", requestUrl, nil)
rctx := chi.NewRouteContext()
Expand Down Expand Up @@ -900,7 +909,7 @@ func (s *RequestsTestSuite) TestDeleteJob() {

for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) {
handler := newHandler(s.resourceType, basePath, apiVersion, s.db)
handler := newHandler(s.resourceType, basePath, apiVersion, s.db, s.pool)

if tt.useMockService {
mockSrv := service.MockService{}
Expand Down Expand Up @@ -960,7 +969,7 @@ func (s *RequestsTestSuite) TestJobFailedStatus() {

for _, tt := range tests {
s.T().Run(tt.name, func(t *testing.T) {
h := newHandler(resourceMap, tt.basePath, tt.version, s.db)
h := newHandler(resourceMap, tt.basePath, tt.version, s.db, s.pool)
mockSrv := service.MockService{}
timestp := time.Now()
mockSrv.On("GetJobAndKeys", testUtils.CtxMatcher, uint(1)).Return(
Expand Down Expand Up @@ -1018,7 +1027,7 @@ func (s *RequestsTestSuite) TestGetResourceTypes() {
{"CT000000", "v2", []string{"Patient", "ExplanationOfBenefit", "Coverage", "Claim", "ClaimResponse"}},
}
for _, test := range testCases {
h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.db)
h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.db, s.pool)
rp := middleware.RequestParameters{
Version: test.apiVersion,
ResourceTypes: []string{},
Expand Down Expand Up @@ -1051,23 +1060,17 @@ func TestBulkRequest_Integration(t *testing.T) {

client.SetLogger(log.API) // Set logger so we don't get errors later

h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo)
db := database.Connect()
pool := database.ConnectPool()
h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, db, pool)

cfg, err := database.LoadConfig()
if err != nil {
t.FailNow()
}
d, err := database.CreatePgxv5DB(cfg)
if err != nil {
t.FailNow()
}
driver := riverpgxv5.New(d)
driver := riverpgxv5.New(pool)
// start from clean river_job slate
_, err = driver.GetExecutor().Exec(context.Background(), `delete from river_job`)
_, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`)
assert.Nil(t, err)

acoID := "A0002"
repo := postgres.NewRepository(h.db)
repo := postgres.NewRepository(db)

// our DB is not always cleaned up properly so sometimes this record exists when this test runs and sometimes it doesnt
repo.CreateACO(context.Background(), models.ACO{CMSID: &acoID, UUID: uuid.NewUUID()}) // nolint:errcheck
Expand Down Expand Up @@ -1205,7 +1208,7 @@ func (s *RequestsTestSuite) TestValidateResources() {
"Patient": {},
"Coverage": {},
"ExplanationOfBenefit": {},
}, fhirPath, apiVersion, s.db)
}, fhirPath, apiVersion, s.db, s.pool)
err := h.validateResources([]string{"Vegetable"}, "1234")
assert.Contains(s.T(), err.Error(), "invalid resource type")
}
Expand Down
57 changes: 33 additions & 24 deletions bcda/api/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package v1
import (
"bytes"
"compress/gzip"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand All @@ -23,23 +24,31 @@ import (
"github.com/CMSgov/bcda-app/bcda/servicemux"
"github.com/CMSgov/bcda-app/conf"
"github.com/CMSgov/bcda-app/log"
pgxv5Pool "github.com/jackc/pgx/v5/pgxpool"
)

var h *api.Handler
type ApiV1 struct {
db *sql.DB
handler *api.Handler
provider auth.Provider
healthChecker health.HealthChecker
}

func init() {
func NewApiV1(db *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) *ApiV1 {
resources, ok := service.GetDataTypes([]string{
"Patient",
"Coverage",
"ExplanationOfBenefit",
"Observation",
}...)

if ok {
h = api.NewHandler(resources, "/v1/fhir", "v1")
} else {
if !ok {
panic("Failed to configure resource DataTypes")
}

hc := health.NewHealthChecker(db)
h := api.NewHandler(resources, "/v1/fhir", "v1", db, pool)
return &ApiV1{db: db, handler: h, provider: provider, healthChecker: hc}
}

/*
Expand All @@ -64,8 +73,8 @@ Responses:
429: tooManyRequestsResponse
500: errorResponse
*/
func BulkPatientRequest(w http.ResponseWriter, r *http.Request) {
h.BulkPatientRequest(w, r)
func (a ApiV1) BulkPatientRequest(w http.ResponseWriter, r *http.Request) {
a.handler.BulkPatientRequest(w, r)
}

/*
Expand All @@ -92,8 +101,8 @@ func BulkPatientRequest(w http.ResponseWriter, r *http.Request) {
429: tooManyRequestsResponse
500: errorResponse
*/
func BulkGroupRequest(w http.ResponseWriter, r *http.Request) {
h.BulkGroupRequest(w, r)
func (a ApiV1) BulkGroupRequest(w http.ResponseWriter, r *http.Request) {
a.handler.BulkGroupRequest(w, r)
}

/*
Expand Down Expand Up @@ -122,8 +131,8 @@ Responses:
410: goneResponse
500: errorResponse
*/
func JobStatus(w http.ResponseWriter, r *http.Request) {
h.JobStatus(w, r)
func (a ApiV1) JobStatus(w http.ResponseWriter, r *http.Request) {
a.handler.JobStatus(w, r)
}

/*
Expand Down Expand Up @@ -162,8 +171,8 @@ Responses:
410: goneResponse
500: errorResponse
*/
func JobsStatus(w http.ResponseWriter, r *http.Request) {
h.JobsStatus(w, r)
func (a ApiV1) JobsStatus(w http.ResponseWriter, r *http.Request) {
a.handler.JobsStatus(w, r)
}

type gzipResponseWriter struct {
Expand Down Expand Up @@ -204,8 +213,8 @@ Responses:
410: goneResponse
500: errorResponse
*/
func DeleteJob(w http.ResponseWriter, r *http.Request) {
h.DeleteJob(w, r)
func (a ApiV1) DeleteJob(w http.ResponseWriter, r *http.Request) {
a.handler.DeleteJob(w, r)
}

/*
Expand All @@ -229,8 +238,8 @@ Responses:
200: AttributionFileStatusResponse
404: notFoundResponse
*/
func AttributionStatus(w http.ResponseWriter, r *http.Request) {
h.AttributionStatus(w, r)
func (a ApiV1) AttributionStatus(w http.ResponseWriter, r *http.Request) {
a.handler.AttributionStatus(w, r)
}

/*
Expand Down Expand Up @@ -360,7 +369,7 @@ Responses:

200: MetadataResponse
*/
func Metadata(w http.ResponseWriter, r *http.Request) {
func (a ApiV1) Metadata(w http.ResponseWriter, r *http.Request) {
dt := time.Now()

scheme := "http"
Expand Down Expand Up @@ -390,7 +399,7 @@ Responses:

200: VersionResponse
*/
func GetVersion(w http.ResponseWriter, r *http.Request) {
func (a ApiV1) GetVersion(w http.ResponseWriter, r *http.Request) {
respMap := make(map[string]string)
respMap["version"] = constants.Version
respBytes, err := json.Marshal(respMap)
Expand All @@ -409,11 +418,11 @@ func GetVersion(w http.ResponseWriter, r *http.Request) {
}
}

func HealthCheck(w http.ResponseWriter, r *http.Request) {
func (a ApiV1) HealthCheck(w http.ResponseWriter, r *http.Request) {
m := make(map[string]string)

dbStatus, dbOK := health.IsDatabaseOK()
ssasStatus, ssasOK := health.IsSsasOK()
dbStatus, dbOK := a.healthChecker.IsDatabaseOK()
ssasStatus, ssasOK := a.healthChecker.IsSsasOK()

m["database"] = dbStatus
m["ssas"] = ssasStatus
Expand Down Expand Up @@ -454,10 +463,10 @@ Responses:

200: AuthResponse
*/
func GetAuthInfo(w http.ResponseWriter, r *http.Request) {
func (a ApiV1) GetAuthInfo(w http.ResponseWriter, r *http.Request) {
respMap := make(map[string]string)
respMap["auth_provider"] = auth.GetProviderName()
version, err := auth.GetProvider().GetVersion()
version, err := a.provider.GetVersion()
if err == nil {
respMap["version"] = version
} else {
Expand Down
Loading