Skip to content

Commit ac1434b

Browse files
committed
Pre-download S2 scene, batch pixel load and embeding generation
1 parent 6c867ee commit ac1434b

2 files changed

Lines changed: 71 additions & 32 deletions

File tree

embeddings/all-sentinel.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import json
33
import logging
44
import os
5+
import tempfile
56
from pathlib import Path
67

78
import boto3
9+
import numpy as np
810
from pystac import Item
9-
from rasterio.errors import RasterioIOError
1011
from stacchip.chipper import Chipper
1112
from stacchip.indexer import Sentinel2Indexer
1213

1314
from embeddings.utils import (
1415
get_embeddings,
1516
get_pixels,
16-
load_clay,
1717
load_metadata,
1818
prepare_datacube,
1919
write_to_table,
@@ -26,6 +26,7 @@
2626
SCENES_LIST = "data/element84-tiles-2023.gz"
2727
EMBEDDINGS_BUCKET = "clay-embeddings-sentinel-2"
2828
GSD = 10
29+
S2_BUCKET = "sentinel-2-cogs"
2930

3031

3132
def open_scenes_list():
@@ -40,6 +41,21 @@ def open_scenes_list():
4041
return data
4142

4243

44+
def download_scenes_local(tmp, item, bands):
45+
s3 = boto3.client("s3")
46+
for band in bands:
47+
local_asset_path = f"{tmp}/{band}.tif"
48+
remote_asset_key = item.assets[band].href.replace(
49+
"https://sentinel-cogs.s3.us-west-2.amazonaws.com/", ""
50+
)
51+
print(f"Downloading band {band} to {local_asset_path}")
52+
with open(local_asset_path, mode="w+b") as fl:
53+
s3.download_fileobj("sentinel-cogs", remote_asset_key, fl)
54+
item.assets[band].href = local_asset_path
55+
56+
return item
57+
58+
4359
def process_scene(clay, path, batchsize):
4460
bands, waves, mean, std = load_metadata("sentinel-2-l2a")
4561

@@ -57,44 +73,62 @@ def process_scene(clay, path, batchsize):
5773
logger.debug(f"No proj for {path}")
5874
return
5975

60-
try:
76+
all_bboxs = []
77+
all_cls_embeddings = None
78+
79+
with tempfile.TemporaryDirectory() as tmp:
80+
item = download_scenes_local(tmp, item, bands)
6181
indexer = Sentinel2Indexer(item, chip_max_nodata=0.1)
6282
chipper = Chipper(indexer, assets=bands)
63-
bboxs, datetimes, pixels = get_pixels(
64-
item=item, indexer=indexer, chipper=chipper
65-
)
66-
except RasterioIOError:
67-
logger.warning("Skipping scene due to rasterio io error")
68-
return
83+
logger.debug(f"Creating chips for {item.id}")
84+
STEP = 50
85+
for index in range(0, len(chipper), STEP):
86+
bboxs, datetimes, pixels = get_pixels(
87+
item=item,
88+
indexer=indexer,
89+
chipper=chipper,
90+
start=index,
91+
end=index + STEP,
92+
)
93+
94+
if not len(pixels):
95+
continue
96+
97+
time_norm, latlon_norm, gsd, pixels_norm = prepare_datacube(
98+
mean=mean,
99+
std=std,
100+
datetimes=datetimes,
101+
bboxs=bboxs,
102+
pixels=pixels,
103+
gsd=GSD,
104+
)
105+
106+
# Embed data
107+
cls_embeddings = get_embeddings(
108+
clay=clay,
109+
pixels_norm=pixels_norm,
110+
time_norm=time_norm,
111+
latlon_norm=latlon_norm,
112+
waves=waves,
113+
gsd=gsd,
114+
batchsize=batchsize,
115+
)
116+
all_bboxs += bboxs
117+
if all_cls_embeddings is None:
118+
all_cls_embeddings = cls_embeddings
119+
else:
120+
all_cls_embeddings = np.vstack((all_cls_embeddings, cls_embeddings))
69121

70-
if not len(pixels):
71-
logger.debug("Finishing early, no valid data found in scene.")
72-
return
73-
74-
time_norm, latlon_norm, gsd, pixels_norm = prepare_datacube(
75-
mean=mean, std=std, datetimes=datetimes, bboxs=bboxs, pixels=pixels, gsd=GSD
76-
)
77-
78-
# Embed data
79-
cls_embeddings = get_embeddings(
80-
clay=clay,
81-
pixels_norm=pixels_norm,
82-
time_norm=time_norm,
83-
latlon_norm=latlon_norm,
84-
waves=waves,
85-
gsd=gsd,
86-
batchsize=batchsize,
87-
)
88122
kwargs = dict(
89-
bboxs=bboxs,
123+
bboxs=all_bboxs,
90124
datestr=str(item.datetime.date()),
91125
gsd=gsd,
92126
destination_bucket=EMBEDDINGS_BUCKET,
93127
path=path,
94128
source_bucket="sentinel-cogs",
95129
)
96130

97-
write_to_table(embeddings=cls_embeddings, **kwargs)
131+
write_to_table(embeddings=all_cls_embeddings, **kwargs)
98132

99133

100134
def process():
@@ -105,7 +139,8 @@ def process():
105139
batchsize = int(os.environ.get("EMBEDDING_BATCH_SIZE", 50))
106140

107141
scenes = open_scenes_list()
108-
clay = load_clay()
142+
# clay = load_clay()
143+
clay = None
109144

110145
for i in range(index * items_per_job, (index + 1) * items_per_job):
111146
process_scene(

embeddings/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,17 @@ def prepare_datacube(mean, std, datetimes, bboxs, pixels, gsd):
7373
return time_norm, latlon_norm, gsd, pixels_norm
7474

7575

76-
def get_pixels(item, indexer, chipper):
76+
def get_pixels(item, indexer, chipper, start=None, end=None):
7777
chips = []
7878
datetimes = []
7979
bboxs = []
8080
chip_ids = []
8181
item_ids = []
82-
for index in range(len(chipper)):
82+
if start:
83+
index_range = range(start, min(end, len(chipper)))
84+
else:
85+
index_range = range(len(chipper))
86+
for index in index_range:
8387
y = index // chipper.indexer.x_size
8488
x = index % chipper.indexer.x_size
8589

0 commit comments

Comments
 (0)