From 46bb19e51c2b394a2b20ef4040d63a63ef98889e Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@inrae.fr>
Date: Tue, 22 Aug 2023 20:06:08 +0200
Subject: [PATCH 1/2] ENH: option to expand dims in otbtf.Argmax

---
 otbtf/layers.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/otbtf/layers.py b/otbtf/layers.py
index a3680421..8fc76deb 100644
--- a/otbtf/layers.py
+++ b/otbtf/layers.py
@@ -136,13 +136,15 @@ class Argmax(tf.keras.layers.Layer):
     Useful to transform a softmax into a "categorical" map for instance.
 
     """
-    def __init__(self, name: str = None):
+    def __init__(self, name: str = None, expand_last_dim: bool = True):
         """
         Params:
             name: layer name
+            expand_last_dim: expand the last dimension when True
 
         """
         super().__init__(name=name)
+        self.expand_last_dim = expand_last_dim
 
     def call(self, inputs):
         """
@@ -157,7 +159,10 @@ class Argmax(tf.keras.layers.Layer):
             (nb_classes - 1).
 
         """
-        return tf.expand_dims(tf.math.argmax(inputs, axis=-1), axis=-1)
+        argmax = tf.math.argmax(inputs, axis=-1)
+        if self.expand_last_dim:
+            return tf.expand_dims(argmax, axis=-1)
+        return argmax
 
 
 class Max(tf.keras.layers.Layer):
-- 
GitLab


From 887fb17332dd485edb528c4a53d2d030e3ca7b31 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@inrae.fr>
Date: Tue, 22 Aug 2023 20:08:01 +0200
Subject: [PATCH 2/2] STY: linting

---
 otbtf/layers.py | 2 +-
 otbtf/ops.py    | 4 +++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/otbtf/layers.py b/otbtf/layers.py
index 8fc76deb..ef65ec1c 100644
--- a/otbtf/layers.py
+++ b/otbtf/layers.py
@@ -55,7 +55,7 @@ class DilatedMask(tf.keras.layers.Layer):
         nodata_mask = tf.cast(tf.math.equal(inp, self.nodata_value), tf.uint8)
 
         se_size = 1 + 2 * self.radius
-        # Create a morphological kernel suitable for binary dilatation, see 
+        # Create a morphological kernel suitable for binary dilatation, see
         # https://stackoverflow.com/q/54686895/13711499
         kernel = tf.zeros((se_size, se_size, 1), dtype=tf.uint8)
         conv2d_out = tf.nn.dilation2d(
diff --git a/otbtf/ops.py b/otbtf/ops.py
index 4a8d0b96..ef5c52b9 100644
--- a/otbtf/ops.py
+++ b/otbtf/ops.py
@@ -30,6 +30,8 @@ import tensorflow as tf
 
 Tensor = Any
 Scalars = List[float] | Tuple[float]
+
+
 def one_hot(labels: Tensor, nb_classes: int):
     """
     Converts labels values into one-hot vector.
@@ -43,4 +45,4 @@ def one_hot(labels: Tensor, nb_classes: int):
 
     """
     labels_xy = tf.squeeze(tf.cast(labels, tf.int32), axis=-1)  # shape [x, y]
-    return tf.one_hot(labels_xy, depth=nb_classes)  # shape [x, y, nb_classes]
\ No newline at end of file
+    return tf.one_hot(labels_xy, depth=nb_classes)  # shape [x, y, nb_classes]
-- 
GitLab