Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions people-and-planet-ai/geospatial-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ This model uses satellite data to predict if a coal plant is turned on and produ
* **Training the model**: [TensorFlow] in [Vertex AI]
* **Getting predictions**: [TensorFlow] in [Cloud Run]

## Serving app security configuration

Set the `MODEL_DIR` environment variable to a trusted model artifact location before starting `serving_app`. The prediction endpoint no longer accepts a user-controlled bucket for model loading because loading untrusted TensorFlow/Keras artifacts can execute attacker-controlled code during deserialization.

[Cloud Run]: https://cloud.google.com/run
[Sentinel-2]: https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2
[Earth Engine]: https://earthengine.google.com/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import flask

app = flask.Flask(__name__)
Expand All @@ -28,15 +31,21 @@ def run_root() -> str:
}


def get_trusted_model_dir() -> str:
model_dir = os.environ.get("MODEL_DIR")
if not model_dir:
raise RuntimeError("MODEL_DIR must be set to a trusted model location.")
return model_dir


@app.route("/predict", methods=["POST"])
def run_predict() -> dict:
import predict

try:
args = flask.request.get_json() or {}
bucket = args["bucket"]
model_dir = f"gs://{bucket}/model_output"
data = args["data"]
model_dir = get_trusted_model_dir()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The run_predict function currently returns a 200 OK status code even when an error occurs, such as when MODEL_DIR is unset and get_trusted_model_dir() raises a RuntimeError. For better API design and clearer communication to API consumers, it's recommended to return an appropriate HTTP status code (e.g., 400 Bad Request for client-side configuration errors or 500 Internal Server Error for server-side issues) when an exception is caught. This would provide a more standard and expected API response for error conditions.

predictions = predict.run(data, model_dir)

return {
Expand All @@ -45,10 +54,9 @@ def run_predict() -> dict:
"predictions": predictions,
}
except Exception as e:
return {"error": f"{type(e).__name__}: {e}"}
logging.exception(e)
return ({"error": f"{type(e).__name__}: {e}"}, 500)


if __name__ == "__main__":
import os

app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==8.2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Generator
import sys
import types

from flask.testing import FlaskClient
import pytest

import main


@pytest.fixture
def client() -> Generator[FlaskClient, None, None]:
main.app.config.update(TESTING=True)
with main.app.test_client() as test_client:
yield test_client


def test_predict_uses_trusted_model_dir_from_environment(
client: FlaskClient, monkeypatch: pytest.MonkeyPatch
) -> None:
calls: list[tuple[dict, str]] = []

def fake_run(data: dict, model_dir: str) -> dict:
calls.append((data, model_dir))
return {"predictions": [[0.42]]}

monkeypatch.setenv("MODEL_DIR", "gs://trusted-bucket/model_output")
monkeypatch.setitem(sys.modules, "predict", types.SimpleNamespace(run=fake_run))

response = client.post(
"/predict",
json={
"bucket": "attacker-bucket/evil-prefix",
"data": {
"band_1": [[1, 2, 3]],
"band_2": [[4, 5, 6]],
"band_3": [[7, 8, 9]],
},
},
)

assert response.status_code == 200
assert response.get_json() == {
"method": "predict",
"model_dir": "gs://trusted-bucket/model_output",
"predictions": {"predictions": [[0.42]]},
}
assert calls == [
(
{
"band_1": [[1, 2, 3]],
"band_2": [[4, 5, 6]],
"band_3": [[7, 8, 9]],
},
"gs://trusted-bucket/model_output",
)
]


def test_predict_returns_error_when_trusted_model_dir_is_missing(
client: FlaskClient, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("MODEL_DIR", raising=False)
monkeypatch.setitem(
sys.modules,
"predict",
types.SimpleNamespace(run=lambda data, model_dir: {"predictions": [[0.42]]}),
)

response = client.post(
"/predict",
json={
"bucket": "attacker-bucket/evil-prefix",
"data": {
"band_1": [[1, 2, 3]],
"band_2": [[4, 5, 6]],
"band_3": [[7, 8, 9]],
},
},
)

assert response.status_code == 500
assert response.get_json() == {
"error": "RuntimeError: MODEL_DIR must be set to a trusted model location."
}