Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dev-tools/airavata-python-sdk/airavata_auth/device_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@


class AuthContext:

@staticmethod
def get_access_token():
if os.environ.get("CS_ACCESS_TOKEN", None) is None:
context = AuthContext()
context.login()
return os.environ["CS_ACCESS_TOKEN"]

def __init__(self):
self.settings = Settings()
Expand All @@ -21,6 +28,8 @@ def __init__(self):
self.console = Console()

def login(self):
if os.environ.get('CS_ACCESS_TOKEN', None) is not None:
return
# Step 1: Request device and user code
auth_device_url = f"{self.settings.AUTH_SERVER_URL}/realms/{self.settings.AUTH_REALM}/protocol/openid-connect/auth/device"
response = requests.post(auth_device_url, data={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

from . import base, plan
from airavata_auth.device_auth import AuthContext
from .runtime import list_runtimes, Runtime
from .runtime import find_runtimes, Runtime
from typing import Any
from . import md, neuro


context = AuthContext()

def login():
context.login()

__all__ = ["list_runtimes", "base", "plan", "login"]
__all__ = ["find_runtimes", "base", "plan", "login", "md", "neuro"]


def display_runtimes(runtimes: list[Runtime]) -> None:
Expand Down
163 changes: 111 additions & 52 deletions dev-tools/airavata-python-sdk/airavata_experiments/airavata.py

Large diffs are not rendered by default.

45 changes: 30 additions & 15 deletions dev-tools/airavata-python-sdk/airavata_experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,49 +85,64 @@ def with_resource(self, resource: Runtime) -> Experiment[T]:
self.resource = resource
return self

def create_task(self, *allowed_runtimes: Runtime, name: str | None = None) -> None:
def add_run(self, use: list[Runtime], cpus: int, nodes: int, walltime: int, name: str | None = None, **extra_params) -> None:
"""
Create a task to run the experiment on a given runtime.
"""
runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource
runtime = random.choice(use) if len(use) > 0 else self.resource
uuid_str = str(uuid.uuid4())[:4].upper()

# override runtime args with given values
runtime = runtime.model_copy()
runtime.args["cpu_count"] = cpus
runtime.args["node_count"] = nodes
runtime.args["walltime"] = walltime
# add extra inputs to task inputs
task_inputs = {**self.inputs, **extra_params}
# create a task with the given runtime and inputs
self.tasks.append(
Task(
name=name or f"{self.name}_{uuid_str}",
name=f"{name or self.name}_{uuid_str}",
app_id=self.application.app_id,
inputs={**self.inputs},
inputs=task_inputs,
runtime=runtime,
)
)
print(f"Task created. ({len(self.tasks)} tasks in total)")

def add_sweep(self, *allowed_runtimes: Runtime, **space: list) -> None:
def add_sweep(self, use: list[Runtime], cpus: int, nodes: int, walltime: int, name: str | None = None, **space: list) -> None:
"""
Add a sweep to the experiment.

"""
for values in product(space.values()):
runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource
runtime = random.choice(use) if len(use) > 0 else self.resource
uuid_str = str(uuid.uuid4())[:4].upper()

# override runtime args with given values
runtime = runtime.model_copy()
runtime.args["cpu_count"] = cpus
runtime.args["node_count"] = nodes
runtime.args["walltime"] = walltime
# add sweep params to task inputs
task_specific_params = dict(zip(space.keys(), values))
agg_inputs = {**self.inputs, **task_specific_params}
task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()}

# create a task with the given runtime and inputs
self.tasks.append(Task(
name=f"{self.name}_{uuid_str}",
name=f"{name or self.name}_{uuid_str}",
app_id=self.application.app_id,
inputs=task_inputs,
runtime=runtime or self.resource,
))

def plan(self, **kwargs) -> Plan:
if len(self.tasks) == 0:
self.create_task(self.resource)
def plan(self) -> Plan:
assert len(self.tasks) > 0, "add_run() must be called before plan() to define runtimes and resources."
tasks = []
for t in self.tasks:
agg_inputs = {**self.inputs, **t.inputs}
task_inputs = {k: {"value": agg_inputs[v[0]], "type": v[1]} for k, v in self.input_mapping.items()}
tasks.append(Task(name=t.name, app_id=self.application.app_id, inputs=task_inputs, runtime=t.runtime))
return Plan(tasks=tasks)
task = Task(name=t.name, app_id=self.application.app_id, inputs=task_inputs, runtime=t.runtime)
# task.freeze() # TODO upload the task-related data and freeze the task
tasks.append(task)
plan = Plan(tasks=tasks)
plan.save()
return plan
27 changes: 15 additions & 12 deletions dev-tools/airavata-python-sdk/airavata_experiments/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .runtime import is_terminal_state
from .task import Task
import uuid
from airavata_auth.device_auth import AuthContext

from .airavata import AiravataOperator

Expand Down Expand Up @@ -66,19 +67,21 @@ def __stage_status__(self) -> list:
statuses.append(task.status())
return statuses

def __stage_stop__(self) -> None:
print("Stopping task(s)...")
for task in self.tasks:
task.stop()
print("Task(s) stopped.")
def __stage_stop__(self, runs: list[int] = []) -> None:
runs = runs if len(runs) > 0 else list(range(len(self.tasks)))
print(f"Stopping task(s): {runs}")
for i, task in enumerate(self.tasks):
if i in runs:
task.stop()
print(f"Task(s) stopped: {runs}")

def __stage_fetch__(self, local_dir: str) -> list[list[str]]:
print("Fetching results...")
fps = list[list[str]]()
for task in self.tasks:
fps.append(task.download_all(local_dir))
print("Results fetched.")
self.save_json(os.path.join(local_dir, "plan.json"))
self.export(os.path.join(local_dir, "plan.json"))
return fps

def launch(self, silent: bool = True) -> None:
Expand Down Expand Up @@ -119,17 +122,17 @@ def download(self, local_dir: str):
assert os.path.isdir(local_dir)
self.__stage_fetch__(local_dir)

def stop(self) -> None:
self.__stage_stop__()
def stop(self, runs: list[int] = []) -> None:
self.__stage_stop__(runs)
self.save()

def save_json(self, filename: str) -> None:
def export(self, filename: str) -> None:
with open(filename, "w") as f:
json.dump(self.model_dump(), f, indent=2)

def save(self) -> None:
settings = Settings()
av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
av = AiravataOperator(AuthContext.get_access_token())
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
Expand Down Expand Up @@ -162,7 +165,7 @@ def load_json(filename: str) -> Plan:
def load(id: str | None) -> Plan:
settings = Settings()
assert id is not None
av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
av = AiravataOperator(AuthContext.get_access_token())
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
Expand All @@ -183,7 +186,7 @@ def load(id: str | None) -> Plan:

def query() -> list[Plan]:
settings = Settings()
av = AiravataOperator(os.environ['CS_ACCESS_TOKEN'])
av = AiravataOperator(AuthContext.get_access_token())
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
Expand Down
Loading