Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 91 additions & 4 deletions balancer/catabalancer/catalyst_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package catabalancer

import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math/rand"
"sort"
Expand All @@ -23,6 +25,7 @@ type CataBalancer struct {
NodeMetrics map[string]NodeMetrics // Node name -> NodeMetrics
metricTimeout time.Duration
ingestStreamTimeout time.Duration
NodeStatsDB *sql.DB
}

type Streams map[string]Stream // Stream ID -> Stream
Expand Down Expand Up @@ -79,7 +82,35 @@ func (s ScoredNode) String() string {
)
}

func NewBalancer(nodeName string, metricTimeout time.Duration, ingestStreamTimeout time.Duration) *CataBalancer {
// JSON representation is deliberately truncated to keep the message size small
type NodeUpdateEvent struct {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just moved this code from the events package to make the dependencies easier

Resource string `json:"resource,omitempty"`
NodeID string `json:"n,omitempty"`
NodeMetrics NodeMetrics `json:"nm,omitempty"`
Streams string `json:"s,omitempty"`
}

func (n *NodeUpdateEvent) SetStreams(streamIDs []string, ingestStreamIDs []string) {
n.Streams = strings.Join(streamIDs, "|") + "~" + strings.Join(ingestStreamIDs, "|")
}

func (n *NodeUpdateEvent) GetStreams() []string {
before, _, _ := strings.Cut(n.Streams, "~")
if len(before) > 0 {
return strings.Split(before, "|")
}
return []string{}
}

func (n *NodeUpdateEvent) GetIngestStreams() []string {
_, after, _ := strings.Cut(n.Streams, "~")
if len(after) > 0 {
return strings.Split(after, "|")
}
return []string{}
}

func NewBalancer(nodeName string, metricTimeout time.Duration, ingestStreamTimeout time.Duration, nodeStatsDB *sql.DB) *CataBalancer {
return &CataBalancer{
NodeName: nodeName,
Nodes: make(map[string]*Node),
Expand All @@ -88,6 +119,7 @@ func NewBalancer(nodeName string, metricTimeout time.Duration, ingestStreamTimeo
NodeMetrics: make(map[string]NodeMetrics),
metricTimeout: metricTimeout,
ingestStreamTimeout: ingestStreamTimeout,
NodeStatsDB: nodeStatsDB,
}
}

Expand Down Expand Up @@ -134,6 +166,10 @@ func (c *CataBalancer) UpdateMembers(ctx context.Context, members []cluster.Memb
}

func (c *CataBalancer) GetBestNode(ctx context.Context, redirectPrefixes []string, playbackID, lat, lon, fallbackPrefix string, isStudioReq bool) (string, string, error) {
if err := c.RefreshNodes(); err != nil {
return "", "", fmt.Errorf("error refreshing nodes: %w", err)
}

var err error
latf := 0.0
if lat != "" {
Expand Down Expand Up @@ -291,7 +327,58 @@ func truncateReturned(scoredNodes []ScoredNode, numNodes int) []ScoredNode {
return scoredNodes[:numNodes]
}

func (c *CataBalancer) RefreshNodes() error {
log.LogNoRequestID("catabalancer refreshing nodes")
if c.NodeStatsDB == nil {
return fmt.Errorf("node stats DB was nil")
}

query := "SELECT stats FROM node_stats"
rows, err := c.NodeStatsDB.Query(query)
if err != nil {
return fmt.Errorf("failed to query node stats: %w", err)
}
defer rows.Close()

// Process the result set
for rows.Next() {
var statsBytes []byte
if err := rows.Scan(&statsBytes); err != nil {
return fmt.Errorf("failed to scan node stats row: %w", err)
}

var event NodeUpdateEvent
err = json.Unmarshal(statsBytes, &event)
if err != nil {
return fmt.Errorf("failed to unmarshal node update event: %w", err)
}

if isStale(event.NodeMetrics.Timestamp, c.metricTimeout) {
log.LogNoRequestID("catabalancer skipping stale data while refreshing", "nodeID", event.NodeID, "timestamp", event.NodeMetrics.Timestamp)
continue
}

c.UpdateNodes(event.NodeID, event.NodeMetrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to keep any internal state for nodes? If we make a query to DB with every playback redirect request, then couldn't we reason everything from there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep sorry I totally forgot to mention refactoring that as a plan for another PR, in two minds about whether I should do that in this PR but I think since it's been approved I think I'll just do it now but in a fresh PR.

for _, stream := range event.GetStreams() {
c.UpdateStreams(event.NodeID, stream, false)
}
for _, stream := range event.GetIngestStreams() {
c.UpdateStreams(event.NodeID, stream, true)
}
}

// Check for errors after iterating through rows
if err := rows.Err(); err != nil {
return err
}
return nil
}

func (c *CataBalancer) MistUtilLoadSource(ctx context.Context, streamID, lat, lon string) (string, error) {
if err := c.RefreshNodes(); err != nil {
return "", fmt.Errorf("error refreshing nodes: %w", err)
}

c.nodesLock.Lock()
defer c.nodesLock.Unlock()
for nodeName := range c.Nodes {
Expand All @@ -318,13 +405,13 @@ func (c *CataBalancer) checkAndCreateNode(nodeName string) {
}
}

func (c *CataBalancer) UpdateNodes(id string, nodeMetrics NodeMetrics) {
func (c *CataBalancer) UpdateNodes(nodeName string, nodeMetrics NodeMetrics) {
c.nodesLock.Lock()
defer c.nodesLock.Unlock()

c.checkAndCreateNode(id)
c.checkAndCreateNode(nodeName)
nodeMetrics.Timestamp = time.Now()
c.NodeMetrics[id] = nodeMetrics
c.NodeMetrics[nodeName] = nodeMetrics
}

var UpdateNodeStatsEvery = 5 * time.Second
Expand Down
31 changes: 22 additions & 9 deletions balancer/catabalancer/catalyst_balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package catabalancer

import (
"context"
"database/sql"
"fmt"
"github.com/DATA-DOG/go-sqlmock"
"github.com/livepeer/catalyst-api/cluster"
"golang.org/x/sync/errgroup"
"math/rand"
Expand Down Expand Up @@ -40,16 +42,27 @@ var BandwidthOverloadedNode = ScoredNode{
},
}

func mockDB(t *testing.T) *sql.DB {
db, mock, err := sqlmock.New()
require.NoError(t, err)
for i := 0; i < 10; i++ {
mock.ExpectQuery("SELECT stats FROM node_stats").
WillReturnRows(sqlmock.NewRows([]string{"stats"}).AddRow("{}"))

}
return db
}

func TestItReturnsItselfWhenNoOtherNodesPresent(t *testing.T) {
c := NewBalancer("me", time.Second, time.Second)
c := NewBalancer("me", time.Second, time.Second, mockDB(t))
nodeName, prefix, err := c.GetBestNode(context.Background(), nil, "playbackID", "", "", "", false)
require.NoError(t, err)
require.Equal(t, "me", nodeName)
require.Equal(t, "video+playbackID", prefix)
}

func TestStaleNodes(t *testing.T) {
c := NewBalancer("me", time.Second, time.Second)
c := NewBalancer("me", time.Second, time.Second, mockDB(t))
err := c.UpdateMembers(context.Background(), []cluster.Member{{Name: "node1"}})
require.NoError(t, err)

Expand Down Expand Up @@ -281,7 +294,7 @@ func scores(node1 ScoredNode, node2 ScoredNode) ScoredNode {

func TestSetMetrics(t *testing.T) {
// simple check that node metrics make it through to the load balancing algo
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))
err := c.UpdateMembers(context.Background(), []cluster.Member{{Name: "node1"}, {Name: "node2"}})
require.NoError(t, err)

Expand All @@ -296,7 +309,7 @@ func TestSetMetrics(t *testing.T) {

func TestUnknownNode(t *testing.T) {
// check that the node metrics call creates the unknown node
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))

c.UpdateNodes("node1", NodeMetrics{CPUUsagePercentage: 90})
c.UpdateNodes("bgw-node1", NodeMetrics{CPUUsagePercentage: 10})
Expand All @@ -307,7 +320,7 @@ func TestUnknownNode(t *testing.T) {
}

func TestNoIngestStream(t *testing.T) {
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))
// first test no nodes available
c.UpdateNodes("id", NodeMetrics{})
c.UpdateStreams("id", "stream", false)
Expand All @@ -329,7 +342,7 @@ func TestNoIngestStream(t *testing.T) {
}

func TestMistUtilLoadSource(t *testing.T) {
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))
err := c.UpdateMembers(context.Background(), []cluster.Member{{
Name: "node",
Tags: map[string]string{
Expand All @@ -356,7 +369,7 @@ func TestMistUtilLoadSource(t *testing.T) {
}

func TestStreamTimeout(t *testing.T) {
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))
err := c.UpdateMembers(context.Background(), []cluster.Member{{
Name: "node",
Tags: map[string]string{
Expand Down Expand Up @@ -402,7 +415,7 @@ func TestStreamTimeout(t *testing.T) {

// needs to be run with go test -race
func TestConcurrentUpdates(t *testing.T) {
c := NewBalancer("", time.Second, time.Second)
c := NewBalancer("", time.Second, time.Second, mockDB(t))

err := c.UpdateMembers(context.Background(), []cluster.Member{{Name: "node"}})
require.NoError(t, err)
Expand Down Expand Up @@ -436,7 +449,7 @@ func TestSimulate(t *testing.T) {

updateEvery := 5 * time.Second

c := NewBalancer("node0", time.Second, time.Second)
c := NewBalancer("node0", time.Second, time.Second, mockDB(t))
var nodes []cluster.Member
for i := 0; i < nodeCount; i++ {
nodes = append(nodes, cluster.Member{Name: fmt.Sprintf("node%d", i)})
Expand Down
1 change: 1 addition & 0 deletions config/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Cli struct {
ExternalTranscoder string
VodPipelineStrategy string
MetricsDBConnectionString string
NodeStatsConnectionString string
ImportIPFSGatewayURLs []*url.URL
ImportArweaveGatewayURLs []*url.URL
NodeName string
Expand Down
Loading
Loading