Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@


ui = thalassa.ThalassaUI(
display_variables=True,
display_stations=True,
)

Expand Down
8 changes: 0 additions & 8 deletions thalassa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from __future__ import annotations

from .api import get_elevation_dmap
from .api import get_tiles
from .api import get_trimesh
from .api import get_wireframe
from .api import get_timeseries
from .ui import ThalassaUI
from .utils import open_dataset
from .utils import reload

__all__: list[str] = [
"open_dataset",
"reload",
"get_trimesh",
"get_tiles",
"get_wireframe",
"get_elevation_dmap",
"ThalassaUI",
]
140 changes: 77 additions & 63 deletions thalassa/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,73 @@
from holoviews.operation.datashader import rasterize
from holoviews.streams import PointerXY,DoubleTap
import numpy as np

logger = logging.getLogger(__name__)

# Load bokeh backend
from . import utils
hv.extension("bokeh")


def get_trimesh(
dataset: xr.Dataset,
longitude_var: str,
latitude_var: str,
elevation_var: str,
simplices_var: str,
time_var: str,
timestamp: str | pd.Timestamp,
) -> gv.TriMesh:
simplices = dataset[simplices_var].values
columns = [longitude_var, latitude_var, elevation_var]
if timestamp == "MAXIMUM":
points_df = dataset.max(time_var)[columns].to_dataframe()
elif timestamp == "MINIMUM":
points_df = dataset.min(time_var)[columns].to_dataframe()
else:
points_df = dataset.sel({time_var: timestamp})[columns].to_dataframe().drop(columns=time_var)
points_df = points_df.reset_index(drop=True)
points_gv = gv.Points(points_df, kdims=[longitude_var, latitude_var], vdims=elevation_var)
trimesh = gv.TriMesh((simplices, points_gv))
return trimesh

logger = logging.getLogger(__name__)

def get_tiles() -> gv.Tiles:
tiles = gv.WMTS("http://c.tile.openstreetmap.org/{Z}/{X}/{Y}.png")
return tiles


def get_wireframe(trimesh: gv.TriMesh) -> hv.Layout:
wireframe = dynspread(rasterize(trimesh.edgepaths, precompute=True))
return wireframe


def get_elevation_dmap(trimesh: gv.TriMesh, show_grid: bool = False) -> hv.Overlay:
tiles = get_tiles()
elevation = rasterize(trimesh, precompute=True).opts( # pylint: disable=no-member
title="Elevation Forecast",
colorbar=True,
clabel="meters",
show_legend=True,
)
logger.debug("show grid: %s", show_grid)
if show_grid:
overlay = tiles * elevation * get_wireframe(trimesh=trimesh)
else:
overlay = tiles * elevation
return overlay

#----------------------------------------------------------------------------------------
#time series
#----------------------------------------------------------------------------------------
class MapData:
'''
define a class to store data related to dynamic map
'''
def __init__(self):
#dataset info
self.name = None
self.format = None
self.prj = None
#header info
self.dataset = None #file handle -> xr.Dataset
self.times = None
self.variables = None
#connectivity
self.x = None
self.y = None
self.elnode = None
#dataset snapshot
self.time = None
self.variable = None
self.data = None
self.grid = None
self.trimesh = None
self.trimap = None
self.tiles = get_tiles()

def get_data(self,time,variable,layer):
'''
extract a snapshot from dataset
'''
self.time = time
self.variable = variable
self.layer = layer
tid=int(np.nonzero(np.array(self.times)==time)[0][0])
self.data=utils.read_dataset(self.dataset,2,self.format,time=tid,variable=variable,layer=layer)

def get_plot_map(self):
'''
plot a snapshot: only SCHISM method is defined so far
'''
if self.format=="SCHISM":
if self.x.min()<-360 or self.x.max()>360 or self.y.min()<-90 or self.y.max()>90:
raise ValueError(f"check dataset projection: abs(lat)>360 or abs(lon)>90")
df=pd.DataFrame({'longitude':self.x, 'latitude':self.y, 'data':self.data})
pdf=gv.Points(df,kdims=['longitude','latitude'],vdims='data')
self.trimesh=gv.TriMesh((self.elnode,pdf))
if self.grid is None:
self.grid=dynspread(rasterize(self.trimesh.edgepaths, precompute=True))
self.trimap=rasterize(self.trimesh, precompute=True).opts(
title=f"SCHISM Forecast: {self.variable}",
colorbar=True,
clabel="meters",
cmap="jet",
Copy link
Copy Markdown
Collaborator

@pmav99 pmav99 Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this colormap throws an exception:

# ...
  File "/home/panos/.conda/envs/thalassa/lib/python3.8/site-packages/holoviews/plotting/util.py", line 912, in process_cmap
    raise ValueError("Supplied cmap %s not found among %s colormaps." %
ValueError: Supplied cmap jet not found among matplotlib, bokeh, or colorcet colormaps.

We should either add matplotlib to the dependencies or use a different colormap. I wouldn't add an extra dependency just for a colormap.

Copy link
Copy Markdown
Collaborator Author

@wzhengui wzhengui Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pmav99 : Thank you for reviewing the PR. I think these are good suggestions. I will try to address them.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pmav99 and @brey : Thank Panos' useful suggestions again. I am sorry that I didn't have a chance talking to George this morning, as I was working on Thalassa. I just finished the revision except point 4. I tried to dynamically update the interface according to different variable picked by users, but it turns out to be very tricky. I will keep this on my mind, and see whether this is a good solution.

show_legend=True,
)
else:
raise ValueError(f"please define plot method for dataset format: {dataset_format}")

class TimeseriesData:
'''
define a class to store data related to time series points
Expand All @@ -77,32 +86,32 @@ def __init__(self):
def clear(self):
self.init=False

def extract_timeseries(x,y,sx,sy,data):
def extract_timeseries(x,y,sx,sy,dataset,variable):
'''
function for extracting time series@(x,y) from data
'''
dist=abs(sx+1j*sy-x-1j*y)
mdist=dist.min()
nid=np.nonzero(dist==mdist)[0][0]
mdata=data['elev'].data[:,nid].copy()
mdata=dataset[variable].data[:,nid].copy()
return mdist,mdata

def add_remove_pts(x,y,data,dataset,fmt):
def add_remove_pts(x,y,data,dataset,fmt,variable):
'''
function to dynamically add or remove pts by double clicking on the map
'''
if fmt=='add pts':
if len(data.xys)==0:
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"])
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"])
if mdist<=data.mdist:
data.xys.append((x,y))
data.elev.append(mdata)
data.curve.append(hcurve)
else:
if data.xys[-1][0]!=x and data.xys[-1][1]!=y:
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
hcurve=hv.Curve((data.time,mdata),'time','elevation').opts(tools=["hover"])
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
hcurve=hv.Curve((data.time,mdata),'time',variable).opts(tools=["hover"])
if mdist<=data.mdist:
data.xys.append((x,y))
data.elev.append(mdata)
Expand All @@ -120,15 +129,20 @@ def add_remove_pts(x,y,data,dataset,fmt):
else:
pass

def get_timeseries(source,data,dataset,ymin,ymax,fmt):
def get_timeseries(MData,data,ymin,ymax,fmt):
'''
get time series plots
'''

source, dataset = MData.trimesh, MData.dataset
variable='elev' #todo: add an input for time series variable

#initialize timeseries_data
if data.init is False:
#find the maximum side length
x,y=dataset['SCHISM_hgrid_node_x'].data,dataset['SCHISM_hgrid_node_y'].data
e1,e2,e3=dataset['SCHISM_hgrid_face_nodes'].data.T
x,y=MData.x,MData.y #tmp fix, improve: todo
e1,e2,e3=MData.elnode.T

s1=abs((x[e1]-x[e2])+1j*(y[e1]-y[e2])).max()
s2=abs((x[e2]-x[e3])+1j*(y[e2]-y[e3])).max()
s3=abs((x[e3]-x[e1])+1j*(y[e3]-y[e1])).max()
Expand All @@ -144,7 +158,7 @@ def get_timeseries(source,data,dataset,ymin,ymax,fmt):

def get_plot_point(x,y):
if None not in [x,y]:
add_remove_pts(x,y,data,dataset,fmt)
add_remove_pts(x,y,data,dataset,fmt,variable)

if ((x is None) or (y is None)) and len(data.xys)==0:
xys=[(data.x0,data.y0)]
Expand All @@ -159,7 +173,7 @@ def get_plot_point(x,y):
return hpoint*htext

def get_plot_curve(x,y):
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset)
mdist,mdata=extract_timeseries(x,y,data.sx,data.sy,dataset,variable)
if mdist>data.mdist:
mdata=mdata*np.nan
hdynamic=hv.Curve((data.time,mdata)).opts(color='k',line_width=2,line_dash='dotted')
Expand Down
Loading