From 8e1176ca308443a98147765b8e6406b2f3dbd57f Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Tue, 22 Jul 2025 17:56:49 -0400 Subject: [PATCH 01/28] BCDA-9287: create struct to wrap db connections --- bcda/database/connection.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 2ed2b65f8..f17152f89 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -23,24 +23,37 @@ var ( Pgxv5Pool *pgxv5Pool.Pool ) +type Connections struct { + Connection *sql.DB + QueueConnection *pgx.ConnPool + Pgxv5Pool *pgxv5Pool.Pool +} + func init() { + c := Connect() + Connection = c.Connection + QueueConnection = c.QueueConnection + Pgxv5Pool = c.Pgxv5Pool +} + +func Connect() *Connections { cfg, err := LoadConfig() if err != nil { logrus.Fatalf("Failed to load database config %s", err.Error()) } - Connection, err = createDB(cfg) + conn, err := createDB(cfg) if err != nil { logrus.Fatalf("Failed to create db %s", err.Error()) } - QueueConnection, err = createQueue(cfg) + queue, err := createQueue(cfg) if err != nil { logrus.Fatalf("Failed to create queue %s", err.Error()) } - Pgxv5Pool, err = CreatePgxv5DB(cfg) + pool, err := CreatePgxv5DB(cfg) if err != nil { logrus.Fatalf("Failed to create pgxv5 DB connection %s", err.Error()) } @@ -50,11 +63,13 @@ func init() { startHealthCheck( ctx, - Connection, - QueueConnection, - Pgxv5Pool, + conn, + queue, + pool, time.Duration(cfg.HealthCheckSec)*time.Second, ) + + return &Connections{conn, queue, pool} } func createDB(cfg *Config) (*sql.DB, error) { From bd4a3134d46dfb65d3157d947262ecf04b08034b Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 11:34:55 -0400 Subject: [PATCH 02/28] Pass db connections explicitly from cli to service --- bcda/api/requests.go | 13 +++--- bcda/api/requests_test.go | 61 ++++++++++++------------- bcda/api/v1/api.go | 39 +++++++++------- bcda/api/v1/api_test.go | 92 +++++++++++++++++++------------------- bcda/api/v2/api.go | 62 +++++++++++++------------- bcda/api/v2/api_test.go | 93 ++++++++++++++++++++------------------- bcda/api/v3/api.go | 62 +++++++++++++------------- bcda/api/v3/api_test.go | 93 ++++++++++++++++++++------------------- bcda/auth/middleware.go | 3 ++ bcda/bcdacli/cli.go | 11 +++-- bcda/web/router.go | 45 ++++++++++--------- bcda/web/router_test.go | 21 +++++---- 12 files changed, 304 insertions(+), 291 deletions(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 2cd272d2e..8657e5ad3 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -53,8 +53,7 @@ type Handler struct { // Needed to have access to the repository/db for lookup needed in the bulkRequest. // TODO (BCDA-3412): Remove this reference once we've captured all of the necessary // logic into a service method. - r models.Repository - db *sql.DB + r models.Repository } type fhirResponseWriter interface { @@ -63,11 +62,11 @@ type fhirResponseWriter interface { JobsBundle(context.Context, http.ResponseWriter, []*models.Job, string) } -func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string) *Handler { - return newHandler(dataTypes, basePath, apiVersion, database.Connection) +func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connections *database.Connections) *Handler { + return newHandler(dataTypes, basePath, apiVersion, connections) } -func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB) *Handler { +func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connections *database.Connections) *Handler { h := &Handler{JobTimeout: time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24))} h.Enq = queueing.NewEnqueuer() @@ -80,8 +79,8 @@ func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersi log.API.Fatalf("no ACO configs found, these are required for processing logic") } - repository := postgres.NewRepository(db) - h.db, h.r = db, repository + repository := postgres.NewRepository(connections.Connection) + h.r = repository h.Svc = service.NewService(repository, cfg, basePath) h.supportedDataTypes = dataTypes diff --git a/bcda/api/requests_test.go b/bcda/api/requests_test.go index a4c1c27b0..88742a782 100644 --- a/bcda/api/requests_test.go +++ b/bcda/api/requests_test.go @@ -65,7 +65,7 @@ type RequestsTestSuite struct { runoutEnabledEnvVar string - db *sql.DB + connections *database.Connections acoID uuid.UUID @@ -79,9 +79,10 @@ func TestRequestsTestSuite(t *testing.T) { func (s *RequestsTestSuite) SetupSuite() { // See testdata/acos.yml s.acoID = uuid.Parse("ba21d24d-cd96-4d7d-a691-b0e8c88e67a5") - s.db, _ = databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) + db, _ := databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) + s.connections = &database.Connections{Connection: db} tf, err := testfixtures.New( - testfixtures.Database(s.db), + testfixtures.Database(db), testfixtures.Dialect("postgres"), testfixtures.Directory("testdata/"), ) @@ -137,7 +138,7 @@ func (s *RequestsTestSuite) TestRunoutEnabled() { mockSvc := &service.MockService{} mockAco := service.ACOConfig{Data: []string{"adjudicated"}} mockSvc.On("GetACOConfigForID", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAco, true) - h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.db) + h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connections) h.Svc = mockSvc enqueuer := queueing.NewMockEnqueuer(s.T()) h.Enq = enqueuer @@ -239,7 +240,7 @@ func (s *RequestsTestSuite) TestJobsStatusV1() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.db) + }, fhirPath, apiVersion, s.connections) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -353,7 +354,7 @@ func (s *RequestsTestSuite) TestJobsStatusV2() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, v2BasePath, apiVersionTwo, s.db) + }, v2BasePath, apiVersionTwo, s.connections) if tt.useMock { h.Svc = mockSvc } @@ -472,7 +473,7 @@ func (s *RequestsTestSuite) TestAttributionStatus() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.db) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connections) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -563,7 +564,11 @@ func (s *RequestsTestSuite) TestDataTypeAuthorization() { "ClaimResponse": {Adjudicated: false, PartiallyAdjudicated: true}, } - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo) + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connections) + r := models.NewMockRepository(s.T()) + r.On("CreateJob", mock.Anything, mock.Anything).Return(uint(4), nil) + h.r = r + h.supportedDataTypes = dataTypeMap client.SetLogger(log.API) // Set logger so we don't get errors later jsonBytes, _ := json.Marshal("{}") @@ -647,7 +652,7 @@ func (s *RequestsTestSuite) TestRequests() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.db) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connections) // Test Group and Patient // Patient, Coverage, and ExplanationOfBenefit @@ -777,7 +782,7 @@ func (s *RequestsTestSuite) TestJobStatusErrorHandling() { for _, tt := range tests { s.T().Run(tt.testName, func(t *testing.T) { - h := newHandler(resourceMap, basePath, apiVersion, s.db) + h := newHandler(resourceMap, basePath, apiVersion, s.connections) if tt.useMockService { mockSrv := service.MockService{} timestp := time.Now() @@ -851,7 +856,7 @@ func (s *RequestsTestSuite) TestJobStatusProgress() { apiVersion := apiVersionTwo requestUrl := v2JobRequestUrl resourceMap := s.resourceType - h := newHandler(resourceMap, basePath, apiVersion, s.db) + h := newHandler(resourceMap, basePath, apiVersion, s.connections) req := httptest.NewRequest("GET", requestUrl, nil) rctx := chi.NewRouteContext() @@ -900,7 +905,7 @@ func (s *RequestsTestSuite) TestDeleteJob() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - handler := newHandler(s.resourceType, basePath, apiVersion, s.db) + handler := newHandler(s.resourceType, basePath, apiVersion, s.connections) if tt.useMockService { mockSrv := service.MockService{} @@ -960,7 +965,7 @@ func (s *RequestsTestSuite) TestJobFailedStatus() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - h := newHandler(resourceMap, tt.basePath, tt.version, s.db) + h := newHandler(resourceMap, tt.basePath, tt.version, s.connections) mockSrv := service.MockService{} timestp := time.Now() mockSrv.On("GetJobAndKeys", testUtils.CtxMatcher, uint(1)).Return( @@ -1018,7 +1023,7 @@ func (s *RequestsTestSuite) TestGetResourceTypes() { {"CT000000", "v2", []string{"Patient", "ExplanationOfBenefit", "Coverage", "Claim", "ClaimResponse"}}, } for _, test := range testCases { - h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.db) + h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connections) rp := middleware.RequestParameters{ Version: test.apiVersion, ResourceTypes: []string{}, @@ -1051,23 +1056,15 @@ func TestBulkRequest_Integration(t *testing.T) { client.SetLogger(log.API) // Set logger so we don't get errors later - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo) - - cfg, err := database.LoadConfig() - if err != nil { - t.FailNow() - } - d, err := database.CreatePgxv5DB(cfg) - if err != nil { - t.FailNow() - } - driver := riverpgxv5.New(d) + connections := database.Connect() + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connections) + driver := riverpgxv5.New(connections.Pgxv5Pool) // start from clean river_job slate - _, err = driver.GetExecutor().Exec(context.Background(), `delete from river_job`) + _, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`) assert.Nil(t, err) acoID := "A0002" - repo := postgres.NewRepository(h.db) + repo := postgres.NewRepository(connections.Connection) // our DB is not always cleaned up properly so sometimes this record exists when this test runs and sometimes it doesnt repo.CreateACO(context.Background(), models.ACO{CMSID: &acoID, UUID: uuid.NewUUID()}) // nolint:errcheck @@ -1130,7 +1127,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques rctx := chi.NewRouteContext() rctx.URLParams.Add("groupId", groupID) - aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) @@ -1145,7 +1142,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/Patient/$export", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) ctx = middleware.SetRequestParamsCtx(ctx, rp) @@ -1156,7 +1153,7 @@ func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) * func (s *RequestsTestSuite) genASRequest() *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/attribution_status", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) @@ -1184,7 +1181,7 @@ func (s *RequestsTestSuite) genGetJobsRequest(version string, statuses []models. req := httptest.NewRequest("GET", target, nil) - aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) @@ -1205,7 +1202,7 @@ func (s *RequestsTestSuite) TestValidateResources() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.db) + }, fhirPath, apiVersion, s.connections) err := h.validateResources([]string{"Vegetable"}, "1234") assert.Contains(s.T(), err.Error(), "invalid resource type") } diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index 4e8311b02..9d6cf62c2 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -17,6 +17,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/health" "github.com/CMSgov/bcda-app/bcda/responseutils" "github.com/CMSgov/bcda-app/bcda/service" @@ -25,9 +26,12 @@ import ( "github.com/CMSgov/bcda-app/log" ) -var h *api.Handler +type ApiV1 struct { + handler *api.Handler + connections *database.Connections +} -func init() { +func NewApiV1(connections *database.Connections) *ApiV1 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -35,10 +39,11 @@ func init() { "Observation", }...) - if ok { - h = api.NewHandler(resources, "/v1/fhir", "v1") - } else { + if !ok { panic("Failed to configure resource DataTypes") + } else { + h := api.NewHandler(resources, "/v1/fhir", "v1", connections) + return &ApiV1{handler: h, connections: connections} } } @@ -64,8 +69,8 @@ Responses: 429: tooManyRequestsResponse 500: errorResponse */ -func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { - h.BulkPatientRequest(w, r) +func (a ApiV1) BulkPatientRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkPatientRequest(w, r) } /* @@ -92,8 +97,8 @@ func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { 429: tooManyRequestsResponse 500: errorResponse */ -func BulkGroupRequest(w http.ResponseWriter, r *http.Request) { - h.BulkGroupRequest(w, r) +func (a ApiV1) BulkGroupRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkGroupRequest(w, r) } /* @@ -122,8 +127,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobStatus(w http.ResponseWriter, r *http.Request) { - h.JobStatus(w, r) +func (a ApiV1) JobStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobStatus(w, r) } /* @@ -162,8 +167,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobsStatus(w http.ResponseWriter, r *http.Request) { - h.JobsStatus(w, r) +func (a ApiV1) JobsStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobsStatus(w, r) } type gzipResponseWriter struct { @@ -204,8 +209,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func DeleteJob(w http.ResponseWriter, r *http.Request) { - h.DeleteJob(w, r) +func (a ApiV1) DeleteJob(w http.ResponseWriter, r *http.Request) { + a.handler.DeleteJob(w, r) } /* @@ -229,8 +234,8 @@ Responses: 200: AttributionFileStatusResponse 404: notFoundResponse */ -func AttributionStatus(w http.ResponseWriter, r *http.Request) { - h.AttributionStatus(w, r) +func (a ApiV1) AttributionStatus(w http.ResponseWriter, r *http.Request) { + a.handler.AttributionStatus(w, r) } /* diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 7b9c7c143..6f868bad6 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -4,7 +4,6 @@ import ( "compress/gzip" "context" "crypto/tls" - "database/sql" "encoding/json" "fmt" "io" @@ -45,11 +44,15 @@ var ( type APITestSuite struct { suite.Suite - rr *httptest.ResponseRecorder - db *sql.DB + rr *httptest.ResponseRecorder + connections *database.Connections + apiV1 *ApiV1 } func (s *APITestSuite) SetupSuite() { + s.connections = database.Connect() + s.apiV1 = NewApiV1(s.connections) + origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) conf.SetEnv(s.T(), "BB_REQUEST_RETRY_INTERVAL_MS", "10") @@ -66,12 +69,11 @@ func (s *APITestSuite) SetupSuite() { } func (s *APITestSuite) SetupTest() { - s.db = database.Connection s.rr = httptest.NewRecorder() } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) } func (s *APITestSuite) TestJobStatusBadInputs() { @@ -98,7 +100,7 @@ func (s *APITestSuite) TestJobStatusBadInputs() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV1.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -131,13 +133,13 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV1.JobStatus(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) switch rr.Code { case http.StatusAccepted: @@ -146,7 +148,7 @@ func (s *APITestSuite) TestJobStatusNotComplete() { case http.StatusInternalServerError: assert.Contains(t, rr.Body.String(), "Service encountered numerous errors") case http.StatusGone: - assertExpiryEquals(t, j.CreatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(t, j.CreatedAt.Add(s.apiV1.handler.JobTimeout), rr.Header().Get("Expires")) } }) } @@ -159,25 +161,25 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.db, + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } req := s.createJobStatusRequest(acoUnderTest, j.ID) - JobStatus(s.rr, req) + s.apiV1.JobStatus(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) assert.Equal(s.T(), "application/json", s.rr.Header().Get(constants.ContentType)) str := s.rr.Header().Get("Expires") fmt.Println(str) - assertExpiryEquals(s.T(), j.CreatedAt.Add(h.JobTimeout), s.rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.CreatedAt.Add(s.apiV1.handler.JobTimeout), s.rr.Header().Get("Expires")) var rb api.BulkResponseBody err := json.Unmarshal(s.rr.Body.Bytes(), &rb) @@ -210,7 +212,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -218,7 +220,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.db, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -236,7 +238,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { } req := s.createJobStatusRequest(acoUnderTest, j.ID) - JobStatus(s.rr, req) + s.apiV1.JobStatus(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) assert.Equal(s.T(), "application/json", s.rr.Header().Get(constants.ContentType)) @@ -270,16 +272,16 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - j.UpdatedAt = time.Now().Add(-(h.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.db, j) + j.UpdatedAt = time.Now().Add(-(s.apiV1.handler.JobTimeout + time.Second)) + postgrestest.UpdateJob(s.T(), s.connections.Connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) - JobStatus(s.rr, req) + s.apiV1.JobStatus(s.rr, req) assert.Equal(s.T(), http.StatusGone, s.rr.Code) - assertExpiryEquals(s.T(), j.UpdatedAt.Add(h.JobTimeout), s.rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.UpdatedAt.Add(s.apiV1.handler.JobTimeout), s.rr.Header().Get("Expires")) } func (s *APITestSuite) TestDeleteJobBadInputs() { @@ -306,7 +308,7 @@ func (s *APITestSuite) TestDeleteJobBadInputs() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV1.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -339,13 +341,13 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - DeleteJob(rr, req) + s.apiV1.DeleteJob(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) if rr.Code == http.StatusGone { assert.Contains(t, rr.Body.String(), "job was not cancelled because it is not Pending or In Progress") @@ -454,9 +456,9 @@ func (s *APITestSuite) TestJobStatusWithWrongACO() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusPending, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - handler := auth.RequireTokenJobMatch(http.HandlerFunc(JobStatus)) + handler := auth.RequireTokenJobMatch(http.HandlerFunc(s.apiV1.JobStatus)) req := s.createJobStatusRequest(uuid.Parse(constants.LargeACOUUID), j.ID) handler.ServeHTTP(s.rr, req) @@ -476,10 +478,10 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -491,7 +493,7 @@ func (s *APITestSuite) TestJobsStatusNotFound() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - JobsStatus(rr, req) + s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -508,10 +510,10 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -528,10 +530,10 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -548,10 +550,10 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -583,22 +585,22 @@ func (s *APITestSuite) TestGetAttributionStatus() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - AttributionStatus(rr, req) + s.apiV1.AttributionStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) var resp api.AttributionFileStatusResponse err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v2/api.go b/bcda/api/v2/api.go index 454e1a81a..4d8412fb1 100644 --- a/bcda/api/v2/api.go +++ b/bcda/api/v2/api.go @@ -15,20 +15,20 @@ import ( api "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/service" "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" "github.com/CMSgov/bcda-app/log" ) -var ( - h *api.Handler - marshaller *jsonformat.Marshaller -) - -func init() { - var err error +type ApiV2 struct { + handler *api.Handler + marshaller *jsonformat.Marshaller + connections *database.Connections +} +func NewApiV2(connections *database.Connections) *ApiV2 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -37,17 +37,17 @@ func init() { "ClaimResponse", }...) - if ok { - h = api.NewHandler(resources, "/v2/fhir", "v2") - } else { + if !ok { panic("Failed to configure resource DataTypes") - } - - // Ensure that we write the serialized FHIR resources as a single line. - // Needed to comply with the NDJSON format that we are using. - marshaller, err = jsonformat.NewMarshaller(false, "", "", fhirversion.R4) - if err != nil { - log.API.Fatalf("Failed to create marshaller %s", err) + } else { + h := api.NewHandler(resources, "/v2/fhir", "v2", connections) + // Ensure that we write the serialized FHIR resources as a single line. + // Needed to comply with the NDJSON format that we are using. + marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) + if err != nil { + log.API.Fatalf("Failed to create marshaller %s", err) + } + return &ApiV2{marshaller: marshaller, handler: h, connections: connections} } } @@ -73,8 +73,8 @@ Responses: 429: tooManyRequestsResponse 500: errorResponse */ -func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { - h.BulkPatientRequest(w, r) +func (a ApiV2) BulkPatientRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkPatientRequest(w, r) } /* @@ -101,8 +101,8 @@ func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { 429: tooManyRequestsResponse 500: errorResponse */ -func BulkGroupRequest(w http.ResponseWriter, r *http.Request) { - h.BulkGroupRequest(w, r) +func (a ApiV2) BulkGroupRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkGroupRequest(w, r) } /* @@ -131,8 +131,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobStatus(w http.ResponseWriter, r *http.Request) { - h.JobStatus(w, r) +func (a ApiV2) JobStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobStatus(w, r) } /* @@ -171,8 +171,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobsStatus(w http.ResponseWriter, r *http.Request) { - h.JobsStatus(w, r) +func (a ApiV2) JobsStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobsStatus(w, r) } /* @@ -200,8 +200,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func DeleteJob(w http.ResponseWriter, r *http.Request) { - h.DeleteJob(w, r) +func (a ApiV2) DeleteJob(w http.ResponseWriter, r *http.Request) { + a.handler.DeleteJob(w, r) } /* @@ -225,8 +225,8 @@ Responses: 200: AttributionFileStatusResponse 404: notFoundResponse */ -func AttributionStatus(w http.ResponseWriter, r *http.Request) { - h.AttributionStatus(w, r) +func (a ApiV2) AttributionStatus(w http.ResponseWriter, r *http.Request) { + a.handler.AttributionStatus(w, r) } /* @@ -245,7 +245,7 @@ Responses: 200: MetadataResponse */ -func Metadata(w http.ResponseWriter, r *http.Request) { +func (a ApiV2) Metadata(w http.ResponseWriter, r *http.Request) { dt := time.Now() bbServer := conf.GetEnv("BB_SERVER_LOCATION") @@ -354,7 +354,7 @@ func Metadata(w http.ResponseWriter, r *http.Request) { resource := &fhirresources.ContainedResource{ OneofResource: &fhirresources.ContainedResource_CapabilityStatement{CapabilityStatement: statement}, } - b, err := marshaller.Marshal(resource) + b, err := a.marshaller.Marshal(resource) if err != nil { log.API.Errorf("Failed to marshal Capability Statement %s", err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/bcda/api/v2/api_test.go b/bcda/api/v2/api_test.go index d87074c3a..b0383a4bc 100644 --- a/bcda/api/v2/api_test.go +++ b/bcda/api/v2/api_test.go @@ -2,7 +2,6 @@ package v2 import ( "context" - "database/sql" "encoding/json" "fmt" "io" @@ -52,10 +51,14 @@ var ( type APITestSuite struct { suite.Suite - db *sql.DB + connections *database.Connections + apiV2 *ApiV2 } func (s *APITestSuite) SetupSuite() { + s.connections = database.Connect() + s.apiV2 = NewApiV2(s.connections) + origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) conf.SetEnv(s.T(), "BB_REQUEST_RETRY_INTERVAL_MS", "10") @@ -70,14 +73,12 @@ func (s *APITestSuite) SetupSuite() { conf.SetEnv(s.T(), "BB_CLIENT_KEY_FILE", origBBKey) }) - s.db = database.Connection - // Set up the logger since we're using the real client client.SetLogger(log.BBAPI) } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -108,7 +109,7 @@ func (s *APITestSuite) TestJobStatusBadInputs() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -141,13 +142,13 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) switch rr.Code { case http.StatusAccepted: @@ -156,7 +157,7 @@ func (s *APITestSuite) TestJobStatusNotComplete() { case http.StatusInternalServerError: assert.Contains(t, rr.Body.String(), "Service encountered numerous errors") case http.StatusGone: - assertExpiryEquals(t, j.CreatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(t, j.CreatedAt.Add(s.apiV2.handler.JobTimeout), rr.Header().Get("Expires")) } }) } @@ -169,27 +170,27 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.db, + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) assert.Equal(s.T(), constants.JsonContentType, rr.Header().Get(constants.ContentType)) str := rr.Header().Get("Expires") fmt.Println(str) - assertExpiryEquals(s.T(), j.CreatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.CreatedAt.Add(s.apiV2.handler.JobTimeout), rr.Header().Get("Expires")) var rb api.BulkResponseBody err := json.Unmarshal(rr.Body.Bytes(), &rb) @@ -222,7 +223,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -230,7 +231,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.db, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -250,7 +251,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) assert.Equal(s.T(), constants.JsonContentType, rr.Header().Get(constants.ContentType)) @@ -284,18 +285,18 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - j.UpdatedAt = time.Now().Add(-(h.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.db, j) + j.UpdatedAt = time.Now().Add(-(s.apiV2.handler.JobTimeout + time.Second)) + postgrestest.UpdateJob(s.T(), s.connections.Connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) assert.Equal(s.T(), http.StatusGone, rr.Code) - assertExpiryEquals(s.T(), j.UpdatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.UpdatedAt.Add(s.apiV2.handler.JobTimeout), rr.Header().Get("Expires")) } func (s *APITestSuite) TestJobsStatus() { @@ -311,10 +312,10 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -326,7 +327,7 @@ func (s *APITestSuite) TestJobsStatusNotFound() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - JobsStatus(rr, req) + s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -343,10 +344,10 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -363,10 +364,10 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -383,10 +384,10 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -413,7 +414,7 @@ func (s *APITestSuite) TestDeleteJobBadInputs() { req = req.WithContext(context.WithValue(req.Context(), auth.AuthDataContextKey, ad)) newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV2.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -446,8 +447,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: "/api/v2/Patient/$export?_type=Patient,Coverage", Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -455,7 +456,7 @@ func (s *APITestSuite) TestDeleteJob() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - DeleteJob(rr, req) + s.apiV2.DeleteJob(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) if rr.Code == http.StatusGone { assert.Contains(t, rr.Body.String(), "job was not cancelled because it is not Pending or In Progress") @@ -464,7 +465,7 @@ func (s *APITestSuite) TestDeleteJob() { } } func (s *APITestSuite) TestMetadataResponse() { - ts := httptest.NewServer(http.HandlerFunc(Metadata)) + ts := httptest.NewServer(http.HandlerFunc(s.apiV2.Metadata)) defer ts.Close() unmarshaller, err := jsonformat.NewUnmarshaller("UTC", fhirversion.R4) @@ -542,7 +543,7 @@ func (s *APITestSuite) TestResourceTypes() { "ClaimResponse", }...) - h := api.NewHandler(resources, "/v2/fhir", "v2") + h := api.NewHandler(resources, "/v2/fhir", "v2", s.connections) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -602,27 +603,27 @@ func (s *APITestSuite) TestGetAttributionStatus() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - AttributionStatus(rr, req) + s.apiV2.AttributionStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) var resp api.AttributionFileStatusResponse err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v3/api.go b/bcda/api/v3/api.go index 5cbb3c419..86fc63cbd 100644 --- a/bcda/api/v3/api.go +++ b/bcda/api/v3/api.go @@ -7,6 +7,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/service" "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" @@ -21,31 +22,30 @@ import ( fhirvaluesets "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/valuesets_go_proto" ) -var ( - h *api.Handler - marshaller *jsonformat.Marshaller -) - -func init() { - var err error +type ApiV3 struct { + handler *api.Handler + marshaller *jsonformat.Marshaller + connections *database.Connections +} +func NewApiV3(connections *database.Connections) *ApiV3 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", "ExplanationOfBenefit", }...) - if ok { - h = api.NewHandler(resources, constants.BFDV3Path, constants.V3Version) - } else { + if !ok { panic("Failed to configure resource DataTypes") - } - - // Ensure that we write the serialized FHIR resources as a single line. - // Needed to comply with the NDJSON format that we are using. - marshaller, err = jsonformat.NewMarshaller(false, "", "", fhirversion.R4) - if err != nil { - log.API.Fatalf("Failed to create marshaller %s", err) + } else { + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connections) + // Ensure that we write the serialized FHIR resources as a single line. + // Needed to comply with the NDJSON format that we are using. + marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) + if err != nil { + log.API.Fatalf("Failed to create marshaller %s", err) + } + return &ApiV3{marshaller: marshaller, handler: h, connections: connections} } } @@ -73,8 +73,8 @@ Responses: 429: tooManyRequestsResponse 500: errorResponse */ -func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { - h.BulkPatientRequest(w, r) +func (a ApiV3) BulkPatientRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkPatientRequest(w, r) } /* @@ -101,8 +101,8 @@ func BulkPatientRequest(w http.ResponseWriter, r *http.Request) { 429: tooManyRequestsResponse 500: errorResponse */ -func BulkGroupRequest(w http.ResponseWriter, r *http.Request) { - h.BulkGroupRequest(w, r) +func (a ApiV3) BulkGroupRequest(w http.ResponseWriter, r *http.Request) { + a.handler.BulkGroupRequest(w, r) } /* @@ -131,8 +131,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobStatus(w http.ResponseWriter, r *http.Request) { - h.JobStatus(w, r) +func (a ApiV3) JobStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobStatus(w, r) } /* @@ -171,8 +171,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func JobsStatus(w http.ResponseWriter, r *http.Request) { - h.JobsStatus(w, r) +func (a ApiV3) JobsStatus(w http.ResponseWriter, r *http.Request) { + a.handler.JobsStatus(w, r) } /* @@ -200,8 +200,8 @@ Responses: 410: goneResponse 500: errorResponse */ -func DeleteJob(w http.ResponseWriter, r *http.Request) { - h.DeleteJob(w, r) +func (a ApiV3) DeleteJob(w http.ResponseWriter, r *http.Request) { + a.handler.DeleteJob(w, r) } /* @@ -225,8 +225,8 @@ Responses: 200: AttributionFileStatusResponse 404: notFoundResponse */ -func AttributionStatus(w http.ResponseWriter, r *http.Request) { - h.AttributionStatus(w, r) +func (a ApiV3) AttributionStatus(w http.ResponseWriter, r *http.Request) { + a.handler.AttributionStatus(w, r) } /* @@ -245,7 +245,7 @@ Responses: 200: MetadataResponse */ -func Metadata(w http.ResponseWriter, r *http.Request) { +func (a ApiV3) Metadata(w http.ResponseWriter, r *http.Request) { dt := time.Now() bbServer := conf.GetEnv("BB_SERVER_LOCATION") @@ -354,7 +354,7 @@ func Metadata(w http.ResponseWriter, r *http.Request) { resource := &fhirresources.ContainedResource{ OneofResource: &fhirresources.ContainedResource_CapabilityStatement{CapabilityStatement: statement}, } - b, err := marshaller.Marshal(resource) + b, err := a.marshaller.Marshal(resource) if err != nil { log.API.Errorf("Failed to marshal Capability Statement %s", err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/bcda/api/v3/api_test.go b/bcda/api/v3/api_test.go index 5c7f2246e..43fb2bfa7 100644 --- a/bcda/api/v3/api_test.go +++ b/bcda/api/v3/api_test.go @@ -2,7 +2,6 @@ package v3 import ( "context" - "database/sql" "encoding/json" "fmt" "io" @@ -52,10 +51,14 @@ var ( type APITestSuite struct { suite.Suite - db *sql.DB + connections *database.Connections + apiV3 *ApiV3 } func (s *APITestSuite) SetupSuite() { + s.connections = database.Connect() + s.apiV3 = NewApiV3(s.connections) + origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) conf.SetEnv(s.T(), "BB_REQUEST_RETRY_INTERVAL_MS", "10") @@ -70,14 +73,12 @@ func (s *APITestSuite) SetupSuite() { conf.SetEnv(s.T(), "BB_CLIENT_KEY_FILE", origBBKey) }) - s.db = database.Connection - // Set up the logger since we're using the real client client.SetLogger(log.BBAPI) } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -108,7 +109,7 @@ func (s *APITestSuite) TestJobStatusBadInputs() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -141,13 +142,13 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) switch rr.Code { case http.StatusAccepted: @@ -156,7 +157,7 @@ func (s *APITestSuite) TestJobStatusNotComplete() { case http.StatusInternalServerError: assert.Contains(t, rr.Body.String(), "Service encountered numerous errors") case http.StatusGone: - assertExpiryEquals(t, j.CreatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(t, j.CreatedAt.Add(s.apiV3.handler.JobTimeout), rr.Header().Get("Expires")) } }) } @@ -174,27 +175,27 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.db, + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) assert.Equal(s.T(), constants.JsonContentType, rr.Header().Get(constants.ContentType)) str := rr.Header().Get("Expires") fmt.Println(str) - assertExpiryEquals(s.T(), j.CreatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.CreatedAt.Add(s.apiV3.handler.JobTimeout), rr.Header().Get("Expires")) var rb api.BulkResponseBody err := json.Unmarshal(rr.Body.Bytes(), &rb) @@ -227,7 +228,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -235,7 +236,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.db, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -255,7 +256,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) assert.Equal(s.T(), constants.JsonContentType, rr.Header().Get(constants.ContentType)) @@ -289,18 +290,18 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - j.UpdatedAt = time.Now().Add(-(h.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.db, j) + j.UpdatedAt = time.Now().Add(-(s.apiV3.handler.JobTimeout + time.Second)) + postgrestest.UpdateJob(s.T(), s.connections.Connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) assert.Equal(s.T(), http.StatusGone, rr.Code) - assertExpiryEquals(s.T(), j.UpdatedAt.Add(h.JobTimeout), rr.Header().Get("Expires")) + assertExpiryEquals(s.T(), j.UpdatedAt.Add(s.apiV3.handler.JobTimeout), rr.Header().Get("Expires")) } func (s *APITestSuite) TestJobsStatus() { @@ -316,10 +317,10 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -331,7 +332,7 @@ func (s *APITestSuite) TestJobsStatusNotFound() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - JobsStatus(rr, req) + s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -348,10 +349,10 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) } @@ -368,10 +369,10 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -388,10 +389,10 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.db, &j) - defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) + postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) - JobsStatus(rr, req) + s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) } @@ -418,7 +419,7 @@ func (s *APITestSuite) TestDeleteJobBadInputs() { req = req.WithContext(context.WithValue(req.Context(), auth.AuthDataContextKey, ad)) newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - JobStatus(rr, req) + s.apiV3.JobStatus(rr, req) respOO := getOperationOutcome(t, rr.Body.Bytes()) @@ -451,8 +452,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=Patient,Coverage", constants.V3Path), Status: tt.status, } - postgrestest.CreateJobs(t, s.db, &j) - defer postgrestest.DeleteJobByID(t, s.db, j.ID) + postgrestest.CreateJobs(t, s.connections.Connection, &j) + defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -460,7 +461,7 @@ func (s *APITestSuite) TestDeleteJob() { newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) - DeleteJob(rr, req) + s.apiV3.DeleteJob(rr, req) assert.Equal(t, tt.expStatusCode, rr.Code) if rr.Code == http.StatusGone { assert.Contains(t, rr.Body.String(), "job was not cancelled because it is not Pending or In Progress") @@ -469,7 +470,7 @@ func (s *APITestSuite) TestDeleteJob() { } } func (s *APITestSuite) TestMetadataResponse() { - ts := httptest.NewServer(http.HandlerFunc(Metadata)) + ts := httptest.NewServer(http.HandlerFunc(s.apiV3.Metadata)) defer ts.Close() unmarshaller, err := jsonformat.NewUnmarshaller("UTC", fhirversion.R4) @@ -545,7 +546,7 @@ func (s *APITestSuite) TestResourceTypes() { "ExplanationOfBenefit", }...) - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connections) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -605,27 +606,27 @@ func (s *APITestSuite) TestGetAttributionStatus() { req = req.WithContext(context.WithValue(req.Context(), log.CtxLoggerKey, newLogEntry)) rr := httptest.NewRecorder() - AttributionStatus(rr, req) + s.apiV3.AttributionStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) var resp api.AttributionFileStatusResponse err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index 64e7ca474..b1b8b8e0d 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -31,6 +31,9 @@ var ( AuthDataContextKey = &contextKey{"ad"} ) +type Middleware struct { +} + // ParseToken puts the decoded token and AuthData value into the request context. Decoded values come from // tokens verified by our provider as correct and unexpired. Tokens may be presented in requests to // unauthenticated endpoints (mostly swagger?). We still want to extract the token data for logging purposes, diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index c232dfa64..822eb81eb 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -3,7 +3,6 @@ package bcdacli import ( "archive/zip" "context" - "database/sql" "encoding/json" "fmt" "io" @@ -47,8 +46,8 @@ const Name = "bcda" const Usage = "Beneficiary Claims Data API CLI" var ( - db *sql.DB - r models.Repository + connections *database.Connections + r models.Repository ) func GetApp() *cli.App { @@ -61,8 +60,8 @@ func setUpApp() *cli.App { app.Usage = Usage app.Version = constants.Version app.Before = func(c *cli.Context) error { - db = database.Connection - r = postgres.NewRepository(db) + connections = database.Connect() + r = postgres.NewRepository(connections.Connection) return nil } var hours, err = safecast.ToUint(utils.GetEnvInt("FILE_ARCHIVE_THRESHOLD_HR", 72)) @@ -122,7 +121,7 @@ func setUpApp() *cli.App { } api := &http.Server{ - Handler: web.NewAPIRouter(), + Handler: web.NewAPIRouter(connections), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, diff --git a/bcda/web/router.go b/bcda/web/router.go index e8be8c4cd..cfb6dce2c 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -28,7 +28,7 @@ var commonAuth = []func(http.Handler) http.Handler{ auth.RequireTokenAuth, auth.CheckBlacklist} -func NewAPIRouter() http.Handler { +func NewAPIRouter(connections *database.Connections) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, auth.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) @@ -54,37 +54,40 @@ func NewAPIRouter() http.Handler { } r.Route("/api/v1", func(r chi.Router) { - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", v1.BulkPatientRequest)) - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", v1.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, v1.JobStatus)) - r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", v1.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, v1.DeleteJob)) - r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", v1.AttributionStatus)) + apiV1 := v1.NewApiV1(connections) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) + r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV1.JobsStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) + r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV1.AttributionStatus)) r.Get(m.WrapHandler("/metadata", v1.Metadata)) }) if utils.GetEnvBool("VERSION_2_ENDPOINT_ACTIVE", true) { FileServer(r, "/api/v2/swagger", http.Dir("./swaggerui/v2")) + apiV2 := v2.NewApiV2(connections) r.Route("/api/v2", func(r chi.Router) { - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", v2.BulkPatientRequest)) - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", v2.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, v2.JobStatus)) - r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", v2.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, v2.DeleteJob)) - r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", v2.AttributionStatus)) - r.Get(m.WrapHandler("/metadata", v2.Metadata)) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) + r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV2.JobsStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) + r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV2.AttributionStatus)) + r.Get(m.WrapHandler("/metadata", apiV2.Metadata)) }) } if utils.GetEnvBool("VERSION_3_ENDPOINT_ACTIVE", true) { + apiV3 := v3.NewApiV3(connections) r.Route("/api/demo", func(r chi.Router) { - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", v3.BulkPatientRequest)) - r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", v3.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, v3.JobStatus)) - r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", v3.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, v3.DeleteJob)) - r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", v3.AttributionStatus)) - r.Get(m.WrapHandler("/metadata", v3.Metadata)) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) + r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) + r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV3.JobsStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) + r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV3.AttributionStatus)) + r.Get(m.WrapHandler("/metadata", apiV3.Metadata)) }) } diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index 6c9f09e3e..eaf9e298c 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -29,14 +29,17 @@ var nDJsonDataRoute string = "/data/test/test.ndjson" type RouterTestSuite struct { suite.Suite - apiRouter http.Handler - dataRouter http.Handler + apiRouter http.Handler + dataRouter http.Handler + connections database.Connections } func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") - s.apiRouter = NewAPIRouter() + s.connections = *database.Connect() + s.apiRouter = NewAPIRouter(&s.connections) s.dataRouter = NewDataRouter() + } func (s *RouterTestSuite) getAPIRoute(route string) *http.Response { @@ -76,7 +79,7 @@ func (s *RouterTestSuite) TestDefaultProdRoute() { s.FailNow("err in setting env var", err) } // Need a new router because the one in the test setup does not use the environment variable set in this test. - s.apiRouter = NewAPIRouter() + s.apiRouter = NewAPIRouter(&s.connections) res := s.getAPIRoute("/v1/") assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -193,7 +196,7 @@ func (s *RouterTestSuite) TestV2EndpointsDisabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter() + s.apiRouter = NewAPIRouter(&s.connections) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -210,7 +213,7 @@ func (s *RouterTestSuite) TestV2EndpointsEnabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter() + s.apiRouter = NewAPIRouter(&s.connections) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -231,7 +234,7 @@ func (s *RouterTestSuite) TestV3EndpointsDisabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter() + s.apiRouter = NewAPIRouter(&s.connections) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -248,7 +251,7 @@ func (s *RouterTestSuite) TestV3EndpointsEnabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter() + s.apiRouter = NewAPIRouter(&s.connections) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -348,7 +351,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str handler http.Handler paths []string }) { - apiRouter := NewAPIRouter() + apiRouter := NewAPIRouter(&s.connections) configs = []struct { handler http.Handler From ca1b0f136e183a2dba0974eb1262fafd5c6dc987 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 14:23:21 -0400 Subject: [PATCH 03/28] Remove unintended change to middleware --- bcda/auth/middleware.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index b1b8b8e0d..64e7ca474 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -31,9 +31,6 @@ var ( AuthDataContextKey = &contextKey{"ad"} ) -type Middleware struct { -} - // ParseToken puts the decoded token and AuthData value into the request context. Decoded values come from // tokens verified by our provider as correct and unexpired. Tokens may be presented in requests to // unauthenticated endpoints (mostly swagger?). We still want to extract the token data for logging purposes, From bb1fd60ecf701a0427ee2001da55626cfeeddad3 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 14:27:27 -0400 Subject: [PATCH 04/28] Separate connection and pool --- bcda/database/connection.go | 79 ++++++++++++++++---------------- bcda/database/connection_test.go | 4 +- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/bcda/database/connection.go b/bcda/database/connection.go index f17152f89..3d510e808 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -30,15 +30,16 @@ type Connections struct { } func init() { - c := Connect() - Connection = c.Connection - QueueConnection = c.QueueConnection - Pgxv5Pool = c.Pgxv5Pool + Connection = GetConnection() + Pgxv5Pool = GetPool() } func Connect() *Connections { - cfg, err := LoadConfig() + return nil +} +func GetConnection() *sql.DB { + cfg, err := LoadConfig() if err != nil { logrus.Fatalf("Failed to load database config %s", err.Error()) } @@ -48,9 +49,22 @@ func Connect() *Connections { logrus.Fatalf("Failed to create db %s", err.Error()) } - queue, err := createQueue(cfg) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + startConnectionHealthCheck( + ctx, + conn, + time.Duration(cfg.HealthCheckSec)*time.Second, + ) + + return conn +} + +func GetPool() *pgxv5Pool.Pool { + cfg, err := LoadConfig() if err != nil { - logrus.Fatalf("Failed to create queue %s", err.Error()) + logrus.Fatalf("Failed to load database config %s", err.Error()) } pool, err := CreatePgxv5DB(cfg) @@ -61,15 +75,13 @@ func Connect() *Connections { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - startHealthCheck( + startPoolHealthCheck( ctx, - conn, - queue, pool, time.Duration(cfg.HealthCheckSec)*time.Second, ) - return &Connections{conn, queue, pool} + return pool } func createDB(cfg *Config) (*sql.DB, error) { @@ -102,23 +114,6 @@ func createDB(cfg *Config) (*sql.DB, error) { return db, nil } -func createQueue(cfg *Config) (*pgx.ConnPool, error) { - pgxCfg, err := pgx.ParseURI(strings.TrimSpace(cfg.QueueDatabaseURL)) - if err != nil { - return nil, err - } - - pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{ - ConnConfig: pgxCfg, - MaxConnections: cfg.MaxOpenConns, - }) - if err != nil { - return nil, err - } - - return pool, err -} - func CreatePgxv5DB(cfg *Config) (*pgxv5Pool.Pool, error) { ctx := context.Background() @@ -152,7 +147,7 @@ func CreatePgxv5DB(cfg *Config) (*pgxv5Pool.Pool, error) { // // startHealthCheck returns immediately with the health check running in a goroutine that // can be stopped via the supplied context -func startHealthCheck(ctx context.Context, db *sql.DB, pool *pgx.ConnPool, pgxv5Pool *pgxv5Pool.Pool, interval time.Duration) { +func startConnectionHealthCheck(ctx context.Context, db *sql.DB, interval time.Duration) { go func() { ticker := time.NewTicker(interval) for { @@ -168,19 +163,23 @@ func startHealthCheck(ctx context.Context, db *sql.DB, pool *pgx.ConnPool, pgxv5 if err := db.Ping(); err != nil { logrus.Warnf("Failed to ping %s", err.Error()) } + } + } + }() +} - // Acquire and ping Queue DB - c, err := pool.Acquire() - if err != nil { - logrus.Warnf("Failed to acquire Queue DB connection %s", err.Error()) - continue - } - if err := c.Ping(context.Background()); err != nil { - logrus.Warnf("Failed to ping Queue DB %s", err.Error()) - } - pool.Release(c) +func startPoolHealthCheck(ctx context.Context, pgxv5Pool *pgxv5Pool.Pool, interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + for { + select { + case <-ctx.Done(): + ticker.Stop() + logrus.Debug("Stopping health checker") + return + case <-ticker.C: + logrus.StandardLogger().Debug("Sending ping") - // Acquire and ping pgxv5 connection to App DB pgxv5Conn, err := pgxv5Pool.Acquire(ctx) if err != nil { logrus.Warnf("Failed to acquire pgxv5 App DB connection: %s", err.Error()) diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index f52a28ea9..8cb6dc23d 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -26,7 +26,7 @@ func TestConnections(t *testing.T) { // TestHealthCheck verifies that we are able to start the health check // and the checks do not cause a panic by waiting some amount of time // to ensure that health checks are being executed. -func TestHealthCheck(t *testing.T) { +func TestConnectionHealthCheck(t *testing.T) { level, reporter := logrus.GetLevel(), logrus.StandardLogger().ReportCaller t.Cleanup(func() { logrus.SetLevel(level) @@ -38,7 +38,7 @@ func TestHealthCheck(t *testing.T) { hook := test.NewGlobal() ctx, cancel := context.WithCancel(context.Background()) - startHealthCheck(ctx, Connection, QueueConnection, Pgxv5Pool, 100*time.Microsecond) + startConnectionHealthCheck(ctx, Connection, 100*time.Microsecond) // Let some time elapse to ensure we've successfully ran health checks time.Sleep(50 * time.Millisecond) cancel() From 6aa2fa2ff25105bba84eb2e7b42fa5e291bf4312 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 14:36:22 -0400 Subject: [PATCH 05/28] Pass connection through routers and services instead of combined struct --- bcda/api/requests.go | 9 +++--- bcda/api/requests_test.go | 46 ++++++++++++++++-------------- bcda/api/v1/api.go | 12 ++++---- bcda/api/v1/api_test.go | 57 +++++++++++++++++++------------------ bcda/api/v2/api.go | 14 ++++----- bcda/api/v2/api_test.go | 57 +++++++++++++++++++------------------ bcda/api/v3/api.go | 14 ++++----- bcda/api/v3/api_test.go | 57 +++++++++++++++++++------------------ bcda/bcdacli/cli.go | 11 +++---- bcda/database/connection.go | 10 ------- bcda/web/router.go | 9 +++--- bcda/web/router_test.go | 23 ++++++++------- 12 files changed, 158 insertions(+), 161 deletions(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 8657e5ad3..625b68496 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -21,7 +21,6 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" responseutils "github.com/CMSgov/bcda-app/bcda/responseutils" @@ -62,11 +61,11 @@ type fhirResponseWriter interface { JobsBundle(context.Context, http.ResponseWriter, []*models.Job, string) } -func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connections *database.Connections) *Handler { - return newHandler(dataTypes, basePath, apiVersion, connections) +func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB) *Handler { + return newHandler(dataTypes, basePath, apiVersion, connection) } -func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connections *database.Connections) *Handler { +func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB) *Handler { h := &Handler{JobTimeout: time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24))} h.Enq = queueing.NewEnqueuer() @@ -79,7 +78,7 @@ func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersi log.API.Fatalf("no ACO configs found, these are required for processing logic") } - repository := postgres.NewRepository(connections.Connection) + repository := postgres.NewRepository(connection) h.r = repository h.Svc = service.NewService(repository, cfg, basePath) diff --git a/bcda/api/requests_test.go b/bcda/api/requests_test.go index 88742a782..75b54b0a1 100644 --- a/bcda/api/requests_test.go +++ b/bcda/api/requests_test.go @@ -65,7 +65,7 @@ type RequestsTestSuite struct { runoutEnabledEnvVar string - connections *database.Connections + connection *sql.DB acoID uuid.UUID @@ -80,7 +80,7 @@ func (s *RequestsTestSuite) SetupSuite() { // See testdata/acos.yml s.acoID = uuid.Parse("ba21d24d-cd96-4d7d-a691-b0e8c88e67a5") db, _ := databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) - s.connections = &database.Connections{Connection: db} + s.connection = db tf, err := testfixtures.New( testfixtures.Database(db), testfixtures.Dialect("postgres"), @@ -138,7 +138,7 @@ func (s *RequestsTestSuite) TestRunoutEnabled() { mockSvc := &service.MockService{} mockAco := service.ACOConfig{Data: []string{"adjudicated"}} mockSvc.On("GetACOConfigForID", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAco, true) - h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connections) + h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connection) h.Svc = mockSvc enqueuer := queueing.NewMockEnqueuer(s.T()) h.Enq = enqueuer @@ -240,7 +240,7 @@ func (s *RequestsTestSuite) TestJobsStatusV1() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connections) + }, fhirPath, apiVersion, s.connection) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -354,7 +354,7 @@ func (s *RequestsTestSuite) TestJobsStatusV2() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, v2BasePath, apiVersionTwo, s.connections) + }, v2BasePath, apiVersionTwo, s.connection) if tt.useMock { h.Svc = mockSvc } @@ -473,7 +473,7 @@ func (s *RequestsTestSuite) TestAttributionStatus() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connections) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connection) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -564,7 +564,7 @@ func (s *RequestsTestSuite) TestDataTypeAuthorization() { "ClaimResponse": {Adjudicated: false, PartiallyAdjudicated: true}, } - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connections) + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connection) r := models.NewMockRepository(s.T()) r.On("CreateJob", mock.Anything, mock.Anything).Return(uint(4), nil) h.r = r @@ -652,7 +652,7 @@ func (s *RequestsTestSuite) TestRequests() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connections) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connection) // Test Group and Patient // Patient, Coverage, and ExplanationOfBenefit @@ -782,7 +782,7 @@ func (s *RequestsTestSuite) TestJobStatusErrorHandling() { for _, tt := range tests { s.T().Run(tt.testName, func(t *testing.T) { - h := newHandler(resourceMap, basePath, apiVersion, s.connections) + h := newHandler(resourceMap, basePath, apiVersion, s.connection) if tt.useMockService { mockSrv := service.MockService{} timestp := time.Now() @@ -856,7 +856,7 @@ func (s *RequestsTestSuite) TestJobStatusProgress() { apiVersion := apiVersionTwo requestUrl := v2JobRequestUrl resourceMap := s.resourceType - h := newHandler(resourceMap, basePath, apiVersion, s.connections) + h := newHandler(resourceMap, basePath, apiVersion, s.connection) req := httptest.NewRequest("GET", requestUrl, nil) rctx := chi.NewRouteContext() @@ -905,7 +905,7 @@ func (s *RequestsTestSuite) TestDeleteJob() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - handler := newHandler(s.resourceType, basePath, apiVersion, s.connections) + handler := newHandler(s.resourceType, basePath, apiVersion, s.connection) if tt.useMockService { mockSrv := service.MockService{} @@ -965,7 +965,7 @@ func (s *RequestsTestSuite) TestJobFailedStatus() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - h := newHandler(resourceMap, tt.basePath, tt.version, s.connections) + h := newHandler(resourceMap, tt.basePath, tt.version, s.connection) mockSrv := service.MockService{} timestp := time.Now() mockSrv.On("GetJobAndKeys", testUtils.CtxMatcher, uint(1)).Return( @@ -1023,7 +1023,7 @@ func (s *RequestsTestSuite) TestGetResourceTypes() { {"CT000000", "v2", []string{"Patient", "ExplanationOfBenefit", "Coverage", "Claim", "ClaimResponse"}}, } for _, test := range testCases { - h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connections) + h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connection) rp := middleware.RequestParameters{ Version: test.apiVersion, ResourceTypes: []string{}, @@ -1056,15 +1056,17 @@ func TestBulkRequest_Integration(t *testing.T) { client.SetLogger(log.API) // Set logger so we don't get errors later - connections := database.Connect() - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connections) - driver := riverpgxv5.New(connections.Pgxv5Pool) + connection := database.GetConnection() + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connection) + + pool := database.GetPool() + driver := riverpgxv5.New(pool) // start from clean river_job slate _, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`) assert.Nil(t, err) acoID := "A0002" - repo := postgres.NewRepository(connections.Connection) + repo := postgres.NewRepository(connection) // our DB is not always cleaned up properly so sometimes this record exists when this test runs and sometimes it doesnt repo.CreateACO(context.Background(), models.ACO{CMSID: &acoID, UUID: uuid.NewUUID()}) // nolint:errcheck @@ -1127,7 +1129,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques rctx := chi.NewRouteContext() rctx.URLParams.Add("groupId", groupID) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) @@ -1142,7 +1144,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/Patient/$export", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) ctx = middleware.SetRequestParamsCtx(ctx, rp) @@ -1153,7 +1155,7 @@ func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) * func (s *RequestsTestSuite) genASRequest() *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/attribution_status", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) @@ -1181,7 +1183,7 @@ func (s *RequestsTestSuite) genGetJobsRequest(version string, statuses []models. req := httptest.NewRequest("GET", target, nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) @@ -1202,7 +1204,7 @@ func (s *RequestsTestSuite) TestValidateResources() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connections) + }, fhirPath, apiVersion, s.connection) err := h.validateResources([]string{"Vegetable"}, "1234") assert.Contains(s.T(), err.Error(), "invalid resource type") } diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index 9d6cf62c2..c5df042f7 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -3,6 +3,7 @@ package v1 import ( "bytes" "compress/gzip" + "database/sql" "encoding/json" "errors" "fmt" @@ -17,7 +18,6 @@ import ( "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/health" "github.com/CMSgov/bcda-app/bcda/responseutils" "github.com/CMSgov/bcda-app/bcda/service" @@ -27,11 +27,11 @@ import ( ) type ApiV1 struct { - handler *api.Handler - connections *database.Connections + handler *api.Handler + connection *sql.DB } -func NewApiV1(connections *database.Connections) *ApiV1 { +func NewApiV1(connection *sql.DB) *ApiV1 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -42,8 +42,8 @@ func NewApiV1(connections *database.Connections) *ApiV1 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, "/v1/fhir", "v1", connections) - return &ApiV1{handler: h, connections: connections} + h := api.NewHandler(resources, "/v1/fhir", "v1", connection) + return &ApiV1{handler: h, connection: connection} } } diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 6f868bad6..05390fae7 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "database/sql" "encoding/json" "fmt" "io" @@ -44,14 +45,14 @@ var ( type APITestSuite struct { suite.Suite - rr *httptest.ResponseRecorder - connections *database.Connections - apiV1 *ApiV1 + rr *httptest.ResponseRecorder + connection *sql.DB + apiV1 *ApiV1 } func (s *APITestSuite) SetupSuite() { - s.connections = database.Connect() - s.apiV1 = NewApiV1(s.connections) + s.connection = database.GetConnection() + s.apiV1 = NewApiV1(s.connection) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -73,7 +74,7 @@ func (s *APITestSuite) SetupTest() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) } func (s *APITestSuite) TestJobStatusBadInputs() { @@ -133,8 +134,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -161,14 +162,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, + postgrestest.CreateJobKeys(s.T(), s.connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -212,7 +213,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -220,7 +221,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -272,10 +273,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) j.UpdatedAt = time.Now().Add(-(s.apiV1.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connections.Connection, j) + postgrestest.UpdateJob(s.T(), s.connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) s.apiV1.JobStatus(s.rr, req) @@ -341,8 +342,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -456,7 +457,7 @@ func (s *APITestSuite) TestJobStatusWithWrongACO() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusPending, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) handler := auth.RequireTokenJobMatch(http.HandlerFunc(s.apiV1.JobStatus)) req := s.createJobStatusRequest(uuid.Parse(constants.LargeACOUUID), j.ID) @@ -478,8 +479,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -510,8 +511,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -530,8 +531,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -550,8 +551,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -592,15 +593,15 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v2/api.go b/bcda/api/v2/api.go index 4d8412fb1..974a033c3 100644 --- a/bcda/api/v2/api.go +++ b/bcda/api/v2/api.go @@ -1,6 +1,7 @@ package v2 import ( + "database/sql" "fmt" "net/http" "time" @@ -15,7 +16,6 @@ import ( api "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/service" "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" @@ -23,12 +23,12 @@ import ( ) type ApiV2 struct { - handler *api.Handler - marshaller *jsonformat.Marshaller - connections *database.Connections + handler *api.Handler + marshaller *jsonformat.Marshaller + connection *sql.DB } -func NewApiV2(connections *database.Connections) *ApiV2 { +func NewApiV2(connection *sql.DB) *ApiV2 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -40,14 +40,14 @@ func NewApiV2(connections *database.Connections) *ApiV2 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, "/v2/fhir", "v2", connections) + h := api.NewHandler(resources, "/v2/fhir", "v2", connection) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) if err != nil { log.API.Fatalf("Failed to create marshaller %s", err) } - return &ApiV2{marshaller: marshaller, handler: h, connections: connections} + return &ApiV2{marshaller: marshaller, handler: h, connection: connection} } } diff --git a/bcda/api/v2/api_test.go b/bcda/api/v2/api_test.go index b0383a4bc..a6fea91e2 100644 --- a/bcda/api/v2/api_test.go +++ b/bcda/api/v2/api_test.go @@ -2,6 +2,7 @@ package v2 import ( "context" + "database/sql" "encoding/json" "fmt" "io" @@ -51,13 +52,13 @@ var ( type APITestSuite struct { suite.Suite - connections *database.Connections - apiV2 *ApiV2 + connection *sql.DB + apiV2 *ApiV2 } func (s *APITestSuite) SetupSuite() { - s.connections = database.Connect() - s.apiV2 = NewApiV2(s.connections) + s.connection = database.GetConnection() + s.apiV2 = NewApiV2(s.connection) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -78,7 +79,7 @@ func (s *APITestSuite) SetupSuite() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -142,8 +143,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -170,14 +171,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, + postgrestest.CreateJobKeys(s.T(), s.connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -223,7 +224,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -231,7 +232,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -285,10 +286,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) j.UpdatedAt = time.Now().Add(-(s.apiV2.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connections.Connection, j) + postgrestest.UpdateJob(s.T(), s.connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -312,8 +313,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -344,8 +345,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -364,8 +365,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -384,8 +385,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -447,8 +448,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: "/api/v2/Patient/$export?_type=Patient,Coverage", Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -543,7 +544,7 @@ func (s *APITestSuite) TestResourceTypes() { "ClaimResponse", }...) - h := api.NewHandler(resources, "/v2/fhir", "v2", s.connections) + h := api.NewHandler(resources, "/v2/fhir", "v2", s.connection) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -610,20 +611,20 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v3/api.go b/bcda/api/v3/api.go index 86fc63cbd..5fe0fba97 100644 --- a/bcda/api/v3/api.go +++ b/bcda/api/v3/api.go @@ -1,13 +1,13 @@ package v3 import ( + "database/sql" "fmt" "net/http" "time" "github.com/CMSgov/bcda-app/bcda/api" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/service" "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" @@ -23,12 +23,12 @@ import ( ) type ApiV3 struct { - handler *api.Handler - marshaller *jsonformat.Marshaller - connections *database.Connections + handler *api.Handler + marshaller *jsonformat.Marshaller + connection *sql.DB } -func NewApiV3(connections *database.Connections) *ApiV3 { +func NewApiV3(connection *sql.DB) *ApiV3 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -38,14 +38,14 @@ func NewApiV3(connections *database.Connections) *ApiV3 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connections) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connection) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) if err != nil { log.API.Fatalf("Failed to create marshaller %s", err) } - return &ApiV3{marshaller: marshaller, handler: h, connections: connections} + return &ApiV3{marshaller: marshaller, handler: h, connection: connection} } } diff --git a/bcda/api/v3/api_test.go b/bcda/api/v3/api_test.go index 43fb2bfa7..4f78fe809 100644 --- a/bcda/api/v3/api_test.go +++ b/bcda/api/v3/api_test.go @@ -2,6 +2,7 @@ package v3 import ( "context" + "database/sql" "encoding/json" "fmt" "io" @@ -51,13 +52,13 @@ var ( type APITestSuite struct { suite.Suite - connections *database.Connections - apiV3 *ApiV3 + connection *sql.DB + apiV3 *ApiV3 } func (s *APITestSuite) SetupSuite() { - s.connections = database.Connect() - s.apiV3 = NewApiV3(s.connections) + s.connection = database.GetConnection() + s.apiV3 = NewApiV3(s.connection) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -78,7 +79,7 @@ func (s *APITestSuite) SetupSuite() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connections.Connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -142,8 +143,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -175,14 +176,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, + postgrestest.CreateJobKeys(s.T(), s.connection, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -228,7 +229,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -236,7 +237,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connections.Connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -290,10 +291,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) j.UpdatedAt = time.Now().Add(-(s.apiV3.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connections.Connection, j) + postgrestest.UpdateJob(s.T(), s.connection, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -317,8 +318,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -349,8 +350,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -369,8 +370,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -389,8 +390,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connections.Connection, j.ID) + postgrestest.CreateJobs(s.T(), s.connection, &j) + defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -452,8 +453,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=Patient,Coverage", constants.V3Path), Status: tt.status, } - postgrestest.CreateJobs(t, s.connections.Connection, &j) - defer postgrestest.DeleteJobByID(t, s.connections.Connection, j.ID) + postgrestest.CreateJobs(t, s.connection, &j) + defer postgrestest.DeleteJobByID(t, s.connection, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -546,7 +547,7 @@ func (s *APITestSuite) TestResourceTypes() { "ExplanationOfBenefit", }...) - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connections) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connection) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -613,20 +614,20 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connections.Connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connections.Connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 822eb81eb..d7e317fcb 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -3,6 +3,7 @@ package bcdacli import ( "archive/zip" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -46,8 +47,8 @@ const Name = "bcda" const Usage = "Beneficiary Claims Data API CLI" var ( - connections *database.Connections - r models.Repository + connection *sql.DB + r models.Repository ) func GetApp() *cli.App { @@ -60,8 +61,8 @@ func setUpApp() *cli.App { app.Usage = Usage app.Version = constants.Version app.Before = func(c *cli.Context) error { - connections = database.Connect() - r = postgres.NewRepository(connections.Connection) + connection = database.GetConnection() + r = postgres.NewRepository(connection) return nil } var hours, err = safecast.ToUint(utils.GetEnvInt("FILE_ARCHIVE_THRESHOLD_HR", 72)) @@ -121,7 +122,7 @@ func setUpApp() *cli.App { } api := &http.Server{ - Handler: web.NewAPIRouter(connections), + Handler: web.NewAPIRouter(connection), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 3d510e808..1edd4e7b5 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -23,21 +23,11 @@ var ( Pgxv5Pool *pgxv5Pool.Pool ) -type Connections struct { - Connection *sql.DB - QueueConnection *pgx.ConnPool - Pgxv5Pool *pgxv5Pool.Pool -} - func init() { Connection = GetConnection() Pgxv5Pool = GetPool() } -func Connect() *Connections { - return nil -} - func GetConnection() *sql.DB { cfg, err := LoadConfig() if err != nil { diff --git a/bcda/web/router.go b/bcda/web/router.go index cfb6dce2c..4e7108ce5 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -1,6 +1,7 @@ package web import ( + "database/sql" "fmt" "net/http" "strings" @@ -28,7 +29,7 @@ var commonAuth = []func(http.Handler) http.Handler{ auth.RequireTokenAuth, auth.CheckBlacklist} -func NewAPIRouter(connections *database.Connections) http.Handler { +func NewAPIRouter(connection *sql.DB) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, auth.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) @@ -54,7 +55,7 @@ func NewAPIRouter(connections *database.Connections) http.Handler { } r.Route("/api/v1", func(r chi.Router) { - apiV1 := v1.NewApiV1(connections) + apiV1 := v1.NewApiV1(connection) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) @@ -66,7 +67,7 @@ func NewAPIRouter(connections *database.Connections) http.Handler { if utils.GetEnvBool("VERSION_2_ENDPOINT_ACTIVE", true) { FileServer(r, "/api/v2/swagger", http.Dir("./swaggerui/v2")) - apiV2 := v2.NewApiV2(connections) + apiV2 := v2.NewApiV2(connection) r.Route("/api/v2", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) @@ -79,7 +80,7 @@ func NewAPIRouter(connections *database.Connections) http.Handler { } if utils.GetEnvBool("VERSION_3_ENDPOINT_ACTIVE", true) { - apiV3 := v3.NewApiV3(connections) + apiV3 := v3.NewApiV3(connection) r.Route("/api/demo", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index eaf9e298c..c43e83551 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -1,6 +1,7 @@ package web import ( + "database/sql" "encoding/json" "fmt" "io" @@ -29,15 +30,15 @@ var nDJsonDataRoute string = "/data/test/test.ndjson" type RouterTestSuite struct { suite.Suite - apiRouter http.Handler - dataRouter http.Handler - connections database.Connections + apiRouter http.Handler + dataRouter http.Handler + connection *sql.DB } func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") - s.connections = *database.Connect() - s.apiRouter = NewAPIRouter(&s.connections) + s.connection = database.GetConnection() + s.apiRouter = NewAPIRouter(s.connection) s.dataRouter = NewDataRouter() } @@ -79,7 +80,7 @@ func (s *RouterTestSuite) TestDefaultProdRoute() { s.FailNow("err in setting env var", err) } // Need a new router because the one in the test setup does not use the environment variable set in this test. - s.apiRouter = NewAPIRouter(&s.connections) + s.apiRouter = NewAPIRouter(s.connection) res := s.getAPIRoute("/v1/") assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -196,7 +197,7 @@ func (s *RouterTestSuite) TestV2EndpointsDisabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(&s.connections) + s.apiRouter = NewAPIRouter(s.connection) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -213,7 +214,7 @@ func (s *RouterTestSuite) TestV2EndpointsEnabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(&s.connections) + s.apiRouter = NewAPIRouter(s.connection) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -234,7 +235,7 @@ func (s *RouterTestSuite) TestV3EndpointsDisabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(&s.connections) + s.apiRouter = NewAPIRouter(s.connection) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -251,7 +252,7 @@ func (s *RouterTestSuite) TestV3EndpointsEnabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(&s.connections) + s.apiRouter = NewAPIRouter(s.connection) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -351,7 +352,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str handler http.Handler paths []string }) { - apiRouter := NewAPIRouter(&s.connections) + apiRouter := NewAPIRouter(s.connection) configs = []struct { handler http.Handler From cab0a71e9b177fac990f2a673b000d85f5547db2 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 18:50:33 -0400 Subject: [PATCH 06/28] Add connection as dependency of data router --- bcda/auth/api_test.go | 2 +- bcda/bcdacli/cli.go | 2 +- bcda/web/router.go | 5 ++--- bcda/web/router_test.go | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bcda/auth/api_test.go b/bcda/auth/api_test.go index 25892ce6f..fb43676fc 100644 --- a/bcda/auth/api_test.go +++ b/bcda/auth/api_test.go @@ -46,7 +46,7 @@ func (s *AuthAPITestSuite) CreateRouter() http.Handler { } func (s *AuthAPITestSuite) SetupSuite() { - s.db = database.Connection + s.db = database.GetConnection() s.r = postgres.NewRepository(s.db) } diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index d7e317fcb..b74cab063 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -130,7 +130,7 @@ func setUpApp() *cli.App { } fileserver := &http.Server{ - Handler: web.NewDataRouter(), + Handler: web.NewDataRouter(connection), ReadTimeout: time.Duration(utils.GetEnvInt("FILESERVER_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("FILESERVER_WRITE_TIMEOUT", 360)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("FILESERVER_IDLE_TIMEOUT", 120)) * time.Second, diff --git a/bcda/web/router.go b/bcda/web/router.go index 4e7108ce5..c09fef368 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -11,7 +11,6 @@ import ( v3 "github.com/CMSgov/bcda-app/bcda/api/v3" "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/logging" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/monitoring" @@ -102,11 +101,11 @@ func NewAuthRouter() http.Handler { return auth.NewAuthRouter(gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) } -func NewDataRouter() http.Handler { +func NewDataRouter(connection *sql.DB) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() resourceTypeLogger := &logging.ResourceTypeLogger{ - Repository: postgres.NewRepository(database.Connection), + Repository: postgres.NewRepository(connection), } r.Use(auth.ParseToken, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) r.With(append( diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index c43e83551..d332b3c7c 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -39,7 +39,7 @@ func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") s.connection = database.GetConnection() s.apiRouter = NewAPIRouter(s.connection) - s.dataRouter = NewDataRouter() + s.dataRouter = NewDataRouter(s.connection) } From 0d37ba1e5d49d7a436b2695bf6b076da0609090b Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 19:04:44 -0400 Subject: [PATCH 07/28] Pass connection as argument to token job middleware --- bcda/api/v1/api_test.go | 2 +- bcda/auth/middleware.go | 77 +++++++++++++++++++----------------- bcda/auth/middleware_test.go | 27 +++++++------ bcda/web/router.go | 14 +++---- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 05390fae7..1a021be90 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -459,7 +459,7 @@ func (s *APITestSuite) TestJobStatusWithWrongACO() { } postgrestest.CreateJobs(s.T(), s.connection, &j) - handler := auth.RequireTokenJobMatch(http.HandlerFunc(s.apiV1.JobStatus)) + handler := auth.RequireTokenJobMatch(s.connection)(http.HandlerFunc(s.apiV1.JobStatus)) req := s.createJobStatusRequest(uuid.Parse(constants.LargeACOUUID), j.ID) handler.ServeHTTP(s.rr, req) diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index 64e7ca474..103ec6def 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -2,6 +2,7 @@ package auth import ( "context" + "database/sql" "fmt" "net/http" "regexp" @@ -13,7 +14,6 @@ import ( "github.com/pkg/errors" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" customErrors "github.com/CMSgov/bcda-app/bcda/errors" "github.com/CMSgov/bcda-app/bcda/models/postgres" responseutils "github.com/CMSgov/bcda-app/bcda/responseutils" @@ -163,43 +163,46 @@ func CheckBlacklist(next http.Handler) http.Handler { }) } -func RequireTokenJobMatch(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rw := getRespWriter(r.URL.Path) - - ad, ok := r.Context().Value(AuthDataContextKey).(AuthData) - if !ok { - log.Auth.Error("Auth data not found") - rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusUnauthorized, responseutils.UnauthorizedErr, "AuthData not found") - return - } - - //Throw an invalid request for non-unsigned integers - jobID, err := strconv.ParseUint(chi.URLParam(r, "jobID"), 10, 64) - if err != nil { - log.Auth.Error(err) - rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusBadRequest, responseutils.RequestErr, err.Error()) - return - } - - repository := postgres.NewRepository(database.Connection) - - job, err := repository.GetJobByID(r.Context(), uint(jobID)) - if err != nil { - log.Auth.Error(err) - rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusNotFound, responseutils.NotFoundErr, "") - return - } - - // ACO did not create the job - if !strings.EqualFold(ad.ACOID, job.ACOID.String()) { - log.Auth.Errorf("ACO %s does not have access to job ID %d %s", - ad.ACOID, job.ID, job.ACOID) - rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusUnauthorized, responseutils.UnauthorizedErr, "") - return +func RequireTokenJobMatch(connection *sql.DB) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + rw := getRespWriter(r.URL.Path) + + ad, ok := r.Context().Value(AuthDataContextKey).(AuthData) + if !ok { + log.Auth.Error("Auth data not found") + rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusUnauthorized, responseutils.UnauthorizedErr, "AuthData not found") + return + } + + //Throw an invalid request for non-unsigned integers + jobID, err := strconv.ParseUint(chi.URLParam(r, "jobID"), 10, 64) + if err != nil { + log.Auth.Error(err) + rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusBadRequest, responseutils.RequestErr, err.Error()) + return + } + + repository := postgres.NewRepository(connection) + + job, err := repository.GetJobByID(r.Context(), uint(jobID)) + if err != nil { + log.Auth.Error(err) + rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusNotFound, responseutils.NotFoundErr, "") + return + } + + // ACO did not create the job + if !strings.EqualFold(ad.ACOID, job.ACOID.String()) { + log.Auth.Errorf("ACO %s does not have access to job ID %d %s", + ad.ACOID, job.ID, job.ACOID) + rw.Exception(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, http.StatusUnauthorized, responseutils.UnauthorizedErr, "") + return + } + next.ServeHTTP(w, r) } - next.ServeHTTP(w, r) - }) + return http.HandlerFunc(fn) + } } type fhirResponseWriter interface { diff --git a/bcda/auth/middleware_test.go b/bcda/auth/middleware_test.go index b90cb53e0..e14677391 100644 --- a/bcda/auth/middleware_test.go +++ b/bcda/auth/middleware_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "crypto/rsa" + "database/sql" "errors" "fmt" "log" @@ -39,8 +40,13 @@ var bearerStringMsg string = "Bearer %s" type MiddlewareTestSuite struct { suite.Suite - server *httptest.Server - rr *httptest.ResponseRecorder + server *httptest.Server + rr *httptest.ResponseRecorder + connection *sql.DB +} + +func (s *MiddlewareTestSuite) SetupSuite() { + s.connection = database.GetConnection() } func (s *MiddlewareTestSuite) CreateRouter() http.Handler { @@ -333,14 +339,13 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse401WhenNoBearerTok // integration test: involves db connection to postgres func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDataProvided() { - db := database.Connection j := models.Job{ ACOID: uuid.Parse(constants.TestACOID), RequestURL: constants.V1Path + constants.EOBExportPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), db, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) @@ -358,7 +363,7 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa {"Mismatching ACOID", jobID, uuid.New(), http.StatusUnauthorized}, } - handler := auth.RequireTokenJobMatch(mockHandler) + handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { @@ -385,14 +390,12 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa // integration test: involves db connection to postgres func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccountableCareOrganizationAndJob() { - db := database.Connection - j := models.Job{ ACOID: uuid.Parse(constants.TestACOID), RequestURL: constants.V1Path + constants.EOBExportPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), db, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) @@ -407,7 +410,7 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccoun rctx := chi.NewRouteContext() rctx.URLParams.Add("jobID", jobID) - handler := auth.RequireTokenJobMatch(mockHandler) + handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) ad := auth.AuthData{ ACOID: j.ACOID.String(), @@ -422,15 +425,13 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccoun // integration test: involves db connection to postgres func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenNoAuthDataProvidedInContext() { - db := database.Connection - j := models.Job{ ACOID: uuid.Parse(constants.TestACOID), RequestURL: constants.V1Path + constants.EOBExportPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), db, &j) + postgrestest.CreateJobs(s.T(), s.connection, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) @@ -445,7 +446,7 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenNoAuthDataPro log.Fatal(err) } - handler := auth.RequireTokenJobMatch(mockHandler) + handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusUnauthorized, s.rr.Code) diff --git a/bcda/web/router.go b/bcda/web/router.go index c09fef368..73d5778ab 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -57,9 +57,9 @@ func NewAPIRouter(connection *sql.DB) http.Handler { apiV1 := v1.NewApiV1(connection) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV1.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV1.AttributionStatus)) r.Get(m.WrapHandler("/metadata", v1.Metadata)) }) @@ -70,9 +70,9 @@ func NewAPIRouter(connection *sql.DB) http.Handler { r.Route("/api/v2", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV2.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV2.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV2.Metadata)) }) @@ -83,9 +83,9 @@ func NewAPIRouter(connection *sql.DB) http.Handler { r.Route("/api/demo", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV3.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch)...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) + r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV3.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV3.Metadata)) }) @@ -110,7 +110,7 @@ func NewDataRouter(connection *sql.DB) http.Handler { r.Use(auth.ParseToken, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) r.With(append( commonAuth, - auth.RequireTokenJobMatch, + auth.RequireTokenJobMatch(connection), resourceTypeLogger.LogJobResourceType, )...).Get(m.WrapHandler("/data/{jobID}/{fileName}", v1.ServeData)) return r From ed1783cb4fcc88b13bf0a01d41d59c3de877270b Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 23 Jul 2025 22:20:53 -0400 Subject: [PATCH 08/28] Remove connection global from cli --- bcda/bcdacli/cli.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index b74cab063..0a723bfc9 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -399,8 +399,7 @@ func setUpApp() *cli.App { }, Action: func(c *cli.Context) error { ignoreSignals() - db := database.Connection - r := postgres.NewRepository(db) + r := postgres.NewRepository(connection) var file_handler optout.OptOutFileHandler From 8272973693db634e1a8f82750da066c0a1c8a024 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Thu, 24 Jul 2025 10:19:10 -0400 Subject: [PATCH 09/28] Define db pool as dependency for apis and workers --- bcda/api/requests.go | 9 +++--- bcda/api/requests_test.go | 32 ++++++++++++---------- bcda/api/v1/api.go | 5 ++-- bcda/api/v1/api_test.go | 4 ++- bcda/api/v2/api.go | 5 ++-- bcda/api/v2/api_test.go | 7 +++-- bcda/api/v3/api.go | 5 ++-- bcda/api/v3/api_test.go | 7 +++-- bcda/bcdacli/cli.go | 5 +++- bcda/database/connection_test.go | 6 ---- bcda/web/router.go | 9 +++--- bcda/web/router_test.go | 16 ++++++----- bcdaworker/queueing/enqueue.go | 11 ++++---- bcdaworker/queueing/enqueue_test.go | 7 +++-- bcdaworker/queueing/river.go | 10 +++++-- bcdaworker/queueing/river_test.go | 5 ++-- bcdaworker/queueing/worker_prepare.go | 6 ++-- bcdaworker/queueing/worker_prepare_test.go | 4 +-- bcdaworker/queueing/worker_process_job.go | 5 ++-- 19 files changed, 91 insertions(+), 67 deletions(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 625b68496..ff2d866c6 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -34,6 +34,7 @@ import ( "github.com/CMSgov/bcda-app/conf" "github.com/CMSgov/bcda-app/log" m "github.com/CMSgov/bcda-app/middleware" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) type Handler struct { @@ -61,14 +62,14 @@ type fhirResponseWriter interface { JobsBundle(context.Context, http.ResponseWriter, []*models.Job, string) } -func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB) *Handler { - return newHandler(dataTypes, basePath, apiVersion, connection) +func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB, pool *pgxv5Pool.Pool) *Handler { + return newHandler(dataTypes, basePath, apiVersion, connection, pool) } -func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB) *Handler { +func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB, pool *pgxv5Pool.Pool) *Handler { h := &Handler{JobTimeout: time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24))} - h.Enq = queueing.NewEnqueuer() + h.Enq = queueing.NewEnqueuer(connection, pool) cfg, err := service.LoadConfig() if err != nil { diff --git a/bcda/api/requests_test.go b/bcda/api/requests_test.go index 75b54b0a1..e3150fb0f 100644 --- a/bcda/api/requests_test.go +++ b/bcda/api/requests_test.go @@ -51,6 +51,7 @@ import ( fhirmodelv2CR "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/bundle_and_contained_resource_go_proto" fhircodesv1 "github.com/google/fhir/go/proto/google/fhir/proto/stu3/codes_go_proto" fhirmodelsv1 "github.com/google/fhir/go/proto/google/fhir/proto/stu3/resources_go_proto" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) const apiVersionOne = "v1" @@ -67,6 +68,8 @@ type RequestsTestSuite struct { connection *sql.DB + pool *pgxv5Pool.Pool + acoID uuid.UUID resourceType map[string]service.DataType @@ -81,6 +84,7 @@ func (s *RequestsTestSuite) SetupSuite() { s.acoID = uuid.Parse("ba21d24d-cd96-4d7d-a691-b0e8c88e67a5") db, _ := databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) s.connection = db + s.pool = database.GetPool() tf, err := testfixtures.New( testfixtures.Database(db), testfixtures.Dialect("postgres"), @@ -138,7 +142,7 @@ func (s *RequestsTestSuite) TestRunoutEnabled() { mockSvc := &service.MockService{} mockAco := service.ACOConfig{Data: []string{"adjudicated"}} mockSvc.On("GetACOConfigForID", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAco, true) - h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connection) + h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connection, s.pool) h.Svc = mockSvc enqueuer := queueing.NewMockEnqueuer(s.T()) h.Enq = enqueuer @@ -240,7 +244,7 @@ func (s *RequestsTestSuite) TestJobsStatusV1() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connection) + }, fhirPath, apiVersion, s.connection, s.pool) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -354,7 +358,7 @@ func (s *RequestsTestSuite) TestJobsStatusV2() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, v2BasePath, apiVersionTwo, s.connection) + }, v2BasePath, apiVersionTwo, s.connection, s.pool) if tt.useMock { h.Svc = mockSvc } @@ -473,7 +477,7 @@ func (s *RequestsTestSuite) TestAttributionStatus() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connection) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connection, s.pool) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -564,7 +568,7 @@ func (s *RequestsTestSuite) TestDataTypeAuthorization() { "ClaimResponse": {Adjudicated: false, PartiallyAdjudicated: true}, } - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connection) + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connection, s.pool) r := models.NewMockRepository(s.T()) r.On("CreateJob", mock.Anything, mock.Anything).Return(uint(4), nil) h.r = r @@ -652,7 +656,7 @@ func (s *RequestsTestSuite) TestRequests() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connection) + h := newHandler(resourceMap, fhirPath, apiVersion, s.connection, s.pool) // Test Group and Patient // Patient, Coverage, and ExplanationOfBenefit @@ -782,7 +786,7 @@ func (s *RequestsTestSuite) TestJobStatusErrorHandling() { for _, tt := range tests { s.T().Run(tt.testName, func(t *testing.T) { - h := newHandler(resourceMap, basePath, apiVersion, s.connection) + h := newHandler(resourceMap, basePath, apiVersion, s.connection, s.pool) if tt.useMockService { mockSrv := service.MockService{} timestp := time.Now() @@ -856,7 +860,7 @@ func (s *RequestsTestSuite) TestJobStatusProgress() { apiVersion := apiVersionTwo requestUrl := v2JobRequestUrl resourceMap := s.resourceType - h := newHandler(resourceMap, basePath, apiVersion, s.connection) + h := newHandler(resourceMap, basePath, apiVersion, s.connection, s.pool) req := httptest.NewRequest("GET", requestUrl, nil) rctx := chi.NewRouteContext() @@ -905,7 +909,7 @@ func (s *RequestsTestSuite) TestDeleteJob() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - handler := newHandler(s.resourceType, basePath, apiVersion, s.connection) + handler := newHandler(s.resourceType, basePath, apiVersion, s.connection, s.pool) if tt.useMockService { mockSrv := service.MockService{} @@ -965,7 +969,7 @@ func (s *RequestsTestSuite) TestJobFailedStatus() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - h := newHandler(resourceMap, tt.basePath, tt.version, s.connection) + h := newHandler(resourceMap, tt.basePath, tt.version, s.connection, s.pool) mockSrv := service.MockService{} timestp := time.Now() mockSrv.On("GetJobAndKeys", testUtils.CtxMatcher, uint(1)).Return( @@ -1023,7 +1027,7 @@ func (s *RequestsTestSuite) TestGetResourceTypes() { {"CT000000", "v2", []string{"Patient", "ExplanationOfBenefit", "Coverage", "Claim", "ClaimResponse"}}, } for _, test := range testCases { - h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connection) + h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connection, s.pool) rp := middleware.RequestParameters{ Version: test.apiVersion, ResourceTypes: []string{}, @@ -1057,9 +1061,9 @@ func TestBulkRequest_Integration(t *testing.T) { client.SetLogger(log.API) // Set logger so we don't get errors later connection := database.GetConnection() - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connection) - pool := database.GetPool() + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connection, pool) + driver := riverpgxv5.New(pool) // start from clean river_job slate _, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`) @@ -1204,7 +1208,7 @@ func (s *RequestsTestSuite) TestValidateResources() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connection) + }, fhirPath, apiVersion, s.connection, s.pool) err := h.validateResources([]string{"Vegetable"}, "1234") assert.Contains(s.T(), err.Error(), "invalid resource type") } diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index c5df042f7..ae2b67988 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -24,6 +24,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" "github.com/CMSgov/bcda-app/log" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) type ApiV1 struct { @@ -31,7 +32,7 @@ type ApiV1 struct { connection *sql.DB } -func NewApiV1(connection *sql.DB) *ApiV1 { +func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV1 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -42,7 +43,7 @@ func NewApiV1(connection *sql.DB) *ApiV1 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, "/v1/fhir", "v1", connection) + h := api.NewHandler(resources, "/v1/fhir", "v1", connection, pool) return &ApiV1{handler: h, connection: connection} } } diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 1a021be90..42979ee2e 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -33,6 +33,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/models/postgres/postgrestest" "github.com/CMSgov/bcda-app/conf" "github.com/CMSgov/bcda-app/log" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) const ( @@ -47,12 +48,13 @@ type APITestSuite struct { suite.Suite rr *httptest.ResponseRecorder connection *sql.DB + pool *pgxv5Pool.Pool apiV1 *ApiV1 } func (s *APITestSuite) SetupSuite() { s.connection = database.GetConnection() - s.apiV1 = NewApiV1(s.connection) + s.apiV1 = NewApiV1(s.connection, s.pool) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) diff --git a/bcda/api/v2/api.go b/bcda/api/v2/api.go index 974a033c3..7b3985f97 100644 --- a/bcda/api/v2/api.go +++ b/bcda/api/v2/api.go @@ -20,6 +20,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/servicemux" "github.com/CMSgov/bcda-app/conf" "github.com/CMSgov/bcda-app/log" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) type ApiV2 struct { @@ -28,7 +29,7 @@ type ApiV2 struct { connection *sql.DB } -func NewApiV2(connection *sql.DB) *ApiV2 { +func NewApiV2(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV2 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -40,7 +41,7 @@ func NewApiV2(connection *sql.DB) *ApiV2 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, "/v2/fhir", "v2", connection) + h := api.NewHandler(resources, "/v2/fhir", "v2", connection, pool) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) diff --git a/bcda/api/v2/api_test.go b/bcda/api/v2/api_test.go index a6fea91e2..1a8a83040 100644 --- a/bcda/api/v2/api_test.go +++ b/bcda/api/v2/api_test.go @@ -36,6 +36,7 @@ import ( fhircodes "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/codes_go_proto" fhirresources "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/bundle_and_contained_resource_go_proto" fhiroo "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/operation_outcome_go_proto" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -53,12 +54,14 @@ var ( type APITestSuite struct { suite.Suite connection *sql.DB + pool *pgxv5Pool.Pool apiV2 *ApiV2 } func (s *APITestSuite) SetupSuite() { s.connection = database.GetConnection() - s.apiV2 = NewApiV2(s.connection) + s.pool = database.GetPool() + s.apiV2 = NewApiV2(s.connection, s.pool) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -544,7 +547,7 @@ func (s *APITestSuite) TestResourceTypes() { "ClaimResponse", }...) - h := api.NewHandler(resources, "/v2/fhir", "v2", s.connection) + h := api.NewHandler(resources, "/v2/fhir", "v2", s.connection, s.pool) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) diff --git a/bcda/api/v3/api.go b/bcda/api/v3/api.go index 5fe0fba97..40f695c76 100644 --- a/bcda/api/v3/api.go +++ b/bcda/api/v3/api.go @@ -20,6 +20,7 @@ import ( fhirresources "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/bundle_and_contained_resource_go_proto" fhircapabilitystatement "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/capability_statement_go_proto" fhirvaluesets "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/valuesets_go_proto" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) type ApiV3 struct { @@ -28,7 +29,7 @@ type ApiV3 struct { connection *sql.DB } -func NewApiV3(connection *sql.DB) *ApiV3 { +func NewApiV3(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV3 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -38,7 +39,7 @@ func NewApiV3(connection *sql.DB) *ApiV3 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connection) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connection, pool) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) diff --git a/bcda/api/v3/api_test.go b/bcda/api/v3/api_test.go index 4f78fe809..5fa40b813 100644 --- a/bcda/api/v3/api_test.go +++ b/bcda/api/v3/api_test.go @@ -36,6 +36,7 @@ import ( fhircodes "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/codes_go_proto" fhirresources "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/bundle_and_contained_resource_go_proto" fhiroo "github.com/google/fhir/go/proto/google/fhir/proto/r4/core/resources/operation_outcome_go_proto" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -53,12 +54,14 @@ var ( type APITestSuite struct { suite.Suite connection *sql.DB + pool *pgxv5Pool.Pool apiV3 *ApiV3 } func (s *APITestSuite) SetupSuite() { s.connection = database.GetConnection() - s.apiV3 = NewApiV3(s.connection) + s.pool = database.GetPool() + s.apiV3 = NewApiV3(s.connection, s.pool) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -547,7 +550,7 @@ func (s *APITestSuite) TestResourceTypes() { "ExplanationOfBenefit", }...) - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connection) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connection, s.pool) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 0a723bfc9..d6a0d9abb 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -37,6 +37,7 @@ import ( "github.com/CMSgov/bcda-app/log" "github.com/CMSgov/bcda-app/optout" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/urfave/cli" @@ -48,6 +49,7 @@ const Usage = "Beneficiary Claims Data API CLI" var ( connection *sql.DB + pool *pgxv5Pool.Pool r models.Repository ) @@ -62,6 +64,7 @@ func setUpApp() *cli.App { app.Version = constants.Version app.Before = func(c *cli.Context) error { connection = database.GetConnection() + pool = database.GetPool() r = postgres.NewRepository(connection) return nil } @@ -122,7 +125,7 @@ func setUpApp() *cli.App { } api := &http.Server{ - Handler: web.NewAPIRouter(connection), + Handler: web.NewAPIRouter(connection, pool), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index 8cb6dc23d..d2e326baf 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -14,13 +14,7 @@ import ( func TestConnections(t *testing.T) { // Verify that we can initialize the package as expected assert.NotNil(t, Connection) - assert.NotNil(t, QueueConnection) - assert.NoError(t, Connection.Ping()) - c, err := QueueConnection.Acquire() - assert.NoError(t, err) - assert.NoError(t, c.Ping(context.Background())) - QueueConnection.Release(c) } // TestHealthCheck verifies that we are able to start the health check diff --git a/bcda/web/router.go b/bcda/web/router.go index 73d5778ab..332fedaa9 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -21,6 +21,7 @@ import ( appMiddleware "github.com/CMSgov/bcda-app/middleware" "github.com/go-chi/chi/v5" gcmw "github.com/go-chi/chi/v5/middleware" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" ) // Auth middleware checks that verifies that caller is authorized @@ -28,7 +29,7 @@ var commonAuth = []func(http.Handler) http.Handler{ auth.RequireTokenAuth, auth.CheckBlacklist} -func NewAPIRouter(connection *sql.DB) http.Handler { +func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, auth.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) @@ -54,7 +55,7 @@ func NewAPIRouter(connection *sql.DB) http.Handler { } r.Route("/api/v1", func(r chi.Router) { - apiV1 := v1.NewApiV1(connection) + apiV1 := v1.NewApiV1(connection, pool) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) @@ -66,7 +67,7 @@ func NewAPIRouter(connection *sql.DB) http.Handler { if utils.GetEnvBool("VERSION_2_ENDPOINT_ACTIVE", true) { FileServer(r, "/api/v2/swagger", http.Dir("./swaggerui/v2")) - apiV2 := v2.NewApiV2(connection) + apiV2 := v2.NewApiV2(connection, pool) r.Route("/api/v2", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) @@ -79,7 +80,7 @@ func NewAPIRouter(connection *sql.DB) http.Handler { } if utils.GetEnvBool("VERSION_3_ENDPOINT_ACTIVE", true) { - apiV3 := v3.NewApiV3(connection) + apiV3 := v3.NewApiV3(connection, pool) r.Route("/api/demo", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index d332b3c7c..540467f01 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -20,6 +20,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/conf" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -33,12 +34,13 @@ type RouterTestSuite struct { apiRouter http.Handler dataRouter http.Handler connection *sql.DB + pool *pgxv5Pool.Pool } func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") s.connection = database.GetConnection() - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) s.dataRouter = NewDataRouter(s.connection) } @@ -80,7 +82,7 @@ func (s *RouterTestSuite) TestDefaultProdRoute() { s.FailNow("err in setting env var", err) } // Need a new router because the one in the test setup does not use the environment variable set in this test. - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) res := s.getAPIRoute("/v1/") assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -197,7 +199,7 @@ func (s *RouterTestSuite) TestV2EndpointsDisabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -214,7 +216,7 @@ func (s *RouterTestSuite) TestV2EndpointsEnabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -235,7 +237,7 @@ func (s *RouterTestSuite) TestV3EndpointsDisabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -252,7 +254,7 @@ func (s *RouterTestSuite) TestV3EndpointsEnabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -352,7 +354,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str handler http.Handler paths []string }) { - apiRouter := NewAPIRouter(s.connection) + apiRouter := NewAPIRouter(s.connection, s.pool) configs = []struct { handler http.Handler diff --git a/bcdaworker/queueing/enqueue.go b/bcdaworker/queueing/enqueue.go index 95ebaa6e9..c70607299 100644 --- a/bcdaworker/queueing/enqueue.go +++ b/bcdaworker/queueing/enqueue.go @@ -7,11 +7,12 @@ package queueing import ( "context" + "database/sql" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcdaworker/queueing/worker_types" pgxv5 "github.com/jackc/pgx/v5" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver/riverpgxv5" ) @@ -24,16 +25,16 @@ type Enqueuer interface { // Creates a river client for the Job queue. If the client does not call .Start(), then it is insert only // We still need the workers and the types of workers to insert them -func NewEnqueuer() Enqueuer { +func NewEnqueuer(connection *sql.DB, pool *pgxv5Pool.Pool) Enqueuer { workers := river.NewWorkers() - river.AddWorker(workers, &JobWorker{}) - prepareWorker, err := NewPrepareJobWorker() + river.AddWorker(workers, &JobWorker{connection: connection}) + prepareWorker, err := NewPrepareJobWorker(connection) if err != nil { panic(err) } river.AddWorker(workers, prepareWorker) - riverClient, err := river.NewClient(riverpgxv5.New(database.Pgxv5Pool), &river.Config{ + riverClient, err := river.NewClient(riverpgxv5.New(pool), &river.Config{ Workers: workers, }) if err != nil { diff --git a/bcdaworker/queueing/enqueue_test.go b/bcdaworker/queueing/enqueue_test.go index 13089241b..aba96de26 100644 --- a/bcdaworker/queueing/enqueue_test.go +++ b/bcdaworker/queueing/enqueue_test.go @@ -20,7 +20,7 @@ import ( func TestEnqueuerImplementation(t *testing.T) { // Test river implementation - enq := NewEnqueuer() + enq := NewEnqueuer(nil, nil) var expectedRiverEnq riverEnqueuer assert.IsType(t, expectedRiverEnq, enq) } @@ -33,9 +33,10 @@ func TestRiverEnqueuer_Integration(t *testing.T) { conf.SetEnv(t, "QUEUE_LIBRARY", "river") // Need access to the queue database to ensure we've enqueued the job successfully - db := database.Connection + db := database.GetConnection() + pool := database.GetPool() - enqueuer := NewEnqueuer() + enqueuer := NewEnqueuer(db, pool) jobID, e := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) if e != nil { t.Fatalf("failed to generate job ID: %v\n", e) diff --git a/bcdaworker/queueing/river.go b/bcdaworker/queueing/river.go index a40891fdb..ebbdf1f8a 100644 --- a/bcdaworker/queueing/river.go +++ b/bcdaworker/queueing/river.go @@ -49,12 +49,16 @@ type Notifier interface { // TODO: better dependency injection (db, worker, logger). Waiting for pgxv5 upgrade func StartRiver(numWorkers int) *queue { + + connection := database.GetConnection() + pool := database.GetPool() + workers := river.NewWorkers() - prepareWorker, err := NewPrepareJobWorker() + prepareWorker, err := NewPrepareJobWorker(connection) if err != nil { panic(err) } - river.AddWorker(workers, &JobWorker{}) + river.AddWorker(workers, &JobWorker{connection: connection}) river.AddWorker(workers, NewCleanupJobWorker()) river.AddWorker(workers, prepareWorker) @@ -76,7 +80,7 @@ func StartRiver(numWorkers int) *queue { logger := getSlogLogger() - riverClient, err := river.NewClient(riverpgxv5.New(database.Pgxv5Pool), &river.Config{ + riverClient, err := river.NewClient(riverpgxv5.New(pool), &river.Config{ Queues: map[string]river.QueueConfig{ river.QueueDefault: {MaxWorkers: numWorkers}, }, diff --git a/bcdaworker/queueing/river_test.go b/bcdaworker/queueing/river_test.go index c7e9c6074..016c83c25 100644 --- a/bcdaworker/queueing/river_test.go +++ b/bcdaworker/queueing/river_test.go @@ -66,7 +66,8 @@ func TestWork_Integration(t *testing.T) { conf.SetEnv(t, "FHIR_PAYLOAD_DIR", tempDir1) conf.SetEnv(t, "FHIR_STAGING_DIR", tempDir2) - db := database.Connection + db := database.GetConnection() + pool := database.GetPool() cmsID := testUtils.RandomHexID()[0:4] aco := models.ACO{UUID: uuid.NewRandom(), CMSID: &cmsID} @@ -82,7 +83,7 @@ func TestWork_Integration(t *testing.T) { id, _ := safecast.ToInt(job.ID) jobArgs := worker_types.JobEnqueueArgs{ID: id, ACOID: cmsID, BBBasePath: uuid.New()} - enqueuer := NewEnqueuer() + enqueuer := NewEnqueuer(db, pool) assert.NoError(t, enqueuer.AddJob(context.Background(), jobArgs, 1)) timeout := time.After(10 * time.Second) diff --git a/bcdaworker/queueing/worker_prepare.go b/bcdaworker/queueing/worker_prepare.go index c505a1d0c..9ee96540f 100644 --- a/bcdaworker/queueing/worker_prepare.go +++ b/bcdaworker/queueing/worker_prepare.go @@ -7,13 +7,13 @@ package queueing import ( "context" + "database/sql" "errors" "fmt" "time" "github.com/CMSgov/bcda-app/bcda/client" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/service" @@ -37,7 +37,7 @@ type PrepareJobWorker struct { r models.Repository } -func NewPrepareJobWorker() (*PrepareJobWorker, error) { +func NewPrepareJobWorker(connection *sql.DB) (*PrepareJobWorker, error) { logger := log.Worker client.SetLogger(logger) @@ -50,7 +50,7 @@ func NewPrepareJobWorker() (*PrepareJobWorker, error) { logger.Fatalf("no ACO configs found, these are required for downstream processing") } - repository := postgres.NewRepository(database.Connection) + repository := postgres.NewRepository(connection) svc := service.NewService(repository, cfg, "") v1, err := client.NewBlueButtonClient(client.NewConfig(constants.BFDV1Path)) diff --git a/bcdaworker/queueing/worker_prepare_test.go b/bcdaworker/queueing/worker_prepare_test.go index 955f25433..f3dff5e76 100644 --- a/bcdaworker/queueing/worker_prepare_test.go +++ b/bcdaworker/queueing/worker_prepare_test.go @@ -302,7 +302,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestPrepareWorkerWork_Integration() } func (s *PrepareWorkerIntegrationTestSuite) TestPrepareWorker() { - w, err := NewPrepareJobWorker() + w, err := NewPrepareJobWorker(s.db) assert.Nil(s.T(), err) assert.NotEmpty(s.T(), w) } @@ -323,7 +323,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestQueueExportJobs() { ms.On("GetJobPriority", mock.Anything, mock.Anything, mock.Anything).Return(int16(1)) worker := &PrepareJobWorker{svc: ms, v1Client: &client.MockBlueButtonClient{}, v2Client: &client.MockBlueButtonClient{}, r: s.r} - q := NewEnqueuer() + q := NewEnqueuer(s.db, database.GetPool()) a := &worker_types.JobEnqueueArgs{ ID: 33, } diff --git a/bcdaworker/queueing/worker_process_job.go b/bcdaworker/queueing/worker_process_job.go index bbd49ac2a..b3f6d0871 100644 --- a/bcdaworker/queueing/worker_process_job.go +++ b/bcdaworker/queueing/worker_process_job.go @@ -2,8 +2,8 @@ package queueing import ( "context" + "database/sql" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcdaworker/queueing/worker_types" "github.com/CMSgov/bcda-app/bcdaworker/repository/postgres" "github.com/CMSgov/bcda-app/bcdaworker/worker" @@ -15,6 +15,7 @@ import ( type JobWorker struct { river.WorkerDefaults[worker_types.JobEnqueueArgs] + connection *sql.DB } func (w *JobWorker) Work(ctx context.Context, rjob *river.Job[worker_types.JobEnqueueArgs]) error { @@ -31,7 +32,7 @@ func (w *JobWorker) Work(ctx context.Context, rjob *river.Job[worker_types.JobEn ctx, logger := log.SetCtxLogger(ctx, "transaction_id", rjob.Args.TransactionID) // TODO: use pgxv5 when available - mainDB := database.Connection + mainDB := w.connection workerInstance := worker.NewWorker(mainDB) repo := postgres.NewRepository(mainDB) From e30cd6e2fac182217401471c837d8512dd26e05b Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 12:28:44 -0400 Subject: [PATCH 10/28] Remove connection globals from several tests --- bcda/auth/ssas_test.go | 2 +- bcda/bcdacli/cli_test.go | 2 +- bcda/database/connection_test.go | 8 +++++--- bcda/database/database_test.go | 8 +++++--- bcda/database/databasetest/databasetest.go | 2 +- bcda/database/databasetest/databasetest_test.go | 2 +- bcda/database/pgx_test.go | 2 +- bcda/models/postgres/repository_test.go | 2 +- bcda/web/router_test.go | 4 ++-- bcdaworker/repository/postgres/repository_test.go | 2 +- bcdaworker/worker/worker_test.go | 2 +- db/migrations/migrations_test.go | 2 +- 12 files changed, 21 insertions(+), 17 deletions(-) diff --git a/bcda/auth/ssas_test.go b/bcda/auth/ssas_test.go index 67e0ec2d6..4c0b04418 100644 --- a/bcda/auth/ssas_test.go +++ b/bcda/auth/ssas_test.go @@ -71,7 +71,7 @@ func (s *SSASPluginTestSuite) SetupSuite() { origSSASClientID = conf.GetEnv("BCDA_SSAS_CLIENT_ID") origSSASSecret = conf.GetEnv("BCDA_SSAS_SECRET") - s.db = database.Connection + s.db = database.GetConnection() s.r = postgres.NewRepository(s.db) } diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index 996befd68..2f315382c 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -68,7 +68,7 @@ func (s *CLITestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.Connection + s.db = database.GetConnection() cmsID := testUtils.RandomHexID()[0:4] s.testACO = models.ACO{Name: uuid.New(), UUID: uuid.NewRandom(), ClientID: uuid.New(), CMSID: &cmsID} diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index d2e326baf..a4862acef 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -13,8 +13,9 @@ import ( func TestConnections(t *testing.T) { // Verify that we can initialize the package as expected - assert.NotNil(t, Connection) - assert.NoError(t, Connection.Ping()) + c := GetConnection() + assert.NotNil(t, c) + assert.NoError(t, c.Ping()) } // TestHealthCheck verifies that we are able to start the health check @@ -32,7 +33,8 @@ func TestConnectionHealthCheck(t *testing.T) { hook := test.NewGlobal() ctx, cancel := context.WithCancel(context.Background()) - startConnectionHealthCheck(ctx, Connection, 100*time.Microsecond) + c := GetConnection() + startConnectionHealthCheck(ctx, c, 100*time.Microsecond) // Let some time elapse to ensure we've successfully ran health checks time.Sleep(50 * time.Millisecond) cancel() diff --git a/bcda/database/database_test.go b/bcda/database/database_test.go index 777cdd20f..93c3caa14 100644 --- a/bcda/database/database_test.go +++ b/bcda/database/database_test.go @@ -9,8 +9,9 @@ import ( ) func TestDBOperations(t *testing.T) { - var q Queryable = &DB{Connection} - var e Executable = &DB{Connection} + c := GetConnection() + var q Queryable = &DB{c} + var e Executable = &DB{c} rows, err := q.QueryContext(context.Background(), constants.TestSelectNowSQL) assert.NoError(t, err) @@ -31,7 +32,8 @@ func TestDBOperations(t *testing.T) { } func TestTxOperations(t *testing.T) { - tx, err := Connection.Begin() + c := GetConnection() + tx, err := c.Begin() assert.NoError(t, err) defer func() { assert.NoError(t, tx.Rollback()) diff --git a/bcda/database/databasetest/databasetest.go b/bcda/database/databasetest/databasetest.go index 911659fdc..f65f1ab72 100644 --- a/bcda/database/databasetest/databasetest.go +++ b/bcda/database/databasetest/databasetest.go @@ -23,7 +23,7 @@ func CreateDatabase(t *testing.T, migrationPath string, cleanup bool) (*sql.DB, cfg, err := database.LoadConfig() assert.NoError(t, err) dsn := cfg.DatabaseURL - db := database.Connection + db := database.GetConnection() newDBName := strings.ReplaceAll(fmt.Sprintf("%s_%s", dbName(dsn), uuid.New()), "-", "_") newDSN := dsnPattern.ReplaceAllString(dsn, fmt.Sprintf("${conn}%s${options}", newDBName)) diff --git a/bcda/database/databasetest/databasetest_test.go b/bcda/database/databasetest/databasetest_test.go index b9153d050..9b500cf0a 100644 --- a/bcda/database/databasetest/databasetest_test.go +++ b/bcda/database/databasetest/databasetest_test.go @@ -28,7 +28,7 @@ func TestCreateDatabase(t *testing.T) { assert.NoError(t, db.Close()) }) - db := database.Connection + db := database.GetConnection() var count int assert.NoError(t, diff --git a/bcda/database/pgx_test.go b/bcda/database/pgx_test.go index 50878e9b5..2e0543321 100644 --- a/bcda/database/pgx_test.go +++ b/bcda/database/pgx_test.go @@ -11,7 +11,7 @@ import ( ) func TestPgxTxOperations(t *testing.T) { - conn, err := stdlib.AcquireConn(Connection) + conn, err := stdlib.AcquireConn(GetConnection()) assert.NoError(t, err) defer func() { assert.NoError(t, conn.Close()) diff --git a/bcda/models/postgres/repository_test.go b/bcda/models/postgres/repository_test.go index 2c8e430cc..24e0b5d22 100644 --- a/bcda/models/postgres/repository_test.go +++ b/bcda/models/postgres/repository_test.go @@ -40,7 +40,7 @@ func TestRepositoryTestSuite(t *testing.T) { } func (r *RepositoryTestSuite) SetupSuite() { - r.db = database.Connection + r.db = database.GetConnection() r.repository = postgres.NewRepository(r.db) } diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index 540467f01..3d98e4416 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -402,7 +402,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { mock := &auth.MockProvider{} setExpectedMockCalls(s, mock, token, aco, bearerString, cmsID) - db := database.Connection + db := s.connection postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) @@ -447,7 +447,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturnNOT403WhenACONOTBlacklisted() mock := &auth.MockProvider{} setExpectedMockCalls(s, mock, token, aco, bearerString, cmsID) - db := database.Connection + db := s.connection postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) diff --git a/bcdaworker/repository/postgres/repository_test.go b/bcdaworker/repository/postgres/repository_test.go index 286b22b51..55bb19020 100644 --- a/bcdaworker/repository/postgres/repository_test.go +++ b/bcdaworker/repository/postgres/repository_test.go @@ -30,7 +30,7 @@ func TestRepositoryTestSuite(t *testing.T) { } func (r *RepositoryTestSuite) SetupSuite() { - r.db = database.Connection + r.db = database.GetConnection() r.repository = postgres.NewRepository(r.db) } diff --git a/bcdaworker/worker/worker_test.go b/bcdaworker/worker/worker_test.go index 58fafa6b0..ec8b687e7 100644 --- a/bcdaworker/worker/worker_test.go +++ b/bcdaworker/worker/worker_test.go @@ -63,7 +63,7 @@ type WorkerTestSuite struct { } func (s *WorkerTestSuite) SetupSuite() { - s.db = database.Connection + s.db = database.GetConnection() s.r = postgres.NewRepository(s.db) s.w = NewWorker(s.db) diff --git a/db/migrations/migrations_test.go b/db/migrations/migrations_test.go index cf88d66ef..ad9e65f9e 100644 --- a/db/migrations/migrations_test.go +++ b/db/migrations/migrations_test.go @@ -48,7 +48,7 @@ func (s *MigrationTestSuite) SetupSuite() { // postgres://:@: / re := regexp.MustCompile(`(postgresql\:\/\/\S+\:\S+\@\S+\:\d+\/)(.*)(\?.*)`) - db := database.Connection + db := database.GetConnection() databaseURL := conf.GetEnv("DATABASE_URL") bcdaDB := fmt.Sprintf("migrate_test_bcda_%d", time.Now().Nanosecond()) From 82c98f083263e5c18c94ddd5c23c7e941041c6ca Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 13:58:36 -0400 Subject: [PATCH 11/28] Pass connection to health.go --- bcda/api/v1/api.go | 18 ++++++++++-------- bcda/api/v1/api_test.go | 2 +- bcda/health/health.go | 25 ++++++++++++++++--------- bcda/web/router.go | 5 ++--- bcda/web/router_test.go | 2 +- bcdaworker/main.go | 10 ++++++---- 6 files changed, 36 insertions(+), 26 deletions(-) diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index ae2b67988..1854f0490 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -28,8 +28,9 @@ import ( ) type ApiV1 struct { - handler *api.Handler - connection *sql.DB + handler *api.Handler + connection *sql.DB + healthChecker health.HealthChecker } func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV1 { @@ -42,10 +43,11 @@ func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV1 { if !ok { panic("Failed to configure resource DataTypes") - } else { - h := api.NewHandler(resources, "/v1/fhir", "v1", connection, pool) - return &ApiV1{handler: h, connection: connection} } + + hc := health.NewHealthChecker(connection) + h := api.NewHandler(resources, "/v1/fhir", "v1", connection, pool) + return &ApiV1{handler: h, connection: connection, healthChecker: hc} } /* @@ -415,11 +417,11 @@ func GetVersion(w http.ResponseWriter, r *http.Request) { } } -func HealthCheck(w http.ResponseWriter, r *http.Request) { +func (a ApiV1) HealthCheck(w http.ResponseWriter, r *http.Request) { m := make(map[string]string) - dbStatus, dbOK := health.IsDatabaseOK() - ssasStatus, ssasOK := health.IsSsasOK() + dbStatus, dbOK := a.healthChecker.IsDatabaseOK() + ssasStatus, ssasOK := a.healthChecker.IsSsasOK() m["database"] = dbStatus m["ssas"] = ssasStatus diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 42979ee2e..ee3f0fe91 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -563,7 +563,7 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { func (s *APITestSuite) TestHealthCheck() { req, err := http.NewRequest("GET", "/_health", nil) assert.Nil(s.T(), err) - handler := http.HandlerFunc(HealthCheck) + handler := http.HandlerFunc(s.apiV1.HealthCheck) handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) } diff --git a/bcda/health/health.go b/bcda/health/health.go index a870d8a4d..bb1a1a22a 100644 --- a/bcda/health/health.go +++ b/bcda/health/health.go @@ -1,16 +1,24 @@ package health import ( + "database/sql" + ssasClient "github.com/CMSgov/bcda-app/bcda/auth/client" "github.com/CMSgov/bcda-app/bcda/client" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/log" _ "github.com/jackc/pgx" ) -func IsDatabaseOK() (result string, ok bool) { - db := database.Connection - if err := db.Ping(); err != nil { +type HealthChecker struct { + db *sql.DB +} + +func NewHealthChecker(connection *sql.DB) HealthChecker { + return HealthChecker{db: connection} +} + +func (h HealthChecker) IsDatabaseOK() (result string, ok bool) { + if err := h.db.Ping(); err != nil { log.API.Error("Health check: database ping error: ", err.Error()) return "database ping error", false } @@ -18,9 +26,8 @@ func IsDatabaseOK() (result string, ok bool) { return "ok", true } -func IsWorkerDatabaseOK() (result string, ok bool) { - db := database.Connection - if err := db.Ping(); err != nil { +func (h HealthChecker) IsWorkerDatabaseOK() (result string, ok bool) { + if err := h.db.Ping(); err != nil { log.Worker.Error("Health check: database ping error: ", err.Error()) return "database ping error", false } @@ -28,7 +35,7 @@ func IsWorkerDatabaseOK() (result string, ok bool) { return "ok", true } -func IsBlueButtonOK() bool { +func (h HealthChecker) IsBlueButtonOK() bool { bbc, err := client.NewBlueButtonClient(client.NewConfig("/v1/fhir")) if err != nil { log.Worker.Error("Health check: Blue Button client error: ", err.Error()) @@ -44,7 +51,7 @@ func IsBlueButtonOK() bool { return true } -func IsSsasOK() (result string, ok bool) { +func (h HealthChecker) IsSsasOK() (result string, ok bool) { c, err := ssasClient.NewSSASClient() if err != nil { log.Auth.Errorf("no client for SSAS. no provider set; %s", err.Error()) diff --git a/bcda/web/router.go b/bcda/web/router.go index 332fedaa9..1fd70c1b0 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -53,9 +53,8 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { r.Get("/", userGuideRedirect) r.Get(`/{:(user_guide|encryption|decryption_walkthrough).html}`, userGuideRedirect) } - + apiV1 := v1.NewApiV1(connection, pool) r.Route("/api/v1", func(r chi.Router) { - apiV1 := v1.NewApiV1(connection, pool) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) @@ -93,7 +92,7 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { } r.Get(m.WrapHandler("/_version", v1.GetVersion)) - r.Get(m.WrapHandler("/_health", v1.HealthCheck)) + r.Get(m.WrapHandler("/_health", apiV1.HealthCheck)) r.Get(m.WrapHandler("/_auth", v1.GetAuthInfo)) return r } diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index 3d98e4416..d2797cea2 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -39,10 +39,10 @@ type RouterTestSuite struct { func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") + conf.SetEnv(s.T(), "BB_SERVER_LOCATION", "v1-server-location") s.connection = database.GetConnection() s.apiRouter = NewAPIRouter(s.connection, s.pool) s.dataRouter = NewDataRouter(s.connection) - } func (s *RouterTestSuite) getAPIRoute(route string) *http.Response { diff --git a/bcdaworker/main.go b/bcdaworker/main.go index 2c4969f17..9e9ab382f 100644 --- a/bcdaworker/main.go +++ b/bcdaworker/main.go @@ -10,6 +10,7 @@ import ( "time" "github.com/CMSgov/bcda-app/bcda/client" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/health" "github.com/CMSgov/bcda-app/bcda/utils" "github.com/CMSgov/bcda-app/bcdaworker/queueing" @@ -96,6 +97,7 @@ func waitForSig() { func main() { fmt.Println("Starting bcdaworker...") + healthChecker := health.NewHealthChecker(database.Connection) queue := queueing.StartRiver(utils.GetEnvInt("WORKER_POOL_SIZE", 4)) defer queue.StopRiver() @@ -106,7 +108,7 @@ func main() { for { select { case <-ticker.C: - logHealth() + logHealth(healthChecker) case <-quit: ticker.Stop() return @@ -118,20 +120,20 @@ func main() { waitForSig() } -func logHealth() { +func logHealth(healthChecker health.HealthChecker) { entry := log.Health logFields := logrus.Fields{} logFields["type"] = "health" logFields["id"] = uuid.NewRandom() - if _, ok := health.IsWorkerDatabaseOK(); ok { + if _, ok := healthChecker.IsWorkerDatabaseOK(); ok { logFields["db"] = "ok" } else { logFields["db"] = "error" } - if health.IsBlueButtonOK() { + if healthChecker.IsBlueButtonOK() { logFields["bb"] = "ok" } else { logFields["bb"] = "error" From 134d8fb4029a20ed84c7e317476df7abcc15b0a1 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 15:57:39 -0400 Subject: [PATCH 12/28] Refactor connection globals in cclf --- bcda/bcdacli/cli.go | 7 +--- bcda/cclf/cclf.go | 70 ++++++++++++++++--------------- bcda/cclf/cclf_test.go | 29 ++++++------- bcda/cclf/utils/cclfUtils.go | 8 ++-- bcda/cclf/utils/cclfUtils_test.go | 12 ++++-- bcda/lambda/cclf/main.go | 29 ++++++------- bcda/lambda/cclf/main_test.go | 9 +++- 7 files changed, 85 insertions(+), 79 deletions(-) diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index d6a0d9abb..39d651911 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -332,10 +332,7 @@ func setUpApp() *cli.App { } } - importer := cclf.CclfImporter{ - Logger: log.API, - FileProcessor: file_processor, - } + importer := cclf.NewCclfImporter(log.API, file_processor, connection) success, failure, skipped, err := importer.ImportCCLFDirectory(filePath) if err != nil { @@ -464,7 +461,7 @@ func setUpApp() *cli.App { return errors.New("Unsupported file type.") } } - err := cclfUtils.ImportCCLFPackage(acoSize, environment, ft) + err := cclfUtils.ImportCCLFPackage(connection, acoSize, environment, ft) return err }, }, diff --git a/bcda/cclf/cclf.go b/bcda/cclf/cclf.go index 9f35d35ab..42fb0c09f 100644 --- a/bcda/cclf/cclf.go +++ b/bcda/cclf/cclf.go @@ -5,6 +5,7 @@ import ( "bufio" "bytes" "context" + "database/sql" "fmt" "strconv" "time" @@ -15,7 +16,6 @@ import ( "github.com/CMSgov/bcda-app/bcda/cclf/metrics" "github.com/CMSgov/bcda-app/bcda/constants" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/utils" @@ -62,13 +62,18 @@ type CclfFileProcessor interface { // Manages the import process for CCLF files from a given source type CclfImporter struct { - Logger logrus.FieldLogger - FileProcessor CclfFileProcessor + logger logrus.FieldLogger + fileProcessor CclfFileProcessor + db *sql.DB +} + +func NewCclfImporter(logger logrus.FieldLogger, fileProcessor CclfFileProcessor, db *sql.DB) CclfImporter { + return CclfImporter{logger: logger, fileProcessor: fileProcessor, db: db} } func (importer CclfImporter) importCCLF0(ctx context.Context, zipMetadata *cclfZipMetadata) (*cclfFileValidator, error) { fileMetadata := zipMetadata.cclf0Metadata - importer.Logger.Infof("Importing CCLF0 file %s...", fileMetadata) + importer.logger.Infof("Importing CCLF0 file %s...", fileMetadata) const ( fileNumStart, fileNumEnd = 0, 7 @@ -82,7 +87,7 @@ func (importer CclfImporter) importCCLF0(ctx context.Context, zipMetadata *cclfZ rc, err := zipMetadata.cclf0File.Open() if err != nil { err = errors.Wrapf(err, "could not read file %s in CCLF0 archive %s", fileMetadata.name, zipMetadata.filePath) - importer.Logger.Error(err) + importer.logger.Error(err) return nil, err } defer rc.Close() @@ -98,20 +103,20 @@ func (importer CclfImporter) importCCLF0(ctx context.Context, zipMetadata *cclfZ if filetype == "CCLF8" { if validator != nil { err := fmt.Errorf("duplicate %v file type found from CCLF0 file", filetype) - importer.Logger.Error(err) + importer.logger.Error(err) return nil, err } count, err := strconv.Atoi(string(bytes.TrimSpace(b[totalRecordStart:totalRecordEnd]))) if err != nil { err = errors.Wrapf(err, "failed to parse %s record count from CCLF0 file", filetype) - importer.Logger.Error(err) + importer.logger.Error(err) return nil, err } length, err := strconv.Atoi(string(bytes.TrimSpace(b[recordLengthStart:recordLengthEnd]))) if err != nil { err = errors.Wrapf(err, "failed to parse %s record length from CCLF0 file", filetype) - importer.Logger.Error(err) + importer.logger.Error(err) return nil, err } @@ -121,41 +126,40 @@ func (importer CclfImporter) importCCLF0(ctx context.Context, zipMetadata *cclfZ } if validator != nil { - importer.Logger.Infof("Successfully imported CCLF0 file %s.", fileMetadata) + importer.logger.Infof("Successfully imported CCLF0 file %s.", fileMetadata) return validator, nil } err = fmt.Errorf("failed to parse CCLF8 from CCLF0 file %s", fileMetadata.name) - importer.Logger.Error(err) + importer.logger.Error(err) return nil, err } func (importer CclfImporter) importCCLF8(ctx context.Context, zipMetadata *cclfZipMetadata, validator cclfFileValidator) (err error) { fileMetadata := zipMetadata.cclf8Metadata - db := database.Connection - repository := postgres.NewRepository(db) + repository := postgres.NewRepository(importer.db) exists, err := repository.GetCCLFFileExistsByName(ctx, fileMetadata.name) if err != nil { err = errors.Wrapf(err, "failed to check existence of CCLF%d file", fileMetadata.cclfNum) - importer.Logger.Error(err) + importer.logger.Error(err) return err } if exists { - importer.Logger.Infof("CCL%d file %s already exists in database, skipping import...", fileMetadata.cclfNum, fileMetadata) + importer.logger.Infof("CCL%d file %s already exists in database, skipping import...", fileMetadata.cclfNum, fileMetadata) return nil } - importer.Logger.Infof("Importing CCLF%d file %s...", fileMetadata.cclfNum, fileMetadata) + importer.logger.Infof("Importing CCLF%d file %s...", fileMetadata.cclfNum, fileMetadata) - conn, err := stdlib.AcquireConn(db) - defer utils.CloseAndLog(logrus.WarnLevel, func() error { return stdlib.ReleaseConn(db, conn) }) + conn, err := stdlib.AcquireConn(importer.db) + defer utils.CloseAndLog(logrus.WarnLevel, func() error { return stdlib.ReleaseConn(importer.db, conn) }) tx, err := conn.BeginEx(ctx, nil) if err != nil { err = fmt.Errorf("failed to start transaction: %w", err) - importer.Logger.Error(err) + importer.logger.Error(err) return err } @@ -165,7 +169,7 @@ func (importer CclfImporter) importCCLF8(ctx context.Context, zipMetadata *cclfZ defer func() { if err != nil { if err1 := tx.Rollback(); err1 != nil { - importer.Logger.Warnf("Failed to rollback transaction %s", err.Error()) + importer.logger.Warnf("Failed to rollback transaction %s", err.Error()) } return } @@ -186,7 +190,7 @@ func (importer CclfImporter) importCCLF8(ctx context.Context, zipMetadata *cclfZ cclfFile.ID, err = rtx.CreateCCLFFile(ctx, cclfFile) if err != nil { err = errors.Wrapf(err, "could not create CCLF%d file record", fileMetadata.cclfNum) - importer.Logger.Error(err) + importer.logger.Error(err) return err } @@ -195,37 +199,37 @@ func (importer CclfImporter) importCCLF8(ctx context.Context, zipMetadata *cclfZ rc, err := zipMetadata.cclf8File.Open() if err != nil { err = errors.Wrapf(err, "could not read file %s for CCLF%d in archive %s", cclfFile.Name, fileMetadata.cclfNum, zipMetadata.filePath) - importer.Logger.Error(err) + importer.logger.Error(err) return err } defer rc.Close() sc := bufio.NewScanner(rc) - importedCount, recordCount, err := CopyFrom(ctx, tx, sc, cclfFile.ID, utils.GetEnvInt("CCLF_IMPORT_STATUS_RECORDS_INTERVAL", 10000), importer.Logger, validator.maxRecordLength) + importedCount, recordCount, err := CopyFrom(ctx, tx, sc, cclfFile.ID, utils.GetEnvInt("CCLF_IMPORT_STATUS_RECORDS_INTERVAL", 10000), importer.logger, validator.maxRecordLength) if err != nil { return errors.Wrap(err, "failed to copy data to beneficiaries table") } if recordCount > validator.totalRecordCount { err := fmt.Errorf("unexpected number of records imported for file %s (expected: %d, actual: %d)", fileMetadata.name, validator.totalRecordCount, recordCount) - importer.Logger.Error(err) + importer.logger.Error(err) return err } err = rtx.UpdateCCLFFileImportStatus(ctx, fileMetadata.fileID, constants.ImportComplete) if err != nil { err = errors.Wrapf(err, "could not update cclf file record for file: %s.", fileMetadata.name) - importer.Logger.Error(err) + importer.logger.Error(err) } if err = tx.Commit(); err != nil { - importer.Logger.Error(err.Error()) + importer.logger.Error(err.Error()) failMsg := fmt.Sprintf("failed to commit transaction for CCLF%d import file %s", fileMetadata.cclfNum, fileMetadata.name) return errors.Wrap(err, failMsg) } successMsg := fmt.Sprintf("Successfully imported %d records from CCLF%d file %s.", importedCount, fileMetadata.cclfNum, fileMetadata.name) - importer.Logger.WithFields(logrus.Fields{"imported_count": importedCount}).Info(successMsg) + importer.logger.WithFields(logrus.Fields{"imported_count": importedCount}).Info(successMsg) return nil } @@ -238,7 +242,7 @@ func (importer CclfImporter) ImportCCLFDirectory(filePath string) (success, fail // We are not going to create any children from this parent so we can // safely ignored the returned context. _, c := metrics.NewParent(ctx, "ImportCCLFDirectory#sortCCLFArchives") - cclfMap, skipped, failure, err := importer.FileProcessor.LoadCclfFiles(filePath) + cclfMap, skipped, failure, err := importer.fileProcessor.LoadCclfFiles(filePath) c() if err != nil { @@ -246,7 +250,7 @@ func (importer CclfImporter) ImportCCLFDirectory(filePath string) (success, fail } if len(cclfMap) == 0 { - importer.Logger.Info("Did not find any CCLF files in directory -- returning safely.") + importer.logger.Info("Did not find any CCLF files in directory -- returning safely.") return 0, failure, skipped, err } @@ -259,7 +263,7 @@ func (importer CclfImporter) ImportCCLFDirectory(filePath string) (success, fail cclfvalidator, err := importer.importCCLF0(ctx, zipMetadata) if err != nil { - importer.Logger.Errorf("Failed to import CCLF0 file: %s, Skipping CCLF8 file: %s ", zipMetadata.cclf0Metadata, zipMetadata.cclf8Metadata) + importer.logger.Errorf("Failed to import CCLF0 file: %s, Skipping CCLF8 file: %s ", zipMetadata.cclf0Metadata, zipMetadata.cclf8Metadata) failure++ skipped += 2 } else { @@ -267,7 +271,7 @@ func (importer CclfImporter) ImportCCLFDirectory(filePath string) (success, fail } if err = importer.importCCLF8(ctx, zipMetadata, *cclfvalidator); err != nil { - importer.Logger.Errorf("Failed to import CCLF8 file: %s %s", zipMetadata.cclf8Metadata, err) + importer.logger.Errorf("Failed to import CCLF8 file: %s %s", zipMetadata.cclf8Metadata, err) failure++ } else { zipMetadata.imported = true @@ -280,15 +284,15 @@ func (importer CclfImporter) ImportCCLFDirectory(filePath string) (success, fail if err = func() error { ctx, c := metrics.NewParent(ctx, "ImportCCLFDirectory#cleanupCCLF") defer c() - _, err := importer.FileProcessor.CleanUpCCLF(ctx, cclfMap) + _, err := importer.fileProcessor.CleanUpCCLF(ctx, cclfMap) return err }(); err != nil { - importer.Logger.Error(err) + importer.logger.Error(err) } if failure > 0 { err = errors.New(fmt.Sprintf("Failed to import %d files", failure)) - importer.Logger.Error(err) + importer.logger.Error(err) } else { err = nil } diff --git a/bcda/cclf/cclf_test.go b/bcda/cclf/cclf_test.go index 4654fdc53..457b269e0 100644 --- a/bcda/cclf/cclf_test.go +++ b/bcda/cclf/cclf_test.go @@ -60,10 +60,7 @@ func (s *CCLFTestSuite) SetupTest() { }, } - s.importer = CclfImporter{ - Logger: log.API, - FileProcessor: file_processor, - } + s.importer = NewCclfImporter(log.API, file_processor, s.db) } func (s *CCLFTestSuite) SetupSuite() { @@ -76,7 +73,7 @@ func (s *CCLFTestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.Connection + s.db = database.GetConnection() } func (s *CCLFTestSuite) TearDownSuite() { @@ -97,7 +94,7 @@ func (s *CCLFTestSuite) TestImportCCLF0() { assert := assert.New(s.T()) cclfZipfilePath := filepath.Join(s.basePath, "cclf/archives/valid/T.BCD.A0001.ZCY18.D181120.T1000000") - metadata, zipCloser1 := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000011", "", models.FileTypeDefault) + metadata, zipCloser1 := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000011", "", models.FileTypeDefault) defer zipCloser1() // positive @@ -107,7 +104,7 @@ func (s *CCLFTestSuite) TestImportCCLF0() { // missing cclf8 from cclf0 cclfZipfilePath = filepath.Join(s.basePath, "cclf/archives/0/missing_data/T.BCD.A0001.ZCY18.D181120.T1000000") - metadata, zipCloser2 := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000011", "", models.FileTypeDefault) + metadata, zipCloser2 := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000011", "", models.FileTypeDefault) defer zipCloser2() _, err = s.importer.importCCLF0(ctx, metadata) @@ -115,7 +112,7 @@ func (s *CCLFTestSuite) TestImportCCLF0() { // duplicate file types from cclf0 cclfZipfilePath = filepath.Join(s.basePath, "cclf/archives/0/missing_data/T.BCD.A0001.ZCY18.D181122.T1000000") - metadata, zipCloser3 := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000013", "", models.FileTypeDefault) + metadata, zipCloser3 := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", cclfZipfilePath, "T.BCD.A0001.ZC0Y18.D181120.T1000013", "", models.FileTypeDefault) defer zipCloser3() _, err = s.importer.importCCLF0(ctx, metadata) @@ -123,7 +120,7 @@ func (s *CCLFTestSuite) TestImportCCLF0() { //invalid record count cclfZipfilePath = filepath.Join(s.basePath, "cclf/archives/0/invalid/T.A0001.ACO.ZC0Y18.D181120.Z1000000") - metadata, zipCloser4 := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", cclfZipfilePath, "T.A0001.ACO.ZC0Y18.D181120.Z1000011", "", models.FileTypeDefault) + metadata, zipCloser4 := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", cclfZipfilePath, "T.A0001.ACO.ZC0Y18.D181120.Z1000011", "", models.FileTypeDefault) defer zipCloser4() _, err = s.importer.importCCLF0(ctx, metadata) @@ -131,7 +128,7 @@ func (s *CCLFTestSuite) TestImportCCLF0() { //invalid record length cclfZipfilePath = filepath.Join(s.basePath, "cclf/archives/0/invalid/T.BCD.ACOB.ZC0Y18.D181120.E0001000") - metadata, zipCloser5 := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", cclfZipfilePath, "T.A0001.ACO.ZC0Y18.D181120.E1000011", "", models.FileTypeDefault) + metadata, zipCloser5 := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", cclfZipfilePath, "T.A0001.ACO.ZC0Y18.D181120.E1000011", "", models.FileTypeDefault) defer zipCloser5() _, err = s.importer.importCCLF0(ctx, metadata) @@ -181,7 +178,7 @@ func (s *CCLFTestSuite) TestImportCCLF8() { acoID := "A0001" fileTime, _ := time.Parse(time.RFC3339, constants.TestFileTime) - metadata, zipCloser := buildZipMetadata(s.T(), s.importer.FileProcessor, acoID, filepath.Join(s.basePath, constants.CCLF8CompPath), "", constants.CCLF8Name, models.FileTypeDefault) + metadata, zipCloser := buildZipMetadata(s.T(), s.importer.fileProcessor, acoID, filepath.Join(s.basePath, constants.CCLF8CompPath), "", constants.CCLF8Name, models.FileTypeDefault) metadata.cclf8Metadata.timestamp = fileTime defer zipCloser() @@ -240,7 +237,7 @@ func (s *CCLFTestSuite) TestImportCCLF8DBErrors() { defer postgrestest.DeleteCCLFFilesByCMSID(s.T(), s.db, "A0002") - metadata, zipCloser := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", filepath.Join(s.basePath, constants.CCLF8CompPath), "", constants.CCLF8Name, models.FileTypeDefault) + metadata, zipCloser := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", filepath.Join(s.basePath, constants.CCLF8CompPath), "", constants.CCLF8Name, models.FileTypeDefault) defer zipCloser() validator := cclfFileValidator{ @@ -268,7 +265,7 @@ func (s *CCLFTestSuite) TestImportCCLF8_alreadyExists() { cclfFile := &models.CCLFFile{CCLFNum: 8, ACOCMSID: acoID, Timestamp: time.Now(), PerformanceYear: 18, Name: constants.CCLF8Name} postgrestest.CreateCCLFFile(s.T(), s.db, cclfFile) - metadata, zipCloser := buildZipMetadata(s.T(), s.importer.FileProcessor, "A0001", filepath.Join(s.basePath, constants.CCLF8CompPath), "", cclfFile.Name, cclfFile.Type) + metadata, zipCloser := buildZipMetadata(s.T(), s.importer.fileProcessor, "A0001", filepath.Join(s.basePath, constants.CCLF8CompPath), "", cclfFile.Name, cclfFile.Type) defer zipCloser() validator := cclfFileValidator{ @@ -298,7 +295,7 @@ func (s *CCLFTestSuite) TestImportCCLF8_Invalid() { fileName, cclfName := createTemporaryCCLF8ZipFile(s.T(), "A 1") defer os.Remove(fileName) - metadata, zipCloser := buildZipMetadata(s.T(), s.importer.FileProcessor, "1234", fileName, "", cclfName, models.FileTypeDefault) + metadata, zipCloser := buildZipMetadata(s.T(), s.importer.fileProcessor, "1234", fileName, "", cclfName, models.FileTypeDefault) defer zipCloser() validator := cclfFileValidator{ @@ -312,7 +309,7 @@ func (s *CCLFTestSuite) TestImportCCLF8_Invalid() { } func (s *CCLFTestSuite) TestImportRunoutCCLF() { - db := database.Connection + db := s.db cmsID := "RNOUT" defer func() { @@ -334,7 +331,7 @@ func (s *CCLFTestSuite) TestImportRunoutCCLF() { fileName, cclfName := createTemporaryCCLF8ZipFile(s.T(), mbi) defer os.Remove(fileName) - metadata, zipCloser := buildZipMetadata(s.T(), s.importer.FileProcessor, "1234", fileName, "", cclfName, tt.fileType) + metadata, zipCloser := buildZipMetadata(s.T(), s.importer.fileProcessor, "1234", fileName, "", cclfName, tt.fileType) defer zipCloser() validator := cclfFileValidator{ diff --git a/bcda/cclf/utils/cclfUtils.go b/bcda/cclf/utils/cclfUtils.go index 844a6fed7..c72fd8dd6 100644 --- a/bcda/cclf/utils/cclfUtils.go +++ b/bcda/cclf/utils/cclfUtils.go @@ -3,6 +3,7 @@ package testutils import ( "archive/zip" "crypto/rand" + "database/sql" "errors" "fmt" "io" @@ -23,7 +24,7 @@ import ( // ImportCCLFPackage will copy the appropriate synthetic CCLF files, rename them, // begin the import of those files and delete them from the place they were copied to after successful import. -func ImportCCLFPackage(acoSize, environment string, fileType models.CCLFFileType) (err error) { +func ImportCCLFPackage(connection *sql.DB, acoSize, environment string, fileType models.CCLFFileType) (err error) { dir, err := os.MkdirTemp("", "*") if err != nil { @@ -148,10 +149,7 @@ func ImportCCLFPackage(acoSize, environment string, fileType models.CCLFFileType }, } - importer := cclf.CclfImporter{ - Logger: log.API, - FileProcessor: file_processor, - } + importer := cclf.NewCclfImporter(log.API, file_processor, connection) success, failure, skipped, err := importer.ImportCCLFDirectory(dir) if err != nil { diff --git a/bcda/cclf/utils/cclfUtils_test.go b/bcda/cclf/utils/cclfUtils_test.go index 0fbf8948e..cc4bae1c7 100644 --- a/bcda/cclf/utils/cclfUtils_test.go +++ b/bcda/cclf/utils/cclfUtils_test.go @@ -2,10 +2,12 @@ package testutils import ( "archive/zip" + "database/sql" "fmt" "os" "testing" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/utils" "github.com/CMSgov/bcda-app/conf" @@ -16,12 +18,14 @@ import ( type CCLFUtilTestSuite struct { suite.Suite + db *sql.DB } var origDate string func (s *CCLFUtilTestSuite) SetupSuite() { origDate = conf.GetEnv("CCLF_REF_DATE") + s.db = database.GetConnection() } func (s *CCLFUtilTestSuite) SetupTest() { @@ -39,19 +43,19 @@ func TestCCLFTestSuite(t *testing.T) { func (s *CCLFUtilTestSuite) TestImportInvalidSizeACO() { assert := assert.New(s.T()) conf.SetEnv(s.T(), "CCLF_REF_DATE", "D190617") - err := ImportCCLFPackage("NOTREAL", "test", models.FileTypeDefault) + err := ImportCCLFPackage(s.db, "NOTREAL", "test", models.FileTypeDefault) assert.EqualError(err, "invalid argument for ACO size") } func (s *CCLFUtilTestSuite) TestImportInvalidEnvironment() { assert := assert.New(s.T()) - err := ImportCCLFPackage("dev", "environment", models.FileTypeDefault) + err := ImportCCLFPackage(s.db, "dev", "environment", models.FileTypeDefault) assert.EqualError(err, "invalid argument for environment") } func (s *CCLFUtilTestSuite) TestInvalidFilePath() { assert := assert.New(s.T()) - err := ImportCCLFPackage("improved-small", "test-partially-adjudicated", models.FileTypeRunout) + err := ImportCCLFPackage(s.db, "improved-small", "test-partially-adjudicated", models.FileTypeRunout) assert.EqualError(err, "unable to locate ../../../../../../shared_files/cclf/files/synthetic/test-partially-adjudicated/small in file path") } @@ -100,7 +104,7 @@ func (s *CCLFUtilTestSuite) TestImport() { for _, fileType := range []models.CCLFFileType{models.FileTypeDefault, models.FileTypeRunout} { s.T().Run(fmt.Sprintf("ACO Size %s - Env %s - File Type %s", tt.acoSize, tt.env, fileType), func(t *testing.T) { - err := ImportCCLFPackage(tt.acoSize, tt.env, fileType) + err := ImportCCLFPackage(s.db, tt.acoSize, tt.env, fileType) assert.NoError(t, err) }) } diff --git a/bcda/lambda/cclf/main.go b/bcda/lambda/cclf/main.go index 93d30e376..aaf2518d8 100644 --- a/bcda/lambda/cclf/main.go +++ b/bcda/lambda/cclf/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "errors" "fmt" "os" @@ -24,7 +25,7 @@ func main() { // Localstack is a local-development server that mimics AWS. The endpoint variable // should only be set in local development to avoid making external calls to a real AWS account. if os.Getenv("LOCAL_STACK_ENDPOINT") != "" { - res, err := handleCclfImport(os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) + res, err := handleCclfImport(database.GetConnection(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) if err != nil { fmt.Printf("Failed to run opt out import: %s\n", err.Error()) } else { @@ -39,6 +40,7 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) + connection := database.GetConnection() s3Event, err := bcdaaws.ParseSQSEvent(sqsEvent) @@ -66,9 +68,9 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st filepath := fmt.Sprintf("%s/%s", e.S3.Bucket.Name, e.S3.Object.Key) logger.Infof("Reading %s event for file %s", e.EventName, filepath) if cclf.CheckIfAttributionCSVFile(e.S3.Object.Key) { - return handleCSVImport(s3AssumeRoleArn, filepath) + return handleCSVImport(connection, s3AssumeRoleArn, filepath) } else { - return handleCclfImport(s3AssumeRoleArn, filepath) + return handleCclfImport(connection, s3AssumeRoleArn, filepath) } } } @@ -77,7 +79,7 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st return "", nil } -func handleCSVImport(s3AssumeRoleArn, s3ImportPath string) (string, error) { +func handleCSVImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) @@ -85,7 +87,7 @@ func handleCSVImport(s3AssumeRoleArn, s3ImportPath string) (string, error) { importer := cclf.CSVImporter{ Logger: logger, - Database: database.Connection, + Database: connection, FileProcessor: &cclf.S3FileProcessor{ Handler: optout.S3FileHandler{ Logger: logger, @@ -128,23 +130,22 @@ func loadBCDAParams() error { return nil } -func handleCclfImport(s3AssumeRoleArn, s3ImportPath string) (string, error) { +func handleCclfImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) logger = logger.WithFields(logrus.Fields{"import_filename": s3ImportPath}) - importer := cclf.CclfImporter{ - Logger: logger, - FileProcessor: &cclf.S3FileProcessor{ - Handler: optout.S3FileHandler{ - Logger: logger, - Endpoint: os.Getenv("LOCAL_STACK_ENDPOINT"), - AssumeRoleArn: s3AssumeRoleArn, - }, + fileProcessor := cclf.S3FileProcessor{ + Handler: optout.S3FileHandler{ + Logger: logger, + Endpoint: os.Getenv("LOCAL_STACK_ENDPOINT"), + AssumeRoleArn: s3AssumeRoleArn, }, } + importer := cclf.NewCclfImporter(logger, &fileProcessor, connection) + success, failure, skipped, err := importer.ImportCCLFDirectory(s3ImportPath) if err != nil { diff --git a/bcda/lambda/cclf/main_test.go b/bcda/lambda/cclf/main_test.go index 8655de33c..a3474c0eb 100644 --- a/bcda/lambda/cclf/main_test.go +++ b/bcda/lambda/cclf/main_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "errors" "fmt" "testing" @@ -16,8 +17,12 @@ import ( type AttributionImportMainSuite struct { suite.Suite + db *sql.DB } +func (s *AttributionImportMainSuite) SetupSuite() { + s.db = database.GetConnection() +} func TestAttributionImportMainSuite(t *testing.T) { suite.Run(t, new(AttributionImportMainSuite)) } @@ -54,8 +59,8 @@ func (s *AttributionImportMainSuite) TestImportCCLFDirectory() { } for _, tc := range tests { - postgrestest.DeleteCCLFFilesByCMSID(s.T(), database.Connection, targetACO) - defer postgrestest.DeleteCCLFFilesByCMSID(s.T(), database.Connection, targetACO) + postgrestest.DeleteCCLFFilesByCMSID(s.T(), s.db, targetACO) + defer postgrestest.DeleteCCLFFilesByCMSID(s.T(), s.db, targetACO) path, cleanup := testUtils.CopyToS3(s.T(), tc.path) defer cleanup() From 607ce188f237a123e2c5e1d209c2fcc1555cd9ec Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 16:10:11 -0400 Subject: [PATCH 13/28] Refactor connection global in admin create group lambda --- bcda/lambda/admin_create_group/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bcda/lambda/admin_create_group/main.go b/bcda/lambda/admin_create_group/main.go index 4bb2850bb..3d2acec8c 100644 --- a/bcda/lambda/admin_create_group/main.go +++ b/bcda/lambda/admin_create_group/main.go @@ -61,7 +61,7 @@ func handler(ctx context.Context, event json.RawMessage) error { } slackClient := slack.New(slackToken) - db := database.Connection + db := database.GetConnection() r := postgres.NewRepository(db) ssas, err := client.NewSSASClient() From 2749820acf4dd404fc35773119bc14b732cdaf9e Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 16:16:09 -0400 Subject: [PATCH 14/28] Refactor db globals in optout lambda --- bcda/lambda/optout/main.go | 10 ++++++---- bcda/lambda/optout/main_test.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/bcda/lambda/optout/main.go b/bcda/lambda/optout/main.go index b4782271f..8b39acbfe 100644 --- a/bcda/lambda/optout/main.go +++ b/bcda/lambda/optout/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "fmt" "os" "strings" @@ -25,7 +26,7 @@ func main() { // Localstack is a local-development server that mimics AWS. The endpoint variable // should only be set in local development to avoid making external calls to a real AWS account. if os.Getenv("LOCAL_STACK_ENDPOINT") != "" { - res, err := handleOptOutImport(os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) + res, err := handleOptOutImport(database.GetConnection(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) if err != nil { fmt.Printf("Failed to run opt out import: %s\n", err.Error()) } else { @@ -40,6 +41,7 @@ func optOutImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (string, env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) + db := database.GetConnection() s3Event, err := bcdaaws.ParseSQSEvent(sqsEvent) @@ -60,7 +62,7 @@ func optOutImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (string, dir := bcdaaws.ParseS3Directory(e.S3.Bucket.Name, e.S3.Object.Key) logger.Infof("Reading %s event for directory %s", e.EventName, dir) - return handleOptOutImport(s3AssumeRoleArn, dir) + return handleOptOutImport(db, s3AssumeRoleArn, dir) } } @@ -84,11 +86,11 @@ func loadBfdS3Params() (string, error) { return param, nil } -func handleOptOutImport(s3AssumeRoleArn, s3ImportPath string) (string, error) { +func handleOptOutImport(db *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) - repo := postgres.NewRepository(database.Connection) + repo := postgres.NewRepository(db) importer := suppression.OptOutImporter{ FileHandler: &optout.S3FileHandler{ diff --git a/bcda/lambda/optout/main_test.go b/bcda/lambda/optout/main_test.go index 7205c46ad..ed5a42f77 100644 --- a/bcda/lambda/optout/main_test.go +++ b/bcda/lambda/optout/main_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "fmt" "testing" @@ -16,6 +17,11 @@ import ( type OptOutImportMainSuite struct { suite.Suite + db *sql.DB +} + +func (s *OptOutImportMainSuite) SetupSuite() { + s.db = database.GetConnection() } func TestOptOutImportMainSuite(t *testing.T) { @@ -45,14 +51,14 @@ func (s *OptOutImportMainSuite) TestOptOutImportHandlerSuccess() { assert.Contains(res, "Files failed: 0") assert.Contains(res, "Files skipped: 0") - fs := postgrestest.GetSuppressionFileByName(s.T(), database.Connection, + fs := postgrestest.GetSuppressionFileByName(s.T(), s.db, "T#EFT.ON.ACO.NGD1800.DPRF.D181120.T1000010", "T#EFT.ON.ACO.NGD1800.DPRF.D190816.T0241391") assert.Len(fs, 2) for _, f := range fs { - postgrestest.DeleteSuppressionFileByID(s.T(), database.Connection, f.ID) + postgrestest.DeleteSuppressionFileByID(s.T(), s.db, f.ID) } } From db9670029906c4da260696301d5b6753a52af3ab Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 18:19:03 -0400 Subject: [PATCH 15/28] Refactor globals in ratelimit middleware --- bcda/web/middleware/ratelimit.go | 81 ++++++++++--------- bcda/web/middleware/ratelimit_test.go | 109 ++++++++++++++------------ bcda/web/router.go | 3 +- 3 files changed, 103 insertions(+), 90 deletions(-) diff --git a/bcda/web/middleware/ratelimit.go b/bcda/web/middleware/ratelimit.go index 59ed7e9e0..f35bbce3e 100644 --- a/bcda/web/middleware/ratelimit.go +++ b/bcda/web/middleware/ratelimit.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "database/sql" "fmt" "net/http" "net/url" @@ -10,7 +11,6 @@ import ( "time" "github.com/CMSgov/bcda-app/bcda/auth" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/responseutils" @@ -21,58 +21,57 @@ import ( "github.com/pkg/errors" ) -var ( +type RateLimitMiddleware struct { + config *service.Config repository models.Repository jobTimeout time.Duration retrySeconds int -) +} -func init() { - repository = postgres.NewRepository(database.Connection) - jobTimeout = time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24)) - retrySeconds = utils.GetEnvInt("CLIENT_RETRY_AFTER_IN_SECONDS", 0) +func NewRateLimitMiddleware(config *service.Config, db *sql.DB) RateLimitMiddleware { + r := postgres.NewRepository(db) + jt := time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24)) + rs := utils.GetEnvInt("CLIENT_RETRY_AFTER_IN_SECONDS", 0) + return RateLimitMiddleware{config: config, repository: r, jobTimeout: jt, retrySeconds: rs} } -func CheckConcurrentJobs(cfg *service.Config) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - ad, ok := r.Context().Value(auth.AuthDataContextKey).(auth.AuthData) - if !ok { - panic("AuthData should be set before calling this handler") - } +func (m RateLimitMiddleware) CheckConcurrentJobs(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ad, ok := r.Context().Value(auth.AuthDataContextKey).(auth.AuthData) + if !ok { + panic("AuthData should be set before calling this handler") + } - rp, ok := GetRequestParamsFromCtx(r.Context()) - if !ok { - panic("RequestParameters should be set before calling this handler") - } + rp, ok := GetRequestParamsFromCtx(r.Context()) + if !ok { + panic("RequestParameters should be set before calling this handler") + } - rw, _ := getResponseWriterFromRequestPath(w, r) - if rw == nil { - return - } + rw, _ := getResponseWriterFromRequestPath(w, r) + if rw == nil { + return + } - acoID := uuid.Parse(ad.ACOID) + acoID := uuid.Parse(ad.ACOID) - if shouldRateLimit(cfg.RateLimitConfig, ad.CMSID) { - pendingAndInProgressJobs, err := repository.GetJobs(r.Context(), acoID, models.JobStatusInProgress, models.JobStatusPending) - if err != nil { - logger := log.GetCtxLogger(r.Context()) - logger.Error(fmt.Errorf("failed to lookup pending and in-progress jobs: %w", err)) - rw.Exception(r.Context(), w, http.StatusInternalServerError, responseutils.InternalErr, "") + if shouldRateLimit(m.config.RateLimitConfig, ad.CMSID) { + pendingAndInProgressJobs, err := m.repository.GetJobs(r.Context(), acoID, models.JobStatusInProgress, models.JobStatusPending) + if err != nil { + logger := log.GetCtxLogger(r.Context()) + logger.Error(fmt.Errorf("failed to lookup pending and in-progress jobs: %w", err)) + rw.Exception(r.Context(), w, http.StatusInternalServerError, responseutils.InternalErr, "") + return + } + if len(pendingAndInProgressJobs) > 0 { + if m.hasDuplicates(r.Context(), pendingAndInProgressJobs, rp.ResourceTypes, rp.Version, rp.RequestURL) { + w.Header().Set("Retry-After", strconv.Itoa(m.retrySeconds)) + w.WriteHeader(http.StatusTooManyRequests) return } - if len(pendingAndInProgressJobs) > 0 { - if hasDuplicates(r.Context(), pendingAndInProgressJobs, rp.ResourceTypes, rp.Version, rp.RequestURL) { - w.Header().Set("Retry-After", strconv.Itoa(retrySeconds)) - w.WriteHeader(http.StatusTooManyRequests) - return - } - } } - next.ServeHTTP(w, r) } - return http.HandlerFunc(fn) - } + next.ServeHTTP(w, r) + }) } func shouldRateLimit(config service.RateLimitConfig, cmsID string) bool { @@ -82,7 +81,7 @@ func shouldRateLimit(config service.RateLimitConfig, cmsID string) bool { return false } -func hasDuplicates(ctx context.Context, pendingAndInProgressJobs []*models.Job, types []string, version string, newRequestUrl string) bool { +func (m RateLimitMiddleware) hasDuplicates(ctx context.Context, pendingAndInProgressJobs []*models.Job, types []string, version string, newRequestUrl string) bool { logger := log.GetCtxLogger(ctx) typeSet := make(map[string]struct{}, len(types)) @@ -116,7 +115,7 @@ func hasDuplicates(ctx context.Context, pendingAndInProgressJobs []*models.Job, } // If the job has timed-out we will allow new job to be created - if time.Now().After(job.CreatedAt.Add(jobTimeout)) { + if time.Now().After(job.CreatedAt.Add(m.jobTimeout)) { logger.Info("Existing job timed out -- ignoring existing job") continue } diff --git a/bcda/web/middleware/ratelimit_test.go b/bcda/web/middleware/ratelimit_test.go index 638470954..1d7e2a538 100644 --- a/bcda/web/middleware/ratelimit_test.go +++ b/bcda/web/middleware/ratelimit_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "database/sql" "errors" "net/http" "net/http/httptest" @@ -10,6 +11,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/service" logAPI "github.com/CMSgov/bcda-app/log" @@ -17,10 +19,23 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" ) -func TestNoConcurrentJobs(t *testing.T) { +type RateLimitMiddlewareTestSuite struct { + suite.Suite + db *sql.DB +} + +func TestRateLimitMiddlewareTestSuite(t *testing.T) { + suite.Run(t, new(RateLimitMiddlewareTestSuite)) +} +func (s *RateLimitMiddlewareTestSuite) SetupSuite() { + s.db = database.GetConnection() +} +func (s *RateLimitMiddlewareTestSuite) TestNoConcurrentJobs() { cfg := &service.Config{RateLimitConfig: service.RateLimitConfig{All: true}} + middleware := NewRateLimitMiddleware(cfg, s.db) tests := []struct { name string rp RequestParameters @@ -36,25 +51,24 @@ func TestNoConcurrentJobs(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockRepo := &models.MockRepository{} - mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( - tt.jobs, - nil, - ) - repository = mockRepo - - rr := httptest.NewRecorder() - middleware := CheckConcurrentJobs(cfg) - middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // Conncurrent job test route check, blank return for overrides - })).ServeHTTP(rr, getRequest(tt.rp)) - assert.Equal(t, http.StatusOK, rr.Code) - }) + mockRepo := &models.MockRepository{} + mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + tt.jobs, + nil, + ) + middleware.repository = mockRepo + + rr := httptest.NewRecorder() + middleware.CheckConcurrentJobs(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Conncurrent job test route check, blank return for overrides + })).ServeHTTP(rr, getRequest(tt.rp)) + + assert.Equal(s.T(), http.StatusOK, rr.Code) } } -func TestHasConcurrentJobs(t *testing.T) { +func (s *RateLimitMiddlewareTestSuite) TestHasConcurrentJobs() { cfg := &service.Config{RateLimitConfig: service.RateLimitConfig{All: true}} + middleware := NewRateLimitMiddleware(cfg, s.db) // These jobs are not considered when determine duplicate jobs ignoredJobs := []*models.Job{ {RequestURL: "http://a{b"}, // InvalidURL @@ -77,38 +91,40 @@ func TestHasConcurrentJobs(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockRepo := &models.MockRepository{} - mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( - append(ignoredJobs, tt.additionalJobs...), - nil, - ) - repository = mockRepo - - rr := httptest.NewRecorder() - middleware := CheckConcurrentJobs(cfg) - middleware(nil).ServeHTTP(rr, getRequest(tt.rp)) - assert.Equal(t, http.StatusTooManyRequests, rr.Code) - assert.NotEmpty(t, rr.Header().Get("Retry-After")) - }) + mockRepo := &models.MockRepository{} + mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + append(ignoredJobs, tt.additionalJobs...), + nil, + ) + middleware.repository = mockRepo + + rr := httptest.NewRecorder() + middleware.CheckConcurrentJobs(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Conncurrent job test route check, blank return for overrides + })).ServeHTTP(rr, getRequest(tt.rp)) + + assert.NotEmpty(s.T(), rr.Header().Get("Retry-After")) } } -func TestFailedToGetJobs(t *testing.T) { +func (s *RateLimitMiddlewareTestSuite) TestFailedToGetJobs() { cfg := &service.Config{RateLimitConfig: service.RateLimitConfig{All: true}} + middleware := NewRateLimitMiddleware(cfg, s.db) mockRepo := &models.MockRepository{} mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( nil, errors.New("FORCING SOME ERROR"), nil, ) - repository = mockRepo + middleware.repository = mockRepo rr := httptest.NewRecorder() - middleware := CheckConcurrentJobs(cfg) - middleware(nil).ServeHTTP(rr, getRequest(RequestParameters{})) - assert.Equal(t, http.StatusInternalServerError, rr.Code) - assert.Contains(t, rr.Body.String(), "code\":\"exception\"") + middleware.CheckConcurrentJobs(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Conncurrent job test route check, blank return for overrides + })).ServeHTTP(rr, getRequest(RequestParameters{})) + + assert.Equal(s.T(), http.StatusInternalServerError, rr.Code) + assert.Contains(s.T(), rr.Body.String(), "code\":\"exception\"") } func getRequest(rp RequestParameters) *http.Request { @@ -119,8 +135,10 @@ func getRequest(rp RequestParameters) *http.Request { return httptest.NewRequest("GET", "/api/v1/Patient", nil).WithContext(ctx) } -func TestHasDuplicatesFullString(t *testing.T) { +func (s *RateLimitMiddlewareTestSuite) TestHasDuplicatesFullString() { ctx := context.Background() + cfg := &service.Config{RateLimitConfig: service.RateLimitConfig{All: true}} + middleware := NewRateLimitMiddleware(cfg, s.db) ctx = logAPI.NewStructuredLoggerEntry(log.New(), ctx) otherJobs := []*models.Job{ {ID: 1, RequestURL: "https://api.abcd.123.net/api/v2/Group/runout/$export?_since=2024-02-11T00%3A00%3A00.0000-00%3A00&_type=Patient%2CCoverage%2CExplanationOfBenefit", CreatedAt: time.Now(), Status: models.JobStatusPending}, @@ -139,15 +157,12 @@ func TestHasDuplicatesFullString(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - responseBool := hasDuplicates(ctx, otherJobs, tt.rp.ResourceTypes, tt.rp.Version, tt.rp.RequestURL) - assert.Equal(t, tt.expectedValue, responseBool) - }) + responseBool := middleware.hasDuplicates(ctx, otherJobs, tt.rp.ResourceTypes, tt.rp.Version, tt.rp.RequestURL) + assert.Equal(s.T(), tt.expectedValue, responseBool) } - } -func TestShouldRateLimit(t *testing.T) { +func (s *RateLimitMiddlewareTestSuite) TestShouldRateLimit() { cmsID := "MyFavoriteACO" otherCMSID := "OtherCMSID" @@ -164,9 +179,7 @@ func TestShouldRateLimit(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actualValue := shouldRateLimit(tt.config, tt.cmsID) - assert.Equal(t, tt.expectedValue, actualValue, tt.name) - }) + actualValue := shouldRateLimit(tt.config, tt.cmsID) + assert.Equal(s.T(), tt.expectedValue, actualValue, tt.name) } } diff --git a/bcda/web/router.go b/bcda/web/router.go index 1fd70c1b0..9dad7ba58 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -42,8 +42,9 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { panic(fmt.Errorf("could not load service config file: %w", err)) } + rlm := middleware.NewRateLimitMiddleware(cfg, connection) var requestValidators = []func(http.Handler) http.Handler{ - middleware.ACOEnabled(cfg), middleware.ValidateRequestURL, middleware.ValidateRequestHeaders, middleware.CheckConcurrentJobs(cfg), + middleware.ACOEnabled(cfg), middleware.ValidateRequestURL, middleware.ValidateRequestHeaders, rlm.CheckConcurrentJobs, } nonExportRequestValidators := []func(http.Handler) http.Handler{ middleware.ACOEnabled(cfg), middleware.ValidateRequestURL, middleware.ValidateRequestHeaders, From 3256641f76052bf2991b99409212905a26d0aee3 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 18:26:40 -0400 Subject: [PATCH 16/28] Refactor db connection in worker and river --- bcdaworker/main.go | 5 +++-- bcdaworker/queueing/river.go | 13 +++++-------- bcdaworker/queueing/river_test.go | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/bcdaworker/main.go b/bcdaworker/main.go index 9e9ab382f..fc357cd17 100644 --- a/bcdaworker/main.go +++ b/bcdaworker/main.go @@ -97,8 +97,9 @@ func waitForSig() { func main() { fmt.Println("Starting bcdaworker...") - healthChecker := health.NewHealthChecker(database.Connection) - queue := queueing.StartRiver(utils.GetEnvInt("WORKER_POOL_SIZE", 4)) + db := database.GetConnection() + healthChecker := health.NewHealthChecker(db) + queue := queueing.StartRiver(db, utils.GetEnvInt("WORKER_POOL_SIZE", 4)) defer queue.StopRiver() if hInt, err := strconv.Atoi(conf.GetEnv("WORKER_HEALTH_INT_SEC")); err == nil { diff --git a/bcdaworker/queueing/river.go b/bcdaworker/queueing/river.go index ebbdf1f8a..8f2407659 100644 --- a/bcdaworker/queueing/river.go +++ b/bcdaworker/queueing/river.go @@ -48,17 +48,15 @@ type Notifier interface { } // TODO: better dependency injection (db, worker, logger). Waiting for pgxv5 upgrade -func StartRiver(numWorkers int) *queue { - - connection := database.GetConnection() +func StartRiver(db *sql.DB, numWorkers int) *queue { pool := database.GetPool() workers := river.NewWorkers() - prepareWorker, err := NewPrepareJobWorker(connection) + prepareWorker, err := NewPrepareJobWorker(db) if err != nil { panic(err) } - river.AddWorker(workers, &JobWorker{connection: connection}) + river.AddWorker(workers, &JobWorker{connection: db}) river.AddWorker(workers, NewCleanupJobWorker()) river.AddWorker(workers, prepareWorker) @@ -101,12 +99,11 @@ func StartRiver(numWorkers int) *queue { panic(err) } - mainDB := database.Connection q := &queue{ ctx: context.Background(), client: riverClient, - worker: worker.NewWorker(mainDB), - repository: postgres.NewRepository(mainDB), + worker: worker.NewWorker(db), + repository: postgres.NewRepository(db), } return q diff --git a/bcdaworker/queueing/river_test.go b/bcdaworker/queueing/river_test.go index 016c83c25..0770c98b9 100644 --- a/bcdaworker/queueing/river_test.go +++ b/bcdaworker/queueing/river_test.go @@ -77,7 +77,7 @@ func TestWork_Integration(t *testing.T) { defer postgrestest.DeleteACO(t, db, aco.UUID) - q := StartRiver(1) + q := StartRiver(db, 1) defer q.StopRiver() id, _ := safecast.ToInt(job.ID) From ae8dc9873fd91ec3d0cfd1ef3ded2ba057613459 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 30 Jul 2025 19:01:57 -0400 Subject: [PATCH 17/28] Refactor db connections in cleanup worker --- bcdaworker/cleanup/cleanup.go | 8 +++----- bcdaworker/cleanup/cleanup_test.go | 20 ++++++++++---------- bcdaworker/queueing/river.go | 2 +- bcdaworker/queueing/river_test.go | 20 +++++++++++--------- bcdaworker/queueing/worker_cleanup.go | 17 ++++++++++------- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/bcdaworker/cleanup/cleanup.go b/bcdaworker/cleanup/cleanup.go index b9cbcffad..d20966a69 100644 --- a/bcdaworker/cleanup/cleanup.go +++ b/bcdaworker/cleanup/cleanup.go @@ -2,13 +2,13 @@ package cleanup import ( "context" + "database/sql" "fmt" "os" "path/filepath" "strconv" "time" - "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/conf" @@ -16,10 +16,9 @@ import ( "github.com/sirupsen/logrus" ) -func ArchiveExpiring(maxDate time.Time) error { +func ArchiveExpiring(db *sql.DB, maxDate time.Time) error { log.API.Info("Archiving expiring job files...") - db := database.Connection r := postgres.NewRepository(db) jobs, err := r.GetJobsByUpdateTimeAndStatus(context.Background(), time.Time{}, maxDate, models.JobStatusCompleted) @@ -56,8 +55,7 @@ func ArchiveExpiring(maxDate time.Time) error { return lastJobError } -func CleanupJob(maxDate time.Time, currentStatus, newStatus models.JobStatus, rootDirsToClean ...string) error { - db := database.Connection +func CleanupJob(db *sql.DB, maxDate time.Time, currentStatus, newStatus models.JobStatus, rootDirsToClean ...string) error { r := postgres.NewRepository(db) jobs, err := r.GetJobsByUpdateTimeAndStatus(context.Background(), time.Time{}, maxDate, currentStatus) diff --git a/bcdaworker/cleanup/cleanup_test.go b/bcdaworker/cleanup/cleanup_test.go index 2b8b2886b..8999f01ea 100644 --- a/bcdaworker/cleanup/cleanup_test.go +++ b/bcdaworker/cleanup/cleanup_test.go @@ -39,7 +39,7 @@ func (s *CleanupTestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.Connection + s.db = database.GetConnection() cmsID := testUtils.RandomHexID()[0:4] s.testACO = models.ACO{Name: uuid.New(), UUID: uuid.NewRandom(), ClientID: uuid.New(), CMSID: &cmsID} @@ -120,7 +120,7 @@ func (s *CleanupTestSuite) TestArchiveExpiring() { } defer f.Close() - if err := ArchiveExpiring(t); err != nil { + if err := ArchiveExpiring(s.db, t); err != nil { s.T().Error(err) } @@ -158,7 +158,7 @@ func (s *CleanupTestSuite) TestArchiveExpiringWithoutPayloadDir() { } postgrestest.CreateJobs(s.T(), s.db, &j) - if err := ArchiveExpiring(t); err != nil { + if err := ArchiveExpiring(s.db, t); err != nil { s.T().Error(err) } @@ -197,7 +197,7 @@ func (s *CleanupTestSuite) TestArchiveExpiringWithThreshold() { } defer f.Close() - if err := ArchiveExpiring(time.Now().Add(-24 * time.Hour)); err != nil { + if err := ArchiveExpiring(s.db, time.Now().Add(-24*time.Hour)); err != nil { s.T().Error(err) } @@ -226,20 +226,20 @@ func (s *CleanupTestSuite) TestCleanArchive() { // condition: FHIR_ARCHIVE_DIR doesn't exist conf.UnsetEnv(s.T(), "FHIR_ARCHIVE_DIR") - err := CleanupJob(now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, + err := CleanupJob(s.db, now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, conf.GetEnv("FHIR_ARCHIVE_DIR"), conf.GetEnv("FHIR_STAGING_DIR")) assert.Nil(err) conf.SetEnv(s.T(), "FHIR_ARCHIVE_DIR", constants.TestArchivePath) // condition: FHIR_STAGING_DIR doesn't exist conf.UnsetEnv(s.T(), "FHIR_STAGING_DIR") - err = CleanupJob(now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, + err = CleanupJob(s.db, now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, conf.GetEnv("FHIR_ARCHIVE_DIR"), conf.GetEnv("FHIR_STAGING_DIR")) assert.Nil(err) conf.SetEnv(s.T(), "FHIR_STAGING_DIR", constants.TestStagingPath) // // condition: no jobs exist - err = CleanupJob(now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, + err = CleanupJob(s.db, now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, conf.GetEnv("FHIR_ARCHIVE_DIR"), conf.GetEnv("FHIR_STAGING_DIR")) if err != nil { s.T().Error(err) @@ -257,7 +257,7 @@ func (s *CleanupTestSuite) TestCleanArchive() { // condition: before < Threshold < after <= now // a file created before the Threshold should be deleted; one created after should not // we use last modified as a proxy for created, because these files should not be changed after creation - err = CleanupJob(now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, + err = CleanupJob(s.db, now.Add(-Threshold*time.Hour), models.JobStatusArchived, models.JobStatusExpired, conf.GetEnv("FHIR_ARCHIVE_DIR"), conf.GetEnv("FHIR_STAGING_DIR")) assert.Nil(err) assert.Nil(err) @@ -306,7 +306,7 @@ func (s *CleanupTestSuite) TestCleanupFailed() { } }() - err := CleanupJob(time.Now().Add(-threshold*time.Hour), models.JobStatusFailed, models.JobStatusFailedExpired, + err := CleanupJob(s.db, time.Now().Add(-threshold*time.Hour), models.JobStatusFailed, models.JobStatusFailedExpired, staging, payload) assert.NoError(s.T(), err) @@ -364,7 +364,7 @@ func (s *CleanupTestSuite) TestCleanupCancelled() { } }() - err := CleanupJob(modified, models.JobStatusCancelled, models.JobStatusCancelledExpired, + err := CleanupJob(s.db, modified, models.JobStatusCancelled, models.JobStatusCancelledExpired, staging, payload) assert.NoError(s.T(), err) diff --git a/bcdaworker/queueing/river.go b/bcdaworker/queueing/river.go index 8f2407659..a88bf6fc4 100644 --- a/bcdaworker/queueing/river.go +++ b/bcdaworker/queueing/river.go @@ -57,7 +57,7 @@ func StartRiver(db *sql.DB, numWorkers int) *queue { panic(err) } river.AddWorker(workers, &JobWorker{connection: db}) - river.AddWorker(workers, NewCleanupJobWorker()) + river.AddWorker(workers, NewCleanupJobWorker(db)) river.AddWorker(workers, prepareWorker) schedule, err := cron.ParseStandard("0 11,23 * * *") diff --git a/bcdaworker/queueing/river_test.go b/bcdaworker/queueing/river_test.go index 0770c98b9..7e365d7fe 100644 --- a/bcdaworker/queueing/river_test.go +++ b/bcdaworker/queueing/river_test.go @@ -2,6 +2,7 @@ package queueing import ( "context" + "database/sql" "os" "testing" "time" @@ -112,13 +113,13 @@ type MockArchiveExpiring struct { mock.Mock } -func (m *MockCleanupJob) CleanupJob(maxDate time.Time, currentStatus, newStatus models.JobStatus, rootDirsToClean ...string) error { - args := m.Called(maxDate, currentStatus, newStatus, rootDirsToClean) +func (m *MockCleanupJob) CleanupJob(db *sql.DB, maxDate time.Time, currentStatus, newStatus models.JobStatus, rootDirsToClean ...string) error { + args := m.Called(db, maxDate, currentStatus, newStatus, rootDirsToClean) return args.Error(0) } -func (m *MockArchiveExpiring) ArchiveExpiring(maxDate time.Time) error { - args := m.Called(maxDate) +func (m *MockArchiveExpiring) ArchiveExpiring(db *sql.DB, maxDate time.Time) error { + args := m.Called(db, maxDate) return args.Error(0) } @@ -146,15 +147,16 @@ func TestCleanupJobWorker_Work(t *testing.T) { conf.SetEnv(t, "FHIR_STAGING_DIR", stagingPath) conf.SetEnv(t, "FHIR_PAYLOAD_DIR", payloadPath) - mockCleanupJob.On("CleanupJob", mock.AnythingOfType("time.Time"), models.JobStatusArchived, models.JobStatusExpired, []string{archivePath, stagingPath}).Return(nil) - mockCleanupJob.On("CleanupJob", mock.AnythingOfType("time.Time"), models.JobStatusFailed, models.JobStatusFailedExpired, []string{stagingPath, payloadPath}).Return(nil) - mockCleanupJob.On("CleanupJob", mock.AnythingOfType("time.Time"), models.JobStatusCancelled, models.JobStatusCancelledExpired, []string{stagingPath, payloadPath}).Return(nil) - mockArchiveExpiring.On("ArchiveExpiring", mock.AnythingOfType("time.Time")).Return(nil) + mockCleanupJob.On("CleanupJob", mock.Anything, mock.AnythingOfType("time.Time"), models.JobStatusArchived, models.JobStatusExpired, []string{archivePath, stagingPath}).Return(nil) + mockCleanupJob.On("CleanupJob", mock.Anything, mock.AnythingOfType("time.Time"), models.JobStatusFailed, models.JobStatusFailedExpired, []string{stagingPath, payloadPath}).Return(nil) + mockCleanupJob.On("CleanupJob", mock.Anything, mock.AnythingOfType("time.Time"), models.JobStatusCancelled, models.JobStatusCancelledExpired, []string{stagingPath, payloadPath}).Return(nil) + mockArchiveExpiring.On("ArchiveExpiring", mock.Anything, mock.AnythingOfType("time.Time")).Return(nil) // Create a worker instance cleanupJobWorker := &CleanupJobWorker{ cleanupJob: mockCleanupJob.CleanupJob, archiveExpiring: mockArchiveExpiring.ArchiveExpiring, + db: database.GetConnection(), } // Create a mock river.Job @@ -211,7 +213,7 @@ func TestGetAWSParams(t *testing.T) { } func TestNewCleanupJobWorker(t *testing.T) { - worker := NewCleanupJobWorker() + worker := NewCleanupJobWorker(database.GetConnection()) assert.NotNil(t, worker) assert.NotNil(t, worker.cleanupJob) diff --git a/bcdaworker/queueing/worker_cleanup.go b/bcdaworker/queueing/worker_cleanup.go index 1c18349cf..26bc42b5a 100644 --- a/bcdaworker/queueing/worker_cleanup.go +++ b/bcdaworker/queueing/worker_cleanup.go @@ -2,6 +2,7 @@ package queueing import ( "context" + "database/sql" "fmt" "time" @@ -19,14 +20,16 @@ import ( type CleanupJobWorker struct { river.WorkerDefaults[worker_types.CleanupJobArgs] - cleanupJob func(time.Time, models.JobStatus, models.JobStatus, ...string) error - archiveExpiring func(time.Time) error + cleanupJob func(*sql.DB, time.Time, models.JobStatus, models.JobStatus, ...string) error + archiveExpiring func(*sql.DB, time.Time) error + db *sql.DB } -func NewCleanupJobWorker() *CleanupJobWorker { +func NewCleanupJobWorker(db *sql.DB) *CleanupJobWorker { return &CleanupJobWorker{ cleanupJob: cleanup.CleanupJob, archiveExpiring: cleanup.ArchiveExpiring, + db: db, } } @@ -59,7 +62,7 @@ func (w *CleanupJobWorker) Work(ctx context.Context, rjob *river.Job[worker_type } // Cleanup archived jobs: remove job directory and files from archive and update job status to Expired - if err := w.cleanupJob(cutoff, models.JobStatusArchived, models.JobStatusExpired, archiveDir, stagingDir); err != nil { + if err := w.cleanupJob(w.db, cutoff, models.JobStatusArchived, models.JobStatusExpired, archiveDir, stagingDir); err != nil { logger.Error(errors.Wrap(err, fmt.Sprintf("failed to process job: %s", constants.CleanupArchArg))) _, _, slackErr := slackClient.PostMessageContext(ctx, slackChannel, slack.MsgOptionText( @@ -73,7 +76,7 @@ func (w *CleanupJobWorker) Work(ctx context.Context, rjob *river.Job[worker_type } // Cleanup failed jobs: remove job directory and files from failed jobs and update job status to FailedExpired - if err := w.cleanupJob(cutoff, models.JobStatusFailed, models.JobStatusFailedExpired, stagingDir, payloadDir); err != nil { + if err := w.cleanupJob(w.db, cutoff, models.JobStatusFailed, models.JobStatusFailedExpired, stagingDir, payloadDir); err != nil { logger.Error(errors.Wrap(err, fmt.Sprintf("failed to process job: %s", constants.CleanupFailedArg))) _, _, slackErr := slackClient.PostMessageContext(ctx, slackChannel, slack.MsgOptionText( @@ -87,7 +90,7 @@ func (w *CleanupJobWorker) Work(ctx context.Context, rjob *river.Job[worker_type } // Cleanup cancelled jobs: remove job directory and files from cancelled jobs and update job status to CancelledExpired - if err := w.cleanupJob(cutoff, models.JobStatusCancelled, models.JobStatusCancelledExpired, stagingDir, payloadDir); err != nil { + if err := w.cleanupJob(w.db, cutoff, models.JobStatusCancelled, models.JobStatusCancelledExpired, stagingDir, payloadDir); err != nil { logger.Error(errors.Wrap(err, fmt.Sprintf("failed to process job: %s", constants.CleanupCancelledArg))) _, _, slackErr := slackClient.PostMessageContext(ctx, slackChannel, slack.MsgOptionText( @@ -101,7 +104,7 @@ func (w *CleanupJobWorker) Work(ctx context.Context, rjob *river.Job[worker_type } // Archive expiring jobs: update job statuses and move files to an inaccessible location - if err := w.archiveExpiring(cutoff); err != nil { + if err := w.archiveExpiring(w.db, cutoff); err != nil { logger.Error(errors.Wrap(err, fmt.Sprintf("failed to process job: %s", constants.ArchiveJobFiles))) _, _, slackErr := slackClient.PostMessageContext(ctx, slackChannel, slack.MsgOptionText( From 486f0c4e5a5756c6c3180906cc8919c492583772 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Thu, 31 Jul 2025 01:58:40 -0400 Subject: [PATCH 18/28] Inject provider as dependency --- bcda/api/v1/api.go | 15 ++++---- bcda/api/v1/api_test.go | 10 +++-- bcda/auth/api.go | 14 +++++-- bcda/auth/api_test.go | 82 ++++++++++++++++++---------------------- bcda/auth/provider.go | 14 ++++++- bcda/auth/router.go | 7 ++-- bcda/auth/router_test.go | 5 ++- bcda/bcdacli/cli.go | 7 +++- bcda/main.go | 3 -- bcda/web/router.go | 14 +++---- bcda/web/router_test.go | 18 +++++---- 11 files changed, 103 insertions(+), 86 deletions(-) diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index 1854f0490..1a394a70e 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -28,12 +28,13 @@ import ( ) type ApiV1 struct { - handler *api.Handler connection *sql.DB + handler *api.Handler + provider auth.Provider healthChecker health.HealthChecker } -func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV1 { +func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) *ApiV1 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -47,7 +48,7 @@ func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV1 { hc := health.NewHealthChecker(connection) h := api.NewHandler(resources, "/v1/fhir", "v1", connection, pool) - return &ApiV1{handler: h, connection: connection, healthChecker: hc} + return &ApiV1{connection: connection, handler: h, provider: provider, healthChecker: hc} } /* @@ -368,7 +369,7 @@ Responses: 200: MetadataResponse */ -func Metadata(w http.ResponseWriter, r *http.Request) { +func (a ApiV1) Metadata(w http.ResponseWriter, r *http.Request) { dt := time.Now() scheme := "http" @@ -398,7 +399,7 @@ Responses: 200: VersionResponse */ -func GetVersion(w http.ResponseWriter, r *http.Request) { +func (a ApiV1) GetVersion(w http.ResponseWriter, r *http.Request) { respMap := make(map[string]string) respMap["version"] = constants.Version respBytes, err := json.Marshal(respMap) @@ -462,10 +463,10 @@ Responses: 200: AuthResponse */ -func GetAuthInfo(w http.ResponseWriter, r *http.Request) { +func (a ApiV1) GetAuthInfo(w http.ResponseWriter, r *http.Request) { respMap := make(map[string]string) respMap["auth_provider"] = auth.GetProviderName() - version, err := auth.GetProvider().GetVersion() + version, err := a.provider.GetVersion() if err == nil { respMap["version"] = version } else { diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index ee3f0fe91..428905644 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -49,12 +49,14 @@ type APITestSuite struct { rr *httptest.ResponseRecorder connection *sql.DB pool *pgxv5Pool.Pool + provider auth.Provider apiV1 *ApiV1 } func (s *APITestSuite) SetupSuite() { s.connection = database.GetConnection() - s.apiV1 = NewApiV1(s.connection, s.pool) + s.provider = auth.NewProvider(s.connection) + s.apiV1 = NewApiV1(s.connection, s.pool, s.provider) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -430,7 +432,7 @@ func (s *APITestSuite) TestMetadata() { req := httptest.NewRequest("GET", "/api/v1/metadata", nil) req.TLS = &tls.ConnectionState{} - handler := http.HandlerFunc(Metadata) + handler := http.HandlerFunc(s.apiV1.Metadata) handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) @@ -439,7 +441,7 @@ func (s *APITestSuite) TestMetadata() { func (s *APITestSuite) TestGetVersion() { req := httptest.NewRequest("GET", "/_version", nil) - handler := http.HandlerFunc(GetVersion) + handler := http.HandlerFunc(s.apiV1.GetVersion) handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) @@ -570,7 +572,7 @@ func (s *APITestSuite) TestHealthCheck() { func (s *APITestSuite) TestAuthInfo() { req, err := http.NewRequest("GET", "/_auth", nil) assert.Nil(s.T(), err) - handler := http.HandlerFunc(GetAuthInfo) + handler := http.HandlerFunc(s.apiV1.GetAuthInfo) handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusOK, s.rr.Code) diff --git a/bcda/auth/api.go b/bcda/auth/api.go index af155e3bc..ab1d072af 100644 --- a/bcda/auth/api.go +++ b/bcda/auth/api.go @@ -34,7 +34,15 @@ import ( 500: serverError */ -func GetAuthToken(w http.ResponseWriter, r *http.Request) { +type BaseApi struct { + provider Provider +} + +func NewBaseApi(provider Provider) BaseApi { + return BaseApi{provider: provider} +} + +func (a BaseApi) GetAuthToken(w http.ResponseWriter, r *http.Request) { ctxLogger := log.API.WithFields(logrus.Fields{"transaction_id": r.Context().Value(middleware.CtxTransactionKey)}) clientId, secret, ok := r.BasicAuth() @@ -44,7 +52,7 @@ func GetAuthToken(w http.ResponseWriter, r *http.Request) { return } - tokenInfo, err := GetProvider().MakeAccessToken(Credentials{ClientID: clientId, ClientSecret: secret}, r) + tokenInfo, err := a.provider.MakeAccessToken(Credentials{ClientID: clientId, ClientSecret: secret}, r) if err != nil { switch err.(type) { case *customErrors.RequestTimeoutError: @@ -102,7 +110,7 @@ Responses: 200: welcome 401: invalidCredentials */ -func Welcome(w http.ResponseWriter, r *http.Request) { +func (a BaseApi) Welcome(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"success":"Welcome to the Beneficiary Claims Data API!"}`)) } diff --git a/bcda/auth/api_test.go b/bcda/auth/api_test.go index fb43676fc..2373259a1 100644 --- a/bcda/auth/api_test.go +++ b/bcda/auth/api_test.go @@ -18,6 +18,7 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/CMSgov/bcda-app/bcda/constants" + customErrors "github.com/CMSgov/bcda-app/bcda/errors" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/testUtils" bcdaLog "github.com/CMSgov/bcda-app/log" @@ -28,21 +29,13 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/database" - customErrors "github.com/CMSgov/bcda-app/bcda/errors" "github.com/CMSgov/bcda-app/bcda/models" ) type AuthAPITestSuite struct { suite.Suite - rr *httptest.ResponseRecorder - db *sql.DB - r models.Repository - server *httptest.Server -} - -func (s *AuthAPITestSuite) CreateRouter() http.Handler { - r := auth.NewAuthRouter() - return r + db *sql.DB + r models.Repository } func (s *AuthAPITestSuite) SetupSuite() { @@ -50,25 +43,10 @@ func (s *AuthAPITestSuite) SetupSuite() { s.r = postgres.NewRepository(s.db) } -func (s *AuthAPITestSuite) SetupTest() { - s.rr = httptest.NewRecorder() - s.server = httptest.NewServer(s.CreateRouter()) -} - func (s *AuthAPITestSuite) TestGetAuthTokenErrorSwitchCases() { const errorHappened = "Error Happened!" const errMsg = "Error Message" - req, err := http.NewRequest("POST", fmt.Sprintf("%s/auth/token", s.server.URL), nil) - if err != nil { - assert.FailNow(s.T(), err.Error()) - } - //req.Header.Add("Authorization", fmt.Sprintf("Basic %s", tt.authHeader)) - req.Header.Add("Accept", constants.JsonContentType) - req.SetBasicAuth("good", "client") - - client := s.server.Client() - tests := []struct { ScenarioName string ErrorToReturn error @@ -87,12 +65,23 @@ func (s *AuthAPITestSuite) TestGetAuthTokenErrorSwitchCases() { //setup logging hook for log message assertion testLogger := test.NewLocal(testUtils.GetLogger(bcdaLog.API)) - s.T().Run(tt.ScenarioName, func(t *testing.T) { - //setup mocks - mockP := &auth.MockProvider{} - mockP.On("MakeAccessToken", auth.Credentials{ClientID: "good", ClientSecret: "client"}, mock.Anything).Return("", tt.ErrorToReturn) - auth.SetMockProvider(s.T(), mockP) + //setup mocks + mockP := &auth.MockProvider{} + mockP.On("MakeAccessToken", auth.Credentials{ClientID: "good", ClientSecret: "client"}, mock.Anything).Return("", tt.ErrorToReturn) + router := auth.NewAuthRouter(mockP) + server := httptest.NewServer(router) + req, err := http.NewRequest("POST", fmt.Sprintf("%s/auth/token", server.URL), nil) + if err != nil { + assert.FailNow(s.T(), err.Error()) + } + //req.Header.Add("Authorization", fmt.Sprintf("Basic %s", tt.authHeader)) + req.Header.Add("Accept", constants.JsonContentType) + req.SetBasicAuth("good", "client") + + client := server.Client() + + s.T().Run(tt.ScenarioName, func(t *testing.T) { //Act resp, err := client.Do(req) if err != nil { @@ -115,15 +104,6 @@ func (s *AuthAPITestSuite) TestGetAuthTokenErrorSwitchCases() { } func (s *AuthAPITestSuite) TestGetAuthToken() { - req, err := http.NewRequest("POST", fmt.Sprintf("%s/auth/token", s.server.URL), nil) - if err != nil { - assert.FailNow(s.T(), err.Error()) - } - req.Header.Add("Accept", constants.JsonContentType) - req.SetBasicAuth("good", "client") - - client := s.server.Client() - tests := []struct { ScenarioName string ErrorToReturn error @@ -134,15 +114,24 @@ func (s *AuthAPITestSuite) TestGetAuthToken() { } for _, tt := range tests { - s.T().Run(tt.ScenarioName, func(t *testing.T) { + //setup mocks + mockP := &auth.MockProvider{} + mockP.On("MakeAccessToken", auth.Credentials{ClientID: "good", ClientSecret: "client"}, mock.Anything).Return(fmt.Sprintf(`{ "token_type": "bearer", "access_token": "goodToken", "expires_in": "%s" }`, constants.ExpiresInDefault), tt.ErrorToReturn) + // auth.SetMockProvider(s.T(), mockP) - //setup mocks - mockP := &auth.MockProvider{} - mockP.On("MakeAccessToken", auth.Credentials{ClientID: "good", ClientSecret: "client"}, mock.Anything).Return(fmt.Sprintf(`{ "token_type": "bearer", "access_token": "goodToken", "expires_in": "%s" }`, constants.ExpiresInDefault), tt.ErrorToReturn) - auth.SetMockProvider(s.T(), mockP) + router := auth.NewAuthRouter(mockP) + server := httptest.NewServer(router) + req, err := http.NewRequest("POST", fmt.Sprintf("%s/auth/token", server.URL), nil) + if err != nil { + assert.FailNow(s.T(), err.Error()) + } + req.Header.Add("Accept", constants.JsonContentType) + req.SetBasicAuth("good", "client") + + s.T().Run(tt.ScenarioName, func(t *testing.T) { //Act - resp, err := client.Do(req) + resp, err := server.Client().Do(req) if err != nil { log.Fatal(err) } @@ -188,7 +177,8 @@ func (s *AuthAPITestSuite) TestWelcome() { // Expect failure with invalid token router := chi.NewRouter() router.Use(auth.ParseToken) - router.With(auth.RequireTokenAuth).Get("/v1/", auth.Welcome) + baseApi := auth.NewBaseApi(mockP) + router.With(auth.RequireTokenAuth).Get("/v1/", baseApi.Welcome) server := httptest.NewServer(router) client := server.Client() req, err := http.NewRequest("GET", fmt.Sprintf("%s/v1/", server.URL), nil) diff --git a/bcda/auth/provider.go b/bcda/auth/provider.go index 8fb2bf6e8..ea256145c 100644 --- a/bcda/auth/provider.go +++ b/bcda/auth/provider.go @@ -2,6 +2,7 @@ package auth import ( "context" + "database/sql" "net/http" "time" @@ -18,7 +19,6 @@ const ( SSAS = "ssas" ) -var providerName = SSAS var repository models.Repository var provider Provider @@ -34,13 +34,23 @@ func init() { } func GetProviderName() string { - return providerName + return SSAS } func GetProvider() Provider { return provider } +func NewProvider(db *sql.DB) Provider { + r := postgres.NewRepository(db) + c, err := client.NewSSASClient() + if err != nil { + log.Auth.Errorf("no client for SSAS. no provider set; %s", err.Error()) + } + + return SSASPlugin{client: c, repository: r} +} + type AuthData struct { ACOID string TokenID string diff --git a/bcda/auth/router.go b/bcda/auth/router.go index dda911327..717627843 100644 --- a/bcda/auth/router.go +++ b/bcda/auth/router.go @@ -7,11 +7,12 @@ import ( "github.com/go-chi/chi/v5" ) -func NewAuthRouter(middlewares ...func(http.Handler) http.Handler) http.Handler { +func NewAuthRouter(provider Provider, middlewares ...func(http.Handler) http.Handler) http.Handler { + baseApi := NewBaseApi(provider) r := chi.NewRouter() m := monitoring.GetMonitor() r.Use(middlewares...) - r.Post(m.WrapHandler("/auth/token", GetAuthToken)) - r.With(ParseToken, RequireTokenAuth, CheckBlacklist).Get(m.WrapHandler("/auth/welcome", Welcome)) + r.Post(m.WrapHandler("/auth/token", baseApi.GetAuthToken)) + r.With(ParseToken, RequireTokenAuth, CheckBlacklist).Get(m.WrapHandler("/auth/welcome", baseApi.Welcome)) return r } diff --git a/bcda/auth/router_test.go b/bcda/auth/router_test.go index 70285690a..3c1d2c834 100644 --- a/bcda/auth/router_test.go +++ b/bcda/auth/router_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "strings" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/conf" "testing" @@ -16,12 +17,14 @@ import ( type AuthRouterTestSuite struct { suite.Suite + provider Provider authRouter http.Handler } func (s *AuthRouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") - s.authRouter = NewAuthRouter() + s.provider = NewProvider(database.GetConnection()) + s.authRouter = NewAuthRouter(s.provider) } func (s *AuthRouterTestSuite) reqAuthRoute(verb string, route string, body io.Reader) *http.Response { diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 39d651911..6c9269626 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -51,6 +51,7 @@ var ( connection *sql.DB pool *pgxv5Pool.Pool r models.Repository + provider auth.Provider ) func GetApp() *cli.App { @@ -66,6 +67,8 @@ func setUpApp() *cli.App { connection = database.GetConnection() pool = database.GetPool() r = postgres.NewRepository(connection) + provider = auth.NewProvider(connection) + log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, auth.GetProvider())) return nil } var hours, err = safecast.ToUint(utils.GetEnvInt("FILE_ARCHIVE_THRESHOLD_HR", 72)) @@ -117,7 +120,7 @@ func setUpApp() *cli.App { go func() { log.API.Fatal(srv.ListenAndServe()) }() auth := &http.Server{ - Handler: web.NewAuthRouter(), + Handler: web.NewAuthRouter(provider), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, @@ -125,7 +128,7 @@ func setUpApp() *cli.App { } api := &http.Server{ - Handler: web.NewAPIRouter(connection, pool), + Handler: web.NewAPIRouter(connection, pool, provider), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, diff --git a/bcda/main.go b/bcda/main.go index b3f5f753d..0a490646b 100644 --- a/bcda/main.go +++ b/bcda/main.go @@ -40,7 +40,6 @@ import ( "github.com/pkg/errors" - "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/bcdacli" "github.com/CMSgov/bcda-app/bcda/client" "github.com/CMSgov/bcda-app/bcda/monitoring" @@ -61,11 +60,9 @@ func init() { if isEtlMode != "true" { log.API.Info("BCDA application is running in API mode.") monitoring.GetMonitor() - log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, auth.GetProvider())) } else { log.API.Info("BCDA application is running in ETL mode.") } - } func createAPIDirs() { diff --git a/bcda/web/router.go b/bcda/web/router.go index 9dad7ba58..39d9517b0 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -29,7 +29,7 @@ var commonAuth = []func(http.Handler) http.Handler{ auth.RequireTokenAuth, auth.CheckBlacklist} -func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { +func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, auth.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) @@ -54,7 +54,7 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { r.Get("/", userGuideRedirect) r.Get(`/{:(user_guide|encryption|decryption_walkthrough).html}`, userGuideRedirect) } - apiV1 := v1.NewApiV1(connection, pool) + apiV1 := v1.NewApiV1(connection, pool, provider) r.Route("/api/v1", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) @@ -62,7 +62,7 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV1.JobsStatus)) r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV1.AttributionStatus)) - r.Get(m.WrapHandler("/metadata", v1.Metadata)) + r.Get(m.WrapHandler("/metadata", apiV1.Metadata)) }) if utils.GetEnvBool("VERSION_2_ENDPOINT_ACTIVE", true) { @@ -92,14 +92,14 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool) http.Handler { }) } - r.Get(m.WrapHandler("/_version", v1.GetVersion)) + r.Get(m.WrapHandler("/_version", apiV1.GetVersion)) r.Get(m.WrapHandler("/_health", apiV1.HealthCheck)) - r.Get(m.WrapHandler("/_auth", v1.GetAuthInfo)) + r.Get(m.WrapHandler("/_auth", apiV1.GetAuthInfo)) return r } -func NewAuthRouter() http.Handler { - return auth.NewAuthRouter(gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) +func NewAuthRouter(provider auth.Provider) http.Handler { + return auth.NewAuthRouter(provider, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) } func NewDataRouter(connection *sql.DB) http.Handler { diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index d2797cea2..30ce47ae1 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -33,6 +33,7 @@ type RouterTestSuite struct { suite.Suite apiRouter http.Handler dataRouter http.Handler + provider auth.Provider connection *sql.DB pool *pgxv5Pool.Pool } @@ -41,7 +42,8 @@ func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") conf.SetEnv(s.T(), "BB_SERVER_LOCATION", "v1-server-location") s.connection = database.GetConnection() - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.provider = auth.NewProvider(s.connection) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) s.dataRouter = NewDataRouter(s.connection) } @@ -82,7 +84,7 @@ func (s *RouterTestSuite) TestDefaultProdRoute() { s.FailNow("err in setting env var", err) } // Need a new router because the one in the test setup does not use the environment variable set in this test. - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) res := s.getAPIRoute("/v1/") assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -199,7 +201,7 @@ func (s *RouterTestSuite) TestV2EndpointsDisabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -216,7 +218,7 @@ func (s *RouterTestSuite) TestV2EndpointsEnabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -237,7 +239,7 @@ func (s *RouterTestSuite) TestV3EndpointsDisabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -254,7 +256,7 @@ func (s *RouterTestSuite) TestV3EndpointsEnabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection, s.pool) + s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -354,7 +356,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str handler http.Handler paths []string }) { - apiRouter := NewAPIRouter(s.connection, s.pool) + apiRouter := NewAPIRouter(s.connection, s.pool, s.provider) configs = []struct { handler http.Handler @@ -364,7 +366,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str constants.V2Path + constants.PatientExportPath, constants.V2Path + constants.GroupExportPath, constants.V1Path + constants.JobsFilePath}}, {s.dataRouter, []string{nDJsonDataRoute}}, - {NewAuthRouter(), []string{"/auth/welcome"}}, + {NewAuthRouter(s.provider), []string{"/auth/welcome"}}, } return configs From c896f87180203d0079acf45efc0534f6913fb18b Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Mon, 4 Aug 2025 19:07:06 -0400 Subject: [PATCH 19/28] Refactor provider-related tests --- bcda/api/v1/api_test.go | 3 +- bcda/auth/api_test.go | 4 +- bcda/auth/middleware.go | 20 ++- bcda/auth/middleware_test.go | 114 +++++++++++------- bcda/auth/provider.go | 20 --- bcda/auth/providertest.go | 16 --- bcda/auth/router.go | 3 +- bcda/auth/ssas_middleware_test.go | 4 +- bcda/auth/ssas_test.go | 11 +- bcda/bcdacli/cli.go | 31 +++-- bcda/bcdacli/cli_test.go | 59 ++++----- bcda/lambda/admin_create_aco_creds/main.go | 7 +- .../admin_create_aco_creds/main_test.go | 7 +- bcda/web/router.go | 22 ++-- bcda/web/router_test.go | 15 ++- 15 files changed, 176 insertions(+), 160 deletions(-) delete mode 100644 bcda/auth/providertest.go diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index 428905644..fe845b310 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -462,8 +462,9 @@ func (s *APITestSuite) TestJobStatusWithWrongACO() { Status: models.JobStatusPending, } postgrestest.CreateJobs(s.T(), s.connection, &j) + am := auth.NewAuthMiddleware(s.provider) - handler := auth.RequireTokenJobMatch(s.connection)(http.HandlerFunc(s.apiV1.JobStatus)) + handler := am.RequireTokenJobMatch(s.connection)(http.HandlerFunc(s.apiV1.JobStatus)) req := s.createJobStatusRequest(uuid.Parse(constants.LargeACOUUID), j.ID) handler.ServeHTTP(s.rr, req) diff --git a/bcda/auth/api_test.go b/bcda/auth/api_test.go index 2373259a1..66b8c4311 100644 --- a/bcda/auth/api_test.go +++ b/bcda/auth/api_test.go @@ -172,11 +172,11 @@ func (s *AuthAPITestSuite) TestWelcome() { mockP.On("VerifyToken", mock.Anything, goodToken).Return(token, nil) mockP.On("VerifyToken", mock.Anything, badToken).Return(nil, errors.New("bad token")) mockP.On("getAuthDataFromClaims", token.Claims).Return(ad, nil) - auth.SetMockProvider(s.T(), mockP) // Expect failure with invalid token router := chi.NewRouter() - router.Use(auth.ParseToken) + am := auth.NewAuthMiddleware(mockP) + router.Use(am.ParseToken) baseApi := auth.NewBaseApi(mockP) router.With(auth.RequireTokenAuth).Get("/v1/", baseApi.Welcome) server := httptest.NewServer(router) diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index 103ec6def..df82df0e4 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -31,13 +31,21 @@ var ( AuthDataContextKey = &contextKey{"ad"} ) +type AuthMiddleware struct { + provider Provider +} + +func NewAuthMiddleware(provider Provider) AuthMiddleware { + return AuthMiddleware{provider: provider} +} + // ParseToken puts the decoded token and AuthData value into the request context. Decoded values come from // tokens verified by our provider as correct and unexpired. Tokens may be presented in requests to // unauthenticated endpoints (mostly swagger?). We still want to extract the token data for logging purposes, // even when we don't use it for authorization. Authorization for protected endpoints occurs in RequireTokenAuth(). // Only auth code should look at the token claims; API code should rely on the values in AuthData. We use AuthData // to insulate API code from the differences among Provider tokens. -func ParseToken(next http.Handler) http.Handler { +func (m AuthMiddleware) ParseToken(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // ParseToken is called on every request, but not every request has a token @@ -60,7 +68,7 @@ func ParseToken(next http.Handler) http.Handler { tokenString := authSubmatches[1] - token, ad, err := AuthorizeAccess(r.Context(), tokenString) + token, ad, err := m.AuthorizeAccess(r.Context(), tokenString) if err != nil { handleTokenVerificationError(log.NewStructuredLoggerEntry(log.Auth, r.Context()), w, rw, err) return @@ -73,10 +81,10 @@ func ParseToken(next http.Handler) http.Handler { } // AuthorizeAccess asserts that a base64 encoded token string is valid for accessing the BCDA API. -func AuthorizeAccess(ctx context.Context, tokenString string) (*jwt.Token, AuthData, error) { +func (m AuthMiddleware) AuthorizeAccess(ctx context.Context, tokenString string) (*jwt.Token, AuthData, error) { tknEvent := event{op: "AuthorizeAccess"} operationStarted(tknEvent) - token, err := GetProvider().VerifyToken(ctx, tokenString) + token, err := m.provider.VerifyToken(ctx, tokenString) var ad AuthData @@ -92,7 +100,7 @@ func AuthorizeAccess(ctx context.Context, tokenString string) (*jwt.Token, AuthD return nil, ad, errors.New("invalid ssas claims") } - ad, err = GetProvider().getAuthDataFromClaims(claims) + ad, err = m.provider.getAuthDataFromClaims(claims) if err != nil { tknEvent.help = fmt.Sprintf("failed getting AuthData; %s", err.Error()) operationFailed(tknEvent) @@ -163,7 +171,7 @@ func CheckBlacklist(next http.Handler) http.Handler { }) } -func RequireTokenJobMatch(connection *sql.DB) func(next http.Handler) http.Handler { +func (m AuthMiddleware) RequireTokenJobMatch(connection *sql.DB) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { rw := getRespWriter(r.URL.Path) diff --git a/bcda/auth/middleware_test.go b/bcda/auth/middleware_test.go index e14677391..d93805406 100644 --- a/bcda/auth/middleware_test.go +++ b/bcda/auth/middleware_test.go @@ -40,8 +40,9 @@ var bearerStringMsg string = "Bearer %s" type MiddlewareTestSuite struct { suite.Suite - server *httptest.Server - rr *httptest.ResponseRecorder + // server *httptest.Server + rr *httptest.ResponseRecorder + // am auth.AuthMiddleware connection *sql.DB } @@ -49,9 +50,10 @@ func (s *MiddlewareTestSuite) SetupSuite() { s.connection = database.GetConnection() } -func (s *MiddlewareTestSuite) CreateRouter() http.Handler { +func (s *MiddlewareTestSuite) CreateRouter(p auth.Provider) http.Handler { + am := auth.NewAuthMiddleware(p) router := chi.NewRouter() - router.Use(auth.ParseToken) + router.Use(am.ParseToken) router.With(auth.RequireTokenAuth).Get("/v1/", func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("Test router")) if err != nil { @@ -62,21 +64,27 @@ func (s *MiddlewareTestSuite) CreateRouter() http.Handler { return router } +func (s *MiddlewareTestSuite) CreateServer(p auth.Provider) *httptest.Server { + return httptest.NewServer(s.CreateRouter(p)) +} + func (s *MiddlewareTestSuite) SetupTest() { - s.server = httptest.NewServer(s.CreateRouter()) + // s.server = httptest.NewServer(s.CreateRouter(auth.NewProvider(s.connection))) s.rr = httptest.NewRecorder() } -func (s *MiddlewareTestSuite) TearDownTest() { - s.server.Close() -} +// func (s *MiddlewareTestSuite) TearDownTest() { +// s.server.Close() +// } // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestReturn400WhenInvalidTokenAuthWithInvalidSignature() { - client := s.server.Client() + server := s.CreateServer(auth.NewProvider(s.connection)) + defer server.Close() + client := server.Client() badT := "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCIsImtpZCI6ImlUcVhYSTB6YkFuSkNLRGFvYmZoa00xZi02ck1TcFRmeVpNUnBfMnRLSTgifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.cJOP_w-hBqnyTsBm3T6lOE5WpcHaAkLuQGAs1QO-lg2eWs8yyGW8p9WagGjxgvx7h9X72H7pXmXqej3GdlVbFmhuzj45A9SXDOAHZ7bJXwM1VidcPi7ZcrsMSCtP1hiN" - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } @@ -92,7 +100,9 @@ func (s *MiddlewareTestSuite) TestReturn400WhenInvalidTokenAuthWithInvalidSignat // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestReturn401WhenExpiredToken() { - client := s.server.Client() + server := s.CreateServer(auth.NewProvider(s.connection)) + defer server.Close() + client := server.Client() expiredToken := jwt.NewWithClaims(jwt.SigningMethodRS512, &auth.CommonClaims{ StandardClaims: jwt.StandardClaims{ Issuer: "ssas", @@ -105,7 +115,7 @@ func (s *MiddlewareTestSuite) TestReturn401WhenExpiredToken() { pk, _ := rsa.GenerateKey(rand.Reader, 2048) tokenString, _ := expiredToken.SignedString(pk) - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } @@ -147,12 +157,13 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse200WhenValidBearer mockP := &auth.MockProvider{} mockP.On("VerifyToken", mock.Anything, bearerString).Return(token, nil) mockP.On("getAuthDataFromClaims", token.Claims).Return(authData, nil) - auth.SetMockProvider(s.T(), mockP) - client := s.server.Client() + server := s.CreateServer(mockP) + defer server.Close() + client := server.Client() // Valid token should return a 200 response - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } @@ -198,14 +209,6 @@ func (s *MiddlewareTestSuite) TestTokenVerificationErrorHandling() { const errorHappened = "Error Happened!" const errMsg = "Error Message" - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) - if err != nil { - log.Fatal(err) - } - req.Header.Add("Authorization", fmt.Sprintf(bearerStringMsg, bearerString)) - - client := s.server.Client() - tests := []struct { ScenarioName string ErrorToReturn error @@ -228,7 +231,15 @@ func (s *MiddlewareTestSuite) TestTokenVerificationErrorHandling() { //setup mocks mockP := &auth.MockProvider{} mockP.On("VerifyToken", mock.Anything, bearerString).Return(nil, tt.ErrorToReturn) - auth.SetMockProvider(s.T(), mockP) + server := s.CreateServer(mockP) + defer server.Close() + client := server.Client() + + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) + if err != nil { + log.Fatal(err) + } + req.Header.Add("Authorization", fmt.Sprintf(bearerStringMsg, bearerString)) //Act resp, err := client.Do(req) @@ -257,19 +268,18 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse403WhenEntityNotFo mockP := &auth.MockProvider{} mockP.On("VerifyToken", mock.Anything, bearerString).Return(token, nil) mockP.On("getAuthDataFromClaims", token.Claims).Return(authData, entityNotFoundError) - auth.SetMockProvider(s.T(), mockP) + + server := s.CreateServer(mockP) + client := server.Client() + s.rr = httptest.NewRecorder() //fill http request - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } - req.Header.Add("Authorization", fmt.Sprintf(bearerStringMsg, bearerString)) - client := s.server.Client() - s.rr = httptest.NewRecorder() - //Act resp, err := client.Do(req) if err != nil { @@ -284,7 +294,6 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse403WhenEntityNotFo } func (s *MiddlewareTestSuite) TestAuthMiddlewareReturn401WhenNonEntityNotFoundError() { - bearerString, authData, token, _ := setupDataForAuthMiddlewareTest() //custom error expected @@ -294,18 +303,19 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturn401WhenNonEntityNotFoundEr mockP := &auth.MockProvider{} mockP.On("VerifyToken", mock.Anything, bearerString).Return(token, nil) mockP.On("getAuthDataFromClaims", token.Claims).Return(authData, thrownErr) - auth.SetMockProvider(s.T(), mockP) + + server := s.CreateServer(mockP) + defer server.Close() + client := server.Client() //fill http request - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } req.Header.Add("Authorization", fmt.Sprintf(bearerStringMsg, bearerString)) - client := s.server.Client() - //Act resp, err := client.Do(req) if err != nil { @@ -320,9 +330,11 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturn401WhenNonEntityNotFoundEr // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse401WhenNoBearerTokenSupplied() { - client := s.server.Client() + server := s.CreateServer(auth.NewProvider(s.connection)) + defer server.Close() + client := server.Client() - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } @@ -363,7 +375,11 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa {"Mismatching ACOID", jobID, uuid.New(), http.StatusUnauthorized}, } - handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) + p := auth.NewProvider(s.connection) + am := auth.NewAuthMiddleware(p) + handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + server := s.CreateServer(p) + defer server.Close() for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { @@ -372,7 +388,7 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa rctx := chi.NewRouteContext() rctx.URLParams.Add("jobID", tt.jobID) - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) assert.NoError(t, err) ad := auth.AuthData{ @@ -402,7 +418,13 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccoun } jobID := strconv.Itoa(id) - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + p := auth.NewProvider(s.connection) + am := auth.NewAuthMiddleware(p) + handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + server := s.CreateServer(p) + defer server.Close() + + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } @@ -410,8 +432,6 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccoun rctx := chi.NewRouteContext() rctx.URLParams.Add("jobID", jobID) - handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) - ad := auth.AuthData{ ACOID: j.ACOID.String(), TokenID: uuid.New(), @@ -438,16 +458,20 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenNoAuthDataPro } jobID := strconv.Itoa(id) + p := auth.NewProvider(s.connection) + am := auth.NewAuthMiddleware(p) + handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + server := s.CreateServer(p) + defer server.Close() + rctx := chi.NewRouteContext() rctx.URLParams.Add("jobID", jobID) - req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, s.server.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf(constants.ServerPath, server.URL), nil) if err != nil { log.Fatal(err) } - handler := auth.RequireTokenJobMatch(s.connection)(mockHandler) - handler.ServeHTTP(s.rr, req) assert.Equal(s.T(), http.StatusUnauthorized, s.rr.Code) } diff --git a/bcda/auth/provider.go b/bcda/auth/provider.go index ea256145c..6f1ce7aff 100644 --- a/bcda/auth/provider.go +++ b/bcda/auth/provider.go @@ -9,8 +9,6 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/CMSgov/bcda-app/bcda/auth/client" - "github.com/CMSgov/bcda-app/bcda/database" - "github.com/CMSgov/bcda-app/bcda/models" "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/log" ) @@ -19,28 +17,10 @@ const ( SSAS = "ssas" ) -var repository models.Repository -var provider Provider - -func init() { - repository = postgres.NewRepository(database.Connection) - - c, err := client.NewSSASClient() - if err != nil { - log.Auth.Errorf("no client for SSAS. no provider set; %s", err.Error()) - } - provider = SSASPlugin{client: c, repository: repository} - -} - func GetProviderName() string { return SSAS } -func GetProvider() Provider { - return provider -} - func NewProvider(db *sql.DB) Provider { r := postgres.NewRepository(db) c, err := client.NewSSASClient() diff --git a/bcda/auth/providertest.go b/bcda/auth/providertest.go deleted file mode 100644 index 11444ba81..000000000 --- a/bcda/auth/providertest.go +++ /dev/null @@ -1,16 +0,0 @@ -package auth - -import ( - "testing" -) - -// SetMockProvider sets the current provider to the one that's supplied in this function. -// It leverages the Cleanup() func to ensure the original provider is restored at the end of the test. -func SetMockProvider(t *testing.T, other *MockProvider) { - // Ensure that we restore the original provider when the test completes - originalProvider := provider - t.Cleanup(func() { - provider = originalProvider - }) - provider = other -} diff --git a/bcda/auth/router.go b/bcda/auth/router.go index 717627843..255cf474a 100644 --- a/bcda/auth/router.go +++ b/bcda/auth/router.go @@ -11,8 +11,9 @@ func NewAuthRouter(provider Provider, middlewares ...func(http.Handler) http.Han baseApi := NewBaseApi(provider) r := chi.NewRouter() m := monitoring.GetMonitor() + am := NewAuthMiddleware(provider) r.Use(middlewares...) r.Post(m.WrapHandler("/auth/token", baseApi.GetAuthToken)) - r.With(ParseToken, RequireTokenAuth, CheckBlacklist).Get(m.WrapHandler("/auth/welcome", baseApi.Welcome)) + r.With(am.ParseToken, RequireTokenAuth, CheckBlacklist).Get(m.WrapHandler("/auth/welcome", baseApi.Welcome)) return r } diff --git a/bcda/auth/ssas_middleware_test.go b/bcda/auth/ssas_middleware_test.go index 35a5af2a1..a0f5c4861 100644 --- a/bcda/auth/ssas_middleware_test.go +++ b/bcda/auth/ssas_middleware_test.go @@ -17,6 +17,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" "github.com/CMSgov/bcda-app/bcda/constants" + "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/conf" ) @@ -37,7 +38,8 @@ type SSASMiddlewareTestSuite struct { func (s *SSASMiddlewareTestSuite) createRouter() http.Handler { router := chi.NewRouter() - router.Use(auth.ParseToken) + am := auth.NewAuthMiddleware(auth.NewProvider(database.GetConnection())) + router.Use(am.ParseToken) router.With(auth.RequireTokenAuth).Get("/v1/", func(w http.ResponseWriter, r *http.Request) { ad := r.Context().Value(auth.AuthDataContextKey).(auth.AuthData) render.JSON(w, r, ad) diff --git a/bcda/auth/ssas_test.go b/bcda/auth/ssas_test.go index 4c0b04418..650c24ed9 100644 --- a/bcda/auth/ssas_test.go +++ b/bcda/auth/ssas_test.go @@ -287,6 +287,7 @@ func (s *SSASPluginTestSuite) TestRevokeAccessToken() { } func (s *SSASPluginTestSuite) TestAuthorizeAccessErrIsNilWhenHappyPath() { + am := NewAuthMiddleware(NewProvider(s.db)) _, tokenString, _, err := MockSSASToken() require.NotNil(s.T(), tokenString, sSasTErrorMsg, err) require.Nil(s.T(), err, unexpectedErrorMsg, err) @@ -295,7 +296,7 @@ func (s *SSASPluginTestSuite) TestAuthorizeAccessErrIsNilWhenHappyPath() { c, err := client.NewSSASClient() require.NotNil(s.T(), c, sSasClientErrorMsg, err) s.p = SSASPlugin{client: c, repository: s.r} - _, _, err = AuthorizeAccess(context.Background(), tokenString) + _, _, err = am.AuthorizeAccess(context.Background(), tokenString) require.Nil(s.T(), err) } @@ -309,8 +310,10 @@ func (s *SSASPluginTestSuite) TestAuthorizeAccessErrISReturnedWhenVerifyTokenChe require.NotNil(s.T(), c, sSasClientErrorMsg, err) s.p = SSASPlugin{client: c, repository: s.r} + am := NewAuthMiddleware(NewProvider(s.db)) + invalidTokenString := "" - _, _, err = AuthorizeAccess(context.Background(), invalidTokenString) + _, _, err = am.AuthorizeAccess(context.Background(), invalidTokenString) assert.EqualError(s.T(), err, "Requestor Data Error encountered - unable to parse provided tokenString to jwt.token. Err: token contains an invalid number of segments") } @@ -377,11 +380,13 @@ func (s *SSASPluginTestSuite) TestAuthorizeAccessErrIsReturnedWhenGetAuthDataFro MockSSASServer(ts) + am := NewAuthMiddleware(NewProvider(s.db)) + c, err := client.NewSSASClient() require.NotNil(s.T(), c, sSasClientErrorMsg, err) s.p = SSASPlugin{client: c, repository: s.r} - _, _, err = AuthorizeAccess(context.Background(), ts) + _, _, err = am.AuthorizeAccess(context.Background(), ts) assert.EqualError(s.T(), err, "can't decode data claim ac; invalid character 'a' looking for beginning of value") } diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 6c9269626..94beceb9d 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -68,7 +68,7 @@ func setUpApp() *cli.App { pool = database.GetPool() r = postgres.NewRepository(connection) provider = auth.NewProvider(connection) - log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, auth.GetProvider())) + log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, provider)) return nil } var hours, err = safecast.ToUint(utils.GetEnvInt("FILE_ARCHIVE_THRESHOLD_HR", 72)) @@ -136,7 +136,7 @@ func setUpApp() *cli.App { } fileserver := &http.Server{ - Handler: web.NewDataRouter(connection), + Handler: web.NewDataRouter(connection, provider), ReadTimeout: time.Duration(utils.GetEnvInt("FILESERVER_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("FILESERVER_WRITE_TIMEOUT", 360)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("FILESERVER_IDLE_TIMEOUT", 120)) * time.Second, @@ -272,17 +272,10 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - aco, err := r.GetACOByCMSID(context.Background(), acoCMSID) + msg, err := resetClientCredentials(acoCMSID) if err != nil { return err } - - // Generate new credentials - creds, err := auth.GetProvider().ResetSecret(aco.ClientID) - if err != nil { - return err - } - msg := fmt.Sprintf("%s\n%s\n%s", creds.ClientName, creds.ClientID, creds.ClientSecret) fmt.Fprintf(app.Writer, "%s\n", msg) return nil }, @@ -588,7 +581,7 @@ func createACO(name, cmsID string) (string, error) { func generateClientCredentials(acoCMSID string, ips []string) (string, error) { // The public key is optional for SSAS, and not used by the ACO API - creds, err := auth.GetProvider().FindAndCreateACOCredentials(acoCMSID, ips) + creds, err := provider.FindAndCreateACOCredentials(acoCMSID, ips) if err != nil { return "", errors.Wrapf(err, "could not register system for %s", acoCMSID) } @@ -596,12 +589,26 @@ func generateClientCredentials(acoCMSID string, ips []string) (string, error) { return creds, nil } +func resetClientCredentials(acoCMSID string) (string, error) { + aco, err := r.GetACOByCMSID(context.Background(), acoCMSID) + if err != nil { + return "", err + } + + // Generate new credentials + creds, err := provider.ResetSecret(aco.ClientID) + if err != nil { + return "", err + } + return fmt.Sprintf("%s\n%s\n%s", creds.ClientName, creds.ClientID, creds.ClientSecret), nil +} + func revokeAccessToken(accessToken string) error { if accessToken == "" { return errors.New("Access token (--access-token) must be provided") } - return auth.GetProvider().RevokeAccessToken(accessToken) + return provider.RevokeAccessToken(accessToken) } func setDenylistState(cmsID string, td *models.Termination) error { diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index 2f315382c..c6d83dc05 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -22,6 +22,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/constants" "github.com/CMSgov/bcda-app/bcda/database" "github.com/CMSgov/bcda-app/bcda/models" + "github.com/CMSgov/bcda-app/bcda/models/postgres" "github.com/CMSgov/bcda-app/bcda/models/postgres/postgrestest" "github.com/CMSgov/bcda-app/bcda/testUtils" "github.com/CMSgov/bcda-app/bcda/utils" @@ -69,6 +70,8 @@ func (s *CLITestSuite) SetupSuite() { testUtils.SetPendingDeletionDir(&s.Suite, dir) s.db = database.GetConnection() + connection = s.db + r = postgres.NewRepository(s.db) cmsID := testUtils.RandomHexID()[0:4] s.testACO = models.ACO{Name: uuid.New(), UUID: uuid.NewRandom(), ClientID: uuid.New(), CMSID: &cmsID} @@ -89,6 +92,10 @@ func (s *CLITestSuite) TearDownSuite() { postgrestest.DeleteACO(s.T(), s.db, s.testACO.UUID) } +func (s *CLITestSuite) SetProvider(p auth.Provider) { + provider = p +} + func TestCLITestSuite(t *testing.T) { suite.Run(t, new(CLITestSuite)) } @@ -143,15 +150,18 @@ func (s *CLITestSuite) TestGenerateClientCredentials() { } m := &auth.MockProvider{} m.On("FindAndCreateACOCredentials", *s.testACO.CMSID, ips).Return("mock\ncreds\ntest", nil) - auth.SetMockProvider(t, m) + + oldProvider := provider + s.SetProvider(m) + defer s.SetProvider(oldProvider) buf := new(bytes.Buffer) s.testApp.Writer = buf - args := []string{"bcda", constants.GenClientCred, constants.CMSIDArg, *s.testACO.CMSID, "--ips", strings.Join(ips, ",")} - err := s.testApp.Run(args) + msg, err := generateClientCredentials(*s.testACO.CMSID, ips) assert.Nil(t, err) - assert.Regexp(t, regexp.MustCompile(".+\n.+\n.+"), buf.String()) + assert.Regexp(t, regexp.MustCompile(".+\n.+\n.+"), msg) + assert.Equal(t, "mock\ncreds\ntest", msg) m.AssertExpectations(t) }) } @@ -188,27 +198,19 @@ func (s *CLITestSuite) TestResetSecretCLI() { auth.Credentials{ClientName: *s.testACO.CMSID, ClientID: s.testACO.ClientID, ClientSecret: uuid.New()}, nil) - auth.SetMockProvider(s.T(), mock) + oldProvider := provider + s.SetProvider(mock) + defer s.SetProvider(oldProvider) - // execute positive scenarios via CLI - args := []string{"bcda", constants.ResetClientCred, constants.CMSIDArg, *s.testACO.CMSID} - err := s.testApp.Run(args) + // execute positive scenario + msg, err := resetClientCredentials(*s.testACO.CMSID) assert.Nil(err) - assert.Regexp(outputPattern, buf.String()) - buf.Reset() + assert.Regexp(outputPattern, msg) - // Execute CLI with invalid ACO CMS ID - args = []string{"bcda", constants.ResetClientCred, constants.CMSIDArg, "BLAH"} - err = s.testApp.Run(args) + // Execute with invalid ACO CMS ID + msg, err = resetClientCredentials("BLAH") assert.Equal("no ACO record found for BLAH", err.Error()) - assert.Equal(0, buf.Len()) - buf.Reset() - - // Execute CLI with invalid inputs - args = []string{"bcda", constants.ResetClientCred, "--abcd", "efg"} - err = s.testApp.Run(args) - assert.Equal("flag provided but not defined: -abcd", err.Error()) - assert.Contains(buf.String(), "Incorrect Usage: flag provided but not defined") + assert.Equal(0, len(msg)) mock.AssertExpectations(s.T()) } @@ -222,17 +224,16 @@ func (s *CLITestSuite) TestRevokeToken() { accessToken := uuid.New() mock := &auth.MockProvider{} mock.On("RevokeAccessToken", accessToken).Return(nil) - auth.SetMockProvider(s.T(), mock) - assert.NoError(s.testApp.Run([]string{"bcda", "revoke-token", "--access-token", accessToken})) - buf.Reset() + oldProvider := provider + s.SetProvider(mock) + defer s.SetProvider(oldProvider) + + err := revokeAccessToken(accessToken) + assert.Nil(err) // Negative case - attempt to revoke a token passing in a blank token string - args := []string{"bcda", "revoke-token", "--access-token", ""} - err := s.testApp.Run(args) + err = revokeAccessToken("") assert.Equal("Access token (--access-token) must be provided", err.Error()) - assert.Equal(0, buf.Len()) - buf.Reset() - mock.AssertExpectations(s.T()) } diff --git a/bcda/lambda/admin_create_aco_creds/main.go b/bcda/lambda/admin_create_aco_creds/main.go index 67322f40a..edc79339c 100644 --- a/bcda/lambda/admin_create_aco_creds/main.go +++ b/bcda/lambda/admin_create_aco_creds/main.go @@ -12,6 +12,7 @@ import ( "github.com/CMSgov/bcda-app/bcda/auth" bcdaaws "github.com/CMSgov/bcda-app/bcda/aws" + "github.com/CMSgov/bcda-app/bcda/database" log "github.com/sirupsen/logrus" @@ -74,10 +75,11 @@ func handler(ctx context.Context, event json.RawMessage) (string, error) { return "", err } + provider := auth.NewProvider(database.GetConnection()) s3Service := s3.New(session) slackClient := slack.New(params.slackToken) - s3Path, err := handleCreateACOCreds(ctx, data, s3Service, slackClient, params.credsBucket) + s3Path, err := handleCreateACOCreds(ctx, data, provider, s3Service, slackClient, params.credsBucket) if err != nil { log.Errorf("Failed to handle Create ACO creds: %+v", err) return "", err @@ -91,6 +93,7 @@ func handler(ctx context.Context, event json.RawMessage) (string, error) { func handleCreateACOCreds( ctx context.Context, data payload, + provider auth.Provider, s3Service s3iface.S3API, notifier Notifier, credsBucket string, @@ -102,7 +105,7 @@ func handleCreateACOCreds( log.Errorf("Error sending notifier start message: %+v", err) } - creds, err := auth.GetProvider().FindAndCreateACOCredentials(data.ACOID, data.IPs) + creds, err := provider.FindAndCreateACOCredentials(data.ACOID, data.IPs) if err != nil { log.Errorf("Error creating ACO creds: %+v", err) diff --git a/bcda/lambda/admin_create_aco_creds/main_test.go b/bcda/lambda/admin_create_aco_creds/main_test.go index afff2e7ee..3c680e7fb 100644 --- a/bcda/lambda/admin_create_aco_creds/main_test.go +++ b/bcda/lambda/admin_create_aco_creds/main_test.go @@ -23,11 +23,10 @@ func TestHandleCreateACOCreds(t *testing.T) { data := payload{ACOID: "TEST1234", IPs: []string{"1.2.3.4", "1.2.3.5"}} - mock := &auth.MockProvider{} - mock.On("FindAndCreateACOCredentials", data.ACOID, data.IPs).Return("creds\nstring", nil) - auth.SetMockProvider(t, mock) + mockProvider := &auth.MockProvider{} + mockProvider.On("FindAndCreateACOCredentials", data.ACOID, data.IPs).Return("creds\nstring", nil) - s3Path, err := handleCreateACOCreds(ctx, data, &mockS3{}, &mockNotifier{}, "test-bucket") + s3Path, err := handleCreateACOCreds(ctx, data, mockProvider, &mockS3{}, &mockNotifier{}, "test-bucket") assert.Nil(t, err) assert.Equal(t, s3Path, "{\n\n}") } diff --git a/bcda/web/router.go b/bcda/web/router.go index 39d9517b0..404ce0b71 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -32,7 +32,8 @@ var commonAuth = []func(http.Handler) http.Handler{ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() - r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, auth.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) + am := auth.NewAuthMiddleware(provider) + r.Use(gcmw.RequestID, appMiddleware.NewTransactionID, am.ParseToken, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) // Serve up the swagger ui folder FileServer(r, "/api/v1/swagger", http.Dir("./swaggerui/v1")) @@ -58,9 +59,9 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provid r.Route("/api/v1", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV1.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV1.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV1.Metadata)) }) @@ -71,9 +72,9 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provid r.Route("/api/v2", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV2.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV2.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV2.Metadata)) }) @@ -84,9 +85,9 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provid r.Route("/api/demo", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV3.JobsStatus)) - r.With(append(commonAuth, auth.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV3.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV3.Metadata)) }) @@ -102,16 +103,17 @@ func NewAuthRouter(provider auth.Provider) http.Handler { return auth.NewAuthRouter(provider, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) } -func NewDataRouter(connection *sql.DB) http.Handler { +func NewDataRouter(connection *sql.DB, provider auth.Provider) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() + am := auth.NewAuthMiddleware(provider) resourceTypeLogger := &logging.ResourceTypeLogger{ Repository: postgres.NewRepository(connection), } - r.Use(auth.ParseToken, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) + r.Use(am.ParseToken, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) r.With(append( commonAuth, - auth.RequireTokenJobMatch(connection), + am.RequireTokenJobMatch(connection), resourceTypeLogger.LogJobResourceType, )...).Get(m.WrapHandler("/data/{jobID}/{fileName}", v1.ServeData)) return r diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index 30ce47ae1..27f9da7b2 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -44,7 +44,7 @@ func (s *RouterTestSuite) SetupTest() { s.connection = database.GetConnection() s.provider = auth.NewProvider(s.connection) s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) - s.dataRouter = NewDataRouter(s.connection) + s.dataRouter = NewDataRouter(s.connection, s.provider) } func (s *RouterTestSuite) getAPIRoute(route string) *http.Response { @@ -352,11 +352,11 @@ func createExpectedAuthData(cmsID string, aco models.ACO) auth.AuthData { } } -func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []struct { +func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite, p auth.Provider) (configs []struct { handler http.Handler paths []string }) { - apiRouter := NewAPIRouter(s.connection, s.pool, s.provider) + apiRouter := NewAPIRouter(s.connection, s.pool, p) configs = []struct { handler http.Handler @@ -365,8 +365,8 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str {apiRouter, []string{"/api/v1/Patient/$export", "/api/v1/Group/all/$export", constants.V2Path + constants.PatientExportPath, constants.V2Path + constants.GroupExportPath, constants.V1Path + constants.JobsFilePath}}, - {s.dataRouter, []string{nDJsonDataRoute}}, - {NewAuthRouter(s.provider), []string{"/auth/welcome"}}, + {NewDataRouter(s.connection, p), []string{nDJsonDataRoute}}, + {NewAuthRouter(p), []string{"/auth/welcome"}}, } return configs @@ -375,7 +375,6 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite) (configs []str func setExpectedMockCalls(s *RouterTestSuite, mockP *auth.MockProvider, token *jwt.Token, aco models.ACO, bearerString string, cmsID string) { mockP.On("VerifyToken", mock.Anything, bearerString).Return(token, nil) mockP.On("getAuthDataFromClaims", token.Claims).Return(createExpectedAuthData(cmsID, aco), nil) - auth.SetMockProvider(s.T(), mockP) } // integration test, requires connection to postgres db @@ -408,7 +407,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) - configs := createConfigsForACOBlacklistingScenarios(s) + configs := createConfigsForACOBlacklistingScenarios(s, mock) for _, config := range configs { for _, path := range config.paths { @@ -453,7 +452,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturnNOT403WhenACONOTBlacklisted() postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) - configs := createConfigsForACOBlacklistingScenarios(s) + configs := createConfigsForACOBlacklistingScenarios(s, mock) for _, config := range configs { for _, path := range config.paths { From 29ba3cc5ce49edec5384cb5ec2d55b91d2a08145 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Mon, 4 Aug 2025 22:13:14 -0400 Subject: [PATCH 20/28] Remove database connection globals --- bcda/database/connection.go | 10 ---------- bcda/service/service_test.go | 4 ++-- bcda/suppression/suppression_s3_test.go | 2 +- bcda/suppression/suppression_test.go | 4 ++-- bcdaworker/queueing/enqueue_test.go | 2 +- bcdaworker/queueing/worker_prepare_test.go | 19 +++++++++++-------- 6 files changed, 17 insertions(+), 24 deletions(-) diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 989fa0fdf..408a11098 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -17,16 +17,6 @@ import ( "github.com/sirupsen/logrus" ) -var ( - Connection *sql.DB - Pgxv5Pool *pgxv5Pool.Pool -) - -func init() { - Connection = GetConnection() - Pgxv5Pool = GetPool() -} - func GetConnection() *sql.DB { cfg, err := LoadConfig() if err != nil { diff --git a/bcda/service/service_test.go b/bcda/service/service_test.go index a300c7ccf..35274318c 100644 --- a/bcda/service/service_test.go +++ b/bcda/service/service_test.go @@ -492,7 +492,7 @@ func (s *ServiceTestSuite) TestGetNewAndExistingBeneficiaries_Integration() { // - Diff between CCLF File 1 and CCLF File 2 // - No diff - consider all beneficiaries as pre-existing func (s *ServiceTestSuite) TestGetNewAndExistingBeneficiaries_RecentSinceParameter_Integration() { - db := database.Connection + db := database.GetConnection() acoID := "A0005" // Test Setup @@ -1657,7 +1657,7 @@ func (s *ServiceTestSuiteWithDatabase) TestGetBenesByID_Integration() { } func (s *ServiceTestSuiteWithDatabase) TestGetNewAndExistingBeneficiaries_RecentSinceParameterDatabase_Integration() { - db := database.Connection + db := database.GetConnection() acoID := "A0005" // Test Setup diff --git a/bcda/suppression/suppression_s3_test.go b/bcda/suppression/suppression_s3_test.go index 9fcba0386..60391d9b0 100644 --- a/bcda/suppression/suppression_s3_test.go +++ b/bcda/suppression/suppression_s3_test.go @@ -292,7 +292,7 @@ func (s *SuppressionS3TestSuite) TestCleanupSuppression() { func (s *SuppressionS3TestSuite) TestImportSuppressionDirectoryTable() { assert := assert.New(s.T()) importer, _ := s.createImporter() - db := database.Connection + db := database.GetConnection() importer.Saver = &BCDASaver{ Repo: postgres.NewRepository(db), diff --git a/bcda/suppression/suppression_test.go b/bcda/suppression/suppression_test.go index d71ff8669..69ab3b4b0 100644 --- a/bcda/suppression/suppression_test.go +++ b/bcda/suppression/suppression_test.go @@ -283,7 +283,7 @@ func (s *SuppressionTestSuite) TestLoadOptOutFiles_TimeChange() { assert := assert.New(s.T()) importer, _ := s.createImporter() importer.Saver = &BCDASaver{ - Repo: postgres.NewRepository(database.Connection), + Repo: postgres.NewRepository(database.GetConnection()), } folderPath := filepath.Join(s.basePath, "suppressionfile_BadFileNames/") @@ -447,7 +447,7 @@ func (s *SuppressionTestSuite) TestCleanupSuppression_RenameFileError() { func (s *SuppressionTestSuite) TestImportSuppressionDirectoryTable() { assert := assert.New(s.T()) importer, _ := s.createImporter() - db := database.Connection + db := database.GetConnection() importer.Saver = &BCDASaver{ Repo: postgres.NewRepository(db), diff --git a/bcdaworker/queueing/enqueue_test.go b/bcdaworker/queueing/enqueue_test.go index aba96de26..cf0146d40 100644 --- a/bcdaworker/queueing/enqueue_test.go +++ b/bcdaworker/queueing/enqueue_test.go @@ -47,7 +47,7 @@ func TestRiverEnqueuer_Integration(t *testing.T) { assert.NoError(t, enqueuer.AddJob(ctx, jobArgs, 3)) // Use river test helper to assert job was inserted - checkJob := rivertest.RequireInserted(ctx, t, riverpgxv5.New(database.Pgxv5Pool), jobArgs, nil) + checkJob := rivertest.RequireInserted(ctx, t, riverpgxv5.New(pool), jobArgs, nil) assert.NotNil(t, checkJob) // Also Verify that we've inserted the river job as expected via DB queries diff --git a/bcdaworker/queueing/worker_prepare_test.go b/bcdaworker/queueing/worker_prepare_test.go index f3dff5e76..d99d9cd9b 100644 --- a/bcdaworker/queueing/worker_prepare_test.go +++ b/bcdaworker/queueing/worker_prepare_test.go @@ -19,6 +19,7 @@ import ( "github.com/CMSgov/bcda-app/log" cm "github.com/CMSgov/bcda-app/middleware" "github.com/go-testfixtures/testfixtures/v3" + pgxv5Pool "github.com/jackc/pgx/v5/pgxpool" "github.com/pborman/uuid" "github.com/riverqueue/river" "github.com/sirupsen/logrus" @@ -35,9 +36,10 @@ import ( type PrepareWorkerIntegrationTestSuite struct { suite.Suite - r models.Repository - db *sql.DB - ctx context.Context + r models.Repository + db *sql.DB + pool *pgxv5Pool.Pool + ctx context.Context } func TestCleanupTestSuite(t *testing.T) { @@ -46,6 +48,7 @@ func TestCleanupTestSuite(t *testing.T) { func (s *PrepareWorkerIntegrationTestSuite) SetupTest() { s.db, _ = databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) + s.pool = database.GetPool() tf, err := testfixtures.New( testfixtures.Database(s.db), testfixtures.Dialect("postgres"), @@ -215,7 +218,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestPrepareWorkerWork() { }, } - driver := riverpgxv5.New(database.Pgxv5Pool) + driver := riverpgxv5.New(s.pool) _, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`) if err != nil { s.T().Log(err) @@ -227,7 +230,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestPrepareWorkerWork() { r: r, } w := rivertest.NewWorker(s.T(), driver, &river.Config{}, worker) - d := database.Pgxv5Pool + d := s.pool tx, err := d.Begin(s.ctx) if err != nil { s.T().Log(err) @@ -275,13 +278,13 @@ func (s *PrepareWorkerIntegrationTestSuite) TestPrepareWorkerWork_Integration() } worker := &PrepareJobWorker{svc: svc, v1Client: c, v2Client: c, r: s.r} - driver := riverpgxv5.New(database.Pgxv5Pool) + driver := riverpgxv5.New(s.pool) _, err = driver.GetExecutor().Exec(context.Background(), `delete from river_job`) if err != nil { s.T().Log(err) } w := rivertest.NewWorker(s.T(), driver, &river.Config{}, worker) - d := database.Pgxv5Pool + d := s.pool tx, err := d.Begin(s.ctx) if err != nil { s.T().Log(err) @@ -328,7 +331,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestQueueExportJobs() { ID: 33, } - driver := riverpgxv5.New(database.Pgxv5Pool) + driver := riverpgxv5.New(s.pool) _, err := driver.GetExecutor().Exec(context.Background(), `delete from river_job`) assert.Nil(s.T(), err) From b00818bf6d59ce4c06abff7522ca2ced23d2375d Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Tue, 5 Aug 2025 00:30:08 -0400 Subject: [PATCH 21/28] Rename connection to db --- bcda/api/requests.go | 10 +-- bcda/api/requests_test.go | 46 ++++++------- bcda/api/v1/api.go | 10 +-- bcda/api/v1/api_test.go | 64 +++++++++---------- bcda/api/v2/api.go | 8 +-- bcda/api/v2/api_test.go | 60 ++++++++--------- bcda/api/v3/api.go | 8 +-- bcda/api/v3/api_test.go | 60 ++++++++--------- bcda/auth/api_test.go | 2 +- bcda/auth/middleware.go | 4 +- bcda/auth/middleware_test.go | 30 ++++----- bcda/auth/router_test.go | 2 +- bcda/auth/ssas_middleware_test.go | 2 +- bcda/auth/ssas_test.go | 2 +- bcda/bcdacli/cli.go | 26 ++++---- bcda/bcdacli/cli_test.go | 4 +- bcda/cclf/cclf_test.go | 2 +- bcda/cclf/utils/cclfUtils.go | 4 +- bcda/cclf/utils/cclfUtils_test.go | 2 +- bcda/database/connection.go | 4 +- bcda/database/connection_test.go | 4 +- bcda/database/database_test.go | 4 +- bcda/database/databasetest/databasetest.go | 2 +- .../databasetest/databasetest_test.go | 2 +- bcda/database/pgx_test.go | 2 +- bcda/health/health.go | 4 +- bcda/lambda/admin_create_aco_creds/main.go | 2 +- bcda/lambda/admin_create_group/main.go | 2 +- bcda/lambda/cclf/main.go | 16 ++--- bcda/lambda/cclf/main_test.go | 2 +- bcda/lambda/optout/main.go | 4 +- bcda/lambda/optout/main_test.go | 2 +- bcda/models/postgres/repository_test.go | 2 +- bcda/service/service_test.go | 4 +- bcda/suppression/suppression_s3_test.go | 2 +- bcda/suppression/suppression_test.go | 4 +- bcda/web/middleware/ratelimit_test.go | 2 +- bcda/web/router.go | 28 ++++---- bcda/web/router_test.go | 28 ++++---- bcdaworker/cleanup/cleanup_test.go | 2 +- bcdaworker/main.go | 2 +- bcdaworker/queueing/enqueue.go | 6 +- bcdaworker/queueing/enqueue_test.go | 4 +- bcdaworker/queueing/river.go | 4 +- bcdaworker/queueing/river_test.go | 8 +-- bcdaworker/queueing/worker_prepare.go | 4 +- bcdaworker/queueing/worker_prepare_test.go | 4 +- bcdaworker/queueing/worker_process_job.go | 4 +- .../repository/postgres/repository_test.go | 2 +- bcdaworker/worker/worker_test.go | 2 +- db/migrations/migrations_test.go | 2 +- 51 files changed, 254 insertions(+), 256 deletions(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index ff2d866c6..e9c3a56d7 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -62,14 +62,14 @@ type fhirResponseWriter interface { JobsBundle(context.Context, http.ResponseWriter, []*models.Job, string) } -func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB, pool *pgxv5Pool.Pool) *Handler { - return newHandler(dataTypes, basePath, apiVersion, connection, pool) +func NewHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB, pool *pgxv5Pool.Pool) *Handler { + return newHandler(dataTypes, basePath, apiVersion, db, pool) } -func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, connection *sql.DB, pool *pgxv5Pool.Pool) *Handler { +func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersion string, db *sql.DB, pool *pgxv5Pool.Pool) *Handler { h := &Handler{JobTimeout: time.Hour * time.Duration(utils.GetEnvInt("ARCHIVE_THRESHOLD_HR", 24))} - h.Enq = queueing.NewEnqueuer(connection, pool) + h.Enq = queueing.NewEnqueuer(db, pool) cfg, err := service.LoadConfig() if err != nil { @@ -79,7 +79,7 @@ func newHandler(dataTypes map[string]service.DataType, basePath string, apiVersi log.API.Fatalf("no ACO configs found, these are required for processing logic") } - repository := postgres.NewRepository(connection) + repository := postgres.NewRepository(db) h.r = repository h.Svc = service.NewService(repository, cfg, basePath) diff --git a/bcda/api/requests_test.go b/bcda/api/requests_test.go index e3150fb0f..9c8299c43 100644 --- a/bcda/api/requests_test.go +++ b/bcda/api/requests_test.go @@ -66,7 +66,7 @@ type RequestsTestSuite struct { runoutEnabledEnvVar string - connection *sql.DB + db *sql.DB pool *pgxv5Pool.Pool @@ -83,8 +83,8 @@ func (s *RequestsTestSuite) SetupSuite() { // See testdata/acos.yml s.acoID = uuid.Parse("ba21d24d-cd96-4d7d-a691-b0e8c88e67a5") db, _ := databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) - s.connection = db - s.pool = database.GetPool() + s.db = db + s.pool = database.ConnectPool() tf, err := testfixtures.New( testfixtures.Database(db), testfixtures.Dialect("postgres"), @@ -142,7 +142,7 @@ func (s *RequestsTestSuite) TestRunoutEnabled() { mockSvc := &service.MockService{} mockAco := service.ACOConfig{Data: []string{"adjudicated"}} mockSvc.On("GetACOConfigForID", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAco, true) - h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.connection, s.pool) + h := newHandler(resourceMap, fmt.Sprintf("/%s/fhir", tt.apiVersion), tt.apiVersion, s.db, s.pool) h.Svc = mockSvc enqueuer := queueing.NewMockEnqueuer(s.T()) h.Enq = enqueuer @@ -244,7 +244,7 @@ func (s *RequestsTestSuite) TestJobsStatusV1() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connection, s.pool) + }, fhirPath, apiVersion, s.db, s.pool) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -358,7 +358,7 @@ func (s *RequestsTestSuite) TestJobsStatusV2() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, v2BasePath, apiVersionTwo, s.connection, s.pool) + }, v2BasePath, apiVersionTwo, s.db, s.pool) if tt.useMock { h.Svc = mockSvc } @@ -477,7 +477,7 @@ func (s *RequestsTestSuite) TestAttributionStatus() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connection, s.pool) + h := newHandler(resourceMap, fhirPath, apiVersion, s.db, s.pool) h.Svc = mockSvc rr := httptest.NewRecorder() @@ -568,7 +568,7 @@ func (s *RequestsTestSuite) TestDataTypeAuthorization() { "ClaimResponse": {Adjudicated: false, PartiallyAdjudicated: true}, } - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.connection, s.pool) + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, s.db, s.pool) r := models.NewMockRepository(s.T()) r.On("CreateJob", mock.Anything, mock.Anything).Return(uint(4), nil) h.r = r @@ -656,7 +656,7 @@ func (s *RequestsTestSuite) TestRequests() { fhirPath := "/" + apiVersion + "/fhir" resourceMap := s.resourceType - h := newHandler(resourceMap, fhirPath, apiVersion, s.connection, s.pool) + h := newHandler(resourceMap, fhirPath, apiVersion, s.db, s.pool) // Test Group and Patient // Patient, Coverage, and ExplanationOfBenefit @@ -786,7 +786,7 @@ func (s *RequestsTestSuite) TestJobStatusErrorHandling() { for _, tt := range tests { s.T().Run(tt.testName, func(t *testing.T) { - h := newHandler(resourceMap, basePath, apiVersion, s.connection, s.pool) + h := newHandler(resourceMap, basePath, apiVersion, s.db, s.pool) if tt.useMockService { mockSrv := service.MockService{} timestp := time.Now() @@ -860,7 +860,7 @@ func (s *RequestsTestSuite) TestJobStatusProgress() { apiVersion := apiVersionTwo requestUrl := v2JobRequestUrl resourceMap := s.resourceType - h := newHandler(resourceMap, basePath, apiVersion, s.connection, s.pool) + h := newHandler(resourceMap, basePath, apiVersion, s.db, s.pool) req := httptest.NewRequest("GET", requestUrl, nil) rctx := chi.NewRouteContext() @@ -909,7 +909,7 @@ func (s *RequestsTestSuite) TestDeleteJob() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - handler := newHandler(s.resourceType, basePath, apiVersion, s.connection, s.pool) + handler := newHandler(s.resourceType, basePath, apiVersion, s.db, s.pool) if tt.useMockService { mockSrv := service.MockService{} @@ -969,7 +969,7 @@ func (s *RequestsTestSuite) TestJobFailedStatus() { for _, tt := range tests { s.T().Run(tt.name, func(t *testing.T) { - h := newHandler(resourceMap, tt.basePath, tt.version, s.connection, s.pool) + h := newHandler(resourceMap, tt.basePath, tt.version, s.db, s.pool) mockSrv := service.MockService{} timestp := time.Now() mockSrv.On("GetJobAndKeys", testUtils.CtxMatcher, uint(1)).Return( @@ -1027,7 +1027,7 @@ func (s *RequestsTestSuite) TestGetResourceTypes() { {"CT000000", "v2", []string{"Patient", "ExplanationOfBenefit", "Coverage", "Claim", "ClaimResponse"}}, } for _, test := range testCases { - h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.connection, s.pool) + h := newHandler(s.resourceType, "/"+test.apiVersion+"/fhir", test.apiVersion, s.db, s.pool) rp := middleware.RequestParameters{ Version: test.apiVersion, ResourceTypes: []string{}, @@ -1060,9 +1060,9 @@ func TestBulkRequest_Integration(t *testing.T) { client.SetLogger(log.API) // Set logger so we don't get errors later - connection := database.GetConnection() - pool := database.GetPool() - h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, connection, pool) + db := database.Connect() + pool := database.ConnectPool() + h := NewHandler(dataTypeMap, v2BasePath, apiVersionTwo, db, pool) driver := riverpgxv5.New(pool) // start from clean river_job slate @@ -1070,7 +1070,7 @@ func TestBulkRequest_Integration(t *testing.T) { assert.Nil(t, err) acoID := "A0002" - repo := postgres.NewRepository(connection) + repo := postgres.NewRepository(db) // our DB is not always cleaned up properly so sometimes this record exists when this test runs and sometimes it doesnt repo.CreateACO(context.Background(), models.ACO{CMSID: &acoID, UUID: uuid.NewUUID()}) // nolint:errcheck @@ -1133,7 +1133,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques rctx := chi.NewRouteContext() rctx.URLParams.Add("groupId", groupID) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) @@ -1148,7 +1148,7 @@ func (s *RequestsTestSuite) genGroupRequest(groupID string, rp middleware.Reques func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/Patient/$export", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) ctx = middleware.SetRequestParamsCtx(ctx, rp) @@ -1159,7 +1159,7 @@ func (s *RequestsTestSuite) genPatientRequest(rp middleware.RequestParameters) * func (s *RequestsTestSuite) genASRequest() *http.Request { req := httptest.NewRequest("GET", "http://bcda.cms.gov/api/v1/attribution_status", nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) newLogEntry := MakeTestStructuredLoggerEntry(logrus.Fields{"cms_id": "A9999", "request_id": uuid.NewRandom().String()}) @@ -1187,7 +1187,7 @@ func (s *RequestsTestSuite) genGetJobsRequest(version string, statuses []models. req := httptest.NewRequest("GET", target, nil) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, s.acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, s.acoID) ad := auth.AuthData{ACOID: s.acoID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} ctx := context.WithValue(req.Context(), auth.AuthDataContextKey, ad) @@ -1208,7 +1208,7 @@ func (s *RequestsTestSuite) TestValidateResources() { "Patient": {}, "Coverage": {}, "ExplanationOfBenefit": {}, - }, fhirPath, apiVersion, s.connection, s.pool) + }, fhirPath, apiVersion, s.db, s.pool) err := h.validateResources([]string{"Vegetable"}, "1234") assert.Contains(s.T(), err.Error(), "invalid resource type") } diff --git a/bcda/api/v1/api.go b/bcda/api/v1/api.go index 1a394a70e..c8024e9b5 100644 --- a/bcda/api/v1/api.go +++ b/bcda/api/v1/api.go @@ -28,13 +28,13 @@ import ( ) type ApiV1 struct { - connection *sql.DB + db *sql.DB handler *api.Handler provider auth.Provider healthChecker health.HealthChecker } -func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) *ApiV1 { +func NewApiV1(db *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) *ApiV1 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -46,9 +46,9 @@ func NewApiV1(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) panic("Failed to configure resource DataTypes") } - hc := health.NewHealthChecker(connection) - h := api.NewHandler(resources, "/v1/fhir", "v1", connection, pool) - return &ApiV1{connection: connection, handler: h, provider: provider, healthChecker: hc} + hc := health.NewHealthChecker(db) + h := api.NewHandler(resources, "/v1/fhir", "v1", db, pool) + return &ApiV1{db: db, handler: h, provider: provider, healthChecker: hc} } /* diff --git a/bcda/api/v1/api_test.go b/bcda/api/v1/api_test.go index fe845b310..e664a6970 100644 --- a/bcda/api/v1/api_test.go +++ b/bcda/api/v1/api_test.go @@ -46,17 +46,17 @@ var ( type APITestSuite struct { suite.Suite - rr *httptest.ResponseRecorder - connection *sql.DB - pool *pgxv5Pool.Pool - provider auth.Provider - apiV1 *ApiV1 + rr *httptest.ResponseRecorder + db *sql.DB + pool *pgxv5Pool.Pool + provider auth.Provider + apiV1 *ApiV1 } func (s *APITestSuite) SetupSuite() { - s.connection = database.GetConnection() - s.provider = auth.NewProvider(s.connection) - s.apiV1 = NewApiV1(s.connection, s.pool, s.provider) + s.db = database.Connect() + s.provider = auth.NewProvider(s.db) + s.apiV1 = NewApiV1(s.db, s.pool, s.provider) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -78,7 +78,7 @@ func (s *APITestSuite) SetupTest() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) } func (s *APITestSuite) TestJobStatusBadInputs() { @@ -138,8 +138,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -166,14 +166,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connection, + postgrestest.CreateJobKeys(s.T(), s.db, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -217,7 +217,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -225,7 +225,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.db, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -277,10 +277,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) j.UpdatedAt = time.Now().Add(-(s.apiV1.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connection, j) + postgrestest.UpdateJob(s.T(), s.db, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) s.apiV1.JobStatus(s.rr, req) @@ -346,8 +346,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -461,10 +461,10 @@ func (s *APITestSuite) TestJobStatusWithWrongACO() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusPending, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) am := auth.NewAuthMiddleware(s.provider) - handler := am.RequireTokenJobMatch(s.connection)(http.HandlerFunc(s.apiV1.JobStatus)) + handler := am.RequireTokenJobMatch(s.db)(http.HandlerFunc(s.apiV1.JobStatus)) req := s.createJobStatusRequest(uuid.Parse(constants.LargeACOUUID), j.ID) handler.ServeHTTP(s.rr, req) @@ -484,8 +484,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -516,8 +516,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -536,8 +536,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -556,8 +556,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: constants.V1Path + constants.PatientEOBPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV1.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -598,15 +598,15 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v2/api.go b/bcda/api/v2/api.go index 7b3985f97..a66ab5b05 100644 --- a/bcda/api/v2/api.go +++ b/bcda/api/v2/api.go @@ -26,10 +26,10 @@ import ( type ApiV2 struct { handler *api.Handler marshaller *jsonformat.Marshaller - connection *sql.DB + db *sql.DB } -func NewApiV2(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV2 { +func NewApiV2(db *sql.DB, pool *pgxv5Pool.Pool) *ApiV2 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -41,14 +41,14 @@ func NewApiV2(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV2 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, "/v2/fhir", "v2", connection, pool) + h := api.NewHandler(resources, "/v2/fhir", "v2", db, pool) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) if err != nil { log.API.Fatalf("Failed to create marshaller %s", err) } - return &ApiV2{marshaller: marshaller, handler: h, connection: connection} + return &ApiV2{marshaller: marshaller, handler: h, db: db} } } diff --git a/bcda/api/v2/api_test.go b/bcda/api/v2/api_test.go index 1a8a83040..5d0313bcd 100644 --- a/bcda/api/v2/api_test.go +++ b/bcda/api/v2/api_test.go @@ -53,15 +53,15 @@ var ( type APITestSuite struct { suite.Suite - connection *sql.DB - pool *pgxv5Pool.Pool - apiV2 *ApiV2 + db *sql.DB + pool *pgxv5Pool.Pool + apiV2 *ApiV2 } func (s *APITestSuite) SetupSuite() { - s.connection = database.GetConnection() - s.pool = database.GetPool() - s.apiV2 = NewApiV2(s.connection, s.pool) + s.db = database.Connect() + s.pool = database.ConnectPool() + s.apiV2 = NewApiV2(s.db, s.pool) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -82,7 +82,7 @@ func (s *APITestSuite) SetupSuite() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -146,8 +146,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -174,14 +174,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connection, + postgrestest.CreateJobKeys(s.T(), s.db, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -227,7 +227,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -235,7 +235,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.db, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -289,10 +289,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V2Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) j.UpdatedAt = time.Now().Add(-(s.apiV2.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connection, j) + postgrestest.UpdateJob(s.T(), s.db, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -316,8 +316,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -348,8 +348,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -368,8 +368,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -388,8 +388,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: "/api/v2/Patient/$export?_type=ExplanationOfBenefit", Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV2.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -451,8 +451,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: "/api/v2/Patient/$export?_type=Patient,Coverage", Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -547,7 +547,7 @@ func (s *APITestSuite) TestResourceTypes() { "ClaimResponse", }...) - h := api.NewHandler(resources, "/v2/fhir", "v2", s.connection, s.pool) + h := api.NewHandler(resources, "/v2/fhir", "v2", s.db, s.pool) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -614,20 +614,20 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/api/v3/api.go b/bcda/api/v3/api.go index 40f695c76..a6c36e0bc 100644 --- a/bcda/api/v3/api.go +++ b/bcda/api/v3/api.go @@ -26,10 +26,10 @@ import ( type ApiV3 struct { handler *api.Handler marshaller *jsonformat.Marshaller - connection *sql.DB + db *sql.DB } -func NewApiV3(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV3 { +func NewApiV3(db *sql.DB, pool *pgxv5Pool.Pool) *ApiV3 { resources, ok := service.GetDataTypes([]string{ "Patient", "Coverage", @@ -39,14 +39,14 @@ func NewApiV3(connection *sql.DB, pool *pgxv5Pool.Pool) *ApiV3 { if !ok { panic("Failed to configure resource DataTypes") } else { - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, connection, pool) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, db, pool) // Ensure that we write the serialized FHIR resources as a single line. // Needed to comply with the NDJSON format that we are using. marshaller, err := jsonformat.NewMarshaller(false, "", "", fhirversion.R4) if err != nil { log.API.Fatalf("Failed to create marshaller %s", err) } - return &ApiV3{marshaller: marshaller, handler: h, connection: connection} + return &ApiV3{marshaller: marshaller, handler: h, db: db} } } diff --git a/bcda/api/v3/api_test.go b/bcda/api/v3/api_test.go index 5fa40b813..1b966ff9d 100644 --- a/bcda/api/v3/api_test.go +++ b/bcda/api/v3/api_test.go @@ -53,15 +53,15 @@ var ( type APITestSuite struct { suite.Suite - connection *sql.DB - pool *pgxv5Pool.Pool - apiV3 *ApiV3 + db *sql.DB + pool *pgxv5Pool.Pool + apiV3 *ApiV3 } func (s *APITestSuite) SetupSuite() { - s.connection = database.GetConnection() - s.pool = database.GetPool() - s.apiV3 = NewApiV3(s.connection, s.pool) + s.db = database.Connect() + s.pool = database.ConnectPool() + s.apiV3 = NewApiV3(s.db, s.pool) origDate := conf.GetEnv("CCLF_REF_DATE") conf.SetEnv(s.T(), "CCLF_REF_DATE", time.Now().Format("060102 15:01:01")) @@ -82,7 +82,7 @@ func (s *APITestSuite) SetupSuite() { } func (s *APITestSuite) TearDownTest() { - postgrestest.DeleteJobsByACOID(s.T(), s.connection, acoUnderTest) + postgrestest.DeleteJobsByACOID(s.T(), s.db, acoUnderTest) } func TestAPITestSuite(t *testing.T) { @@ -146,8 +146,8 @@ func (s *APITestSuite) TestJobStatusNotComplete() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -179,14 +179,14 @@ func (s *APITestSuite) TestJobStatusCompleted() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) var expectedUrls []string for i := 1; i <= 10; i++ { fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) expectedurl := fmt.Sprintf("%s/%s/%s", constants.ExpectedTestUrl, fmt.Sprint(j.ID), fileName) expectedUrls = append(expectedUrls, expectedurl) - postgrestest.CreateJobKeys(s.T(), s.connection, + postgrestest.CreateJobKeys(s.T(), s.db, models.JobKey{JobID: j.ID, FileName: fileName, ResourceType: "ExplanationOfBenefit"}) } @@ -232,7 +232,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) fileName := fmt.Sprintf("%s.ndjson", uuid.NewRandom().String()) jobKey := models.JobKey{ @@ -240,7 +240,7 @@ func (s *APITestSuite) TestJobStatusCompletedErrorFileExists() { FileName: fileName, ResourceType: "ExplanationOfBenefit", } - postgrestest.CreateJobKeys(s.T(), s.connection, jobKey) + postgrestest.CreateJobKeys(s.T(), s.db, jobKey) f := fmt.Sprintf("%s/%s", conf.GetEnv("FHIR_PAYLOAD_DIR"), fmt.Sprint(j.ID)) if _, err := os.Stat(f); os.IsNotExist(err) { @@ -294,10 +294,10 @@ func (s *APITestSuite) TestJobStatusNotExpired() { RequestURL: constants.V3Path + constants.PatientEOBPath, Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) j.UpdatedAt = time.Now().Add(-(s.apiV3.handler.JobTimeout + time.Second)) - postgrestest.UpdateJob(s.T(), s.connection, j) + postgrestest.UpdateJob(s.T(), s.db, j) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -321,8 +321,8 @@ func (s *APITestSuite) TestJobsStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -353,8 +353,8 @@ func (s *APITestSuite) TestJobsStatusNotFoundWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusCompleted, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusNotFound, rr.Code) @@ -373,8 +373,8 @@ func (s *APITestSuite) TestJobsStatusWithStatus() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -393,8 +393,8 @@ func (s *APITestSuite) TestJobsStatusWithStatuses() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=ExplanationOfBenefit", constants.V3Path), Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) - defer postgrestest.DeleteJobByID(s.T(), s.connection, j.ID) + postgrestest.CreateJobs(s.T(), s.db, &j) + defer postgrestest.DeleteJobByID(s.T(), s.db, j.ID) s.apiV3.JobsStatus(rr, req) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -456,8 +456,8 @@ func (s *APITestSuite) TestDeleteJob() { RequestURL: fmt.Sprintf("%sPatient/$export?_type=Patient,Coverage", constants.V3Path), Status: tt.status, } - postgrestest.CreateJobs(t, s.connection, &j) - defer postgrestest.DeleteJobByID(t, s.connection, j.ID) + postgrestest.CreateJobs(t, s.db, &j) + defer postgrestest.DeleteJobByID(t, s.db, j.ID) req := s.createJobStatusRequest(acoUnderTest, j.ID) rr := httptest.NewRecorder() @@ -550,7 +550,7 @@ func (s *APITestSuite) TestResourceTypes() { "ExplanationOfBenefit", }...) - h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.connection, s.pool) + h := api.NewHandler(resources, constants.BFDV3Path, constants.V3Version, s.db, s.pool) mockSvc := &service.MockService{} mockSvc.On("GetLatestCCLFFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.CCLFFile{PerformanceYear: utils.GetPY()}, nil) @@ -617,20 +617,20 @@ func (s *APITestSuite) TestGetAttributionStatus() { err := json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(s.T(), err) - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) - cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.connection, *aco.CMSID, models.FileTypeDefault) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) + cclfFile := postgrestest.GetLatestCCLFFileByCMSIDAndType(s.T(), s.db, *aco.CMSID, models.FileTypeDefault) assert.Equal(s.T(), "last_attribution_update", resp.Data[0].Type) assert.Equal(s.T(), cclfFile.Timestamp.Format("2006-01-02 15:04:05"), resp.Data[0].Timestamp.Format("2006-01-02 15:04:05")) } func (s *APITestSuite) getAuthData() (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoUnderTest) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoUnderTest) return auth.AuthData{ACOID: acoUnderTest.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } func (s *APITestSuite) makeContextValues(acoID uuid.UUID) (data auth.AuthData) { - aco := postgrestest.GetACOByUUID(s.T(), s.connection, acoID) + aco := postgrestest.GetACOByUUID(s.T(), s.db, acoID) return auth.AuthData{ACOID: aco.UUID.String(), CMSID: *aco.CMSID, TokenID: uuid.NewRandom().String()} } diff --git a/bcda/auth/api_test.go b/bcda/auth/api_test.go index 66b8c4311..d8c025fc1 100644 --- a/bcda/auth/api_test.go +++ b/bcda/auth/api_test.go @@ -39,7 +39,7 @@ type AuthAPITestSuite struct { } func (s *AuthAPITestSuite) SetupSuite() { - s.db = database.GetConnection() + s.db = database.Connect() s.r = postgres.NewRepository(s.db) } diff --git a/bcda/auth/middleware.go b/bcda/auth/middleware.go index df82df0e4..62416e205 100644 --- a/bcda/auth/middleware.go +++ b/bcda/auth/middleware.go @@ -171,7 +171,7 @@ func CheckBlacklist(next http.Handler) http.Handler { }) } -func (m AuthMiddleware) RequireTokenJobMatch(connection *sql.DB) func(next http.Handler) http.Handler { +func (m AuthMiddleware) RequireTokenJobMatch(db *sql.DB) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { rw := getRespWriter(r.URL.Path) @@ -191,7 +191,7 @@ func (m AuthMiddleware) RequireTokenJobMatch(connection *sql.DB) func(next http. return } - repository := postgres.NewRepository(connection) + repository := postgres.NewRepository(db) job, err := repository.GetJobByID(r.Context(), uint(jobID)) if err != nil { diff --git a/bcda/auth/middleware_test.go b/bcda/auth/middleware_test.go index d93805406..1d1177c7f 100644 --- a/bcda/auth/middleware_test.go +++ b/bcda/auth/middleware_test.go @@ -40,14 +40,12 @@ var bearerStringMsg string = "Bearer %s" type MiddlewareTestSuite struct { suite.Suite - // server *httptest.Server rr *httptest.ResponseRecorder - // am auth.AuthMiddleware - connection *sql.DB + db *sql.DB } func (s *MiddlewareTestSuite) SetupSuite() { - s.connection = database.GetConnection() + s.db = database.Connect() } func (s *MiddlewareTestSuite) CreateRouter(p auth.Provider) http.Handler { @@ -79,7 +77,7 @@ func (s *MiddlewareTestSuite) SetupTest() { // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestReturn400WhenInvalidTokenAuthWithInvalidSignature() { - server := s.CreateServer(auth.NewProvider(s.connection)) + server := s.CreateServer(auth.NewProvider(s.db)) defer server.Close() client := server.Client() badT := "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCIsImtpZCI6ImlUcVhYSTB6YkFuSkNLRGFvYmZoa00xZi02ck1TcFRmeVpNUnBfMnRLSTgifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.cJOP_w-hBqnyTsBm3T6lOE5WpcHaAkLuQGAs1QO-lg2eWs8yyGW8p9WagGjxgvx7h9X72H7pXmXqej3GdlVbFmhuzj45A9SXDOAHZ7bJXwM1VidcPi7ZcrsMSCtP1hiN" @@ -100,7 +98,7 @@ func (s *MiddlewareTestSuite) TestReturn400WhenInvalidTokenAuthWithInvalidSignat // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestReturn401WhenExpiredToken() { - server := s.CreateServer(auth.NewProvider(s.connection)) + server := s.CreateServer(auth.NewProvider(s.db)) defer server.Close() client := server.Client() expiredToken := jwt.NewWithClaims(jwt.SigningMethodRS512, &auth.CommonClaims{ @@ -330,7 +328,7 @@ func (s *MiddlewareTestSuite) TestAuthMiddlewareReturn401WhenNonEntityNotFoundEr // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestAuthMiddlewareReturnResponse401WhenNoBearerTokenSupplied() { - server := s.CreateServer(auth.NewProvider(s.connection)) + server := s.CreateServer(auth.NewProvider(s.db)) defer server.Close() client := server.Client() @@ -357,7 +355,7 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) @@ -375,9 +373,9 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenMismatchingDa {"Mismatching ACOID", jobID, uuid.New(), http.StatusUnauthorized}, } - p := auth.NewProvider(s.connection) + p := auth.NewProvider(s.db) am := auth.NewAuthMiddleware(p) - handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + handler := am.RequireTokenJobMatch(s.db)(mockHandler) server := s.CreateServer(p) defer server.Close() @@ -411,16 +409,16 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn200WhenCorrectAccoun RequestURL: constants.V1Path + constants.EOBExportPath, Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) } jobID := strconv.Itoa(id) - p := auth.NewProvider(s.connection) + p := auth.NewProvider(s.db) am := auth.NewAuthMiddleware(p) - handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + handler := am.RequireTokenJobMatch(s.db)(mockHandler) server := s.CreateServer(p) defer server.Close() @@ -451,16 +449,16 @@ func (s *MiddlewareTestSuite) TestRequireTokenJobMatchReturn404WhenNoAuthDataPro Status: models.JobStatusFailed, } - postgrestest.CreateJobs(s.T(), s.connection, &j) + postgrestest.CreateJobs(s.T(), s.db, &j) id, err := safecast.ToInt(j.ID) if err != nil { log.Fatal(err) } jobID := strconv.Itoa(id) - p := auth.NewProvider(s.connection) + p := auth.NewProvider(s.db) am := auth.NewAuthMiddleware(p) - handler := am.RequireTokenJobMatch(s.connection)(mockHandler) + handler := am.RequireTokenJobMatch(s.db)(mockHandler) server := s.CreateServer(p) defer server.Close() diff --git a/bcda/auth/router_test.go b/bcda/auth/router_test.go index 3c1d2c834..f136207c2 100644 --- a/bcda/auth/router_test.go +++ b/bcda/auth/router_test.go @@ -23,7 +23,7 @@ type AuthRouterTestSuite struct { func (s *AuthRouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") - s.provider = NewProvider(database.GetConnection()) + s.provider = NewProvider(database.Connect()) s.authRouter = NewAuthRouter(s.provider) } diff --git a/bcda/auth/ssas_middleware_test.go b/bcda/auth/ssas_middleware_test.go index a0f5c4861..36ac865ea 100644 --- a/bcda/auth/ssas_middleware_test.go +++ b/bcda/auth/ssas_middleware_test.go @@ -38,7 +38,7 @@ type SSASMiddlewareTestSuite struct { func (s *SSASMiddlewareTestSuite) createRouter() http.Handler { router := chi.NewRouter() - am := auth.NewAuthMiddleware(auth.NewProvider(database.GetConnection())) + am := auth.NewAuthMiddleware(auth.NewProvider(database.Connect())) router.Use(am.ParseToken) router.With(auth.RequireTokenAuth).Get("/v1/", func(w http.ResponseWriter, r *http.Request) { ad := r.Context().Value(auth.AuthDataContextKey).(auth.AuthData) diff --git a/bcda/auth/ssas_test.go b/bcda/auth/ssas_test.go index 650c24ed9..a83ff0d25 100644 --- a/bcda/auth/ssas_test.go +++ b/bcda/auth/ssas_test.go @@ -71,7 +71,7 @@ func (s *SSASPluginTestSuite) SetupSuite() { origSSASClientID = conf.GetEnv("BCDA_SSAS_CLIENT_ID") origSSASSecret = conf.GetEnv("BCDA_SSAS_SECRET") - s.db = database.GetConnection() + s.db = database.Connect() s.r = postgres.NewRepository(s.db) } diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index 94beceb9d..c8ea79e5a 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -48,10 +48,10 @@ const Name = "bcda" const Usage = "Beneficiary Claims Data API CLI" var ( - connection *sql.DB - pool *pgxv5Pool.Pool - r models.Repository - provider auth.Provider + db *sql.DB + pool *pgxv5Pool.Pool + r models.Repository + provider auth.Provider ) func GetApp() *cli.App { @@ -64,10 +64,10 @@ func setUpApp() *cli.App { app.Usage = Usage app.Version = constants.Version app.Before = func(c *cli.Context) error { - connection = database.GetConnection() - pool = database.GetPool() - r = postgres.NewRepository(connection) - provider = auth.NewProvider(connection) + db = database.Connect() + pool = database.ConnectPool() + r = postgres.NewRepository(db) + provider = auth.NewProvider(db) log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, provider)) return nil } @@ -128,7 +128,7 @@ func setUpApp() *cli.App { } api := &http.Server{ - Handler: web.NewAPIRouter(connection, pool, provider), + Handler: web.NewAPIRouter(db, pool, provider), ReadTimeout: time.Duration(utils.GetEnvInt("API_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("API_WRITE_TIMEOUT", 20)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("API_IDLE_TIMEOUT", 120)) * time.Second, @@ -136,7 +136,7 @@ func setUpApp() *cli.App { } fileserver := &http.Server{ - Handler: web.NewDataRouter(connection, provider), + Handler: web.NewDataRouter(db, provider), ReadTimeout: time.Duration(utils.GetEnvInt("FILESERVER_READ_TIMEOUT", 10)) * time.Second, WriteTimeout: time.Duration(utils.GetEnvInt("FILESERVER_WRITE_TIMEOUT", 360)) * time.Second, IdleTimeout: time.Duration(utils.GetEnvInt("FILESERVER_IDLE_TIMEOUT", 120)) * time.Second, @@ -328,7 +328,7 @@ func setUpApp() *cli.App { } } - importer := cclf.NewCclfImporter(log.API, file_processor, connection) + importer := cclf.NewCclfImporter(log.API, file_processor, db) success, failure, skipped, err := importer.ImportCCLFDirectory(filePath) if err != nil { @@ -395,7 +395,7 @@ func setUpApp() *cli.App { }, Action: func(c *cli.Context) error { ignoreSignals() - r := postgres.NewRepository(connection) + r := postgres.NewRepository(db) var file_handler optout.OptOutFileHandler @@ -457,7 +457,7 @@ func setUpApp() *cli.App { return errors.New("Unsupported file type.") } } - err := cclfUtils.ImportCCLFPackage(connection, acoSize, environment, ft) + err := cclfUtils.ImportCCLFPackage(db, acoSize, environment, ft) return err }, }, diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index c6d83dc05..33d451726 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -69,8 +69,8 @@ func (s *CLITestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.GetConnection() - connection = s.db + s.db = database.Connect() + db = s.db r = postgres.NewRepository(s.db) cmsID := testUtils.RandomHexID()[0:4] diff --git a/bcda/cclf/cclf_test.go b/bcda/cclf/cclf_test.go index 457b269e0..b30c53450 100644 --- a/bcda/cclf/cclf_test.go +++ b/bcda/cclf/cclf_test.go @@ -73,7 +73,7 @@ func (s *CCLFTestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.GetConnection() + s.db = database.Connect() } func (s *CCLFTestSuite) TearDownSuite() { diff --git a/bcda/cclf/utils/cclfUtils.go b/bcda/cclf/utils/cclfUtils.go index c72fd8dd6..ff91aea9d 100644 --- a/bcda/cclf/utils/cclfUtils.go +++ b/bcda/cclf/utils/cclfUtils.go @@ -24,7 +24,7 @@ import ( // ImportCCLFPackage will copy the appropriate synthetic CCLF files, rename them, // begin the import of those files and delete them from the place they were copied to after successful import. -func ImportCCLFPackage(connection *sql.DB, acoSize, environment string, fileType models.CCLFFileType) (err error) { +func ImportCCLFPackage(db *sql.DB, acoSize, environment string, fileType models.CCLFFileType) (err error) { dir, err := os.MkdirTemp("", "*") if err != nil { @@ -149,7 +149,7 @@ func ImportCCLFPackage(connection *sql.DB, acoSize, environment string, fileType }, } - importer := cclf.NewCclfImporter(log.API, file_processor, connection) + importer := cclf.NewCclfImporter(log.API, file_processor, db) success, failure, skipped, err := importer.ImportCCLFDirectory(dir) if err != nil { diff --git a/bcda/cclf/utils/cclfUtils_test.go b/bcda/cclf/utils/cclfUtils_test.go index cc4bae1c7..7df4413b5 100644 --- a/bcda/cclf/utils/cclfUtils_test.go +++ b/bcda/cclf/utils/cclfUtils_test.go @@ -25,7 +25,7 @@ var origDate string func (s *CCLFUtilTestSuite) SetupSuite() { origDate = conf.GetEnv("CCLF_REF_DATE") - s.db = database.GetConnection() + s.db = database.Connect() } func (s *CCLFUtilTestSuite) SetupTest() { diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 408a11098..298f2702a 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -17,7 +17,7 @@ import ( "github.com/sirupsen/logrus" ) -func GetConnection() *sql.DB { +func Connect() *sql.DB { cfg, err := LoadConfig() if err != nil { logrus.Fatalf("Failed to load database config %s", err.Error()) @@ -40,7 +40,7 @@ func GetConnection() *sql.DB { return conn } -func GetPool() *pgxv5Pool.Pool { +func ConnectPool() *pgxv5Pool.Pool { cfg, err := LoadConfig() if err != nil { logrus.Fatalf("Failed to load database config %s", err.Error()) diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index a4862acef..796ed30ef 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -13,7 +13,7 @@ import ( func TestConnections(t *testing.T) { // Verify that we can initialize the package as expected - c := GetConnection() + c := Connect() assert.NotNil(t, c) assert.NoError(t, c.Ping()) } @@ -33,7 +33,7 @@ func TestConnectionHealthCheck(t *testing.T) { hook := test.NewGlobal() ctx, cancel := context.WithCancel(context.Background()) - c := GetConnection() + c := Connect() startConnectionHealthCheck(ctx, c, 100*time.Microsecond) // Let some time elapse to ensure we've successfully ran health checks time.Sleep(50 * time.Millisecond) diff --git a/bcda/database/database_test.go b/bcda/database/database_test.go index 93c3caa14..2a2326cbe 100644 --- a/bcda/database/database_test.go +++ b/bcda/database/database_test.go @@ -9,7 +9,7 @@ import ( ) func TestDBOperations(t *testing.T) { - c := GetConnection() + c := Connect() var q Queryable = &DB{c} var e Executable = &DB{c} rows, err := q.QueryContext(context.Background(), constants.TestSelectNowSQL) @@ -32,7 +32,7 @@ func TestDBOperations(t *testing.T) { } func TestTxOperations(t *testing.T) { - c := GetConnection() + c := Connect() tx, err := c.Begin() assert.NoError(t, err) defer func() { diff --git a/bcda/database/databasetest/databasetest.go b/bcda/database/databasetest/databasetest.go index f65f1ab72..99d0e2f4e 100644 --- a/bcda/database/databasetest/databasetest.go +++ b/bcda/database/databasetest/databasetest.go @@ -23,7 +23,7 @@ func CreateDatabase(t *testing.T, migrationPath string, cleanup bool) (*sql.DB, cfg, err := database.LoadConfig() assert.NoError(t, err) dsn := cfg.DatabaseURL - db := database.GetConnection() + db := database.Connect() newDBName := strings.ReplaceAll(fmt.Sprintf("%s_%s", dbName(dsn), uuid.New()), "-", "_") newDSN := dsnPattern.ReplaceAllString(dsn, fmt.Sprintf("${conn}%s${options}", newDBName)) diff --git a/bcda/database/databasetest/databasetest_test.go b/bcda/database/databasetest/databasetest_test.go index 9b500cf0a..43767b424 100644 --- a/bcda/database/databasetest/databasetest_test.go +++ b/bcda/database/databasetest/databasetest_test.go @@ -28,7 +28,7 @@ func TestCreateDatabase(t *testing.T) { assert.NoError(t, db.Close()) }) - db := database.GetConnection() + db := database.Connect() var count int assert.NoError(t, diff --git a/bcda/database/pgx_test.go b/bcda/database/pgx_test.go index 2e0543321..70b13cc6d 100644 --- a/bcda/database/pgx_test.go +++ b/bcda/database/pgx_test.go @@ -11,7 +11,7 @@ import ( ) func TestPgxTxOperations(t *testing.T) { - conn, err := stdlib.AcquireConn(GetConnection()) + conn, err := stdlib.AcquireConn(Connect()) assert.NoError(t, err) defer func() { assert.NoError(t, conn.Close()) diff --git a/bcda/health/health.go b/bcda/health/health.go index bb1a1a22a..deb79241e 100644 --- a/bcda/health/health.go +++ b/bcda/health/health.go @@ -13,8 +13,8 @@ type HealthChecker struct { db *sql.DB } -func NewHealthChecker(connection *sql.DB) HealthChecker { - return HealthChecker{db: connection} +func NewHealthChecker(db *sql.DB) HealthChecker { + return HealthChecker{db: db} } func (h HealthChecker) IsDatabaseOK() (result string, ok bool) { diff --git a/bcda/lambda/admin_create_aco_creds/main.go b/bcda/lambda/admin_create_aco_creds/main.go index edc79339c..61e94622a 100644 --- a/bcda/lambda/admin_create_aco_creds/main.go +++ b/bcda/lambda/admin_create_aco_creds/main.go @@ -75,7 +75,7 @@ func handler(ctx context.Context, event json.RawMessage) (string, error) { return "", err } - provider := auth.NewProvider(database.GetConnection()) + provider := auth.NewProvider(database.Connect()) s3Service := s3.New(session) slackClient := slack.New(params.slackToken) diff --git a/bcda/lambda/admin_create_group/main.go b/bcda/lambda/admin_create_group/main.go index 3d2acec8c..b4fa05060 100644 --- a/bcda/lambda/admin_create_group/main.go +++ b/bcda/lambda/admin_create_group/main.go @@ -61,7 +61,7 @@ func handler(ctx context.Context, event json.RawMessage) error { } slackClient := slack.New(slackToken) - db := database.GetConnection() + db := database.Connect() r := postgres.NewRepository(db) ssas, err := client.NewSSASClient() diff --git a/bcda/lambda/cclf/main.go b/bcda/lambda/cclf/main.go index aaf2518d8..711644d42 100644 --- a/bcda/lambda/cclf/main.go +++ b/bcda/lambda/cclf/main.go @@ -25,7 +25,7 @@ func main() { // Localstack is a local-development server that mimics AWS. The endpoint variable // should only be set in local development to avoid making external calls to a real AWS account. if os.Getenv("LOCAL_STACK_ENDPOINT") != "" { - res, err := handleCclfImport(database.GetConnection(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) + res, err := handleCclfImport(database.Connect(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) if err != nil { fmt.Printf("Failed to run opt out import: %s\n", err.Error()) } else { @@ -40,7 +40,7 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) - connection := database.GetConnection() + db := database.Connect() s3Event, err := bcdaaws.ParseSQSEvent(sqsEvent) @@ -68,9 +68,9 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st filepath := fmt.Sprintf("%s/%s", e.S3.Bucket.Name, e.S3.Object.Key) logger.Infof("Reading %s event for file %s", e.EventName, filepath) if cclf.CheckIfAttributionCSVFile(e.S3.Object.Key) { - return handleCSVImport(connection, s3AssumeRoleArn, filepath) + return handleCSVImport(db, s3AssumeRoleArn, filepath) } else { - return handleCclfImport(connection, s3AssumeRoleArn, filepath) + return handleCclfImport(db, s3AssumeRoleArn, filepath) } } } @@ -79,7 +79,7 @@ func attributionImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (st return "", nil } -func handleCSVImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { +func handleCSVImport(db *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) @@ -87,7 +87,7 @@ func handleCSVImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) ( importer := cclf.CSVImporter{ Logger: logger, - Database: connection, + Database: db, FileProcessor: &cclf.S3FileProcessor{ Handler: optout.S3FileHandler{ Logger: logger, @@ -130,7 +130,7 @@ func loadBCDAParams() error { return nil } -func handleCclfImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { +func handleCclfImport(db *sql.DB, s3AssumeRoleArn, s3ImportPath string) (string, error) { env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) @@ -144,7 +144,7 @@ func handleCclfImport(connection *sql.DB, s3AssumeRoleArn, s3ImportPath string) }, } - importer := cclf.NewCclfImporter(logger, &fileProcessor, connection) + importer := cclf.NewCclfImporter(logger, &fileProcessor, db) success, failure, skipped, err := importer.ImportCCLFDirectory(s3ImportPath) diff --git a/bcda/lambda/cclf/main_test.go b/bcda/lambda/cclf/main_test.go index 8a1e7de88..04426d211 100644 --- a/bcda/lambda/cclf/main_test.go +++ b/bcda/lambda/cclf/main_test.go @@ -21,7 +21,7 @@ type AttributionImportMainSuite struct { } func (s *AttributionImportMainSuite) SetupSuite() { - s.db = database.GetConnection() + s.db = database.Connect() } func TestAttributionImportMainSuite(t *testing.T) { suite.Run(t, new(AttributionImportMainSuite)) diff --git a/bcda/lambda/optout/main.go b/bcda/lambda/optout/main.go index 8b39acbfe..473b91782 100644 --- a/bcda/lambda/optout/main.go +++ b/bcda/lambda/optout/main.go @@ -26,7 +26,7 @@ func main() { // Localstack is a local-development server that mimics AWS. The endpoint variable // should only be set in local development to avoid making external calls to a real AWS account. if os.Getenv("LOCAL_STACK_ENDPOINT") != "" { - res, err := handleOptOutImport(database.GetConnection(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) + res, err := handleOptOutImport(database.Connect(), os.Getenv("BFD_BUCKET_ROLE_ARN"), os.Getenv("BFD_S3_IMPORT_PATH")) if err != nil { fmt.Printf("Failed to run opt out import: %s\n", err.Error()) } else { @@ -41,7 +41,7 @@ func optOutImportHandler(ctx context.Context, sqsEvent events.SQSEvent) (string, env := conf.GetEnv("ENV") appName := conf.GetEnv("APP_NAME") logger := configureLogger(env, appName) - db := database.GetConnection() + db := database.Connect() s3Event, err := bcdaaws.ParseSQSEvent(sqsEvent) diff --git a/bcda/lambda/optout/main_test.go b/bcda/lambda/optout/main_test.go index 860e097e5..698bab7f4 100644 --- a/bcda/lambda/optout/main_test.go +++ b/bcda/lambda/optout/main_test.go @@ -21,7 +21,7 @@ type OptOutImportMainSuite struct { } func (s *OptOutImportMainSuite) SetupSuite() { - s.db = database.GetConnection() + s.db = database.Connect() } func TestOptOutImportMainSuite(t *testing.T) { diff --git a/bcda/models/postgres/repository_test.go b/bcda/models/postgres/repository_test.go index 24e0b5d22..2c4f101f7 100644 --- a/bcda/models/postgres/repository_test.go +++ b/bcda/models/postgres/repository_test.go @@ -40,7 +40,7 @@ func TestRepositoryTestSuite(t *testing.T) { } func (r *RepositoryTestSuite) SetupSuite() { - r.db = database.GetConnection() + r.db = database.Connect() r.repository = postgres.NewRepository(r.db) } diff --git a/bcda/service/service_test.go b/bcda/service/service_test.go index 35274318c..411259737 100644 --- a/bcda/service/service_test.go +++ b/bcda/service/service_test.go @@ -492,7 +492,7 @@ func (s *ServiceTestSuite) TestGetNewAndExistingBeneficiaries_Integration() { // - Diff between CCLF File 1 and CCLF File 2 // - No diff - consider all beneficiaries as pre-existing func (s *ServiceTestSuite) TestGetNewAndExistingBeneficiaries_RecentSinceParameter_Integration() { - db := database.GetConnection() + db := database.Connect() acoID := "A0005" // Test Setup @@ -1657,7 +1657,7 @@ func (s *ServiceTestSuiteWithDatabase) TestGetBenesByID_Integration() { } func (s *ServiceTestSuiteWithDatabase) TestGetNewAndExistingBeneficiaries_RecentSinceParameterDatabase_Integration() { - db := database.GetConnection() + db := database.Connect() acoID := "A0005" // Test Setup diff --git a/bcda/suppression/suppression_s3_test.go b/bcda/suppression/suppression_s3_test.go index 60391d9b0..06cbd41e3 100644 --- a/bcda/suppression/suppression_s3_test.go +++ b/bcda/suppression/suppression_s3_test.go @@ -292,7 +292,7 @@ func (s *SuppressionS3TestSuite) TestCleanupSuppression() { func (s *SuppressionS3TestSuite) TestImportSuppressionDirectoryTable() { assert := assert.New(s.T()) importer, _ := s.createImporter() - db := database.GetConnection() + db := database.Connect() importer.Saver = &BCDASaver{ Repo: postgres.NewRepository(db), diff --git a/bcda/suppression/suppression_test.go b/bcda/suppression/suppression_test.go index 69ab3b4b0..faa7b9195 100644 --- a/bcda/suppression/suppression_test.go +++ b/bcda/suppression/suppression_test.go @@ -283,7 +283,7 @@ func (s *SuppressionTestSuite) TestLoadOptOutFiles_TimeChange() { assert := assert.New(s.T()) importer, _ := s.createImporter() importer.Saver = &BCDASaver{ - Repo: postgres.NewRepository(database.GetConnection()), + Repo: postgres.NewRepository(database.Connect()), } folderPath := filepath.Join(s.basePath, "suppressionfile_BadFileNames/") @@ -447,7 +447,7 @@ func (s *SuppressionTestSuite) TestCleanupSuppression_RenameFileError() { func (s *SuppressionTestSuite) TestImportSuppressionDirectoryTable() { assert := assert.New(s.T()) importer, _ := s.createImporter() - db := database.GetConnection() + db := database.Connect() importer.Saver = &BCDASaver{ Repo: postgres.NewRepository(db), diff --git a/bcda/web/middleware/ratelimit_test.go b/bcda/web/middleware/ratelimit_test.go index 1d7e2a538..b6c4d3290 100644 --- a/bcda/web/middleware/ratelimit_test.go +++ b/bcda/web/middleware/ratelimit_test.go @@ -31,7 +31,7 @@ func TestRateLimitMiddlewareTestSuite(t *testing.T) { suite.Run(t, new(RateLimitMiddlewareTestSuite)) } func (s *RateLimitMiddlewareTestSuite) SetupSuite() { - s.db = database.GetConnection() + s.db = database.Connect() } func (s *RateLimitMiddlewareTestSuite) TestNoConcurrentJobs() { cfg := &service.Config{RateLimitConfig: service.RateLimitConfig{All: true}} diff --git a/bcda/web/router.go b/bcda/web/router.go index 404ce0b71..4411e43e9 100644 --- a/bcda/web/router.go +++ b/bcda/web/router.go @@ -29,7 +29,7 @@ var commonAuth = []func(http.Handler) http.Handler{ auth.RequireTokenAuth, auth.CheckBlacklist} -func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) http.Handler { +func NewAPIRouter(db *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provider) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() am := auth.NewAuthMiddleware(provider) @@ -43,7 +43,7 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provid panic(fmt.Errorf("could not load service config file: %w", err)) } - rlm := middleware.NewRateLimitMiddleware(cfg, connection) + rlm := middleware.NewRateLimitMiddleware(cfg, db) var requestValidators = []func(http.Handler) http.Handler{ middleware.ACOEnabled(cfg), middleware.ValidateRequestURL, middleware.ValidateRequestHeaders, rlm.CheckConcurrentJobs, } @@ -55,39 +55,39 @@ func NewAPIRouter(connection *sql.DB, pool *pgxv5Pool.Pool, provider auth.Provid r.Get("/", userGuideRedirect) r.Get(`/{:(user_guide|encryption|decryption_walkthrough).html}`, userGuideRedirect) } - apiV1 := v1.NewApiV1(connection, pool, provider) + apiV1 := v1.NewApiV1(db, pool, provider) r.Route("/api/v1", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV1.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV1.BulkGroupRequest)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Get(m.WrapHandler(constants.JOBIDPath, apiV1.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV1.JobsStatus)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV1.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV1.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV1.Metadata)) }) if utils.GetEnvBool("VERSION_2_ENDPOINT_ACTIVE", true) { FileServer(r, "/api/v2/swagger", http.Dir("./swaggerui/v2")) - apiV2 := v2.NewApiV2(connection, pool) + apiV2 := v2.NewApiV2(db, pool) r.Route("/api/v2", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV2.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV2.BulkGroupRequest)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Get(m.WrapHandler(constants.JOBIDPath, apiV2.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV2.JobsStatus)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV2.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV2.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV2.Metadata)) }) } if utils.GetEnvBool("VERSION_3_ENDPOINT_ACTIVE", true) { - apiV3 := v3.NewApiV3(connection, pool) + apiV3 := v3.NewApiV3(db, pool) r.Route("/api/demo", func(r chi.Router) { r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Patient/$export", apiV3.BulkPatientRequest)) r.With(append(commonAuth, requestValidators...)...).Get(m.WrapHandler("/Group/{groupId}/$export", apiV3.BulkGroupRequest)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Get(m.WrapHandler(constants.JOBIDPath, apiV3.JobStatus)) r.With(append(commonAuth, nonExportRequestValidators...)...).Get(m.WrapHandler("/jobs", apiV3.JobsStatus)) - r.With(append(commonAuth, am.RequireTokenJobMatch(connection))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) + r.With(append(commonAuth, am.RequireTokenJobMatch(db))...).Delete(m.WrapHandler(constants.JOBIDPath, apiV3.DeleteJob)) r.With(commonAuth...).Get(m.WrapHandler("/attribution_status", apiV3.AttributionStatus)) r.Get(m.WrapHandler("/metadata", apiV3.Metadata)) }) @@ -103,17 +103,17 @@ func NewAuthRouter(provider auth.Provider) http.Handler { return auth.NewAuthRouter(provider, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) } -func NewDataRouter(connection *sql.DB, provider auth.Provider) http.Handler { +func NewDataRouter(db *sql.DB, provider auth.Provider) http.Handler { r := chi.NewRouter() m := monitoring.GetMonitor() am := auth.NewAuthMiddleware(provider) resourceTypeLogger := &logging.ResourceTypeLogger{ - Repository: postgres.NewRepository(connection), + Repository: postgres.NewRepository(db), } r.Use(am.ParseToken, gcmw.RequestID, appMiddleware.NewTransactionID, logging.NewStructuredLogger(), middleware.SecurityHeader, middleware.ConnectionClose, logging.NewCtxLogger) r.With(append( commonAuth, - am.RequireTokenJobMatch(connection), + am.RequireTokenJobMatch(db), resourceTypeLogger.LogJobResourceType, )...).Get(m.WrapHandler("/data/{jobID}/{fileName}", v1.ServeData)) return r diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index 27f9da7b2..e55a3cf0f 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -34,17 +34,17 @@ type RouterTestSuite struct { apiRouter http.Handler dataRouter http.Handler provider auth.Provider - connection *sql.DB + db *sql.DB pool *pgxv5Pool.Pool } func (s *RouterTestSuite) SetupTest() { conf.SetEnv(s.T(), "DEBUG", "true") conf.SetEnv(s.T(), "BB_SERVER_LOCATION", "v1-server-location") - s.connection = database.GetConnection() - s.provider = auth.NewProvider(s.connection) - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) - s.dataRouter = NewDataRouter(s.connection, s.provider) + s.db = database.Connect() + s.provider = auth.NewProvider(s.db) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) + s.dataRouter = NewDataRouter(s.db, s.provider) } func (s *RouterTestSuite) getAPIRoute(route string) *http.Response { @@ -84,7 +84,7 @@ func (s *RouterTestSuite) TestDefaultProdRoute() { s.FailNow("err in setting env var", err) } // Need a new router because the one in the test setup does not use the environment variable set in this test. - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) res := s.getAPIRoute("/v1/") assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -201,7 +201,7 @@ func (s *RouterTestSuite) TestV2EndpointsDisabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -218,7 +218,7 @@ func (s *RouterTestSuite) TestV2EndpointsEnabled() { v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) res := s.getAPIRoute(constants.V2Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -239,7 +239,7 @@ func (s *RouterTestSuite) TestV3EndpointsDisabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "false") - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusNotFound, res.StatusCode) @@ -256,7 +256,7 @@ func (s *RouterTestSuite) TestV3EndpointsEnabled() { v3Active := conf.GetEnv("VERSION_3_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", v3Active) conf.SetEnv(s.T(), "VERSION_3_ENDPOINT_ACTIVE", "true") - s.apiRouter = NewAPIRouter(s.connection, s.pool, s.provider) + s.apiRouter = NewAPIRouter(s.db, s.pool, s.provider) res := s.getAPIRoute(constants.V3Path + constants.PatientExportPath) assert.Equal(s.T(), http.StatusUnauthorized, res.StatusCode) @@ -356,7 +356,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite, p auth.Provide handler http.Handler paths []string }) { - apiRouter := NewAPIRouter(s.connection, s.pool, p) + apiRouter := NewAPIRouter(s.db, s.pool, p) configs = []struct { handler http.Handler @@ -365,7 +365,7 @@ func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite, p auth.Provide {apiRouter, []string{"/api/v1/Patient/$export", "/api/v1/Group/all/$export", constants.V2Path + constants.PatientExportPath, constants.V2Path + constants.GroupExportPath, constants.V1Path + constants.JobsFilePath}}, - {NewDataRouter(s.connection, p), []string{nDJsonDataRoute}}, + {NewDataRouter(s.db, p), []string{nDJsonDataRoute}}, {NewAuthRouter(p), []string{"/auth/welcome"}}, } @@ -403,7 +403,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { mock := &auth.MockProvider{} setExpectedMockCalls(s, mock, token, aco, bearerString, cmsID) - db := s.connection + db := s.db postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) @@ -448,7 +448,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturnNOT403WhenACONOTBlacklisted() mock := &auth.MockProvider{} setExpectedMockCalls(s, mock, token, aco, bearerString, cmsID) - db := s.connection + db := s.db postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) diff --git a/bcdaworker/cleanup/cleanup_test.go b/bcdaworker/cleanup/cleanup_test.go index 8999f01ea..b4173c5c8 100644 --- a/bcdaworker/cleanup/cleanup_test.go +++ b/bcdaworker/cleanup/cleanup_test.go @@ -39,7 +39,7 @@ func (s *CleanupTestSuite) SetupSuite() { s.pendingDeletionDir = dir testUtils.SetPendingDeletionDir(&s.Suite, dir) - s.db = database.GetConnection() + s.db = database.Connect() cmsID := testUtils.RandomHexID()[0:4] s.testACO = models.ACO{Name: uuid.New(), UUID: uuid.NewRandom(), ClientID: uuid.New(), CMSID: &cmsID} diff --git a/bcdaworker/main.go b/bcdaworker/main.go index fc357cd17..9f37b26c0 100644 --- a/bcdaworker/main.go +++ b/bcdaworker/main.go @@ -97,7 +97,7 @@ func waitForSig() { func main() { fmt.Println("Starting bcdaworker...") - db := database.GetConnection() + db := database.Connect() healthChecker := health.NewHealthChecker(db) queue := queueing.StartRiver(db, utils.GetEnvInt("WORKER_POOL_SIZE", 4)) defer queue.StopRiver() diff --git a/bcdaworker/queueing/enqueue.go b/bcdaworker/queueing/enqueue.go index c70607299..36d8ecf72 100644 --- a/bcdaworker/queueing/enqueue.go +++ b/bcdaworker/queueing/enqueue.go @@ -25,10 +25,10 @@ type Enqueuer interface { // Creates a river client for the Job queue. If the client does not call .Start(), then it is insert only // We still need the workers and the types of workers to insert them -func NewEnqueuer(connection *sql.DB, pool *pgxv5Pool.Pool) Enqueuer { +func NewEnqueuer(db *sql.DB, pool *pgxv5Pool.Pool) Enqueuer { workers := river.NewWorkers() - river.AddWorker(workers, &JobWorker{connection: connection}) - prepareWorker, err := NewPrepareJobWorker(connection) + river.AddWorker(workers, &JobWorker{db: db}) + prepareWorker, err := NewPrepareJobWorker(db) if err != nil { panic(err) } diff --git a/bcdaworker/queueing/enqueue_test.go b/bcdaworker/queueing/enqueue_test.go index cf0146d40..171d2e811 100644 --- a/bcdaworker/queueing/enqueue_test.go +++ b/bcdaworker/queueing/enqueue_test.go @@ -33,8 +33,8 @@ func TestRiverEnqueuer_Integration(t *testing.T) { conf.SetEnv(t, "QUEUE_LIBRARY", "river") // Need access to the queue database to ensure we've enqueued the job successfully - db := database.GetConnection() - pool := database.GetPool() + db := database.Connect() + pool := database.ConnectPool() enqueuer := NewEnqueuer(db, pool) jobID, e := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) diff --git a/bcdaworker/queueing/river.go b/bcdaworker/queueing/river.go index a88bf6fc4..0774d6645 100644 --- a/bcdaworker/queueing/river.go +++ b/bcdaworker/queueing/river.go @@ -49,14 +49,14 @@ type Notifier interface { // TODO: better dependency injection (db, worker, logger). Waiting for pgxv5 upgrade func StartRiver(db *sql.DB, numWorkers int) *queue { - pool := database.GetPool() + pool := database.ConnectPool() workers := river.NewWorkers() prepareWorker, err := NewPrepareJobWorker(db) if err != nil { panic(err) } - river.AddWorker(workers, &JobWorker{connection: db}) + river.AddWorker(workers, &JobWorker{db: db}) river.AddWorker(workers, NewCleanupJobWorker(db)) river.AddWorker(workers, prepareWorker) diff --git a/bcdaworker/queueing/river_test.go b/bcdaworker/queueing/river_test.go index 7e365d7fe..cd0832836 100644 --- a/bcdaworker/queueing/river_test.go +++ b/bcdaworker/queueing/river_test.go @@ -67,8 +67,8 @@ func TestWork_Integration(t *testing.T) { conf.SetEnv(t, "FHIR_PAYLOAD_DIR", tempDir1) conf.SetEnv(t, "FHIR_STAGING_DIR", tempDir2) - db := database.GetConnection() - pool := database.GetPool() + db := database.Connect() + pool := database.ConnectPool() cmsID := testUtils.RandomHexID()[0:4] aco := models.ACO{UUID: uuid.NewRandom(), CMSID: &cmsID} @@ -156,7 +156,7 @@ func TestCleanupJobWorker_Work(t *testing.T) { cleanupJobWorker := &CleanupJobWorker{ cleanupJob: mockCleanupJob.CleanupJob, archiveExpiring: mockArchiveExpiring.ArchiveExpiring, - db: database.GetConnection(), + db: database.Connect(), } // Create a mock river.Job @@ -213,7 +213,7 @@ func TestGetAWSParams(t *testing.T) { } func TestNewCleanupJobWorker(t *testing.T) { - worker := NewCleanupJobWorker(database.GetConnection()) + worker := NewCleanupJobWorker(database.Connect()) assert.NotNil(t, worker) assert.NotNil(t, worker.cleanupJob) diff --git a/bcdaworker/queueing/worker_prepare.go b/bcdaworker/queueing/worker_prepare.go index 9ee96540f..1b79eeac3 100644 --- a/bcdaworker/queueing/worker_prepare.go +++ b/bcdaworker/queueing/worker_prepare.go @@ -37,7 +37,7 @@ type PrepareJobWorker struct { r models.Repository } -func NewPrepareJobWorker(connection *sql.DB) (*PrepareJobWorker, error) { +func NewPrepareJobWorker(db *sql.DB) (*PrepareJobWorker, error) { logger := log.Worker client.SetLogger(logger) @@ -50,7 +50,7 @@ func NewPrepareJobWorker(connection *sql.DB) (*PrepareJobWorker, error) { logger.Fatalf("no ACO configs found, these are required for downstream processing") } - repository := postgres.NewRepository(connection) + repository := postgres.NewRepository(db) svc := service.NewService(repository, cfg, "") v1, err := client.NewBlueButtonClient(client.NewConfig(constants.BFDV1Path)) diff --git a/bcdaworker/queueing/worker_prepare_test.go b/bcdaworker/queueing/worker_prepare_test.go index d99d9cd9b..3a200a783 100644 --- a/bcdaworker/queueing/worker_prepare_test.go +++ b/bcdaworker/queueing/worker_prepare_test.go @@ -48,7 +48,7 @@ func TestCleanupTestSuite(t *testing.T) { func (s *PrepareWorkerIntegrationTestSuite) SetupTest() { s.db, _ = databasetest.CreateDatabase(s.T(), "../../db/migrations/bcda/", true) - s.pool = database.GetPool() + s.pool = database.ConnectPool() tf, err := testfixtures.New( testfixtures.Database(s.db), testfixtures.Dialect("postgres"), @@ -326,7 +326,7 @@ func (s *PrepareWorkerIntegrationTestSuite) TestQueueExportJobs() { ms.On("GetJobPriority", mock.Anything, mock.Anything, mock.Anything).Return(int16(1)) worker := &PrepareJobWorker{svc: ms, v1Client: &client.MockBlueButtonClient{}, v2Client: &client.MockBlueButtonClient{}, r: s.r} - q := NewEnqueuer(s.db, database.GetPool()) + q := NewEnqueuer(s.db, database.ConnectPool()) a := &worker_types.JobEnqueueArgs{ ID: 33, } diff --git a/bcdaworker/queueing/worker_process_job.go b/bcdaworker/queueing/worker_process_job.go index b3f6d0871..8f933e382 100644 --- a/bcdaworker/queueing/worker_process_job.go +++ b/bcdaworker/queueing/worker_process_job.go @@ -15,7 +15,7 @@ import ( type JobWorker struct { river.WorkerDefaults[worker_types.JobEnqueueArgs] - connection *sql.DB + db *sql.DB } func (w *JobWorker) Work(ctx context.Context, rjob *river.Job[worker_types.JobEnqueueArgs]) error { @@ -32,7 +32,7 @@ func (w *JobWorker) Work(ctx context.Context, rjob *river.Job[worker_types.JobEn ctx, logger := log.SetCtxLogger(ctx, "transaction_id", rjob.Args.TransactionID) // TODO: use pgxv5 when available - mainDB := w.connection + mainDB := w.db workerInstance := worker.NewWorker(mainDB) repo := postgres.NewRepository(mainDB) diff --git a/bcdaworker/repository/postgres/repository_test.go b/bcdaworker/repository/postgres/repository_test.go index 55bb19020..b0a141cf3 100644 --- a/bcdaworker/repository/postgres/repository_test.go +++ b/bcdaworker/repository/postgres/repository_test.go @@ -30,7 +30,7 @@ func TestRepositoryTestSuite(t *testing.T) { } func (r *RepositoryTestSuite) SetupSuite() { - r.db = database.GetConnection() + r.db = database.Connect() r.repository = postgres.NewRepository(r.db) } diff --git a/bcdaworker/worker/worker_test.go b/bcdaworker/worker/worker_test.go index ec8b687e7..d884b8ce2 100644 --- a/bcdaworker/worker/worker_test.go +++ b/bcdaworker/worker/worker_test.go @@ -63,7 +63,7 @@ type WorkerTestSuite struct { } func (s *WorkerTestSuite) SetupSuite() { - s.db = database.GetConnection() + s.db = database.Connect() s.r = postgres.NewRepository(s.db) s.w = NewWorker(s.db) diff --git a/db/migrations/migrations_test.go b/db/migrations/migrations_test.go index a41852c50..f0aded9f4 100644 --- a/db/migrations/migrations_test.go +++ b/db/migrations/migrations_test.go @@ -47,7 +47,7 @@ func (s *MigrationTestSuite) SetupSuite() { // postgres://:@: / re := regexp.MustCompile(`(postgresql\:\/\/\S+\:\S+\@\S+\:\d+\/)(.*)(\?.*)`) - db := database.GetConnection() + db := database.Connect() databaseURL := conf.GetEnv("DATABASE_URL") bcdaDB := fmt.Sprintf("migrate_test_bcda_%d", time.Now().Nanosecond()) From 12372794003a3d5d07a963b0ef3e569b264d1abf Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Tue, 5 Aug 2025 14:49:29 -0400 Subject: [PATCH 22/28] Remove comments --- bcda/auth/middleware_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bcda/auth/middleware_test.go b/bcda/auth/middleware_test.go index 1d1177c7f..20928a798 100644 --- a/bcda/auth/middleware_test.go +++ b/bcda/auth/middleware_test.go @@ -67,14 +67,9 @@ func (s *MiddlewareTestSuite) CreateServer(p auth.Provider) *httptest.Server { } func (s *MiddlewareTestSuite) SetupTest() { - // s.server = httptest.NewServer(s.CreateRouter(auth.NewProvider(s.connection))) s.rr = httptest.NewRecorder() } -// func (s *MiddlewareTestSuite) TearDownTest() { -// s.server.Close() -// } - // integration test: makes HTTP request & asserts HTTP response func (s *MiddlewareTestSuite) TestReturn400WhenInvalidTokenAuthWithInvalidSignature() { server := s.CreateServer(auth.NewProvider(s.db)) From beb66ca8d69fc4fc0bdfc1b488522474e2e94e9e Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Tue, 5 Aug 2025 18:51:51 -0400 Subject: [PATCH 23/28] Pass repo as argument in cli function --- bcda/bcdacli/cli.go | 4 ++-- bcda/bcdacli/cli_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index c8ea79e5a..cc2e87eac 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -272,7 +272,7 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - msg, err := resetClientCredentials(acoCMSID) + msg, err := resetClientCredentials(r, acoCMSID) if err != nil { return err } @@ -589,7 +589,7 @@ func generateClientCredentials(acoCMSID string, ips []string) (string, error) { return creds, nil } -func resetClientCredentials(acoCMSID string) (string, error) { +func resetClientCredentials(repo models.Repository, acoCMSID string) (string, error) { aco, err := r.GetACOByCMSID(context.Background(), acoCMSID) if err != nil { return "", err diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index 33d451726..15d423841 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -203,12 +203,12 @@ func (s *CLITestSuite) TestResetSecretCLI() { defer s.SetProvider(oldProvider) // execute positive scenario - msg, err := resetClientCredentials(*s.testACO.CMSID) + msg, err := resetClientCredentials(r, *s.testACO.CMSID) assert.Nil(err) assert.Regexp(outputPattern, msg) // Execute with invalid ACO CMS ID - msg, err = resetClientCredentials("BLAH") + msg, err = resetClientCredentials(r, "BLAH") assert.Equal("no ACO record found for BLAH", err.Error()) assert.Equal(0, len(msg)) From a54157f66062dff3930a42957c4a3b9ad3147d81 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Tue, 5 Aug 2025 19:04:01 -0400 Subject: [PATCH 24/28] Pass globals as arguments in cli functions --- bcda/bcdacli/cli.go | 42 ++++++++++++++++++++-------------------- bcda/bcdacli/cli_test.go | 22 ++++++--------------- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/bcda/bcdacli/cli.go b/bcda/bcdacli/cli.go index cc2e87eac..5ac916a99 100644 --- a/bcda/bcdacli/cli.go +++ b/bcda/bcdacli/cli.go @@ -48,10 +48,10 @@ const Name = "bcda" const Usage = "Beneficiary Claims Data API CLI" var ( - db *sql.DB - pool *pgxv5Pool.Pool - r models.Repository - provider auth.Provider + db *sql.DB + pool *pgxv5Pool.Pool + repository models.Repository + provider auth.Provider ) func GetApp() *cli.App { @@ -66,7 +66,7 @@ func setUpApp() *cli.App { app.Before = func(c *cli.Context) error { db = database.Connect() pool = database.ConnectPool() - r = postgres.NewRepository(db) + repository = postgres.NewRepository(db) provider = auth.NewProvider(db) log.API.Info(fmt.Sprintf(`Auth is made possible by %T`, provider)) return nil @@ -174,7 +174,7 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - ssasID, err := createGroup(groupID, groupName, acoID) + ssasID, err := createGroup(repository, groupID, groupName, acoID) if err != nil { return err } @@ -199,7 +199,7 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - acoUUID, err := createACO(acoName, acoCMSID) + acoUUID, err := createACO(repository, acoName, acoCMSID) if err != nil { return err } @@ -220,7 +220,7 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - err := revokeAccessToken(accessToken) + err := revokeAccessToken(provider, accessToken) if err != nil { return err } @@ -252,7 +252,7 @@ func setUpApp() *cli.App { if len(ips) > 0 { ipAddr = strings.Split(ips, ",") } - msg, err := generateClientCredentials(acoCMSID, ipAddr) + msg, err := generateClientCredentials(provider, acoCMSID, ipAddr) if err != nil { return err } @@ -272,7 +272,7 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - msg, err := resetClientCredentials(r, acoCMSID) + msg, err := resetClientCredentials(repository, provider, acoCMSID) if err != nil { return err } @@ -478,7 +478,7 @@ func setUpApp() *cli.App { CutoffDate: time.Now(), DenylistType: models.Involuntary, } - return setDenylistState(acoCMSID, td) + return setDenylistState(repository, acoCMSID, td) }, }, { @@ -493,14 +493,14 @@ func setUpApp() *cli.App { }, }, Action: func(c *cli.Context) error { - return setDenylistState(acoCMSID, nil) + return setDenylistState(repository, acoCMSID, nil) }, }, } return app } -func createGroup(id, name, acoID string) (string, error) { +func createGroup(r models.Repository, id, name, acoID string) (string, error) { if id == "" || name == "" || acoID == "" { return "", errors.New("ID (--id), name (--name), and ACO ID (--aco-id) are required") } @@ -554,7 +554,7 @@ func createGroup(id, name, acoID string) (string, error) { return ssasID, nil } -func createACO(name, cmsID string) (string, error) { +func createACO(r models.Repository, name, cmsID string) (string, error) { if name == "" { return "", errors.New("ACO name (--name) must be provided") } @@ -579,9 +579,9 @@ func createACO(name, cmsID string) (string, error) { return aco.UUID.String(), nil } -func generateClientCredentials(acoCMSID string, ips []string) (string, error) { +func generateClientCredentials(p auth.Provider, acoCMSID string, ips []string) (string, error) { // The public key is optional for SSAS, and not used by the ACO API - creds, err := provider.FindAndCreateACOCredentials(acoCMSID, ips) + creds, err := p.FindAndCreateACOCredentials(acoCMSID, ips) if err != nil { return "", errors.Wrapf(err, "could not register system for %s", acoCMSID) } @@ -589,29 +589,29 @@ func generateClientCredentials(acoCMSID string, ips []string) (string, error) { return creds, nil } -func resetClientCredentials(repo models.Repository, acoCMSID string) (string, error) { +func resetClientCredentials(r models.Repository, p auth.Provider, acoCMSID string) (string, error) { aco, err := r.GetACOByCMSID(context.Background(), acoCMSID) if err != nil { return "", err } // Generate new credentials - creds, err := provider.ResetSecret(aco.ClientID) + creds, err := p.ResetSecret(aco.ClientID) if err != nil { return "", err } return fmt.Sprintf("%s\n%s\n%s", creds.ClientName, creds.ClientID, creds.ClientSecret), nil } -func revokeAccessToken(accessToken string) error { +func revokeAccessToken(p auth.Provider, accessToken string) error { if accessToken == "" { return errors.New("Access token (--access-token) must be provided") } - return provider.RevokeAccessToken(accessToken) + return p.RevokeAccessToken(accessToken) } -func setDenylistState(cmsID string, td *models.Termination) error { +func setDenylistState(r models.Repository, cmsID string, td *models.Termination) error { aco, err := r.GetACOByCMSID(context.Background(), cmsID) if err != nil { return err diff --git a/bcda/bcdacli/cli_test.go b/bcda/bcdacli/cli_test.go index 15d423841..34ffb02d2 100644 --- a/bcda/bcdacli/cli_test.go +++ b/bcda/bcdacli/cli_test.go @@ -71,7 +71,7 @@ func (s *CLITestSuite) SetupSuite() { s.db = database.Connect() db = s.db - r = postgres.NewRepository(s.db) + repository = postgres.NewRepository(s.db) cmsID := testUtils.RandomHexID()[0:4] s.testACO = models.ACO{Name: uuid.New(), UUID: uuid.NewRandom(), ClientID: uuid.New(), CMSID: &cmsID} @@ -151,14 +151,10 @@ func (s *CLITestSuite) TestGenerateClientCredentials() { m := &auth.MockProvider{} m.On("FindAndCreateACOCredentials", *s.testACO.CMSID, ips).Return("mock\ncreds\ntest", nil) - oldProvider := provider - s.SetProvider(m) - defer s.SetProvider(oldProvider) - buf := new(bytes.Buffer) s.testApp.Writer = buf - msg, err := generateClientCredentials(*s.testACO.CMSID, ips) + msg, err := generateClientCredentials(m, *s.testACO.CMSID, ips) assert.Nil(t, err) assert.Regexp(t, regexp.MustCompile(".+\n.+\n.+"), msg) assert.Equal(t, "mock\ncreds\ntest", msg) @@ -198,17 +194,14 @@ func (s *CLITestSuite) TestResetSecretCLI() { auth.Credentials{ClientName: *s.testACO.CMSID, ClientID: s.testACO.ClientID, ClientSecret: uuid.New()}, nil) - oldProvider := provider - s.SetProvider(mock) - defer s.SetProvider(oldProvider) // execute positive scenario - msg, err := resetClientCredentials(r, *s.testACO.CMSID) + msg, err := resetClientCredentials(repository, mock, *s.testACO.CMSID) assert.Nil(err) assert.Regexp(outputPattern, msg) // Execute with invalid ACO CMS ID - msg, err = resetClientCredentials(r, "BLAH") + msg, err = resetClientCredentials(repository, mock, "BLAH") assert.Equal("no ACO record found for BLAH", err.Error()) assert.Equal(0, len(msg)) @@ -224,15 +217,12 @@ func (s *CLITestSuite) TestRevokeToken() { accessToken := uuid.New() mock := &auth.MockProvider{} mock.On("RevokeAccessToken", accessToken).Return(nil) - oldProvider := provider - s.SetProvider(mock) - defer s.SetProvider(oldProvider) - err := revokeAccessToken(accessToken) + err := revokeAccessToken(mock, accessToken) assert.Nil(err) // Negative case - attempt to revoke a token passing in a blank token string - err = revokeAccessToken("") + err = revokeAccessToken(mock, "") assert.Equal("Access token (--access-token) must be provided", err.Error()) mock.AssertExpectations(s.T()) } From 49b753ba7b6d1294738ccc9b271dc68274511b6c Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 6 Aug 2025 12:48:33 -0400 Subject: [PATCH 25/28] Run test scenarios as separate tests --- bcda/web/middleware/ratelimit_test.go | 40 +++++++++++++----------- bcdaworker/worker/worker_test.go | 44 +++++++++++++++------------ 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/bcda/web/middleware/ratelimit_test.go b/bcda/web/middleware/ratelimit_test.go index b6c4d3290..ae8eb4832 100644 --- a/bcda/web/middleware/ratelimit_test.go +++ b/bcda/web/middleware/ratelimit_test.go @@ -91,19 +91,21 @@ func (s *RateLimitMiddlewareTestSuite) TestHasConcurrentJobs() { } for _, tt := range tests { - mockRepo := &models.MockRepository{} - mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( - append(ignoredJobs, tt.additionalJobs...), - nil, - ) - middleware.repository = mockRepo - - rr := httptest.NewRecorder() - middleware.CheckConcurrentJobs(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // Conncurrent job test route check, blank return for overrides - })).ServeHTTP(rr, getRequest(tt.rp)) - - assert.NotEmpty(s.T(), rr.Header().Get("Retry-After")) + s.T().Run(tt.name, func(t *testing.T) { + mockRepo := &models.MockRepository{} + mockRepo.On("GetJobs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + append(ignoredJobs, tt.additionalJobs...), + nil, + ) + middleware.repository = mockRepo + + rr := httptest.NewRecorder() + middleware.CheckConcurrentJobs(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Conncurrent job test route check, blank return for overrides + })).ServeHTTP(rr, getRequest(tt.rp)) + + assert.NotEmpty(s.T(), rr.Header().Get("Retry-After")) + }) } } @@ -157,8 +159,10 @@ func (s *RateLimitMiddlewareTestSuite) TestHasDuplicatesFullString() { } for _, tt := range tests { - responseBool := middleware.hasDuplicates(ctx, otherJobs, tt.rp.ResourceTypes, tt.rp.Version, tt.rp.RequestURL) - assert.Equal(s.T(), tt.expectedValue, responseBool) + s.T().Run(tt.name, func(t *testing.T) { + responseBool := middleware.hasDuplicates(ctx, otherJobs, tt.rp.ResourceTypes, tt.rp.Version, tt.rp.RequestURL) + assert.Equal(s.T(), tt.expectedValue, responseBool) + }) } } @@ -179,7 +183,9 @@ func (s *RateLimitMiddlewareTestSuite) TestShouldRateLimit() { } for _, tt := range tests { - actualValue := shouldRateLimit(tt.config, tt.cmsID) - assert.Equal(s.T(), tt.expectedValue, actualValue, tt.name) + s.T().Run(tt.name, func(t *testing.T) { + actualValue := shouldRateLimit(tt.config, tt.cmsID) + assert.Equal(s.T(), tt.expectedValue, actualValue, tt.name) + }) } } diff --git a/bcdaworker/worker/worker_test.go b/bcdaworker/worker/worker_test.go index d884b8ce2..1f41505a1 100644 --- a/bcdaworker/worker/worker_test.go +++ b/bcdaworker/worker/worker_test.go @@ -193,15 +193,17 @@ func (s *WorkerTestSuite) TestGetBlueButtonID_NonHappyPaths() { } for _, tt := range tests { - mockCall := bbc.On("GetPatientByMbi", cclfBeneficiary.MBI).Return(tt.patientJSON, nil) - bbID, err := getBlueButtonID(bbc, beneficiaryID, jobArgs) - if tt.err != nil { - assert.Error(s.T(), err) - assert.Equal(s.T(), fmt.Sprint(tt.err), fmt.Sprint(err)) - } else { - assert.Equal(s.T(), tt.expectedID, bbID) - } - mockCall.Unset() + s.T().Run(tt.name, func(t *testing.T) { + mockCall := bbc.On("GetPatientByMbi", cclfBeneficiary.MBI).Return(tt.patientJSON, nil) + bbID, err := getBlueButtonID(bbc, beneficiaryID, jobArgs) + if tt.err != nil { + assert.Error(s.T(), err) + assert.Equal(s.T(), fmt.Sprint(tt.err), fmt.Sprint(err)) + } else { + assert.Equal(s.T(), tt.expectedID, bbID) + } + mockCall.Unset() + }) } } @@ -222,18 +224,20 @@ func (s *WorkerTestSuite) TestWriteResourcesToFile() { } for _, tt := range tests { - ctx, jobArgs, bbc := SetupWriteResourceToFile(s, tt.resource) - jobKeys, err := writeBBDataToFile(ctx, s.r, bbc, *s.testACO.CMSID, testUtils.CryptoRandInt63(), jobArgs, s.tempDir) - if tt.err == nil { + s.T().Run(tt.resource, func(t *testing.T) { + ctx, jobArgs, bbc := SetupWriteResourceToFile(s, tt.resource) + jobKeys, err := writeBBDataToFile(ctx, s.r, bbc, *s.testACO.CMSID, testUtils.CryptoRandInt63(), jobArgs, s.tempDir) + if tt.err == nil { + assert.NoError(s.T(), err) + } else { + assert.Error(s.T(), err) + } + files, err := os.ReadDir(s.tempDir) assert.NoError(s.T(), err) - } else { - assert.Error(s.T(), err) - } - files, err := os.ReadDir(s.tempDir) - assert.NoError(s.T(), err) - assert.Len(s.T(), jobKeys, tt.jobKeysCount) - assert.Len(s.T(), files, tt.fileCount) - VerifyFileContent(s.T(), files, tt.resource, tt.expectedCount, s.tempDir) + assert.Len(s.T(), jobKeys, tt.jobKeysCount) + assert.Len(s.T(), files, tt.fileCount) + VerifyFileContent(s.T(), files, tt.resource, tt.expectedCount, s.tempDir) + }) } } From 16dd65f9695659d029c0e4045c81778288654719 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 6 Aug 2025 12:55:16 -0400 Subject: [PATCH 26/28] Rename blacklist to denylist in a few places --- bcda/web/router_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/bcda/web/router_test.go b/bcda/web/router_test.go index e55a3cf0f..21ae3123d 100644 --- a/bcda/web/router_test.go +++ b/bcda/web/router_test.go @@ -324,8 +324,8 @@ func (s *RouterTestSuite) TestHTTPServerRedirect() { assert.Equal(s.T(), http.StatusMethodNotAllowed, res.StatusCode, "http to https redirect rejects POST requests") } -func createACO(cmsID string, blackListValue *models.Termination) models.ACO { - return models.ACO{Name: "TestRegisterSystem", CMSID: &cmsID, UUID: uuid.NewUUID(), ClientID: uuid.New(), TerminationDetails: blackListValue} +func createACO(cmsID string, denyListValue *models.Termination) models.ACO { + return models.ACO{Name: "TestRegisterSystem", CMSID: &cmsID, UUID: uuid.NewUUID(), ClientID: uuid.New(), TerminationDetails: denyListValue} } func createTestToken(cmsID string) (token *jwt.Token) { token = &jwt.Token{ @@ -352,7 +352,7 @@ func createExpectedAuthData(cmsID string, aco models.ACO) auth.AuthData { } } -func createConfigsForACOBlacklistingScenarios(s *RouterTestSuite, p auth.Provider) (configs []struct { +func createConfigsForACODenylistingScenarios(s *RouterTestSuite, p auth.Provider) (configs []struct { handler http.Handler paths []string }) { @@ -378,8 +378,8 @@ func setExpectedMockCalls(s *RouterTestSuite, mockP *auth.MockProvider, token *j } // integration test, requires connection to postgres db -// TestBlacklistedACOs ensures that we return 403 FORBIDDEN when a call is made from a blacklisted ACO. -func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { +// TestDenylistedACOs ensures that we return 403 FORBIDDEN when a call is made from a denylisted ACO. +func (s *RouterTestSuite) TestDenylistedACOReturn403WhenACODenylisted() { // Use a new router to ensure that v2 endpoints are active v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) @@ -388,14 +388,14 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { // Set up cmsID := testUtils.RandomHexID()[0:4] - blackListValue := &models.Termination{ + denyListValue := &models.Termination{ TerminationDate: time.Date(2020, time.December, 31, 23, 59, 59, 0, time.Local), CutoffDate: time.Date(2020, time.December, 31, 23, 59, 59, 0, time.Local), DenylistType: models.Involuntary, } - aco := createACO(cmsID, blackListValue) + aco := createACO(cmsID, denyListValue) bearerString := uuid.New() token := createTestToken(cmsID) @@ -407,12 +407,12 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) - configs := createConfigsForACOBlacklistingScenarios(s, mock) + configs := createConfigsForACODenylistingScenarios(s, mock) for _, config := range configs { for _, path := range config.paths { - s.T().Run(fmt.Sprintf("blacklist-value-%v-%s", blackListValue, path), func(t *testing.T) { + s.T().Run(fmt.Sprintf("denylist-value-%v-%s", denyListValue, path), func(t *testing.T) { fmt.Println(aco.Denylisted()) fmt.Println(aco.UUID.String()) postgrestest.UpdateACO(t, db, aco) @@ -431,7 +431,7 @@ func (s *RouterTestSuite) TestBlacklistedACOReturn403WhenACOBlacklisted() { mock.AssertExpectations(s.T()) } -func (s *RouterTestSuite) TestBlacklistedACOReturnNOT403WhenACONOTBlacklisted() { +func (s *RouterTestSuite) TestDenylistedACOReturnNOT403WhenACONOTDenylisted() { // Use a new router to ensure that v2 endpoints are active v2Active := conf.GetEnv("VERSION_2_ENDPOINT_ACTIVE") defer conf.SetEnv(s.T(), "VERSION_2_ENDPOINT_ACTIVE", v2Active) @@ -452,12 +452,12 @@ func (s *RouterTestSuite) TestBlacklistedACOReturnNOT403WhenACONOTBlacklisted() postgrestest.CreateACO(s.T(), db, aco) defer postgrestest.DeleteACO(s.T(), db, aco.UUID) - configs := createConfigsForACOBlacklistingScenarios(s, mock) + configs := createConfigsForACODenylistingScenarios(s, mock) for _, config := range configs { for _, path := range config.paths { - s.T().Run(fmt.Sprintf("blacklist-value-%v-%s", nil, path), func(t *testing.T) { + s.T().Run(fmt.Sprintf("denylist-value-%v-%s", nil, path), func(t *testing.T) { fmt.Println(aco.Denylisted()) fmt.Println(aco.UUID.String()) postgrestest.UpdateACO(t, db, aco) From cfb33efa94d67e6a0a70e9ff4f208fa6ae4b6ec6 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 6 Aug 2025 12:59:07 -0400 Subject: [PATCH 27/28] Add comment for refactoring todo --- bcdaworker/queueing/worker_cleanup.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bcdaworker/queueing/worker_cleanup.go b/bcdaworker/queueing/worker_cleanup.go index 26bc42b5a..20e3612de 100644 --- a/bcdaworker/queueing/worker_cleanup.go +++ b/bcdaworker/queueing/worker_cleanup.go @@ -18,6 +18,7 @@ import ( "github.com/slack-go/slack" ) +// TODO: Consider moving functions like cleanupJob and archiveExpiring to receiver methods of CleanupJobWorker type CleanupJobWorker struct { river.WorkerDefaults[worker_types.CleanupJobArgs] cleanupJob func(*sql.DB, time.Time, models.JobStatus, models.JobStatus, ...string) error From bd29e55c4e0bf929efc60bf87a55c7cd94cfb1c2 Mon Sep 17 00:00:00 2001 From: Michael Valdes Date: Wed, 6 Aug 2025 16:50:51 -0400 Subject: [PATCH 28/28] Add unit test for pool connection --- bcda/database/connection.go | 12 ++++---- bcda/database/connection_test.go | 47 ++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/bcda/database/connection.go b/bcda/database/connection.go index 298f2702a..8c4b90550 100644 --- a/bcda/database/connection.go +++ b/bcda/database/connection.go @@ -31,7 +31,7 @@ func Connect() *sql.DB { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - startConnectionHealthCheck( + startDBHealthCheck( ctx, conn, time.Duration(cfg.HealthCheckSec)*time.Second, @@ -126,17 +126,17 @@ func CreatePgxv5DB(cfg *Config) (*pgxv5Pool.Pool, error) { // // startHealthCheck returns immediately with the health check running in a goroutine that // can be stopped via the supplied context -func startConnectionHealthCheck(ctx context.Context, db *sql.DB, interval time.Duration) { +func startDBHealthCheck(ctx context.Context, db *sql.DB, interval time.Duration) { go func() { ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): ticker.Stop() - logrus.Debug("Stopping health checker") + logrus.Debug("Stopping DB health checker") return case <-ticker.C: - logrus.StandardLogger().Debug("Sending ping") + logrus.StandardLogger().Debug("Sending ping via DB connection") // Handle acquiring connection, pinging, and releasing App DB connection if err := db.Ping(); err != nil { @@ -154,10 +154,10 @@ func startPoolHealthCheck(ctx context.Context, pgxv5Pool *pgxv5Pool.Pool, interv select { case <-ctx.Done(): ticker.Stop() - logrus.Debug("Stopping health checker") + logrus.Debug("Stopping pool health checker") return case <-ticker.C: - logrus.StandardLogger().Debug("Sending ping") + logrus.StandardLogger().Debug("Sending ping via pool connection") pgxv5Conn, err := pgxv5Pool.Acquire(ctx) if err != nil { diff --git a/bcda/database/connection_test.go b/bcda/database/connection_test.go index 796ed30ef..09c55d8fd 100644 --- a/bcda/database/connection_test.go +++ b/bcda/database/connection_test.go @@ -14,14 +14,17 @@ import ( func TestConnections(t *testing.T) { // Verify that we can initialize the package as expected c := Connect() + p := ConnectPool() assert.NotNil(t, c) assert.NoError(t, c.Ping()) + assert.NotNil(t, p) + assert.NoError(t, p.Ping(context.Background())) } // TestHealthCheck verifies that we are able to start the health check // and the checks do not cause a panic by waiting some amount of time // to ensure that health checks are being executed. -func TestConnectionHealthCheck(t *testing.T) { +func TestDBHealthCheck(t *testing.T) { level, reporter := logrus.GetLevel(), logrus.StandardLogger().ReportCaller t.Cleanup(func() { logrus.SetLevel(level) @@ -34,8 +37,8 @@ func TestConnectionHealthCheck(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) c := Connect() - startConnectionHealthCheck(ctx, c, 100*time.Microsecond) - // Let some time elapse to ensure we've successfully ran health checks + startDBHealthCheck(ctx, c, 100*time.Microsecond) + // Let some time elapse to ensure we've successfully run health checks time.Sleep(50 * time.Millisecond) cancel() time.Sleep(100 * time.Millisecond) @@ -43,9 +46,43 @@ func TestConnectionHealthCheck(t *testing.T) { var hasPing, hasClose bool for _, entry := range hook.AllEntries() { if strings.Contains(entry.Caller.File, "database/connection.go") { - if strings.Contains(entry.Message, "Sending ping") { + if strings.Contains(entry.Message, "Sending ping via DB connection") { hasPing = true - } else if strings.Contains(entry.Message, "Stopping health checker") { + } else if strings.Contains(entry.Message, "Stopping DB health checker") { + hasClose = true + } + } + } + + assert.True(t, hasPing, "Should've received a ping message in the logs.") + assert.True(t, hasClose, "Should've received a close message in the logs.") +} + +func TestPoolHealthCheck(t *testing.T) { + level, reporter := logrus.GetLevel(), logrus.StandardLogger().ReportCaller + t.Cleanup(func() { + logrus.SetLevel(level) + logrus.SetReportCaller(reporter) + }) + + logrus.SetLevel(logrus.DebugLevel) + logrus.SetReportCaller(true) + hook := test.NewGlobal() + + ctx, cancel := context.WithCancel(context.Background()) + p := ConnectPool() + startPoolHealthCheck(ctx, p, 100*time.Microsecond) + // Let some time elapse to ensure we've successfully run health checks + time.Sleep(50 * time.Millisecond) + cancel() + time.Sleep(100 * time.Millisecond) + + var hasPing, hasClose bool + for _, entry := range hook.AllEntries() { + if strings.Contains(entry.Caller.File, "database/connection.go") { + if strings.Contains(entry.Message, "Sending ping via pool connection") { + hasPing = true + } else if strings.Contains(entry.Message, "Stopping pool health checker") { hasClose = true } }