[gnumeric] Solver: add analytic gradient computation.



commit c83b1d792ec6b0c9ea2c65420c36f214ef9d72f4
Author: Morten Welinder <terra gnome org>
Date:   Sun Sep 11 18:42:26 2016 -0400

    Solver: add analytic gradient computation.

 src/tools/gnm-solver.c |   68 ++++++++++++++++++++++++++++++++++++++++++++++-
 src/tools/gnm-solver.h |    5 +++
 2 files changed, 71 insertions(+), 2 deletions(-)
---
diff --git a/src/tools/gnm-solver.c b/src/tools/gnm-solver.c
index 7b2ba0b..c34b8a0 100644
--- a/src/tools/gnm-solver.c
+++ b/src/tools/gnm-solver.c
@@ -3,6 +3,7 @@
 #include "value.h"
 #include "cell.h"
 #include "expr.h"
+#include "expr-deriv.h"
 #include "sheet.h"
 #include "workbook.h"
 #include "rangefunc.h"
@@ -850,6 +851,12 @@ gnm_solver_dispose (GObject *obj)
                gnm_solver_update_derived (sol);
        }
 
+       if (sol->gradient) {
+               sol->gradient_status = 0;
+               g_ptr_array_unref (sol->gradient);
+               sol->gradient = NULL;
+       }
+
        gnm_solver_parent_class->dispose (obj);
 }
 
@@ -1943,6 +1950,57 @@ gnm_solver_restore_vars (GnmSolver *sol, GPtrArray *vals)
        g_ptr_array_free (vals, TRUE);
 }
 
+gboolean
+gnm_solver_has_analytic_gradient (GnmSolver *sol)
+{
+       const int n = sol->input_cells->len;
+
+       if (sol->gradient_status == 0) {
+               int i;
+
+               sol->gradient_status++;
+
+               sol->gradient = g_ptr_array_new_with_free_func ((GDestroyNotify)gnm_expr_top_unref);
+               for (i = 0; i < n; i++) {
+                       GnmCell *cell = g_ptr_array_index (sol->input_cells, i);
+                       GnmExprTop const *te =
+                               gnm_expr_cell_deriv (sol->target, cell);
+                       if (te)
+                               g_ptr_array_add (sol->gradient, (gpointer)te);
+                       else {
+                               if (gnm_solver_debug ())
+                                       g_printerr ("Unable to compute analytic gradient\n");
+                               sol->gradient_status++;
+                               break;
+                       }
+               }
+       }
+
+       return sol->gradient_status == 1;
+}
+
+static gnm_float *
+gnm_solver_compute_gradient_analytically (GnmSolver *sol, gnm_float const *xs)
+{
+       const int n = sol->input_cells->len;
+       int i;
+       gnm_float *g = g_new (gnm_float, n);
+       GnmEvalPos ep;
+
+       eval_pos_init_cell (&ep, sol->target);
+       for (i = 0; i < n; i++) {
+               GnmExprTop const *te = g_ptr_array_index (sol->gradient, i);
+               GnmValue *v = gnm_expr_top_eval
+                       (te, &ep, GNM_EXPR_EVAL_SCALAR_NON_EMPTY);
+               g[i] = VALUE_IS_NUMBER (v) ? value_get_as_float (v) : gnm_nan;
+               value_release (v);
+       }
+
+       if (gnm_solver_debug ())
+               print_vector ("Analytic gradient", g, n);
+
+       return g;
+}
 
 /**
  * gnm_solver_compute_gradient:
@@ -1950,8 +2008,8 @@ gnm_solver_restore_vars (GnmSolver *sol, GPtrArray *vals)
  * @xs: Point to compute gradient at
  *
  * Returns: (transfer full): A vector containing the gradient.  This
- * function takes the flip-sign property into account.  Note, that this
- * is a numerical approximation.
+ * function takes the flip-sign property into account.  This will use
+ * analytic gradient, if possible, and a numerical approximation otherwise.
  */
 gnm_float *
 gnm_solver_compute_gradient (GnmSolver *sol, gnm_float const *xs)
@@ -1965,6 +2023,9 @@ gnm_solver_compute_gradient (GnmSolver *sol, gnm_float const *xs)
        gnm_solver_set_vars (sol, xs);
        y0 = gnm_solver_get_target_value (sol);
 
+       if (gnm_solver_has_analytic_gradient (sol))
+               return gnm_solver_compute_gradient_analytically (sol, xs);
+
        g = g_new (gnm_float, n);
        for (i = 0; i < n; i++) {
                gnm_float x0 = xs[i];
@@ -1999,6 +2060,9 @@ gnm_solver_compute_gradient (GnmSolver *sol, gnm_float const *xs)
                gnm_solver_set_var (sol, i, x0);
        }
 
+       if (gnm_solver_debug ())
+               print_vector ("Numerical gradient", g, n);
+
        return g;
 }
 
diff --git a/src/tools/gnm-solver.h b/src/tools/gnm-solver.h
index 8697575..db7fb22 100644
--- a/src/tools/gnm-solver.h
+++ b/src/tools/gnm-solver.h
@@ -243,6 +243,10 @@ struct GnmSolver_ {
        gnm_float *min;
        gnm_float *max;
        guint8 *discrete;
+
+       // Analytic gradient
+       int gradient_status; // 0: not tried; 1: ok; 2: fail
+       GPtrArray *gradient;
 };
 
 typedef struct {
@@ -297,6 +301,7 @@ void gnm_solver_set_vars (GnmSolver *sol, gnm_float const *xs);
 GPtrArray *gnm_solver_save_vars (GnmSolver *sol);
 void gnm_solver_restore_vars (GnmSolver *sol, GPtrArray *vals);
 
+gboolean gnm_solver_has_analytic_gradient (GnmSolver *sol);
 gnm_float *gnm_solver_compute_gradient (GnmSolver *sol, gnm_float const *xs);
 
 gnm_float gnm_solver_line_search (GnmSolver *sol,


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]