Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions precovery/brute_force.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
import os

from sqlalchemy import all_

from precovery.precovery_db import PrecoveryDatabase

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import logging
import multiprocessing as mp
import time
from functools import partial
from typing import Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
from astropy.time import Time
from sklearn.neighbors import BallTree

from .healpix_geom import radec_to_healpixel, radec_to_healpixel_with_neighbors

# replace this usage with Orbit.compute_ephemeris
from .orbit import Orbit
from .residuals import calcResiduals
from .utils import _checkParallel, _initWorker, calcChunkSize, yieldChunks

logger = logging.getLogger(__name__)

__all__ = [
"get_observations_and_ephemerides_for_orbits",
"attribution_worker",
"attributeObservations",
]


def get_observations_and_ephemerides_for_orbits(
orbits: Iterable[Orbit],
mjd_start: float,
mjd_end: float,
precovery_db: PrecoveryDatabase,
obscode: str = "W84",
):
"""
Returns:
- ephemerides for the orbit list, propagated to all epochs for frames in
the range (mjd_start, mjd_end)
- observations in the indexed PrecoveryDatabase consistent with intersecting
and neighboring frames for orbits propagated to each of the represented epochs

***Currently breaks on non-NSC datasets
TODO propagation targets should be unique frames wrt obscode, mjd
"""

# Find all the mjd we need to propagate to
all_frame_mjd = precovery_db.frames.idx.unique_frame_times()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've probably already thought about this but you can solve your comment on line 54 by changing this function to return tuples (or similar) of obscode and mjds for that obscode. Then you can remove the obscode kwarg altogether.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - in my local branch now, will push a fix.

This does break down a little with per-obs MJDs though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, excellent point. We might be able to get around that later on by loading in the frames and then adding a check for if frame.mjd != obs.mjd then propagate to the correct epoch. Not sure how we would generalize it so we can have nelson-force and brute-force use the same logic and code.

frame_mjd_within_range = [
x for x in all_frame_mjd if (x > mjd_start and x < mjd_end)
]

ephemeris_dfs = []
frame_dfs = []
# this should pass through obscode...rather should use the relevant frames' obs codes
for orbit in orbits:
eph = orbit.compute_ephemeris(obscode=obscode, epochs=frame_mjd_within_range)
mjd = [w.mjd for w in eph]
ra = [w.ra for w in eph]
dec = [w.dec for w in eph]
ephemeris_df = pd.DataFrame.from_dict(
{
"mjd_utc": mjd,
"RA_deg": ra,
"Dec_deg": dec,
}
)
ephemeris_df["orbit_id"] = orbit.orbit_id
ephemeris_df["observatory_code"] = obscode
ephemeris_dfs.append(ephemeris_df)

# Now we gathetr the healpixels and neighboring pixels for each propagated position
healpix = radec_to_healpixel_with_neighbors(
ra, dec, precovery_db.frames.healpix_nside
)

frame_df = pd.concat(
[
pd.DataFrame.from_dict(
{
"mjd_utc": [x[0] for y in range(9)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is range(9) doing here exactly? Might be worth adding a comment since it's a hard-coded list length.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this is unclear. Healpixel plus neighbors is a set of 9 healpixels

"obscode": [x[1] for y in range(9)],
"healpix": list(x[2]),
}
)
for x in zip(
mjd, ephemeris_df["observatory_code"], list(healpix.transpose())
)
],
ignore_index=True,
)
frame_dfs.append(frame_df)

# This will be passed back to be used as the ephemeris dataframe later
ephemeris = pd.concat(ephemeris_dfs, ignore_index=True)
unique_frames = pd.concat(frame_dfs, ignore_index=True).drop_duplicates()

# This filter is very inefficient - there is surely a better way to do this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more efficient I'd delegate this to SQL altogether and add another convenience function. This PR already adds some very useful convenience functions to the FrameIndex class. Maybe adding one that returns all frames with a specific healpixel, mjd, and obscode is another one that will be useful.

Actually, I'd imagine that a function like this is necessary, especially in the context of HEALPix down/upsampling that Tanay's been working on. We will want to be able to query for all frames that correspond to one HEALPixel value at a given mjd and obscode so that we can pull those frames in and then check those for matches. In the context of resampling nside, you could imagine adding a nside_search parameter to this convenience function. If nside_search is not equal to the nside at indexing time, then you then resample pixels and return the correct frames for a given nside_search. Of course, this is out of scope for this PR but having part of this function already included might be super useful.

Please feel free to push back here given you have more knowledge of what Tanay's code looks like.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your sense here is correct. I was reticent to add too much bloat to the FrameIndex class - functions convenient just for the brute-force. But I agree that this is a case where we could get utility out of that function outside of brute-force as well.

I'll see what could meet the requirements of both the healpix downsampling and this, and rewrite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seconded on pushing to SQL. Let the database handle filtering heuristics.

filtered_frames = []
all_frames = precovery_db.frames.idx.frames_by_date(mjd_start, mjd_end)
for frame in all_frames:
if (
(unique_frames["mjd_utc"] == frame.mjd)
& (unique_frames["obscode"] == frame.obscode)
& (unique_frames["healpix"] == frame.healpixel)
).any():
filtered_frames.append(frame)

observations = precovery_db.extract_observations_by_frames(filtered_frames)

return ephemeris, observations


def attribution_worker(
ephemeris,
observations,
eps=1 / 3600,
include_probabilistic=True,
):

"""
gather attributions for a df of ephemerides, observations

First filters ephemerides to match the chunked observations

"""


# Create observer's dictionary from observations
observers = {}
for observatory_code in observations["observatory_code"].unique():
observers[observatory_code] = Time(
observations[observations["observatory_code"].isin([observatory_code])][
"mjd_utc"
].unique(),
scale="utc",
format="mjd",
)

# Group the predicted ephemerides and observations by visit / exposure
observations_grouped = observations.groupby(by=["observatory_code", "mjd_utc"])
observations_visits = [
observations_grouped.get_group(g) for g in observations_grouped.groups
]

# We pre-computed the ephemerides. Now we filter the ephemeris for only visits
# that have observations in the obs group passed to the worker

ephemeris_pre_grouped = ephemeris.groupby(by=["observatory_code", "mjd_utc"])
obs_group_keys = list(observations_grouped.groups.keys())
indices_to_drop = pd.Int64Index([])
for g_key in list(ephemeris_pre_grouped.groups.keys()):
if g_key not in obs_group_keys:
indices_to_drop = indices_to_drop.union(
ephemeris_pre_grouped.get_group(g_key).index
)

ephemeris_filtered = ephemeris.drop(indices_to_drop)

# Group the now-filtered ephemerides. There should only be visits for the observation set
ephemeris_grouped = ephemeris_filtered.groupby(by=["observatory_code", "mjd_utc"])
ephemeris_visits = [
ephemeris_grouped.get_group(g) for g in ephemeris_grouped.groups
]

# Loop through each unique exposure and visit, find the nearest observations within
# eps (haversine metric)
distances = []
orbit_ids_associated = []
obs_ids_associated = []
obs_times_associated = []
eps_rad = np.radians(eps)
residuals = []
stats = []
for ephemeris_visit, observations_visit in zip(
ephemeris_visits, observations_visits
):

assert len(ephemeris_visit["mjd_utc"].unique()) == 1
assert len(observations_visit["mjd_utc"].unique()) == 1
assert (
observations_visit["mjd_utc"].unique()[0]
== ephemeris_visit["mjd_utc"].unique()[0]
)

obs_ids = observations_visit[["obs_id"]].values
obs_times = observations_visit[["mjd_utc"]].values
orbit_ids = ephemeris_visit[["orbit_id"]].values
coords = observations_visit[["RA_deg", "Dec_deg"]].values
coords_predicted = ephemeris_visit[["RA_deg", "Dec_deg"]].values
coords_sigma = observations_visit[["RA_sigma_deg", "Dec_sigma_deg"]].values

# Haversine metric requires latitude first then longitude...
coords_latlon = np.radians(observations_visit[["Dec_deg", "RA_deg"]].values)
coords_predicted_latlon = np.radians(
ephemeris_visit[["Dec_deg", "RA_deg"]].values
)

num_obs = len(coords_predicted)
k = np.minimum(3, num_obs)

# Build BallTree with a haversine metric on predicted ephemeris
tree = BallTree(coords_predicted_latlon, metric="haversine")
# Query tree using observed RA, Dec
d, i = tree.query(
coords_latlon,
k=k,
return_distance=True,
dualtree=True,
breadth_first=False,
sort_results=False,
)

# Select all observations with distance smaller or equal
# to the maximum given distance
mask = np.where(d <= eps_rad)

if len(d[mask]) > 0:
orbit_ids_associated.append(orbit_ids[i[mask]])
obs_ids_associated.append(obs_ids[mask[0]])
obs_times_associated.append(obs_times[mask[0]])
distances.append(d[mask].reshape(-1, 1))

residuals_visit, stats_visit = calcResiduals(
coords[mask[0]],
coords_predicted[i[mask]],
sigmas_actual=coords_sigma[mask[0]],
include_probabilistic=True,
)
residuals.append(residuals_visit)
stats.append(np.vstack(stats_visit).T)

if len(distances) > 0:
distances = np.degrees(np.vstack(distances))
orbit_ids_associated = np.vstack(orbit_ids_associated)
obs_ids_associated = np.vstack(obs_ids_associated)
obs_times_associated = np.vstack(obs_times_associated)
residuals = np.vstack(residuals)
stats = np.vstack(stats)

attributions = {
"orbit_id": orbit_ids_associated[:, 0],
"obs_id": obs_ids_associated[:, 0],
"mjd_utc": obs_times_associated[:, 0],
"distance": distances[:, 0],
"residual_ra_arcsec": residuals[:, 0] * 3600,
"residual_dec_arcsec": residuals[:, 1] * 3600,
"chi2": stats[:, 0],
}
if include_probabilistic:
attributions["probability"] = stats[:, 1]
attributions["mahalanobis_distance"] = stats[:, 2]

attributions = pd.DataFrame(attributions)

else:
columns = [
"orbit_id",
"obs_id",
"mjd_utc",
"distance",
"residual_ra_arcsec",
"residual_dec_arcsec",
"chi2",
]
if include_probabilistic:
columns += ["probability", "mahalanobis_distance"]

attributions = pd.DataFrame(columns=columns)

return attributions


def attributeObservations(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest renaming this to attribute_observations. THOR inherited LSST's use of camel case but ideally let's match the rest of the precovery code and go with underscores. The new version of THOR is headed that direction as well.

cc' @akoumjian : function naming going forward might be a good style guide addition for us to adopt at some point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with the casing. I'm in favor of sticking to underscore casing for Python, camelCase for JS. Good to know new THOR is moving that way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although rather immaterial snake_case is the way to go.

orbits,
mjd_start: float,
mjd_end: float,
precovery_db: PrecoveryDatabase,
eps=5 / 3600,
include_probabilistic=True,
backend="PYOORB",
backend_kwargs={},
orbits_chunk_size=10,
observations_chunk_size=100000,
num_jobs=1,
parallel_backend="mp",
):
logger.info("Running observation attribution...")
time_start = time.time()

num_orbits = len(orbits)

attribution_dfs = []

# prepare ephemeris and observation dictionaries
ephemeris, observations = get_observations_and_ephemerides_for_orbits(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be completely misunderstanding this, so I apologize if I'm off.
Can this function be broken up into separate functions? The process seems to be:

  1. orbits -> ephemerides
  2. ephemerides -> all unique intersecting frames + neighbors
  3. frames -> get all observations

The reason being that I believe most of these are re-usable and consistent and may simply require parameters to adjust results based on brute-force vs. nelson-force, etc. Here is what I am thinking this looks like alternatively,

frame_times = get_frame_times_in_range(mjd_start, mjd_end)
ephemerides = ephemerides_from_orbits(orbits, epochs, obscode=obscode)
frames_to_search = intersecting_frames(ephemerides, neighbors=True) # Default is False
observations = precovery_db.extract_observations_by_frames(filtered_frames)

I think each of these steps is usable in multiple cases, especially with using optional parameters.
This makes each individual piece more testable and gives back simpler return values.

orbits, mjd_start, mjd_end, precovery_db
)

parallel, num_workers = _checkParallel(num_jobs, parallel_backend)
if num_workers > 1:

p = mp.Pool(
processes=num_workers,
initializer=_initWorker,
)

# Send up to orbits_chunk_size orbits to each OD worker for processing
chunk_size_ = calcChunkSize(
num_orbits, num_workers, orbits_chunk_size, min_chunk_size=1
)

orbits_split = [
orbits[i : i + chunk_size_] for i in range(0, len(orbits), chunk_size_)
]

eph_split = []
for orbit_c in orbits.split(orbits_chunk_size):
eph_split.append(
ephemeris[
ephemeris["orbit_id"].isin([orbit.orbit_id for orbit in orbit_c])
]
)
for observations_c in yieldChunks(observations, observations_chunk_size):

obs = [observations_c for i in range(len(orbits_split))]
attribution_dfs_i = p.starmap(
partial(
attribution_worker,
eps=eps,
include_probabilistic=include_probabilistic,
backend=backend,
backend_kwargs=backend_kwargs,
),
zip(
eph_split,
obs,
),
)
attribution_dfs += attribution_dfs_i

p.close()

else:
for observations_c in yieldChunks(observations, observations_chunk_size):
for orbit_c in [
orbits[i : i + orbits_chunk_size]
for i in range(0, len(orbits), orbits_chunk_size)
]:

eph_c = ephemeris[
ephemeris["orbit_id"].isin([orbit.orbit_id for orbit in orbit_c])
]
attribution_df_i = attribution_worker(
eph_c,
observations_c,
eps=eps,
include_probabilistic=include_probabilistic,
)
attribution_dfs.append(attribution_df_i)

attributions = pd.concat(attribution_dfs)
attributions.sort_values(
by=["orbit_id", "mjd_utc", "distance"], inplace=True, ignore_index=True
)

time_end = time.time()
logger.info(
"Attributed {} observations to {} orbits.".format(
attributions["obs_id"].nunique(), attributions["orbit_id"].nunique()
)
)
logger.info(
"Attribution completed in {:.3f} seconds.".format(time_end - time_start)
)
return attributions
Loading