Skip to content

Commit dcff49c

Browse files
minor fixes to API
1 parent a6ccdfb commit dcff49c

File tree

8 files changed

+126
-36
lines changed

8 files changed

+126
-36
lines changed

.github/workflows/test.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Run ScAPE Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout code
15+
uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version: '3.10'
21+
22+
- name: Cache _data directory
23+
uses: actions/cache@v4
24+
with:
25+
path: _data
26+
key: scape-data-${{ runner.os }}-${{ hashFiles('**/pyproject.toml') }}
27+
restore-keys: |
28+
scape-data-${{ runner.os }}-
29+
30+
- name: Install package
31+
run: pip install .
32+
33+
- name: Run test code
34+
run: |
35+
python -c "
36+
import scape
37+
scape.io.download_from_zenodo(target_dir = '.')
38+
result = scape.api.train(
39+
de_file='_data/de_train.parquet',
40+
lfc_file='_data/lfc_train.parquet',
41+
cv_drug='Belinostat',
42+
n_genes=64
43+
)
44+
scape.util.plot_result(result._last_train_results)
45+
"

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,4 @@ Thumbs.db
184184

185185
# Project-specific
186186
CLAUDE.md
187+
scape-data.zip

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ pip install git+https://github.com/scapeML/scape.git
3232
```python
3333
import scape
3434

35+
# data from zenodo can be downloaded via
36+
scape.io.download_from_zenodo(target_dir = ".")
37+
3538
# Train model with drug cross-validation
36-
result = scape.train(
37-
de_file="de_train.parquet",
38-
lfc_file="lfc_train.parquet",
39+
result = scape.api.train(
40+
de_file="_data/de_train.parquet",
41+
lfc_file="_data/lfc_train.parquet",
3942
cv_drug="Belinostat",
4043
n_genes=64
4144
)
4245

4346
# Visualize performance vs baselines
44-
scape.plot_result(result)
47+
scape.util.plot_result(result._last_train_results)
4548
```
4649

4750
## 📋 Overview
@@ -75,7 +78,7 @@ Key design choices:
7578

7679
```bash
7780
# Command line
78-
python -m scape train --n-genes 64 --cv-drug Belinostat de_train.parquet lfc_train.parquet
81+
python -m scape train --n-genes 64 --cv-drug Belinostat _data/de_train.parquet _data/lfc_train.parquet
7982

8083
# Python API
8184
import scape

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ dependencies = [
3131
"scikit-learn>=1.2.2",
3232
"fastparquet>=2023.10.1",
3333
"keras>=3.6",
34+
"matplotlib",
35+
"requests"
3436
]
3537

3638
[project.scripts]

scape/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Set JAX backend for Keras before any imports
33
os.environ.setdefault('KERAS_BACKEND', 'jax')
44

5+
import scape._api as api
56
import scape._model as model
67
import scape._losses as losses
78
import scape._io as io

scape/__main__.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,37 +105,18 @@ def main():
105105

106106
# If the command was train, train the model
107107
if parser.parse_args().command == "train":
108-
train(args)
109-
110-
111-
def train(args):
112-
# Read the files
113-
df_de = scape.io.load_slogpvals(args.slogpval)
114-
print(f"DE shape: {df_de.shape}")
115-
df_lfc = scape.io.load_lfc(args.lfc)
116-
print(f"LFC shape: {df_lfc.shape}")
117-
val_cells = [args.cv_cell] if args.cv_cell else None
118-
val_drugs = [args.cv_drug] if args.cv_drug else None
119-
print(f"Training model with {args.n_genes} genes")
120-
print(f"Validation cell(s): {val_cells}")
121-
print(f"Validation drug(s): {val_drugs}")
122-
# Create a default model
123-
model = scape.model.create_default_model(args.n_genes, df_de, df_lfc)
124-
top_genes = top_genes = scape.util.select_top_variable([df_de], k=args.n_genes)
125-
model.train(
126-
val_cells=val_cells,
127-
val_drugs=val_drugs,
128-
output_data="slogpval",
129-
callbacks="default",
130-
input_columns=top_genes,
131-
optimizer=None,
132-
epochs=args.epochs,
133-
batch_size=args.batch_size,
134-
output_folder=args.output_dir,
135-
config_file_name=f"{args.config_name}.pkl",
136-
model_file_name=f"{args.model_name}.keras",
137-
baselines=["zero", "slogpval_drug"]
138-
)
108+
scape.api.train(
109+
de_file=args.slogpval,
110+
lfc_file=args.lfc,
111+
n_genes=args.n_genes,
112+
output_dir=args.output_dir,
113+
cv_cell=args.cv_cell,
114+
cv_drug=args.cv_drug,
115+
epochs=args.epochs,
116+
batch_size=args.batch_size,
117+
config_name=args.config_name,
118+
model_name=args.model_name
119+
)
139120

140121

141122
if __name__ == "__main__":

scape/_api.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import scape
2+
3+
def train(
4+
de_file,
5+
lfc_file,
6+
n_genes=64,
7+
output_dir=None,
8+
cv_cell="NK cells",
9+
cv_drug=None,
10+
epochs=600,
11+
batch_size=128,
12+
config_name="config",
13+
model_name="model"
14+
):
15+
# Read the files
16+
df_de = scape.io.load_slogpvals(de_file)
17+
print(f"DE shape: {df_de.shape}")
18+
df_lfc = scape.io.load_lfc(lfc_file)
19+
print(f"LFC shape: {df_lfc.shape}")
20+
val_cells = [cv_cell] if cv_cell else None
21+
val_drugs = [cv_drug] if cv_drug else None
22+
print(f"Training model with {n_genes} genes")
23+
print(f"Validation cell(s): {val_cells}")
24+
print(f"Validation drug(s): {val_drugs}")
25+
# Create a default model
26+
model = scape.model.create_default_model(n_genes, df_de, df_lfc)
27+
top_genes = scape.util.select_top_variable([df_de], k=n_genes)
28+
model.train(
29+
val_cells=val_cells,
30+
val_drugs=val_drugs,
31+
output_data="slogpval",
32+
callbacks="default",
33+
input_columns=top_genes,
34+
optimizer=None,
35+
epochs=epochs,
36+
batch_size=batch_size,
37+
output_folder=output_dir,
38+
config_file_name=f"{config_name}.pkl",
39+
model_file_name=f"{model_name}.keras",
40+
baselines=["zero", "slogpval_drug"]
41+
)
42+
return model

scape/_io.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44
import zipfile
55
import lzma
66
import tempfile
7+
import requests
8+
9+
def download_from_zenodo(target_dir, cache_dir = '_data'):
10+
if os.path.exists(target_dir):
11+
print(f"Target directory {target_dir} already exists. Skipping download.")
12+
return
13+
url = 'https://zenodo.org/records/10617221/files/scape-data.zip?download=1'
14+
response = requests.get(url)
15+
if response.status_code == 200:
16+
with open(os.path.join(target_dir, "scape-data.zip"), "wb") as f:
17+
f.write(response.content)
18+
with zipfile.ZipFile(os.path.join(target_dir, "scape-data.zip"), 'r') as zip_ref:
19+
zip_ref.extractall(target_dir)
20+
else:
21+
raise ValueError(f"Failed to download file from Zenodo: {response.status_code}")
722

823

924
def compress(file_path, zip_file_path=None, delete=False):

0 commit comments

Comments
 (0)