Skip to content

ModelBase: Base njit class for abstracting model execution #110

@rburghol

Description

@rburghol

Overview

Decision Points

  • 2 types of classes for each operation:
    • Handler: non-@njit to parse/translate data into model runtime class definitions
      • These can leverage all of pythons good stuff in model prep where execution speed is less critical
      • These can also benefit from inheritance
    • Model: run-time objects @njit compatible
      • Can use only numba compatible language constructs
      • Cannot use inheritance, so they will employ common data-structures wherever possible to emulate inheritance and make code easier to understand/extend
    • Is there a reason why certain vars come from UCI in ALLCAPS and some come in as LC?
  • numba limitations for object oriented @jitclass
    • Objects passed to NumbaList cannot have multiple class definitions - one type per location.
    • Classes may not inherit other classes.
      • Class properties (such as value in example) may be defined in a separate structure that separate class definitions can utilize, but methods must be defined separately.
  • Benefits of class model objects
    • Class Properties will persist in state, leading to no need to relocate/reload attributes. For example:
      • rchres_object[n]
  • Performance testing:
    • cost of hiding everything behind set_state(), as opposed to direct array manipulation.
    • cost of using stored property path or index as a class member attribute with direct array manipulation.
    • Using model.attribute format for storing persistent data can save loading from state at the beginning of each timestep, however, this would prevent other objects (such as a special action) from modifying the state.

ModelBase Declaration

from typing import List
import numba
import time
from numpy import zeros
from numba.types import string as numba_str
from numba import njit, types, int32, float32, float64, typeof   # import the types
from numba.experimental import jitclass
from numba.typed import Dict, List as NumbaList

def model_make_spec(prop_names, prop_type):
    new_spec = [(x, prop_type) for x in prop_names]
    return new_spec

# state vars - two options to see which is fastest
state_ix = Dict.empty(key_type=types.int64, value_type=types.float64)
state_paths = Dict.empty(key_type=types.unicode_type, value_type=types.float64)

state_paths_ty = ('state_paths', typeof(state_paths))
state_ix_ty = ('state_ix', typeof(state_ix))

model_num_type = float32
model_str_type = numba_str # Imported from numba.types.string
model_str_props = ['name' , 'path']
model_num_props = ['value']
model_base = [state_paths_ty, state_ix_ty]+ model_make_spec(model_str_props,model_str_type ) + model_make_spec(model_num_props, model_num_type )

@jitclass(model_base)
class ModelBase:
    def __init__(self):
        self.value = 0
        self.state_ix = Dict.empty(key_type=types.int64, value_type=types.float64)
        self.state_paths = Dict.empty(key_type=types.unicode_type, value_type=types.float64)
    
    def step(self, step):
        ret = self.value
        self.value += 1
    
    def pre_step(self):
        # get remote inputs if needed, load timeseries
        self.step_TSGET()
        return
    
    def post_step(self):
        # perform logging actions (timeseries writes, etc)
        return
    
    def step_TSGET(self):
        # TSGET: get timeseries if need be
        return

class HandlerBase:
    def __init__(self, props = None):
        self.model_props = props
        return
    
    def make_model(self):
        # Create an empty model
        model = ModelBase()
        print("Creating ModelBase with props:", self.model_props)
        # Populate model props and return
        self.set_props(model, self.model_props)
        return model
    
    def set_props(self, model, model_props, strict = False ):
        if model_props == None:
            return
        for prop in model_props:
            if hasattr(model, prop):
                propval = self.handle_propval(model, prop, model_props[prop], strict)
                setattr(model,prop,propval)
    
    def handle_propval(self, model, prop, propval, strict = False ):
        # sub classes can check to see if this is in the right format, or change to numba compatible types etc.
        return propval
    
    def make_numba_dict(props):
        ret = Dict.empty(key_type=types.unicode_type, value_type=types.float64)
        for i in props.keys():
            ret[i] = props[i]
        return ret

@njit
def iteration_test(it_ops, it_nums):
    ctr = 0
    for n in range(it_nums):
        for i in range(len(it_ops)):
            it_ops[i].step(n)
        ctr=ctr+1
    print("Completed ", ctr, " loops")


Testing ModelBase

  • Must run the base class declaration above first
m1 = ModelBase()
m2 = ModelBase()

m1.step(1); m1.step(1); m1.step(1); m1.step(1); m1.step(1); 
m2.step(1); m2.step(1)

ref_list = [m1, m2]
obj_ist = numba.typed.List(ref_list )

@njit
def omtest_check(m_list):
    # go through and show the value of the object passed in
    for i in m_list:
        # show value at this point
        print(i.value)
        i.step() # increment

omtest_check(obj_ist)
# now see the values, did the in method updates propagate back?  They should show +1 from output of omtest_check()
m1.value
m2.value

# Set the state info
for i in range(len(obj_ist)):
    obj_ist[i].state_paths = state_paths
    obj_ist[i].state_ix = state_ix 

# test and precompile
iteration_test(obj_ist, 3)

start = time.time()
iteration_test(obj_ist, 300000 )
end = time.time()
print(len(obj_ist), "Dict components  took" , end - start, "seconds")
m1.value;m2.value

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions