Optimization of `compute_tide_corrections` with `FES2014` for multiple lat/lons
robbibt opened this issue · 4 comments
First of all, congrats @tsutterley on an incredible package... such an amazing resource! I've been looking into using pyTMD
for modelling tide heights from FES2014
for our DEA Coastlines coastline mapping work. Essentially, our current process is to:
- For a coastal study area, load every image taken by the Landsat satellites between 1987 and 2022
- For each point on a 2 x 2 km point grid over the ocean, model tides for the exact time each Landsat image was taken (e.g. typically 1000-1500 images per point)
I've been testing out the compute_tide_corrections
as a way to achieve this, passing in the lat/lon of a given 2 x 2 km grid point, and all of the times from my satellite datasets:
import pandas as pd
from pyTMD import compute_tide_corrections
lat, lon = -32, 155
example_times = pd.date_range("2022-01-01", "2022-01-02", freq="1h").values
out = compute_tide_corrections(
x=lon,
y=lat,
delta_time=example_times,
DIRECTORY="FES2014",
MODEL="FES2014",
EPSG=4326,
TYPE="time series",
TIME="datetime",
METHOD="bilinear",
)
This works great, but it's pretty slow: about 38.4 seconds in total for a single point lat/lon. Because I can have up to 100+ lat/lon points in a given study area, this will quickly blow out if I want to apply compute_tide_corrections
to multiple points.
Using line_profiler
, it appears that by far most of this time (e.g. 38.2 seconds, or over 99%) is taken up in the extract_FES_constants
function:
Timer unit: 1e-06 s
Total time: 38.408 s
File: /env/lib/python3.8/site-packages/pyTMD/compute_tide_corrections.py
Function: compute_tide_corrections at line 125
Line # Hits Time Per Hit % Time Line Contents
==============================================================
...
277 2 38292108.0 19146054.0 99.7 amp,ph = extract_FES_constants(lon, lat, model.model_file,
278 1 4.0 4.0 0.0 TYPE=model.type, VERSION=model.version, METHOD=METHOD,
279 1 4.0 4.0 0.0 EXTRAPOLATE=EXTRAPOLATE, CUTOFF=CUTOFF, SCALE=model.scale,
280 1 4.0 4.0 0.0 GZIP=model.compressed)
...
Profiling extract_FES_constants
, it seems like by far the most amount of time in that function (37.5 seconds) is taken up by read_netcdf_file
:
Timer unit: 1e-06 s
Total time: 38.2932 s
File: /env/lib/python3.8/site-packages/pyTMD/read_FES_model.py
Function: extract_FES_constants at line 86
Line # Hits Time Per Hit % Time Line Contents
==============================================================
...
158 68 37470465.0 551036.2 97.9 hc,lon,lat = read_netcdf_file(os.path.expanduser(fi),
159 34 53.0 1.6 0.0 GZIP=GZIP, TYPE=TYPE, VERSION=VERSION)
...
So essentially, loading the FES2014 files with read_netcdf_file
occupies almost all of the time taken to run compute_tide_corrections
. For analyses involving many timesteps for a single lat/lon this isn't a problem, as the files only have to be read once. However, for analyses where compute_tide_corrections
needs to be called multiple times to model tides for multiple lat/lons, the FES2014 data has to be loaded again and again, leading to extremely long processing times.
Instead of loading the FES2014 files with read_netcdf_file
every time compute_tide_corrections
is called, could it be possible to give users the option to load the FES files themselves outside of the function, and then pass in the loaded data (i.e. hc, lon, lat) directly to the function via an optional parameter? This would allow users to greatly optimise processing time for analyses that include many lat/lon tide modelling locations.
Since posting this issue, I have discovered that the drift
method gets me closer to what I want, as I can pass in multiple lat/lons as well as multiple times, and avoid the multiple NetCDF reads:
import pandas as pd
from pyTMD import compute_tide_corrections
# Input data (multiple times per point)
example_times = pd.date_range("2022-01-01", "2022-01-30", freq="1D")
point1_df = pd.DataFrame({'lat': -32, 'lon': 155, 'time': example_times})
point2_df = pd.DataFrame({'lat': -33, 'lon': 157, 'time': example_times})
point3_df = pd.DataFrame({'lat': -34, 'lon': 161, 'time': example_times})
# Combine into a single dataframe
points_df = pd.concat([point1_df, point2_df, point3_df])
# Model tide heights using 'drift'
out = compute_tide_corrections(
x=points_df.lon,
y=points_df.lat,
delta_time=points_df.time.values,
DIRECTORY="FES2014",
MODEL="FES2014",
EPSG=4326,
TYPE="drift",
TIME="datetime",
METHOD="bilinear",
)
# Add back into dataframe
points_df['tide_height'] = out
However, because I have static points with multiple timesteps at each, drift
still ends up being less efficient than I want because it assumes each time also has a unique lat/lon, which causes the spatial interpolation step to be run for every individual lat/lon/time pair (rather than being interpolated once for each unique point location then re-used for each time given that the point coordinates are the same for all times).
I think for this application (many timesteps for a smaller set of static modelling point locations), the most efficient processing flow might be something like this?
- Only once per entire analysis: read NetCDF tide model data
- Only once per static modelling point: Extract and interpolate constants based on lat/lon
- For every time: Model tide heights based on extracted constants at each point
(or alternatively, perhaps some method to detect duplicate/repeated lat/lons, then batch those together to reduce the number of required interpolations...)
@robbibt still thinking about the best way to enact these changes. One idea I've been floating is to cache the interpolation objects for each constituent so that won't have to be repeated reads. I'm worried about this being a bit memory intensive though so I need to put in some tests.
I've also been reorganizing the code structure lately in #132 and #135. Everything should still be backwards compatible just with some additional warnings.
Hey @tsutterley, am doing some further optimisations of our tide modelling code as we're moving towards a multi-tide modelling system where we choose the best tide model locally based on comparisons with our satellite data. Because of this, our modelling now takes a lot longer than previously, so I'm looking into trying to parallelise some of the underlying pyTMD
code to improve performance.
Our two big bottlenecks are:
- Loading the tide constituent NetCDF files (which we have largely addressed by clipping the files to a bounding box around Australia)
- Extracting tide constituents from the NetCDFs
For number 2, I've been able to get a big speed up by parallelising the entire pyTMD.io.*.extract_constants
calls across smaller chunks of lat/lon points using concurrent.futures
. However, I think there's still some gains to be made as pyTMD.io.*.extract_constants
includes the slow NetCDF read step itself, so we're effectively wasting time in each parallel run by loading the same data multiple times.
I know you made some changes to address this last year when I first posted this issue, but I wanted to double check: are the newer pyTMD.io.*.read_constants
and pyTMD.io.*.interpolate_constants
functions intended to completely replicate the existing functionality in pyTMD.io.*.extract_constants
? Or is there any functionality I'd lose by running those two functions instead of pyTMD.io.*.extract_constants
?
Ideally, I'd love to do something like this:
- Run once:
pyTMD.io.*.read_constants
- Run many times in parallel, using previously-loaded constituents:
pyTMD.io.*.interpolate_constants
Hey @robbibt, basically yes that was the plan. The new functions can completely replicate the prior functionality. The difference is that using the new read and interpolate method keeps all of the constituent data in memory. In some cases this may be slower, such as running on a small (possibly distributed) machine. So I've kept both methods.
In cases where you want to run for multiple points with the same data, there is a potential speed up with the new method since (as you mentioned) there's the io bottleneck.
I've thought about switching to dask
arrays (probably using xarray
) but need to do some testing. I'm completely open to suggestions for squeaking out performance.