22import json
33import logging
44import os
5+ import tempfile
56from pathlib import Path
67
78import boto3
9+ import numpy as np
810from pystac import Item
9- from rasterio .errors import RasterioIOError
1011from stacchip .chipper import Chipper
1112from stacchip .indexer import Sentinel2Indexer
1213
1314from embeddings .utils import (
1415 get_embeddings ,
1516 get_pixels ,
16- load_clay ,
1717 load_metadata ,
1818 prepare_datacube ,
1919 write_to_table ,
2626SCENES_LIST = "data/element84-tiles-2023.gz"
2727EMBEDDINGS_BUCKET = "clay-embeddings-sentinel-2"
2828GSD = 10
29+ S2_BUCKET = "sentinel-2-cogs"
2930
3031
3132def open_scenes_list ():
@@ -40,6 +41,21 @@ def open_scenes_list():
4041 return data
4142
4243
44+ def download_scenes_local (tmp , item , bands ):
45+ s3 = boto3 .client ("s3" )
46+ for band in bands :
47+ local_asset_path = f"{ tmp } /{ band } .tif"
48+ remote_asset_key = item .assets [band ].href .replace (
49+ "https://sentinel-cogs.s3.us-west-2.amazonaws.com/" , ""
50+ )
51+ print (f"Downloading band { band } to { local_asset_path } " )
52+ with open (local_asset_path , mode = "w+b" ) as fl :
53+ s3 .download_fileobj ("sentinel-cogs" , remote_asset_key , fl )
54+ item .assets [band ].href = local_asset_path
55+
56+ return item
57+
58+
4359def process_scene (clay , path , batchsize ):
4460 bands , waves , mean , std = load_metadata ("sentinel-2-l2a" )
4561
@@ -57,44 +73,62 @@ def process_scene(clay, path, batchsize):
5773 logger .debug (f"No proj for { path } " )
5874 return
5975
60- try :
76+ all_bboxs = []
77+ all_cls_embeddings = None
78+
79+ with tempfile .TemporaryDirectory () as tmp :
80+ item = download_scenes_local (tmp , item , bands )
6181 indexer = Sentinel2Indexer (item , chip_max_nodata = 0.1 )
6282 chipper = Chipper (indexer , assets = bands )
63- bboxs , datetimes , pixels = get_pixels (
64- item = item , indexer = indexer , chipper = chipper
65- )
66- except RasterioIOError :
67- logger .warning ("Skipping scene due to rasterio io error" )
68- return
83+ logger .debug (f"Creating chips for { item .id } " )
84+ STEP = 50
85+ for index in range (0 , len (chipper ), STEP ):
86+ bboxs , datetimes , pixels = get_pixels (
87+ item = item ,
88+ indexer = indexer ,
89+ chipper = chipper ,
90+ start = index ,
91+ end = index + STEP ,
92+ )
93+
94+ if not len (pixels ):
95+ continue
96+
97+ time_norm , latlon_norm , gsd , pixels_norm = prepare_datacube (
98+ mean = mean ,
99+ std = std ,
100+ datetimes = datetimes ,
101+ bboxs = bboxs ,
102+ pixels = pixels ,
103+ gsd = GSD ,
104+ )
105+
106+ # Embed data
107+ cls_embeddings = get_embeddings (
108+ clay = clay ,
109+ pixels_norm = pixels_norm ,
110+ time_norm = time_norm ,
111+ latlon_norm = latlon_norm ,
112+ waves = waves ,
113+ gsd = gsd ,
114+ batchsize = batchsize ,
115+ )
116+ all_bboxs += bboxs
117+ if all_cls_embeddings is None :
118+ all_cls_embeddings = cls_embeddings
119+ else :
120+ all_cls_embeddings = np .vstack ((all_cls_embeddings , cls_embeddings ))
69121
70- if not len (pixels ):
71- logger .debug ("Finishing early, no valid data found in scene." )
72- return
73-
74- time_norm , latlon_norm , gsd , pixels_norm = prepare_datacube (
75- mean = mean , std = std , datetimes = datetimes , bboxs = bboxs , pixels = pixels , gsd = GSD
76- )
77-
78- # Embed data
79- cls_embeddings = get_embeddings (
80- clay = clay ,
81- pixels_norm = pixels_norm ,
82- time_norm = time_norm ,
83- latlon_norm = latlon_norm ,
84- waves = waves ,
85- gsd = gsd ,
86- batchsize = batchsize ,
87- )
88122 kwargs = dict (
89- bboxs = bboxs ,
123+ bboxs = all_bboxs ,
90124 datestr = str (item .datetime .date ()),
91125 gsd = gsd ,
92126 destination_bucket = EMBEDDINGS_BUCKET ,
93127 path = path ,
94128 source_bucket = "sentinel-cogs" ,
95129 )
96130
97- write_to_table (embeddings = cls_embeddings , ** kwargs )
131+ write_to_table (embeddings = all_cls_embeddings , ** kwargs )
98132
99133
100134def process ():
@@ -105,7 +139,8 @@ def process():
105139 batchsize = int (os .environ .get ("EMBEDDING_BATCH_SIZE" , 50 ))
106140
107141 scenes = open_scenes_list ()
108- clay = load_clay ()
142+ # clay = load_clay()
143+ clay = None
109144
110145 for i in range (index * items_per_job , (index + 1 ) * items_per_job ):
111146 process_scene (
0 commit comments