Skip to content

Commit d29582a

Browse files
committed
Support custom kwargs for timeseries
1 parent 5857fe2 commit d29582a

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

src/autogluon/cloud/backend/timeseries_sagemaker_backend.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ class TimeSeriesSagemakerBackend(SagemakerBackend):
1515
def _preprocess_data(
1616
self,
1717
data: Union[pd.DataFrame, str],
18-
id_column: str,
19-
timestamp_column: str,
20-
target: str,
18+
id_column: Optional[str] = None,
19+
timestamp_column: Optional[str] = None,
20+
target: Optional[str] = None,
2121
static_features: Optional[Union[pd.DataFrame, str]] = None,
2222
) -> pd.DataFrame:
2323
if isinstance(data, str):
@@ -27,12 +27,15 @@ def _preprocess_data(
2727
cols = data.columns.to_list()
2828
# Make sure id and timestamp columns are the first two columns, and target column is in the end
2929
# This is to ensure in the container we know how to find id and timestamp columns, and whether there are static features being merged
30-
timestamp_index = cols.index(timestamp_column)
31-
cols.insert(0, cols.pop(timestamp_index))
32-
id_index = cols.index(id_column)
33-
cols.insert(0, cols.pop(id_index))
34-
target_index = cols.index(target)
35-
cols.append(cols.pop(target_index))
30+
if timestamp_column is not None:
31+
timestamp_index = cols.index(timestamp_column)
32+
cols.insert(0, cols.pop(timestamp_index))
33+
if id_column is not None:
34+
id_index = cols.index(id_column)
35+
cols.insert(0, cols.pop(id_index))
36+
if target is not None:
37+
target_index = cols.index(target)
38+
cols.append(cols.pop(target_index))
3639
data = data[cols]
3740

3841
if static_features is not None:
@@ -48,8 +51,8 @@ def fit(
4851
*,
4952
predictor_init_args: Dict[str, Any],
5053
predictor_fit_args: Dict[str, Any],
51-
id_column: str,
52-
timestamp_column: str,
54+
id_column: Optional[str] = None,
55+
timestamp_column: Optional[str] = None,
5356
static_features: Optional[Union[str, pd.DataFrame]] = None,
5457
framework_version: str = "latest",
5558
job_name: Optional[str] = None,
@@ -199,9 +202,9 @@ def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
199202
def predict(
200203
self,
201204
test_data: Union[str, pd.DataFrame],
202-
id_column: str,
203-
timestamp_column: str,
204-
target: str,
205+
id_column: Optional[str] = None,
206+
timestamp_column: Optional[str] = None,
207+
target: Optional[str] = None,
205208
static_features: Optional[Union[str, pd.DataFrame]] = None,
206209
**kwargs,
207210
) -> Optional[pd.DataFrame]:

src/autogluon/cloud/predictor/timeseries_cloud_predictor.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def fit(
5050
*,
5151
predictor_init_args: Dict[str, Any],
5252
predictor_fit_args: Dict[str, Any],
53-
id_column: str = "item_id",
54-
timestamp_column: str = "timestamp",
53+
id_column: Optional[str] = None,
54+
timestamp_column: Optional[str] = None,
5555
static_features: Optional[Union[str, pd.DataFrame]] = None,
5656
framework_version: str = "latest",
5757
job_name: Optional[str] = None,
@@ -120,7 +120,7 @@ def fit(
120120
if backend_kwargs is None:
121121
backend_kwargs = {}
122122

123-
self.target_column = predictor_init_args.get("target", "target")
123+
self.target_column = predictor_init_args.get("target")
124124
self.id_column = id_column
125125
self.timestamp_column = timestamp_column
126126

@@ -146,6 +146,9 @@ def fit(
146146
def predict_real_time(
147147
self,
148148
test_data: Union[str, pd.DataFrame],
149+
id_column: Optional[str] = None,
150+
timestamp_column: Optional[str] = None,
151+
target: Optional[str] = None,
149152
static_features: Optional[Union[str, pd.DataFrame]] = None,
150153
accept: str = "application/x-parquet",
151154
**kwargs,
@@ -175,13 +178,18 @@ def predict_real_time(
175178
Pandas.DataFrame
176179
Predict results in DataFrame
177180
"""
181+
self.id_column = id_column or self.id_column
182+
self.timestamp_column = timestamp_column or self.timestamp_column
183+
self.target_column = target or self.target_column
184+
178185
return self.backend.predict_real_time(
179186
test_data=test_data,
180187
id_column=self.id_column,
181188
timestamp_column=self.timestamp_column,
182189
target=self.target_column,
183190
static_features=static_features,
184191
accept=accept,
192+
inference_kwargs=kwargs
185193
)
186194

187195
def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
@@ -190,6 +198,9 @@ def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
190198
def predict(
191199
self,
192200
test_data: Union[str, pd.DataFrame],
201+
id_column: Optional[str] = None,
202+
timestamp_column: Optional[str] = None,
203+
target: Optional[str] = None,
193204
static_features: Optional[Union[str, pd.DataFrame]] = None,
194205
predictor_path: Optional[str] = None,
195206
framework_version: str = "latest",
@@ -199,6 +210,7 @@ def predict(
199210
custom_image_uri: Optional[str] = None,
200211
wait: bool = True,
201212
backend_kwargs: Optional[Dict] = None,
213+
**kwargs,
202214
) -> Optional[pd.DataFrame]:
203215
"""
204216
Predict using SageMaker batch transform.
@@ -263,6 +275,10 @@ def predict(
263275
Please refer to
264276
https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
265277
"""
278+
self.id_column = id_column or self.id_column
279+
self.timestamp_column = timestamp_column or self.timestamp_column
280+
self.target_column = target or self.target_column
281+
266282
if backend_kwargs is None:
267283
backend_kwargs = {}
268284
backend_kwargs = self.backend.parse_backend_predict_kwargs(backend_kwargs)
@@ -279,6 +295,7 @@ def predict(
279295
instance_count=instance_count,
280296
custom_image_uri=custom_image_uri,
281297
wait=wait,
298+
inference_kwargs=kwargs,
282299
**backend_kwargs,
283300
)
284301

src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_serve.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from io import BytesIO, StringIO
66

77
import pandas as pd
8+
import logging
9+
import sys
810

911
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
10-
12+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
13+
logger = logging.getLogger(__name__)
1114

1215
def model_fn(model_dir):
1316
"""loads model from previously saved artifact"""
@@ -31,12 +34,15 @@ def model_fn(model_dir):
3134
def prepare_timeseries_dataframe(df, predictor):
3235
target = predictor.target
3336
cols = df.columns.to_list()
37+
logger.info(f"COLUMN {cols}")
3438
id_column = cols[0]
3539
timestamp_column = cols[1]
3640
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
3741
static_features = None
3842
if target != cols[-1]:
3943
# target is not the last column, then there are static features being merged in
44+
logger.info(f"Inside condition: {cols}, {target}")
45+
logger.info(f"Inside condition: {cols}, {target}")
4046
target_index = cols.index(target)
4147
static_columns = cols[target_index + 1 :]
4248
static_features = df[[id_column] + static_columns].groupby([id_column], sort=False).head(1)
@@ -56,6 +62,7 @@ def transform_fn(model, request_body, input_content_type, output_content_type="a
5662

5763
elif input_content_type == "text/csv":
5864
buf = StringIO(request_body)
65+
logger.info(f"request body data path: {buf}")
5966
data = pd.read_csv(buf)
6067

6168
elif input_content_type == "application/json":
@@ -77,6 +84,8 @@ def transform_fn(model, request_body, input_content_type, output_content_type="a
7784
else:
7885
raise ValueError(f"{input_content_type} input content type not supported.")
7986

87+
logger.info(f"Model is: {model}")
88+
logger.info(f"Columns are: {data.columns}")
8089
data = prepare_timeseries_dataframe(data, model)
8190
prediction = model.predict(data, **inference_kwargs)
8291
prediction = pd.DataFrame(prediction)

src/autogluon/cloud/scripts/sagemaker_scripts/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def get_env_if_present(name):
3131

3232

3333
def prepare_timeseries_dataframe(df, predictor_init_args):
34-
target = predictor_init_args["target"]
34+
target = predictor_init_args.get("target")
3535
cols = df.columns.to_list()
3636
id_column = cols[0]
3737
timestamp_column = cols[1]
3838
df[timestamp_column] = pd.to_datetime(df[timestamp_column])
3939
static_features = None
40-
if target != cols[-1]:
40+
if target is not None and target != cols[-1]:
4141
# target is not the last column, then there are static features being merged in
4242
target_index = cols.index(target)
4343
static_columns = cols[target_index + 1 :]
@@ -46,7 +46,6 @@ def prepare_timeseries_dataframe(df, predictor_init_args):
4646
df.drop(columns=static_columns, inplace=True)
4747
df = TimeSeriesDataFrame.from_data_frame(df, id_column=id_column, timestamp_column=timestamp_column)
4848
if static_features is not None:
49-
print(static_features)
5049
df.static_features = static_features
5150
return df
5251

0 commit comments

Comments
 (0)