Skip to content

Commit cfa817a

Browse files
committed
.
1 parent c721592 commit cfa817a

3 files changed

Lines changed: 224 additions & 32 deletions

File tree

configs/grid_index_config.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Grid generation configuration for HNestEOGrid
2+
3+
# Define all resolution levels in meters
4+
levels: [120000, 12000, 2400, 1200, 600, 300]
5+
default_levels: [300, 600, 1200, 2400, 12000, 120000]
6+
7+
# Leave utm_zones undefined to process ALL zones (1N–60N, 1S–60S)
8+
# utm_zones: [f"{i}{d}" for i in range(1, 61) for d in ["N", "S"]]
9+
10+
# Output settings
11+
output_format: "PARQUET" # Format: PARQUET or SHP
12+
output_dir: "D:/nesteo_hf/index_structure" # Where to write output
13+
14+
# Optional enhancements
15+
# save_geohash: true # Save GeoHash column
16+
include_polar: true # If True, generates polar grids
17+
# skip_existing: true # Skip grid generation if output already exists
18+
# save_wgs_files: true # Export WGS84 versions alongside UTM
19+
# save_single_file: true # Save full grid as one .parquet per level
20+
generate: false # Flag to trigger actual generation (safe for dry-run toggle)
21+
22+
# Optional advanced
23+
# chunked_levels: [300] # For large levels, save in multiple chunks
24+
# ref_level: 12000 # Reference level for spatial alignment
25+
# ref_dir: "D:/NestEO_hf/metadata_current/grids_geo/grid_12000m"

scripts/generate_grid_index.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pathlib import Path
2+
from NestEO.grid import NestEOGrid
3+
import yaml
4+
5+
def load_config(config_path: Path) -> dict:
6+
with open(config_path, "r") as f:
7+
return yaml.safe_load(f)
8+
9+
10+
def main(config_file="generate_grid_index.yaml"):
11+
config_path = Path(config_file)
12+
if not config_path.exists():
13+
raise FileNotFoundError(f"Config file not found: {config_path.resolve()}")
14+
15+
config = load_config(config_path)
16+
# main_path = config.get("main_path", "D:/NestEO_hf/")
17+
output_dir=config.get("output_dir")#+"grids"
18+
# ref_dir = main_path+"datasets_AUX/Landcover/ESA_WorldCover/ESA_LC_proportions/600m"
19+
20+
# Dynamically pass all supported parameters from config to the class
21+
grid = NestEOGrid(
22+
levels=config.get("levels", [120000, 12000, 2400, 1200,]),
23+
default_levels=config.get("default_levels"),
24+
# buffer_ratio=config.get("buffer_ratio", 0.0),
25+
# overlap_ratio=config.get("overlap_ratio", 0.0),
26+
# utm_zones=config.get("utm_zones"), # Optional: all zones if None
27+
# latlon_bounds=config.get("latlon_bounds"),
28+
include_polar=config.get("include_polar", True),
29+
# save_geohash=config.get("save_geohash", False),
30+
output_dir=output_dir,
31+
# output_format=config.get("output_format", "PARQUET"),
32+
# save_single_file=config.get("save_single_file", True),
33+
# save_wgs_files=config.get("save_wgs_files", True),
34+
# row_group_size=config.get("row_group_size", 10000),
35+
# file_name_prefix=config.get("file_name_prefix", ""),
36+
# chunked_levels=config.get("chunked_levels", [300, 600]),
37+
# partition_count=config.get("partition_count", 8),
38+
# skip_existing=config.get("skip_existing", True),
39+
# ref_level=config.get("ref_level"),
40+
# ref_dir=ref_dir, #config.get("ref_dir", ""),
41+
generate=config.get("generate", False),
42+
)
43+
grid.build_tile_index_parquet(Path(output_dir) / "grid_index.parquet")
44+
# grid.run()
45+
46+
47+
if __name__ == "__main__":
48+
import sys
49+
config_arg = sys.argv[1] if len(sys.argv) > 1 else "grid_config.yaml"
50+
main(config_file=config_arg)

src/NestEO/grid/grid_generator.py

Lines changed: 149 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,8 @@ def _zero_tile_tuples(self, zone: str) -> set[tuple[int,int]]:
515515
self._zero_cache[zone] = tuples
516516
return tuples
517517

518-
519-
# pat = join(
520-
# self.ref_dir,
521-
# f"lc_proportions_*_{zone}_{self.ref_level}m.parquet")
522-
# files = glob.glob(pat)
523-
# if not files:
524-
# raise FileNotFoundError(f"No ref‑level parquet for zone {zone} under {pat}")
525-
526-
# df = pd.read_parquet(files[0], columns=["tile_id", "landcover_props"])
527-
# df = df[df["landcover_props"] == "{0: 1.0}"]
528-
529-
# # parse tile_id → (x_idx, y_idx) ‑‑ vectorised regex
530-
# rgx = re.compile(r"_X(\d+)_Y(\d+)")
531-
# tuples = set(
532-
# df["tile_id"].str.extract(rgx).astype(int).apply(tuple, axis=1)
533-
# )
534-
# self._zero_cache[zone] = tuples
535-
# return tuples
536518
# ─────────────────────────────────────────────────────────────────────────────
537519

538-
539520
# def ancestor_id_series(self, tile_id_series: pd.Series) -> pd.Series:
540521
# """
541522
# Vectorised: return the tile_id of the ancestor at *ref_level*
@@ -559,19 +540,6 @@ def _zero_tile_tuples(self, zone: str) -> set[tuple[int,int]]:
559540
# df["zone"] + "_X" + x_anc.astype(str).str.zfill(6) +
560541
# "_Y" + y_anc.astype(str).str.zfill(6)
561542
# )
562-
563-
# def _prefilter_grid_centroids(self, cols, rows, grid_size, crs, lon_bounds: Tuple[float, float]):
564-
# grid_x, grid_y = np.meshgrid(cols, rows)
565-
# grid_x = grid_x.ravel()
566-
# grid_y = grid_y.ravel()
567-
# print("Grid X and Y shape: ", grid_x.shape, grid_y.shape)
568-
# cx = grid_x + grid_size / 2
569-
# cy = grid_y + grid_size / 2
570-
# transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
571-
# lons, _ = transformer.transform(cx, cy)
572-
# lon_min, lon_max = lon_bounds
573-
# mask = (lons >= lon_min) & (lons <= lon_max)
574-
# return grid_x[mask], grid_y[mask]
575543

576544
def _prefilter_grid_centroids(self, cols, rows, grid_size, crs,
577545
lon_bounds: Tuple[float, float]):
@@ -749,6 +717,155 @@ def _construct_tile_file_path(self, zone: str, level: int, ext: Optional[str] =
749717
return join(self.output_dir, fname)
750718

751719

720+
# ──────────────────────────────────────────────────────────────
721+
def _iter_valid_xy(self, grid_size: int, zone: str):
722+
"""
723+
Yield (x_idx, y_idx) integers for every tile that would exist at
724+
*grid_size* and *zone* without creating any geometry. Internal helper
725+
for fast index generation.
726+
"""
727+
if zone in ("NP", "SP"): # Polar
728+
EPSG = 3413 if zone == "NP" else 3031
729+
crs = CRS.from_epsg(EPSG)
730+
bounds = (-4_500_000, 0) if zone == "NP" else (-4_500_000, -4_500_000)
731+
origin_x, origin_y = bounds
732+
xmax, ymax = origin_x + 9_000_000, origin_y + 4_500_000
733+
step = int(grid_size * (1 - self.overlap_ratio)) if self.overlap_ratio > 0 else grid_size
734+
cols = np.arange(origin_x, xmax, step)
735+
rows = np.arange(origin_y, ymax, step)
736+
737+
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
738+
grid_x, grid_y = np.meshgrid(cols, rows)
739+
cx = grid_x.ravel() + grid_size / 2
740+
cy = grid_y.ravel() + grid_size / 2
741+
_, lats = transformer.transform(cx, cy)
742+
mask = (lats >= 84) if zone == "NP" else (lats <= -80)
743+
x_idx = ((grid_x.ravel()[mask] - origin_x) // grid_size).astype(int)
744+
y_idx = ((grid_y.ravel()[mask] - origin_y) // grid_size).astype(int)
745+
return x_idx, y_idx
746+
747+
# ─── UTM ──────────────────────────────────────────────────
748+
zone_num = int(zone[:-1]); hemi = zone[-1]
749+
epsg = 32600 + zone_num if hemi == "N" else 32700 + zone_num
750+
crs = CRS.from_epsg(epsg)
751+
752+
origin_x = 100_000
753+
origin_y = 0 if hemi == "N" else 10_000_000
754+
xmin, xmax = origin_x, 900_000
755+
ymin, ymax = (0, 9_329_005) if hemi == "N" else (0, origin_y)
756+
757+
step = int(grid_size * (1 - self.overlap_ratio)) if self.overlap_ratio > 0 else grid_size
758+
cols = np.arange(xmin, xmax, step)
759+
rows = np.arange(ymin, ymax, step)
760+
761+
# Fast centroid-based lon–lat filter (reuse existing logic)
762+
valid_x, valid_y = self._prefilter_grid_centroids(
763+
cols, rows, grid_size, crs,
764+
((zone_num - 1) * 6 - 180, zone_num * 6 - 180)
765+
)
766+
x_idx = ((valid_x - origin_x) // grid_size).astype(int)
767+
y_idx = ((valid_y - origin_y) // grid_size).astype(int)
768+
keep = self._mask_by_ref(grid_size, zone, x_idx, y_idx)
769+
return x_idx[keep], y_idx[keep]
770+
771+
def build_tile_index_parquet(
772+
self,
773+
output_path: str = "grid_index.parquet",
774+
row_group_target: int | None = None,
775+
) -> None:
776+
"""
777+
Create a single Parquet file containing *all* tile_id / super_id pairs
778+
for every configured level and every UTM + optional polar zone—without
779+
generating geometries.
780+
781+
Parameters
782+
----------
783+
output_path : str, default "grid_index.parquet"
784+
Destination path.
785+
row_group_target : int | None
786+
Desired row-group size. If None, it is set to
787+
max(total_rows // 512, 1024).
788+
"""
789+
import pyarrow as pa, pyarrow.parquet as pq
790+
791+
# ------------------------------------------------------------------ zones
792+
zones = [f"{i}{h}" for i in range(1, 61) for h in "NS"]
793+
if getattr(self, "include_polar", False):
794+
zones += ["NP", "SP"]
795+
796+
# ---------------------------------------------------------------- pass 1
797+
total_rows = 0
798+
for level in self.levels:
799+
print("Processing Level: ", level)
800+
for zone in zones:
801+
x_idx, y_idx = self._iter_valid_xy(level, zone)
802+
total_rows += x_idx.size
803+
804+
if total_rows == 0:
805+
raise RuntimeError("No tiles found with current configuration.")
806+
807+
if row_group_target is None:
808+
row_group_target = max(total_rows // 1024, 1024)
809+
810+
# -------------------------------------------------------------- writer
811+
schema = pa.schema(
812+
[("tile_id", pa.string()), ("super_id", pa.string())]
813+
)
814+
writer = pq.ParquetWriter(
815+
output_path, schema, version="2.6", compression="snappy"
816+
)
817+
818+
# --------------------------------------------------------------- pass 2
819+
buffer_tile, buffer_super = [], []
820+
buffer_cap = row_group_target
821+
822+
def _flush():
823+
nonlocal buffer_tile, buffer_super
824+
if buffer_tile:
825+
table = pa.table(
826+
{"tile_id": pa.array(buffer_tile), "super_id": pa.array(buffer_super)}
827+
)
828+
writer.write_table(table, row_group_size=row_group_target)
829+
buffer_tile, buffer_super = [], []
830+
831+
for level in self.levels:
832+
for zone in zones:
833+
x_idx, y_idx = self._iter_valid_xy(level, zone)
834+
if x_idx.size == 0:
835+
continue
836+
837+
tiles = [
838+
self._make_tile_id(level, zone, xi, yi)
839+
for xi, yi in zip(x_idx, y_idx)
840+
]
841+
supers = [
842+
self._compute_super_id(level, zone, xi, yi)
843+
for xi, yi in zip(x_idx, y_idx)
844+
]
845+
846+
buffer_tile.extend(tiles)
847+
buffer_super.extend(supers)
848+
849+
# Flush whenever the buffer reaches the cap to respect row_group_target
850+
while len(buffer_tile) >= buffer_cap:
851+
slice_end = buffer_cap
852+
table = pa.table(
853+
{
854+
"tile_id": pa.array(buffer_tile[:slice_end]),
855+
"super_id": pa.array(buffer_super[:slice_end]),
856+
}
857+
)
858+
writer.write_table(table, row_group_size=row_group_target)
859+
buffer_tile = buffer_tile[slice_end:]
860+
buffer_super = buffer_super[slice_end:]
861+
862+
_flush() # write any remainder
863+
writer.close()
864+
print(f"[OK] grid index written → {output_path} "
865+
f"({total_rows:,} rows, row_group {row_group_target})")
866+
867+
868+
752869
def check_satellite_resolution_compatibility(self, grid_sizes: List[int], satellite_resolutions: List[int]) -> pd.DataFrame:
753870
"""
754871
Computes how well each tile level intersects with a set of satellite resolutions.

0 commit comments

Comments
 (0)