[gtk+/wip/otte/shader: 89/176] gskslexpression: Split multiplication from other binary operations



commit bab24cc93a366e7723174d339532ed30815296bd
Author: Benjamin Otte <otte redhat com>
Date:   Wed Oct 4 00:07:35 2017 +0200

    gskslexpression: Split multiplication from other binary operations
    
    Also implement get_constant() and write_spv() vfuncs.

 gsk/gskslexpression.c |  686 ++++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 591 insertions(+), 95 deletions(-)
---
diff --git a/gsk/gskslexpression.c b/gsk/gskslexpression.c
index 11fa643..14bda0f 100644
--- a/gsk/gskslexpression.c
+++ b/gsk/gskslexpression.c
@@ -171,10 +171,524 @@ static const GskSlExpressionClass GSK_SL_EXPRESSION_ASSIGNMENT = {
   gsk_sl_expression_assignment_write_spv
 };
 
-/* BINARY */
+/* MULTIPLICATION */
+
+typedef struct _GskSlExpressionMultiplication GskSlExpressionMultiplication;
+
+struct _GskSlExpressionMultiplication {
+  GskSlExpression parent;
+
+  GskSlType *type;
+  GskSlExpression *left;
+  GskSlExpression *right;
+};
+
+static void
+gsk_sl_expression_multiplication_free (GskSlExpression *expression)
+{
+  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
+
+  gsk_sl_expression_unref (multiplication->left);
+  gsk_sl_expression_unref (multiplication->right);
+  gsk_sl_type_unref (multiplication->type);
+
+  g_slice_free (GskSlExpressionMultiplication, multiplication);
+}
+
+static void
+gsk_sl_expression_multiplication_print (const GskSlExpression *expression,
+                                        GskSlPrinter          *printer)
+{
+  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
+
+  gsk_sl_expression_print (multiplication->left, printer);
+  gsk_sl_printer_append (printer, " * ");
+  gsk_sl_expression_print (multiplication->right, printer);
+}
+
+static GskSlType *
+gsk_sl_expression_multiplication_get_result_type (GskSlPreprocessor *preproc,
+                                                  GskSlType         *ltype,
+                                                  GskSlType         *rtype)
+{
+  GskSlScalarType scalar;
+
+  if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (ltype),
+                                      gsk_sl_type_get_scalar_type (rtype)))
+    scalar = gsk_sl_type_get_scalar_type (ltype);
+  else if (gsk_sl_scalar_type_can_convert (gsk_sl_type_get_scalar_type (rtype),
+                                           gsk_sl_type_get_scalar_type (ltype)))
+    scalar = gsk_sl_type_get_scalar_type (rtype);
+  else
+    {
+      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                 "Operand types %s and %s do not share compatible scalar types.",
+                                 gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+      return NULL;
+    }
+  if (scalar == GSK_SL_BOOL)
+    {
+      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Cannot multiply booleans.");
+      return NULL;
+    }
+
+  if (gsk_sl_type_is_matrix (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)))
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Incompatible dimensions when multiplying %s * %s.",
+                                         gsk_sl_type_get_name (ltype),
+                                         gsk_sl_type_get_name (rtype));
+              return NULL;
+            }
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (rtype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)));
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Matrix column count doesn't match vector length.");
+              return NULL;
+            }
+          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(ltype)));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (ltype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)));
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                     "Right operand is incompatible type for multiplication.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_vector (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)))
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Vector length for %s doesn't match row count for %s",
+                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+              return NULL;
+            }
+          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (rtype));
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
+            {
+              gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                         "Vector operands %s and %s to arithmetic multiplication have 
different length.",
+                                         gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
+              return NULL;
+            }
+          return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (ltype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (ltype));
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH,
+                                     "Right operand is incompatible type for multiplication.");
+          return NULL;
+        }
+    }
+  else if (gsk_sl_type_is_scalar (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          return gsk_sl_type_get_matrix (scalar,
+                                         gsk_sl_type_get_length (rtype),
+                                         gsk_sl_type_get_length (gsk_sl_type_get_index_type (rtype)));
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          return gsk_sl_type_get_vector (scalar,
+                                         gsk_sl_type_get_length (rtype));
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          return gsk_sl_type_get_scalar (scalar);
+        }
+      else
+        {
+          gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Right operand is incompatible type for 
multiplication.");
+          return NULL;
+        }
+    }
+  else
+    {
+      gsk_sl_preprocessor_error (preproc, TYPE_MISMATCH, "Left operand is incompatible type for 
multiplication.");
+      return NULL;
+    }
+}
+
+static GskSlType *
+gsk_sl_expression_multiplication_get_return_type (const GskSlExpression *expression)
+{
+  GskSlExpressionMultiplication *multiplication = (GskSlExpressionMultiplication *) expression;
+
+  return multiplication->type;
+}
+
+#define GSK_SL_OPERATION_FUNC_SCALAR(func,type,...) \
+static void \
+func (gpointer value, gpointer scalar) \
+{ \
+  type x = *(type *) value; \
+  type y = *(type *) scalar; \
+  __VA_ARGS__ \
+  *(type *) value = x; \
+}
+GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_multiplication_int, gint32, x *= y;)
+GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_multiplication_uint, guint32, x *= y;)
+GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_multiplication_float, float, x *= y;)
+GSK_SL_OPERATION_FUNC_SCALAR(gsk_sl_expression_multiplication_double, double, x *= y;)
+static void (* mult_funcs[]) (gpointer, gpointer) = {
+  [GSK_SL_INT] = gsk_sl_expression_multiplication_int,
+  [GSK_SL_UINT] = gsk_sl_expression_multiplication_uint,
+  [GSK_SL_FLOAT] = gsk_sl_expression_multiplication_float,
+  [GSK_SL_DOUBLE] = gsk_sl_expression_multiplication_double,
+};
+
+static GskSlValue *
+gsk_sl_expression_multiplication_get_constant (const GskSlExpression *expression)
+{
+  const GskSlExpressionMultiplication *multiplication = (const GskSlExpressionMultiplication *) expression;
+  GskSlValue *result, *lvalue, *rvalue;
+  GskSlType *ltype, *rtype;
+  GskSlScalarType scalar;
+
+  scalar = gsk_sl_type_get_scalar_type (multiplication->type);
+  lvalue = gsk_sl_expression_get_constant (multiplication->left);
+  if (lvalue == NULL)
+    return NULL;
+  rvalue = gsk_sl_expression_get_constant (multiplication->right);
+  if (rvalue == NULL)
+    {
+      gsk_sl_value_free (lvalue);
+      return NULL;
+    }
+  lvalue = gsk_sl_value_convert_components (lvalue, scalar);
+  rvalue = gsk_sl_value_convert_components (rvalue, scalar);
+  ltype = gsk_sl_value_get_type (lvalue);
+  rtype = gsk_sl_value_get_type (rvalue);
+
+  if ((gsk_sl_type_is_matrix (rtype) && gsk_sl_type_is_matrix (ltype)) ||
+      (gsk_sl_type_is_vector (rtype) && gsk_sl_type_is_matrix (ltype)) ||
+      (gsk_sl_type_is_matrix (rtype) && gsk_sl_type_is_vector (ltype)))
+    {
+      gsize c, cols;
+      gsize r, rows;
+      gsize i, n;
+      gpointer data, ldata, rdata;
+
+      result = gsk_sl_value_new (multiplication->type);
+      data = gsk_sl_value_get_data (result);
+      ldata = gsk_sl_value_get_data (lvalue);
+      rdata = gsk_sl_value_get_data (rvalue);
+
+      if (gsk_sl_type_is_vector (rtype))
+        {
+          cols = 1;
+          rows = gsk_sl_type_get_length (gsk_sl_value_get_type (result));
+          n = gsk_sl_type_get_length (rtype);
+        }
+      else if (gsk_sl_type_is_vector (ltype))
+        {
+          cols = gsk_sl_type_get_length (gsk_sl_value_get_type (result));
+          rows = 1;
+          n = gsk_sl_type_get_length (ltype);
+        }
+      else
+        {
+          cols = gsk_sl_type_get_length (gsk_sl_value_get_type (result));
+          rows = gsk_sl_type_get_length (gsk_sl_type_get_index_type (gsk_sl_value_get_type (result)));
+          n = gsk_sl_type_get_length (ltype);
+        }
+#define MATRIXMULT(TYPE) G_STMT_START{\
+        for (c = 0; c < cols; c++) \
+          { \
+            for (r = 0; r < rows; r++) \
+              { \
+                TYPE result = 0; \
+                for (i = 0; i < n; i++) \
+                  { \
+                    result += *((TYPE *) rdata + c * n + i) *  \
+                              *((TYPE *) ldata + i * rows + r); \
+                  } \
+                *((TYPE *) data + c * rows + r) = result; \
+              } \
+          } \
+      }G_STMT_END
+      if (gsk_sl_type_get_scalar_type (multiplication->type) == GSK_SL_DOUBLE)
+        MATRIXMULT(double);
+      else
+        MATRIXMULT(float);
+      gsk_sl_value_free (lvalue);
+      gsk_sl_value_free (rvalue);
+      return result;
+    }
+  else
+    {
+      /* we can multiply componentwise */
+      gsize ln, rn;
+
+      ln = gsk_sl_type_get_n_components (ltype);
+      rn = gsk_sl_type_get_n_components (rtype);
+      if (ln == 1)
+        {
+          gsk_sl_value_componentwise (rvalue, mult_funcs[scalar], gsk_sl_value_get_data (lvalue));
+          gsk_sl_value_free (lvalue);
+          result = rvalue;
+        }
+      else if (rn == 1)
+        {
+          gsk_sl_value_componentwise (lvalue, mult_funcs[scalar], gsk_sl_value_get_data (rvalue));
+          gsk_sl_value_free (rvalue);
+          result = lvalue;
+        }
+      else
+        {
+          guchar *ldata, *rdata;
+          gsize i, stride;
+
+          stride = gsk_sl_scalar_type_get_size (scalar);
+          ldata = gsk_sl_value_get_data (lvalue);
+          rdata = gsk_sl_value_get_data (rvalue);
+          for (i = 0; i < ln; i++)
+            {
+              mult_funcs[scalar] (ldata + i * stride, rdata + i * stride);
+            }
+          gsk_sl_value_free (rvalue);
+          result = lvalue;
+        }
+    }
+
+  return result;
+}
+
+static guint32
+gsk_sl_expression_multiplication_write_spv (const GskSlExpression *expression,
+                                            GskSpvWriter          *writer)
+{
+  const GskSlExpressionMultiplication *multiplication = (const GskSlExpressionMultiplication *) expression;
+  GskSlType *ltype, *rtype;
+  guint32 left_id, right_id, result_id, result_type_id;
+
+  ltype = gsk_sl_expression_get_return_type (multiplication->left);
+  rtype = gsk_sl_expression_get_return_type (multiplication->right);
+
+  left_id = gsk_sl_expression_write_spv (multiplication->left, writer);
+  if (gsk_sl_type_get_scalar_type (ltype) != gsk_sl_type_get_scalar_type (multiplication->type))
+    {
+      GskSlType *new_type = gsk_sl_type_get_matrix (gsk_sl_type_get_scalar_type (multiplication->type),
+                                                    gsk_sl_type_get_length (ltype),
+                                                    gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(ltype)));
+      left_id = gsk_spv_writer_add_conversion (writer, left_id, ltype, new_type);
+    }
+  right_id = gsk_sl_expression_write_spv (multiplication->right, writer);
+  if (gsk_sl_type_get_scalar_type (rtype) != gsk_sl_type_get_scalar_type (multiplication->type))
+    {
+      GskSlType *new_type = gsk_sl_type_get_matrix (gsk_sl_type_get_scalar_type (multiplication->type),
+                                                    gsk_sl_type_get_length (rtype),
+                                                    gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(rtype)));
+      right_id = gsk_spv_writer_add_conversion (writer, right_id, rtype, new_type);
+    }
+
+  result_type_id = gsk_spv_writer_get_id_for_type (writer, multiplication->type);
+  result_id = gsk_spv_writer_next_id (writer);
+
+  if (gsk_sl_type_is_matrix (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_MATRIX_TIMES_MATRIX,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             left_id,
+                                             right_id });
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_VECTOR_TIMES_MATRIX,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             right_id,
+                                             left_id });
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_MATRIX_TIMES_SCALAR,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             left_id,
+                                             right_id });
+
+          return result_id;
+        }
+    }
+  else if (gsk_sl_type_is_vector (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_MATRIX_TIMES_VECTOR,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             right_id,
+                                             left_id });
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          switch (gsk_sl_type_get_scalar_type (multiplication->type))
+            {
+            case GSK_SL_FLOAT:
+            case GSK_SL_DOUBLE:
+              gsk_spv_writer_add (writer,
+                                  GSK_SPV_WRITER_SECTION_CODE,
+                                  5, GSK_SPV_OP_F_MUL,
+                                  (guint32[4]) { result_type_id,
+                                                 result_id,
+                                                 left_id,
+                                                 right_id });
+              break;
+            case GSK_SL_INT:
+            case GSK_SL_UINT:
+              gsk_spv_writer_add (writer,
+                                  GSK_SPV_WRITER_SECTION_CODE,
+                                  5, GSK_SPV_OP_I_MUL,
+                                  (guint32[4]) { result_type_id,
+                                                 result_id,
+                                                 left_id,
+                                                 right_id });
+              break;
+            case GSK_SL_VOID:
+            case GSK_SL_BOOL:
+            default:
+              g_assert_not_reached ();
+              break;
+            }
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_VECTOR_TIMES_SCALAR,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             left_id,
+                                             right_id });
+
+          return result_id;
+        }
+    }
+  else if (gsk_sl_type_is_scalar (ltype))
+    {
+      if (gsk_sl_type_is_matrix (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_MATRIX_TIMES_SCALAR,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             right_id,
+                                             left_id });
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          gsk_spv_writer_add (writer,
+                              GSK_SPV_WRITER_SECTION_CODE,
+                              5, GSK_SPV_OP_VECTOR_TIMES_SCALAR,
+                              (guint32[4]) { result_type_id,
+                                             result_id,
+                                             right_id,
+                                             left_id });
+
+          return result_id;
+        }
+      else if (gsk_sl_type_is_scalar (rtype))
+        {
+          switch (gsk_sl_type_get_scalar_type (multiplication->type))
+            {
+            case GSK_SL_FLOAT:
+            case GSK_SL_DOUBLE:
+              gsk_spv_writer_add (writer,
+                                  GSK_SPV_WRITER_SECTION_CODE,
+                                  5, GSK_SPV_OP_F_MUL,
+                                  (guint32[4]) { result_type_id,
+                                                 result_id,
+                                                 left_id,
+                                                 right_id });
+              break;
+            case GSK_SL_INT:
+            case GSK_SL_UINT:
+              gsk_spv_writer_add (writer,
+                                  GSK_SPV_WRITER_SECTION_CODE,
+                                  5, GSK_SPV_OP_I_MUL,
+                                  (guint32[4]) { result_type_id,
+                                                 result_id,
+                                                 left_id,
+                                                 right_id });
+              break;
+            case GSK_SL_VOID:
+            case GSK_SL_BOOL:
+            default:
+              g_assert_not_reached ();
+              break;
+            }
+
+          return result_id;
+        }
+    }
+
+  g_assert_not_reached ();
+
+  return 0;
+}
+
+static const GskSlExpressionClass GSK_SL_EXPRESSION_MULTIPLICATION = {
+  gsk_sl_expression_multiplication_free,
+  gsk_sl_expression_multiplication_print,
+  gsk_sl_expression_multiplication_get_return_type,
+  gsk_sl_expression_multiplication_get_constant,
+  gsk_sl_expression_multiplication_write_spv
+};
+
+/* OPERATION */
 
 typedef enum {
-  GSK_SL_OPERATION_MUL,
   GSK_SL_OPERATION_DIV,
   GSK_SL_OPERATION_MOD,
   GSK_SL_OPERATION_ADD,
@@ -222,7 +736,6 @@ gsk_sl_expression_operation_print (const GskSlExpression *expression,
                                    GskSlPrinter          *printer)
 {
   const char *op_str[] = {
-    [GSK_SL_OPERATION_MUL] = " * ",
     [GSK_SL_OPERATION_DIV] = " / ",
     [GSK_SL_OPERATION_MOD] = " % ",
     [GSK_SL_OPERATION_ADD] = " + ",
@@ -253,7 +766,6 @@ gsk_sl_expression_operation_print (const GskSlExpression *expression,
 
 static GskSlType *
 gsk_sl_expression_arithmetic_type_check (GskSlPreprocessor *stream,
-                                         gboolean           multiply,
                                          GskSlType         *ltype,
                                          GskSlType         *rtype)
 {
@@ -280,60 +792,30 @@ gsk_sl_expression_arithmetic_type_check (GskSlPreprocessor *stream,
     {
       if (gsk_sl_type_is_matrix (rtype))
         {
-          if (multiply)
-            {
-              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(rtype)))
-                {
-                  if (stream)
-                    gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
-                                               "Matrices to multiplication have incompatible dimensions.");
-                  return NULL;
-                }
-              return gsk_sl_type_get_matrix (scalar,
-                                             gsk_sl_type_get_length (gsk_sl_type_get_index_type (ltype)),
-                                             gsk_sl_type_get_length (rtype));
-            }
-          else
+          if (gsk_sl_type_can_convert (ltype, rtype))
             {
-              if (gsk_sl_type_can_convert (ltype, rtype))
-                {
-                  return ltype;
-                }
-              else if (gsk_sl_type_can_convert (rtype, ltype))
-                {
-                  return rtype;
-                }
-              else
-                {
-                  if (stream)
-                    gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
-                                               "Matrix types %s and %s have different size.",
-                                               gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
-                  return NULL;
-                }
+              return ltype;
             }
-        }
-      else if (gsk_sl_type_is_vector (rtype))
-        {
-          if (multiply)
+          else if (gsk_sl_type_can_convert (rtype, ltype))
             {
-              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (rtype))
-                {
-                  if (stream)
-                    gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
-                                               "Matrix column count doesn't match vector length.");
-                  return NULL;
-                }
-              return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(ltype)));
+              return rtype;
             }
           else
             {
               if (stream)
                 gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
-                                           "Cannot perform arithmetic operation between matrix and vector.");
+                                           "Matrix types %s and %s have different size.",
+                                           gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
               return NULL;
             }
         }
+      else if (gsk_sl_type_is_vector (rtype))
+        {
+          if (stream)
+            gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
+                                       "Cannot perform arithmetic operation between matrix and vector.");
+          return NULL;
+        }
       else if (gsk_sl_type_is_scalar (rtype))
         {
           return gsk_sl_type_get_matrix (scalar,
@@ -352,24 +834,9 @@ gsk_sl_expression_arithmetic_type_check (GskSlPreprocessor *stream,
     {
       if (gsk_sl_type_is_matrix (rtype))
         {
-          if (multiply)
-            {
-              if (gsk_sl_type_get_length (ltype) != gsk_sl_type_get_length (gsk_sl_type_get_index_type 
(rtype)))
-                {
-                  if (stream)
-                    gsk_sl_preprocessor_error (stream, TYPE_MISMATCH,
-                                               "Vector length for %s doesn't match row count for %s",
-                                               gsk_sl_type_get_name (ltype), gsk_sl_type_get_name (rtype));
-                  return NULL;
-                }
-              return gsk_sl_type_get_vector (scalar, gsk_sl_type_get_length (rtype));
-            }
-          else
-            {
-              if (stream)
-                gsk_sl_preprocessor_error (stream, TYPE_MISMATCH, "Cannot perform arithmetic operation 
between vector and matrix.");
-              return NULL;
-            }
+          if (stream)
+            gsk_sl_preprocessor_error (stream, TYPE_MISMATCH, "Cannot perform arithmetic operation between 
vector and matrix.");
+          return NULL;
         }
       else if (gsk_sl_type_is_vector (rtype))
         {
@@ -568,16 +1035,10 @@ gsk_sl_expression_operation_get_return_type (const GskSlExpression *expression)
 
   switch (operation->op)
   {
-    case GSK_SL_OPERATION_MUL:
-      return gsk_sl_expression_arithmetic_type_check (NULL,
-                                                      TRUE,
-                                                      gsk_sl_expression_get_return_type (operation->left),
-                                                      gsk_sl_expression_get_return_type (operation->right));
     case GSK_SL_OPERATION_DIV:
     case GSK_SL_OPERATION_ADD:
     case GSK_SL_OPERATION_SUB:
       return gsk_sl_expression_arithmetic_type_check (NULL,
-                                                      FALSE,
                                                       gsk_sl_expression_get_return_type (operation->left),
                                                       gsk_sl_expression_get_return_type (operation->right));
     case GSK_SL_OPERATION_LSHIFT:
@@ -2321,9 +2782,8 @@ gsk_sl_expression_parse_multiplicative (GskSlScope        *scope,
                                         GskSlPreprocessor *stream)
 {
   const GskSlToken *token;
-  GskSlExpression *expression;
-  GskSlExpressionOperation *operation;
-  GskSlOperation op;
+  GskSlExpression *expression, *right;
+  enum { MUL, DIV, MOD } op;
 
   expression = gsk_sl_expression_parse_unary (scope, stream);
 
@@ -2331,35 +2791,72 @@ gsk_sl_expression_parse_multiplicative (GskSlScope        *scope,
     {
       token = gsk_sl_preprocessor_get (stream);
       if (gsk_sl_token_is (token, GSK_SL_TOKEN_STAR))
-        op = GSK_SL_OPERATION_MUL;
+        op = MUL;
       else if (gsk_sl_token_is (token, GSK_SL_TOKEN_SLASH))
-        op = GSK_SL_OPERATION_DIV;
+        op = DIV;
       else if (gsk_sl_token_is (token, GSK_SL_TOKEN_PERCENT))
-        op = GSK_SL_OPERATION_MOD;
+        op = MOD;
       else
         return expression;
 
-      operation = gsk_sl_expression_new (GskSlExpressionOperation, &GSK_SL_EXPRESSION_OPERATION);
-      operation->left = expression;
-      operation->op = op;
-      gsk_sl_preprocessor_consume (stream, (GskSlExpression *) operation);
-      operation->right = gsk_sl_expression_parse_unary (scope, stream);
-      if ((op == GSK_SL_OPERATION_MOD &&
-           !gsk_sl_expression_bitwise_type_check (stream,
-                                            gsk_sl_expression_get_return_type (operation->left),
-                                            gsk_sl_expression_get_return_type (operation->right))) ||
-          (op != GSK_SL_OPERATION_MOD &&
-           !gsk_sl_expression_arithmetic_type_check (stream,
-                                               FALSE,
-                                               gsk_sl_expression_get_return_type (operation->left),
-                                               gsk_sl_expression_get_return_type (operation->right))))
+      gsk_sl_preprocessor_consume (stream, NULL);
+      right = gsk_sl_expression_parse_unary (scope, stream);
+      if (op == MUL)
         {
-          gsk_sl_expression_ref (expression);
-          gsk_sl_expression_unref ((GskSlExpression *) operation);
+          GskSlType *result_type;
+
+          result_type = gsk_sl_expression_multiplication_get_result_type (stream,
+                                                                          gsk_sl_expression_get_return_type 
(expression),
+                                                                          gsk_sl_expression_get_return_type 
(right));
+          if (result_type)
+            {
+              GskSlExpressionMultiplication *multiplication;
+              multiplication = gsk_sl_expression_new (GskSlExpressionMultiplication, 
&GSK_SL_EXPRESSION_MULTIPLICATION);
+              multiplication->type = gsk_sl_type_ref (result_type);
+              multiplication->left = expression;
+              multiplication->right = right;
+              expression = (GskSlExpression *) multiplication;
+            }
+          else
+            {
+              gsk_sl_expression_unref ((GskSlExpression *) right);
+            }
+        }
+      else if (op == DIV)
+        {
+          if (gsk_sl_expression_arithmetic_type_check (stream,
+                                                       gsk_sl_expression_get_return_type (expression),
+                                                       gsk_sl_expression_get_return_type (right)))
+            {
+              GskSlExpressionOperation *operation;
+              operation = gsk_sl_expression_new (GskSlExpressionOperation, &GSK_SL_EXPRESSION_OPERATION);
+              operation->op = GSK_SL_OPERATION_DIV;
+              operation->left = expression;
+              operation->right = right;
+              expression = (GskSlExpression *) operation;
+            }
+          else
+            {
+              gsk_sl_expression_unref ((GskSlExpression *) right);
+            }
         }
       else
         {
-          expression = (GskSlExpression *) operation;
+          if (gsk_sl_expression_bitwise_type_check (stream,
+                                                    gsk_sl_expression_get_return_type (expression),
+                                                    gsk_sl_expression_get_return_type (right)))
+            {
+              GskSlExpressionOperation *operation;
+              operation = gsk_sl_expression_new (GskSlExpressionOperation, &GSK_SL_EXPRESSION_OPERATION);
+              operation->op = GSK_SL_OPERATION_MOD;
+              operation->left = expression;
+              operation->right = right;
+              expression = (GskSlExpression *) operation;
+            }
+          else
+            {
+              gsk_sl_expression_unref ((GskSlExpression *) right);
+            }
         }
     }
 
@@ -2395,7 +2892,6 @@ gsk_sl_expression_parse_additive (GskSlScope        *scope,
       gsk_sl_preprocessor_consume (stream, (GskSlExpression *) operation);
       operation->right = gsk_sl_expression_parse_additive (scope, stream);
       if (!gsk_sl_expression_arithmetic_type_check (stream,
-                                                    FALSE,
                                                     gsk_sl_expression_get_return_type (operation->left),
                                                     gsk_sl_expression_get_return_type (operation->right)))
         {


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]