Skip to content

Commit a88a48a

Browse files
support AWS resource cleanup when walltime hits
1 parent a18cc6f commit a88a48a

File tree

6 files changed

+99
-5
lines changed

6 files changed

+99
-5
lines changed

airavata-api/src/main/java/org/apache/airavata/helix/impl/participant/GlobalParticipant.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class GlobalParticipant extends HelixParticipant<AbstractTask> {
5252
"org.apache.airavata.helix.impl.task.aws.CreateEC2InstanceTask",
5353
"org.apache.airavata.helix.impl.task.aws.NoOperationTask",
5454
"org.apache.airavata.helix.impl.task.aws.AWSJobSubmissionTask",
55+
"org.apache.airavata.helix.impl.task.aws.AWSCompletingTask",
5556
};
5657

5758
@SuppressWarnings("WeakerAccess")

airavata-api/src/main/java/org/apache/airavata/helix/impl/task/AWSTaskFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
*/
2020
package org.apache.airavata.helix.impl.task;
2121

22+
import org.apache.airavata.helix.impl.task.aws.AWSCompletingTask;
2223
import org.apache.airavata.helix.impl.task.aws.AWSJobSubmissionTask;
2324
import org.apache.airavata.helix.impl.task.aws.CreateEC2InstanceTask;
2425
import org.apache.airavata.helix.impl.task.aws.NoOperationTask;
25-
import org.apache.airavata.helix.impl.task.completing.CompletingTask;
2626
import org.slf4j.Logger;
2727
import org.slf4j.LoggerFactory;
2828

@@ -63,7 +63,7 @@ public AiravataTask createJobVerificationTask(String processId) {
6363

6464
@Override
6565
public AiravataTask createCompletingTask(String processId) {
66-
return new CompletingTask();
66+
return new AWSCompletingTask();
6767
}
6868

6969
@Override
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
* <p>
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
* <p>
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.airavata.helix.impl.task.aws;
20+
21+
import org.apache.airavata.helix.impl.task.AiravataTask;
22+
import org.apache.airavata.helix.impl.task.TaskContext;
23+
import org.apache.airavata.helix.impl.task.aws.utils.AWSTaskUtil;
24+
import org.apache.airavata.helix.task.api.TaskHelper;
25+
import org.apache.airavata.helix.task.api.annotation.TaskDef;
26+
import org.apache.airavata.model.status.ProcessState;
27+
import org.apache.helix.task.TaskResult;
28+
import org.slf4j.Logger;
29+
import org.slf4j.LoggerFactory;
30+
31+
@TaskDef(name = "AWS_COMPLETING_TASK")
32+
public class AWSCompletingTask extends AiravataTask {
33+
34+
private static final Logger logger = LoggerFactory.getLogger(AWSCompletingTask.class);
35+
36+
@Override
37+
public TaskResult onRun(TaskHelper helper, TaskContext taskContext) {
38+
logger.info("Starting completing task for task {}, experiment id {}", getTaskId(), getExperimentId());
39+
logger.info("Process {} successfully completed", getProcessId());
40+
saveAndPublishProcessStatus(ProcessState.COMPLETED);
41+
cleanup();
42+
AWSTaskUtil.terminateEC2Instance(getTaskContext(), getGatewayId());
43+
return onSuccess("Process " + getProcessId() + " successfully completed");
44+
}
45+
46+
@Override
47+
public void onCancel(TaskContext taskContext) {
48+
}
49+
}

airavata-api/src/main/java/org/apache/airavata/helix/impl/task/submission/config/GroovyMapBuilder.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ public GroovyMapData build() throws Exception {
8080
mapData.setWorkingDirectory(taskContext.getWorkingDir());
8181
mapData.setTaskId(taskContext.getTaskId());
8282
mapData.setExperimentDataDir(taskContext.getProcessModel().getExperimentDataDir());
83+
mapData.setExperimentId(taskContext.getExperimentId());
8384

8485
SimpleDateFormat gmtDateFormat = new SimpleDateFormat("yyyy-MM-dd+HH:mmZ");
8586
gmtDateFormat.setTimeZone(TimeZone.getTimeZone("EST"));
@@ -121,6 +122,7 @@ public GroovyMapData build() throws Exception {
121122
((JobSubmissionTaskModel) taskContext.getSubTaskModel());
122123
if (jobSubmissionTaskModel.getWallTime() > 0) {
123124
mapData.setMaxWallTime(maxWallTimeCalculator(jobSubmissionTaskModel.getWallTime()));
125+
mapData.setWallTimeInSeconds(jobSubmissionTaskModel.getWallTime() * 60);
124126
// TODO fix this
125127
/*if (resourceJobManager != null) {
126128
if (resourceJobManager.getResourceJobManagerType().equals(ResourceJobManagerType.LSF)) {
@@ -161,6 +163,7 @@ public GroovyMapData build() throws Exception {
161163
// if so we ignore scheduling configuration.
162164
if (scheduling.getWallTimeLimit() > 0 && mapData.getMaxWallTime() == null) {
163165
mapData.setMaxWallTime(maxWallTimeCalculator(scheduling.getWallTimeLimit()));
166+
mapData.setWallTimeInSeconds(scheduling.getWallTimeLimit() * 60);
164167

165168
// TODO fix this
166169
/*

airavata-api/src/main/java/org/apache/airavata/helix/impl/task/submission/config/GroovyMapData.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ public class GroovyMapData {
107107
@ScriptTag(name = "maxWallTime")
108108
private String maxWallTime;
109109

110+
@ScriptTag(name = "wallTimeInSeconds")
111+
private Integer wallTimeInSeconds;
112+
110113
@ScriptTag(name = "qualityOfService")
111114
private String qualityOfService;
112115

@@ -152,6 +155,9 @@ public class GroovyMapData {
152155
@ScriptTag(name = "experimentDataDir")
153156
private String experimentDataDir;
154157

158+
@ScriptTag(name = "experimentId")
159+
private String experimentId;
160+
155161
@ScriptTag(name = "computeHostName")
156162
private String computeHostName;
157163

@@ -363,6 +369,14 @@ public GroovyMapData setMaxWallTime(String maxWallTime) {
363369
return this;
364370
}
365371

372+
public Integer getWallTimeInSeconds() {
373+
return wallTimeInSeconds;
374+
}
375+
376+
public void setWallTimeInSeconds(Integer wallTimeInSeconds) {
377+
this.wallTimeInSeconds = wallTimeInSeconds;
378+
}
379+
366380
public String getQualityOfService() {
367381
return qualityOfService;
368382
}
@@ -496,6 +510,14 @@ public void setExperimentDataDir(String experimentDataDir) {
496510
this.experimentDataDir = experimentDataDir;
497511
}
498512

513+
public String getExperimentId() {
514+
return experimentId;
515+
}
516+
517+
public void setExperimentId(String experimentId) {
518+
this.experimentId = experimentId;
519+
}
520+
499521
public String getCurrentTime() {
500522
return currentTime;
501523
}
Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#!${shellName}
22

33
# Cloud execution script generated by Apache Airavata
4+
# User: ${gatewayUserName}
5+
# Experiment ID: ${experimentId}
6+
# Walltime (seconds): ${wallTimeInSeconds}
47
<%
58
if (exports != null) for(com in exports) out.print 'export ' + com +'\n'
69
if (moduleCommands != null) for(mc in moduleCommands) out.print mc +'\n'
@@ -9,6 +12,22 @@
912
if (jobSubmitterCommand != null && jobSubmitterCommand != "") out.print jobSubmitterCommand + ' '
1013
if (executablePath != null && executablePath != "") out.print executablePath + ' '
1114
if (inputs != null) for(input in inputs) out.print input + ' '
12-
out.print '\n'
13-
if (postJobCommands != null) for(pjc in postJobCommands) out.print pjc +'\n'
14-
%>
15+
out.print '&\n'
16+
out.print 'MAIN_JOB_PID=\\$!\n'
17+
%>
18+
19+
(
20+
sleep ${wallTimeInSeconds}
21+
22+
if ps -p \$MAIN_JOB_PID > /dev/null; then
23+
echo "Walltime of ${wallTimeInSeconds} seconds exceeded. Terminating job PID \$MAIN_JOB_PID." >&2
24+
pkill -P \$MAIN_JOB_PID
25+
kill -9 \$MAIN_JOB_PID
26+
fi
27+
) &
28+
29+
WATCHDOG_PID=\$!
30+
wait \$MAIN_JOB_PID
31+
kill \$WATCHDOG_PID 2>/dev/null
32+
33+
<% if (postJobCommands != null) for(pjc in postJobCommands) out.print pjc +'\n' %>

0 commit comments

Comments
 (0)