diff --git a/people-and-planet-ai/geospatial-classification/README.md b/people-and-planet-ai/geospatial-classification/README.md index d366f6bbb18..c7f7793727c 100644 --- a/people-and-planet-ai/geospatial-classification/README.md +++ b/people-and-planet-ai/geospatial-classification/README.md @@ -11,6 +11,10 @@ This model uses satellite data to predict if a coal plant is turned on and produ * **Training the model**: [TensorFlow] in [Vertex AI] * **Getting predictions**: [TensorFlow] in [Cloud Run] +## Serving app security configuration + +Set the `MODEL_DIR` environment variable to a trusted model artifact location before starting `serving_app`. The prediction endpoint no longer accepts a user-controlled bucket for model loading because loading untrusted TensorFlow/Keras artifacts can execute attacker-controlled code during deserialization. + [Cloud Run]: https://cloud.google.com/run [Sentinel-2]: https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2 [Earth Engine]: https://earthengine.google.com/ diff --git a/people-and-planet-ai/geospatial-classification/serving_app/main.py b/people-and-planet-ai/geospatial-classification/serving_app/main.py index 61c72509e10..f90275ff7ad 100644 --- a/people-and-planet-ai/geospatial-classification/serving_app/main.py +++ b/people-and-planet-ai/geospatial-classification/serving_app/main.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import os + import flask app = flask.Flask(__name__) @@ -28,15 +31,21 @@ def run_root() -> str: } +def get_trusted_model_dir() -> str: + model_dir = os.environ.get("MODEL_DIR") + if not model_dir: + raise RuntimeError("MODEL_DIR must be set to a trusted model location.") + return model_dir + + @app.route("/predict", methods=["POST"]) def run_predict() -> dict: import predict try: args = flask.request.get_json() or {} - bucket = args["bucket"] - model_dir = f"gs://{bucket}/model_output" data = args["data"] + model_dir = get_trusted_model_dir() predictions = predict.run(data, model_dir) return { @@ -45,10 +54,9 @@ def run_predict() -> dict: "predictions": predictions, } except Exception as e: - return {"error": f"{type(e).__name__}: {e}"} + logging.exception(e) + return ({"error": f"{type(e).__name__}: {e}"}, 500) if __name__ == "__main__": - import os - app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 8080))) diff --git a/people-and-planet-ai/geospatial-classification/serving_app/requirements-test.txt b/people-and-planet-ai/geospatial-classification/serving_app/requirements-test.txt new file mode 100644 index 00000000000..15d066af319 --- /dev/null +++ b/people-and-planet-ai/geospatial-classification/serving_app/requirements-test.txt @@ -0,0 +1 @@ +pytest==8.2.0 diff --git a/people-and-planet-ai/geospatial-classification/serving_app/test_main.py b/people-and-planet-ai/geospatial-classification/serving_app/test_main.py new file mode 100644 index 00000000000..ec26d25a8dc --- /dev/null +++ b/people-and-planet-ai/geospatial-classification/serving_app/test_main.py @@ -0,0 +1,101 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Generator +import sys +import types + +from flask.testing import FlaskClient +import pytest + +import main + + +@pytest.fixture +def client() -> Generator[FlaskClient, None, None]: + main.app.config.update(TESTING=True) + with main.app.test_client() as test_client: + yield test_client + + +def test_predict_uses_trusted_model_dir_from_environment( + client: FlaskClient, monkeypatch: pytest.MonkeyPatch +) -> None: + calls: list[tuple[dict, str]] = [] + + def fake_run(data: dict, model_dir: str) -> dict: + calls.append((data, model_dir)) + return {"predictions": [[0.42]]} + + monkeypatch.setenv("MODEL_DIR", "gs://trusted-bucket/model_output") + monkeypatch.setitem(sys.modules, "predict", types.SimpleNamespace(run=fake_run)) + + response = client.post( + "/predict", + json={ + "bucket": "attacker-bucket/evil-prefix", + "data": { + "band_1": [[1, 2, 3]], + "band_2": [[4, 5, 6]], + "band_3": [[7, 8, 9]], + }, + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "method": "predict", + "model_dir": "gs://trusted-bucket/model_output", + "predictions": {"predictions": [[0.42]]}, + } + assert calls == [ + ( + { + "band_1": [[1, 2, 3]], + "band_2": [[4, 5, 6]], + "band_3": [[7, 8, 9]], + }, + "gs://trusted-bucket/model_output", + ) + ] + + +def test_predict_returns_error_when_trusted_model_dir_is_missing( + client: FlaskClient, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.delenv("MODEL_DIR", raising=False) + monkeypatch.setitem( + sys.modules, + "predict", + types.SimpleNamespace(run=lambda data, model_dir: {"predictions": [[0.42]]}), + ) + + response = client.post( + "/predict", + json={ + "bucket": "attacker-bucket/evil-prefix", + "data": { + "band_1": [[1, 2, 3]], + "band_2": [[4, 5, 6]], + "band_3": [[7, 8, 9]], + }, + }, + ) + + assert response.status_code == 500 + assert response.get_json() == { + "error": "RuntimeError: MODEL_DIR must be set to a trusted model location." + }