Skip to content

16 v2 global embedding#49

Open
meiertgrootes wants to merge 6 commits into
mainfrom
16_v2_global_embedding
Open

16 v2 global embedding#49
meiertgrootes wants to merge 6 commits into
mainfrom
16_v2_global_embedding

Conversation

@meiertgrootes
Copy link
Copy Markdown
Collaborator

This pull request replaces #48 which was corrupted by mal-configured local github setup

This pull request adds fully sphere aware geo position and scale encoding for patches.
Geo position is encoded using real-valued spherical harmonics as basis.
To create embeddings real-valued spherical harmonics at the lat/lon positions of pixels of the input data (i.e. native resolution) are calculated up to a user-defined order L. This results in an (L+1)^2 dimensional embedding vector. L ~ 10 should be fine.
Subsequently a sphere aware area-weighted PCA is performed on the native resolution SH embdding grid, with the requirment that the target diemension for the PCA (the sh_embed_dim) be smaller than (L+1)^2. The ranked PCA components up to sh_embed_dim are retained to be used as basis functions. The SH embedding vectors are then reprojected to this basis and scaled to zero mean and unit variance with tanh based soft-clipping at ~3 sigma to suppress pathological outliers in high order harmonics.
For each patch, a patch position embedding is then constructed as the area weighted mean of the token/pixel embeddings in the patch.

In addition for each patch a scale embedding is constructed consisting of: the patch physical extent in lat and lon directions, the patch area, the anisotropy of the patch extent, the pixel scale in phyical units [m] in lat/lon, anisotropy and isotropized linear scale, as well as finally effective harmonic order cutoff in lat/lon.

Both embeddings are precalculated once for all patches

The 10-dimensional scale embedding is concatenated with the geo position embedding for each patch. During training an trainable MLP is used to project the concatented embeddings into the desired embedding diemension for additive incorporation.

Note, it makes sense to choose the hidden diemension for the projection larger than bith the dimension of the concatenated embeddings, as well as the target embedding dimension.

Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meiertgrootes thanks, the implementation looks good 👍 I left some comments. Most of them are related to code style. If something isnot clear, let me know.

Also, please consider running ruff as it can fixes things automatically and helps saving time in reviewing:

pip install ruff
ruff check --fix your_script.py   # this fixes/shows errors
ruff format --check your_script.py --diff  # this shows formatting issues

Comment thread climanet/dataset.py
Comment on lines +100 to +102



Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment thread climanet/dataset.py
Comment on lines +201 to +203
#geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)

Comment thread climanet/dataset.py
lon_patch = self.lon_coords[j : j + pw] # (W,) -> (pW,)

#get patch geo pos embedding
#geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch)

Comment thread climanet/dataset.py
geo_pos_embedding_tensor = self.patch_geo_embeddings[idx]

#get scale feature for patch
#scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,)

Comment on lines +727 to +729
#self.spatial_pe = SpatialPositionalEncoding2D(
# embed_dim=embed_dim, max_H=max_H, max_W=max_W
#)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#self.spatial_pe = SpatialPositionalEncoding2D(
# embed_dim=embed_dim, max_H=max_H, max_W=max_W
#)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please remove "SpatialPositionalEncoding2D" from the script since we are not using it anymore.


# east-west extent
dx = earth_radius * cos_lat_c * dlon
dx_pix = dx/max(lon_ext -1, 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dx_pix = dx/max(lon_ext -1, 1)
dx_pix = dx / max(lon_ext -1, eps)

Comment thread climanet/dataset.py



def _set_geo_pos_table(self, sh_pos_table: str):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function doesnot return anything, but in __init__ method, it is called as self.geo_pos_table = self._set_geo_pos_table(sh_pos_table). As a result, self.geo_pos_table will be None.

Comment thread climanet/dataset.py
Comment on lines +36 to 40
self.sh_order_L = sh_order_L



# Check that the input data has the expected dimensions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.sh_order_L = sh_order_L
# Check that the input data has the expected dimensions
self.sh_order_L = sh_order_L
# Check that the input data has the expected dimensions

Comment thread climanet/dataset.py
"land_mask_patch": land_tensor, # (pH,pW) True=Land
"daily_timef_patch": daily_timef_tensor, #(M, T=31, 2)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
#"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim)

Comment on lines +831 to +833
geo_emb = geo_emb[:, None, None, :] # (B,1,1,E)

x = agg_latent + geo_emb # (B, M, Hp*Wp, E)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
geo_emb = geo_emb[:, None, None, :] # (B,1,1,E)
x = agg_latent + geo_emb # (B, M, Hp*Wp, E)
# Broadcasting: same geo embedding for all M months and all Hp*Wp locations
# we use weighted mean patch embedding, see `geo_embedding_utils.py`
geo_emb = geo_emb[:, None, None, :] # (B,1,1,E)
x = agg_latent + geo_emb # (B, M, Hp*Wp, E)

Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meiertgrootes thanks, the implementation looks good 👍 I left some comments. Most of them are related to code style. If something isnot clear, let me know.

Also, please consider running ruff as it can fixes things automatically and helps saving time in reviewing:

pip install ruff
ruff check --fix your_script.py   # this fixes/shows errors
ruff format --check your_script.py --diff  # this shows formatting issues

Copy link
Copy Markdown
Collaborator

@rogerkuou rogerkuou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @meiertgrootes , besides Sarah's comments, I only have two very minor comments from my side. So I will already approve this PR. Please go ahead and merge after adapting Sarah's comment.

Comment thread climanet/train.py
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["geo_pos_embedding_patch"].to(device, non_blocking=use_cuda),
batch["scale_feature_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda) ,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch["padded_days_mask"].to(device, non_blocking=use_cuda) ,
batch["padded_days_mask"].to(device, non_blocking=use_cuda),

Comment thread notebooks/example.ipynb
@@ -35,7 +35,7 @@
"metadata": {},
"outputs": [],
"source": [
"data_folder = Path(\"./eso4clima\")\n",
"data_folder = Path(\"/Users/mwgrootes/Projects/REPOS/ESO4CLIMA/data/output\") #(\"./eso4clima\")\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the generality of the note book, maybe let's keep the original relative path ./eso4clima?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants