[gtk+/wip/otte/shader: 15/56] gskslnode: Add arithmetic operations



commit 796e823bec2673875af6e0dd71d0778098710de6
Author: Benjamin Otte <otte redhat com>
Date:   Sun Sep 17 18:55:56 2017 +0200

    gskslnode: Add arithmetic operations
    
    This required adding gsk_sl_type_get_index_type() which is used to get
    the type for an array index operation (or NULL if array indexing is not
    allowed).
    
    It can also be abused to get the row count of matrices, which is what we
    do here.

 gsk/gskslnode.c        |  291 +++++++++++++++++++++++++++++++++++++++++++++++-
 gsk/gsksltype.c        |   78 +++++++++----
 gsk/gsksltypeprivate.h |    3 +
 3 files changed, 345 insertions(+), 27 deletions(-)
---
diff --git a/gsk/gskslnode.c b/gsk/gskslnode.c
index 72c92e9..91cc5c3 100644
--- a/gsk/gskslnode.c
+++ b/gsk/gskslnode.c
@@ -350,6 +350,174 @@ gsk_sl_node_operation_print (GskSlNode *node,
 }
 
 static GskSlType *
+gsk_sl_node_arithmetic_type_check (GskSlPreprocessor *stream,
+                                   gboolean           multiply,
+                                   GskSlType         *ltype,
+                                   GskSlType         *rtype)
+{
+  GskSlScalarType scalar;
+
+  if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (ltype),
+                                      gsk_sl_type_get_scalar_type (rtype)))
+    scalar = gsk_sl_type_get_scalar_type (ltype);
+  else if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (rtype),
+                                           gsk_sl_type_get_scalar_type (ltype)))
+    scalar = gsk_sl_type_get_scalar_type (rtype);
+  else
+    {
+      if (stream)
+        {
+          char *lstr = gsk_sl_type_to_string (ltype);
+          char *rstr = gsk_sl_type_to_string (rtype);
+          gsk_sl_preprocessor_error (stream, "Operand types %s and %s do not share compatible scalar 
types.", lstr, rstr);
+          g_free (lstr);
+          g_free (rstr);
+        }
+      return NULL;
+    }
+
+  if (gsk_sl_type_is_matrix (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          if (multiply)
+            {
+              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(rtype)))
+                {
+                  if (stream)
+                    gsk_sl_preprocessor_error (stream, "Matrices to multiplication have incompatible 
dimensions.");
+                  return NULL;
+                }
+              return gsk_sl_type_get_matrix (scalar,
+                                             gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)),
+                                             gsk_sl_type_get_length (rtype));
+            }
+          else
+            {
+              if (gsk_sl_type_can_convert (ltype, rtype))
+                {
+                  return ltype;
+                }
+              else if (gsk_sl_type_can_convert (rtype, ltype))
+                {
+                  return rtype;
+                }
+              else
+                {
+                  if (stream)
+                    gsk_sl_preprocessor_error (stream, "Matrices to arithmetic operation have different 
size.");
+                  return NULL;
+                }
+            }
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (multiply)
+            {
+              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
+                {
+                  if (stream)
+                    gsk_sl_preprocessor_error (stream, "Matrix column count doesn't match vector length.");
+                  return NULL;
+                }
+              return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(ltype)));
+            }
+          else
+            {
+              if (stream)
+                gsk_sl_preprocessor_error (stream, "Cannot perform arithmetic operation between matrix and 
vector.");
+              return NULL;
+            }
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (ltype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)));
+        }
+      else
+        {
+          if (stream)
+            gsk_sl_preprocessor_error (stream, "Right operand is incompatible type for arithemtic 
operation.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_vector (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          if (multiply)
+            {
+              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(rtype)))
+                {
+                  if (stream)
+                    gsk_sl_preprocessor_error (stream, "Vector length doesn't match matrix row count.");
+                  return NULL;
+                }
+              return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (rtype));
+            }
+          else
+            {
+              if (stream)
+                gsk_sl_preprocessor_error (stream, "Cannot perform arithmetic operation between vector and 
matrix.");
+              return NULL;
+            }
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
+            {
+              if (stream)
+                gsk_sl_preprocessor_error (stream, "Vector operands to arithmetic operation have different 
length.");
+              return NULL;
+            }
+          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (ltype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (ltype));
+        }
+      else
+        {
+          if (stream)
+            gsk_sl_preprocessor_error (stream, "Right operand is incompatible type for arithemtic 
operation.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_scalar (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (rtype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)));
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (rtype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_scalar (scalar);
+        }
+      else
+        {
+          if (stream)
+            gsk_sl_preprocessor_error (stream, "Right operand is incompatible type for arithemtic 
operation.");
+          return NULL;
+        }
+    }
+  else
+    {
+      if (stream)
+        gsk_sl_preprocessor_error (stream, "Left operand is incompatible type for arithemtic operation.");
+      return NULL;
+    }
+}
+
+static GskSlType *
 gsk_sl_node_bitwise_type_check (GskSlPreprocessor *stream,
                                 GskSlType         *ltype,
                                 GskSlType         *rtype)
@@ -489,15 +657,21 @@ gsk_sl_node_operation_get_return_type (GskSlNode *node)
   switch (operation->op)
   {
     case GSK_SL_OPERATION_MUL:
+      return gsk_sl_node_arithmetic_type_check (NULL,
+                                                TRUE,
+                                                gsk_sl_node_get_return_type (operation->left),
+                                                gsk_sl_node_get_return_type (operation->right));
     case GSK_SL_OPERATION_DIV:
-    case GSK_SL_OPERATION_MOD:
     case GSK_SL_OPERATION_ADD:
     case GSK_SL_OPERATION_SUB:
-      g_assert_not_reached ();
-      return NULL;
+      return gsk_sl_node_arithmetic_type_check (NULL,
+                                                FALSE,
+                                                gsk_sl_node_get_return_type (operation->left),
+                                                gsk_sl_node_get_return_type (operation->right));
     case GSK_SL_OPERATION_LSHIFT:
     case GSK_SL_OPERATION_RSHIFT:
       return gsk_sl_node_get_return_type (operation->left);
+    case GSK_SL_OPERATION_MOD:
     case GSK_SL_OPERATION_AND:
     case GSK_SL_OPERATION_XOR:
     case GSK_SL_OPERATION_OR:
@@ -868,11 +1042,120 @@ gsk_sl_node_parse_primary_expression (GskSlNodeProgram  *program,
 }
 
 static GskSlNode *
+gsk_sl_node_parse_unary_expression (GskSlNodeProgram  *program,
+                                    GskSlScope        *scope,
+                                    GskSlPreprocessor *stream)
+{
+  return gsk_sl_node_parse_primary_expression (program, scope, stream);
+}
+
+static GskSlNode *
+gsk_sl_node_parse_multiplicative_expression (GskSlNodeProgram  *program,
+                                             GskSlScope        *scope,
+                                             GskSlPreprocessor *stream)
+{
+  const GskSlToken *token;
+  GskSlNode *node;
+  GskSlNodeOperation *operation;
+  GskSlOperation op;
+
+  node = gsk_sl_node_parse_unary_expression (program, scope, stream);
+  if (node == NULL)
+    return NULL;
+
+  while (TRUE)
+    {
+      token = gsk_sl_preprocessor_get (stream);
+      if (gsk_sl_token_is (token, GSK_SL_TOKEN_STAR))
+        op = GSK_SL_OPERATION_MUL;
+      else if (gsk_sl_token_is (token, GSK_SL_TOKEN_SLASH))
+        op = GSK_SL_OPERATION_DIV;
+      else if (gsk_sl_token_is (token, GSK_SL_TOKEN_PERCENT))
+        op = GSK_SL_OPERATION_MOD;
+      else
+        return node;
+
+      operation = gsk_sl_node_new (GskSlNodeOperation, &GSK_SL_NODE_OPERATION);
+      operation->left = node;
+      operation->op = op;
+      gsk_sl_preprocessor_consume (stream, (GskSlNode *) operation);
+      operation->right = gsk_sl_node_parse_unary_expression (program, scope, stream);
+      if (operation->right == NULL)
+        {
+          gsk_sl_node_ref (node);
+          gsk_sl_node_unref ((GskSlNode *) operation);
+        }
+      else if ((op == GSK_SL_OPERATION_MOD &&
+                !gsk_sl_node_bitwise_type_check (stream,
+                                                 gsk_sl_node_get_return_type (operation->left),
+                                                 gsk_sl_node_get_return_type (operation->right))) ||
+               (op != GSK_SL_OPERATION_MOD &&
+                !gsk_sl_node_arithmetic_type_check (stream,
+                                                    FALSE,
+                                                    gsk_sl_node_get_return_type (operation->left),
+                                                    gsk_sl_node_get_return_type (operation->right))))
+        {
+          gsk_sl_node_ref (node);
+          gsk_sl_node_unref ((GskSlNode *) operation);
+        }
+      else
+        {
+          node = (GskSlNode *) operation;
+        }
+    }
+
+  return node;
+}
+
+static GskSlNode *
 gsk_sl_node_parse_additive_expression (GskSlNodeProgram  *program,
                                        GskSlScope        *scope,
                                        GskSlPreprocessor *stream)
 {
-  return gsk_sl_node_parse_primary_expression (program, scope, stream);
+  const GskSlToken *token;
+  GskSlNode *node;
+  GskSlNodeOperation *operation;
+  GskSlOperation op;
+
+  node = gsk_sl_node_parse_multiplicative_expression (program, scope, stream);
+  if (node == NULL)
+    return NULL;
+
+  while (TRUE)
+    {
+      token = gsk_sl_preprocessor_get (stream);
+      if (gsk_sl_token_is (token, GSK_SL_TOKEN_PLUS))
+        op = GSK_SL_OPERATION_ADD;
+      else if (gsk_sl_token_is (token, GSK_SL_TOKEN_DASH))
+        op = GSK_SL_OPERATION_SUB;
+      else
+        return node;
+
+      operation = gsk_sl_node_new (GskSlNodeOperation, &GSK_SL_NODE_OPERATION);
+      operation->left = node;
+      operation->op = op;
+      gsk_sl_preprocessor_consume (stream, (GskSlNode *) operation);
+      operation->right = gsk_sl_node_parse_additive_expression (program, scope, stream);
+      if (operation->right == NULL)
+        {
+          gsk_sl_node_ref (node);
+          gsk_sl_node_unref ((GskSlNode *) operation);
+        }
+      else if (!gsk_sl_node_arithmetic_type_check (stream,
+                                                   FALSE,
+                                                   gsk_sl_node_get_return_type (operation->left),
+                                                   gsk_sl_node_get_return_type (operation->right)))
+        {
+          gsk_sl_node_ref (node);
+          gsk_sl_node_unref ((GskSlNode *) operation);
+        }
+      else
+        {
+          node = (GskSlNode *) operation;
+        }
+    }
+
+  return node;
 }
 
 static GskSlNode *
diff --git a/gsk/gsksltype.c b/gsk/gsksltype.c
index 686f15f..aeb0853 100644
--- a/gsk/gsksltype.c
+++ b/gsk/gsksltype.c
@@ -42,34 +42,12 @@ struct _GskSlTypeClass {
   void                  (* print)                               (GskSlType           *type,
                                                                  GString             *string);
   GskSlScalarType       (* get_scalar_type)                     (GskSlType           *type);
+  GskSlType *           (* get_index_type)                      (GskSlType           *type);
   guint                 (* get_length)                          (GskSlType           *type);
   gboolean              (* can_convert)                         (GskSlType           *target,
                                                                  GskSlType           *source);
 };
 
-static gboolean
-gsk_sl_scalar_type_can_convert (GskSlScalarType target,
-                                GskSlScalarType source)
-{
-  if (target == source)
-    return TRUE;
-
-  switch (source)
-  {
-    case GSK_SL_INT:
-      return target == GSK_SL_UINT
-          || target == GSK_SL_FLOAT
-          || target == GSK_SL_DOUBLE;
-    case GSK_SL_UINT:
-      return target == GSK_SL_FLOAT
-          || target == GSK_SL_DOUBLE;
-    case GSK_SL_FLOAT:
-      return target == GSK_SL_DOUBLE;
-    default:
-      return FALSE;
-  }
-}
-
 /* SCALAR */
 
 typedef struct _GskSlTypeScalar GskSlTypeScalar;
@@ -126,6 +104,12 @@ gsk_sl_type_scalar_get_scalar_type (GskSlType *type)
   return scalar->scalar;
 }
 
+static GskSlType *
+gsk_sl_type_scalar_get_index_type (GskSlType *type)
+{
+  return NULL;
+}
+
 static guint
 gsk_sl_type_scalar_get_length (GskSlType *type)
 {
@@ -149,6 +133,7 @@ static const GskSlTypeClass GSK_SL_TYPE_SCALAR = {
   gsk_sl_type_scalar_free,
   gsk_sl_type_scalar_print,
   gsk_sl_type_scalar_get_scalar_type,
+  gsk_sl_type_scalar_get_index_type,
   gsk_sl_type_scalar_get_length,
   gsk_sl_type_scalar_can_convert
 };
@@ -210,6 +195,14 @@ gsk_sl_type_vector_get_scalar_type (GskSlType *type)
   return vector->scalar;
 }
 
+static GskSlType *
+gsk_sl_type_vector_get_index_type (GskSlType *type)
+{
+  GskSlTypeVector *vector = (GskSlTypeVector *) type;
+
+  return gsk_sl_type_get_scalar (vector->scalar);
+}
+
 static guint
 gsk_sl_type_vector_get_length (GskSlType *type)
 {
@@ -238,6 +231,7 @@ static const GskSlTypeClass GSK_SL_TYPE_VECTOR = {
   gsk_sl_type_vector_free,
   gsk_sl_type_vector_print,
   gsk_sl_type_vector_get_scalar_type,
+  gsk_sl_type_vector_get_index_type,
   gsk_sl_type_vector_get_length,
   gsk_sl_type_vector_can_convert
 };
@@ -282,6 +276,14 @@ gsk_sl_type_matrix_get_scalar_type (GskSlType *type)
   return matrix->scalar;
 }
 
+static GskSlType *
+gsk_sl_type_matrix_get_index_type (GskSlType *type)
+{
+  GskSlTypeMatrix *matrix = (GskSlTypeMatrix *) type;
+
+  return gsk_sl_type_get_vector (matrix->scalar, matrix->rows);
+}
+
 static guint
 gsk_sl_type_matrix_get_length (GskSlType *type)
 {
@@ -311,6 +313,7 @@ static const GskSlTypeClass GSK_SL_TYPE_MATRIX = {
   gsk_sl_type_matrix_free,
   gsk_sl_type_matrix_print,
   gsk_sl_type_matrix_get_scalar_type,
+  gsk_sl_type_matrix_get_index_type,
   gsk_sl_type_matrix_get_length,
   gsk_sl_type_matrix_can_convert
 };
@@ -634,6 +637,12 @@ gsk_sl_type_get_scalar_type (const GskSlType *type)
   return type->class->get_scalar_type (type);
 }
 
+GskSlType *
+gsk_sl_type_get_index_type (const GskSlType *type)
+{
+  return type->class->get_index_type (type);
+}
+
 GskSlScalarType
 gsk_sl_type_get_length (const GskSlType *type)
 {
@@ -641,6 +650,29 @@ gsk_sl_type_get_length (const GskSlType *type)
 }
 
 gboolean
+gsk_sl_scalar_type_can_convert (GskSlScalarType target,
+                                GskSlScalarType source)
+{
+  if (target == source)
+    return TRUE;
+
+  switch (source)
+  {
+    case GSK_SL_INT:
+      return target == GSK_SL_UINT
+          || target == GSK_SL_FLOAT
+          || target == GSK_SL_DOUBLE;
+    case GSK_SL_UINT:
+      return target == GSK_SL_FLOAT
+          || target == GSK_SL_DOUBLE;
+    case GSK_SL_FLOAT:
+      return target == GSK_SL_DOUBLE;
+    default:
+      return FALSE;
+  }
+}
+
+gboolean
 gsk_sl_type_can_convert (const GskSlType *target,
                          const GskSlType *source)
 {
diff --git a/gsk/gsksltypeprivate.h b/gsk/gsksltypeprivate.h
index b7166ac..d522e1b 100644
--- a/gsk/gsksltypeprivate.h
+++ b/gsk/gsksltypeprivate.h
@@ -53,7 +53,10 @@ gboolean                gsk_sl_type_is_scalar                   (const GskSlType
 gboolean                gsk_sl_type_is_vector                   (const GskSlType     *type);
 gboolean                gsk_sl_type_is_matrix                   (const GskSlType     *type);
 GskSlScalarType         gsk_sl_type_get_scalar_type             (const GskSlType     *type);
+GskSlType *             gsk_sl_type_get_index_type              (const GskSlType     *type);
 guint                   gsk_sl_type_get_length                  (const GskSlType     *type);
+gboolean                gsk_sl_scalar_type_can_convert          (GskSlScalarType      target,
+                                                                 GskSlScalarType      source);
 gboolean                gsk_sl_type_can_convert                 (const GskSlType     *target,
                                                                  const GskSlType     *source);
 


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]