From 238585ca1524d9ce955e8ee8ce37d81ae37c086d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9mi?= <remi.cresson@inrae.fr>
Date: Fri, 15 Mar 2024 16:56:29 +0100
Subject: [PATCH] Apply to items or assets

---
 stacflow_stac_extension/testing.py | 58 +++++++++++++++++++++++++++++
 tests/extensions_test.py           | 60 +++---------------------------
 2 files changed, 64 insertions(+), 54 deletions(-)

diff --git a/stacflow_stac_extension/testing.py b/stacflow_stac_extension/testing.py
index 05480ca..fa256aa 100644
--- a/stacflow_stac_extension/testing.py
+++ b/stacflow_stac_extension/testing.py
@@ -1,6 +1,7 @@
 import pystac
 from datetime import datetime
 import random
+import json
 
 
 def create_dummy_item(date=None):
@@ -50,3 +51,60 @@ def create_dummy_item(date=None):
     col.add_item(item)
 
     return item
+
+
+def basic_test(ext_md, ext_cls, validate: bool = True):
+    print(
+        f"Extension metadata model: \n{ext_md.__class__.schema_json(indent=2)}"
+    )
+
+    def apply(stac_obj):
+        """
+        Apply the extension to the item
+        """
+        print(f"Check extension applied to {stac_obj.__class__.__name__}")
+        ext = ext_cls.ext(stac_obj, add_if_missing=True)
+        ext.apply(ext_md)
+
+    def print_item(item):
+        """
+        Print item as JSON
+        """
+        print(json.dumps(item.to_dict(), indent=2))
+
+    def comp(stac_obj):
+        """
+        Compare the metadata carried by the stac object with the expected metadata.
+        """
+        read_ext = ext_cls(stac_obj)
+        for field in ext_md.__class__.__fields__:
+            ref = getattr(ext_md, field)
+            got = getattr(read_ext, field)
+            assert got == ref, f"'{field}': values differ: {got} (expected {ref})"
+
+    def test_item():
+        """
+        Test extension against item
+        """
+        item = create_dummy_item()
+        apply(item)
+        print_item(item)
+        if validate:
+            item.validate()  # <--- This will try to read the actual schema URI
+        # Check that we can retrieve the extension metadata from the item
+        comp(item)
+
+    def test_asset():
+        """
+        Test extension against asset
+        """
+        item = create_dummy_item()
+        apply(item.assets["ndvi"])
+        print_item(item)
+        if validate:
+            item.validate()  # <--- This will try to read the actual schema URI
+        # Check that we can retrieve the extension metadata from the asset
+        comp(item.assets["ndvi"])
+
+    test_item()
+    test_asset()
diff --git a/tests/extensions_test.py b/tests/extensions_test.py
index 6d8d3b7..1dd48ac 100644
--- a/tests/extensions_test.py
+++ b/tests/extensions_test.py
@@ -1,5 +1,5 @@
 from stacflow_stac_extension import create_extension_cls
-from stacflow_stac_extension.testing import create_dummy_item
+from stacflow_stac_extension.testing import basic_test
 from pydantic import BaseModel, Field, ConfigDict
 from typing import List
 import json
@@ -9,8 +9,8 @@ SCHEMA_URI: str = "https://example.com/image-process/v1.0.0/schema.json"
 PREFIX: str = "some_prefix"
 
 
-# Extension model
-class ExtensionModelExample(BaseModel):
+# Extension metadata model
+class MyExtensionMetadataModel(BaseModel):
     # Required so that one model can be instantiated with the attribute name
     # rather than the alias
     model_config = ConfigDict(populate_by_name=True)
@@ -23,63 +23,15 @@ class ExtensionModelExample(BaseModel):
 
 # Create the extension class
 MyExtension = create_extension_cls(
-    model_cls=ExtensionModelExample,
+    model_cls=MyExtensionMetadataModel,
     schema_uri=SCHEMA_URI
 )
 
 # Metadata fields
-ext_md = ExtensionModelExample(
+ext_md = MyExtensionMetadataModel(
     name="test",
     authors=["michel", "denis"],
     version="alpha"
 )
 
-
-def apply(stac_obj):
-    """
-    Apply the extension to the STAC object, which is modified inplace
-
-    """
-    processing_ext = MyExtension.ext(stac_obj, add_if_missing=True)
-    processing_ext.apply(ext_md)
-    # item.validate()  # <--- This will try to read the actual schema URI
-
-
-def check(props):
-    """
-    Check that the properties are well filled with the metadata payload
-
-    """
-    assert all(
-        f"{PREFIX}:{field}" in props
-        for field in ext_md.__fields__
-    )
-    assert all(
-        props[f"{PREFIX}:{field}"] == getattr(ext_md, field)
-        for field in ext_md.__fields__
-    )
-
-
-def print_item(item):
-    item_dic = item.to_dict()
-    print(f"Item metadata:\n{json.dumps(item_dic, indent=2)}")
-
-
-def test_item():
-    print("Test item")
-    item = create_dummy_item()
-    apply(item)
-    print_item(item)
-    check(item.to_dict()["properties"])
-
-
-def test_asset():
-    print("Test asset")
-    item = create_dummy_item()
-    apply(item.assets["ndvi"])
-    print_item(item)
-    check(item.assets["ndvi"].extra_fields)
-
-
-test_asset()
-test_item()
+basic_test(ext_md, MyExtension, validate=False)
-- 
GitLab