[shotwell/wip/phako/enhanced-faces: 39/136] Basic face recognition is working!



commit f7b2c2203d2d7b5a722f5133f79ea8baf14d7fcd
Author: NarendraMA <narendra_m_a yahoo com>
Date:   Mon Jul 23 13:35:31 2018 +0530

    Basic face recognition is working!

 facedetect/facedetect-opencv.cpp |   7 ++-
 src/AppDirs.vala                 |   6 --
 src/db/FaceLocationTable.vala    |  44 +++++++++++++-
 src/db/FaceTable.vala            |  33 ++++++++++-
 src/faces/FaceDetect.vala        |  11 ++++
 src/faces/FaceLocation.vala      |   8 +--
 src/faces/FaceShape.vala         |  59 ++++++++++--------
 src/faces/FacesTool.vala         | 125 +++++++++++++++++++--------------------
 8 files changed, 192 insertions(+), 101 deletions(-)
---
diff --git a/facedetect/facedetect-opencv.cpp b/facedetect/facedetect-opencv.cpp
index 203c487a..d3b8ee62 100644
--- a/facedetect/facedetect-opencv.cpp
+++ b/facedetect/facedetect-opencv.cpp
@@ -80,9 +80,12 @@ std::vector<double> faceToVecMat(cv::Mat img) {
                                           cv::Scalar(), true, false);
     faceRecogNet.setInput(blob);
     cv::Mat vec = faceRecogNet.forward();
+    //std::cout << "mat " << vec << std::endl;
     // Return vector
-    ret.assign((double*)vec.datastart, (double*)vec.dataend);
-    std::cout << "Recognition done! " << ret.back() << std::endl;
+    for (int i = 0; i < vec.rows; ++i)
+        ret.insert(ret.end(), vec.ptr<float>(i), vec.ptr<float>(i) + vec.cols);
+    //std::cout << "vec " << ret.back() << std::endl;
+    //std::cout << "Recognition done! " << vec.cols << "->" << ret.size() << std::endl;
     return ret;
 }
 
diff --git a/src/AppDirs.vala b/src/AppDirs.vala
index e2a5065a..40b28814 100644
--- a/src/AppDirs.vala
+++ b/src/AppDirs.vala
@@ -211,12 +211,6 @@ class AppDirs {
         return tmp_dir;
     }
 
-    public static string get_temp_filename() {
-        string tmp_dir = get_temp_dir().get_path();
-        assert(tmp_dir != null);
-        return Path.build_filename(tmp_dir, "face_XXXXXX.png");
-    }
-    
     public static File get_data_subdir(string name, string? subname = null) {
         File subdir = get_data_dir().get_child(name);
         if (subname != null)
diff --git a/src/db/FaceLocationTable.vala b/src/db/FaceLocationTable.vala
index 219928d9..2d84424c 100644
--- a/src/db/FaceLocationTable.vala
+++ b/src/db/FaceLocationTable.vala
@@ -46,7 +46,8 @@ public class FaceLocationTable : DatabaseTable {
             + "face_id INTEGER NOT NULL, "
             + "photo_id INTEGER NOT NULL, "
             + "geometry TEXT, "
-            + "vec TEXT"
+            + "vec TEXT, "
+            + "guess INTEGER DEFAULT 0"
             + ")", -1, out stmt);
         assert(res == Sqlite.OK);
         
@@ -222,6 +223,47 @@ public class FaceLocationTable : DatabaseTable {
         if (res != Sqlite.DONE)
             throw_error("FaceLocationTable.update_face_location_serialized_geometry", res);
     }
+
+    public Gee.List<FaceLocationRow?> get_face_ref_vecs(Gee.List<FaceRow?> face_rows)
+        throws DatabaseError {
+        Sqlite.Statement stmt;
+
+        string[] where_in = {};
+        foreach (var r in face_rows) {
+            if (r != null) where_in += "?";
+        }
+        int res = db.prepare_v2(
+            "SELECT id, face_id, photo_id, geometry, vec FROM FaceLocationTable WHERE photo_id IN (%s)"
+                    .printf(string.joinv(",", where_in)),
+            -1, out stmt);
+        assert(res == Sqlite.OK);
+        int c = 1;
+        foreach (var r in face_rows) {
+            if (r != null) { 
+                res = stmt.bind_int64(c, r.ref.id);
+                assert(res == Sqlite.OK);
+            }
+            c++;
+        }
+        
+        Gee.List<FaceLocationRow?> rows = new Gee.ArrayList<FaceLocationRow?>();
+        for (;;) {
+            res = stmt.step();
+            if (res == Sqlite.DONE)
+                break;
+            else if (res != Sqlite.ROW)
+                throw_error("FaceLocationTable.get_face_ref_vecs", res);
+            
+            FaceLocationRow row = new FaceLocationRow();
+            row.face_location_id = FaceLocationID(stmt.column_int64(0));
+            row.face_id = FaceID(stmt.column_int64(1));
+            row.photo_id = PhotoID(stmt.column_int64(2));
+            row.geometry = stmt.column_text(3);
+            row.vec = stmt.column_text(4);
+            rows.add(row);
+        }
+        return rows;
+    }
 }
 
 #endif
diff --git a/src/db/FaceTable.vala b/src/db/FaceTable.vala
index c8e934cd..be53515f 100644
--- a/src/db/FaceTable.vala
+++ b/src/db/FaceTable.vala
@@ -27,6 +27,8 @@ public class FaceRow {
     public FaceID face_id;
     public string name;
     public time_t time_created;
+    public PhotoID ref;
+    public string vec;
 }
 
 public class FaceTable : DatabaseTable {
@@ -41,7 +43,8 @@ public class FaceTable : DatabaseTable {
             + "("
             + "id INTEGER NOT NULL PRIMARY KEY, "
             + "name TEXT NOT NULL, "
-            + "time_created TIMESTAMP"
+            + "time_created TIMESTAMP, "
+            + "ref INTEGER DEFAULT -1"
             + ")", -1, out stmt);
         assert(res == Sqlite.OK);
         
@@ -180,5 +183,33 @@ public class FaceTable : DatabaseTable {
         if (res != Sqlite.DONE)
             throw_error("FaceTable.set_reference", res);
     }
+
+    public Gee.List<FaceRow?> get_ref_rows() throws DatabaseError {
+        Sqlite.Statement stmt;
+        int res = db.prepare_v2("SELECT id, name, time_created, ref FROM FaceTable WHERE ref != -1", -1,
+            out stmt);
+        assert(res == Sqlite.OK);
+        
+        Gee.List<FaceRow?> rows = new Gee.ArrayList<FaceRow?>();
+        
+        for (;;) {
+            res = stmt.step();
+            if (res == Sqlite.DONE)
+                break;
+            else if (res != Sqlite.ROW)
+                throw_error("FaceTable.get_all_rows", res);
+            
+            // res == Sqlite.ROW
+            FaceRow row = new FaceRow();
+            row.face_id = FaceID(stmt.column_int64(0));
+            row.name = stmt.column_text(1);
+            row.time_created = (time_t) stmt.column_int64(2);
+            row.ref = PhotoID(stmt.column_int64(3));
+            
+            rows.add(row);
+        }
+        
+        return rows;
+    }
 }
 #endif
diff --git a/src/faces/FaceDetect.vala b/src/faces/FaceDetect.vala
index 7a027b48..3b269803 100644
--- a/src/faces/FaceDetect.vala
+++ b/src/faces/FaceDetect.vala
@@ -79,4 +79,15 @@ public class FaceDetect {
         }
         connected = true;
     }
+
+    public static double dot_product(double[] vec1, double[] vec2) {
+        if (vec1.length != vec2.length) {
+            return 0;
+        }
+        double ret = 0;
+        for (var i = 0; i < vec1.length; i++) {
+            ret += vec1[i] * vec2[i];
+        }
+        return ret;
+    }
 }
diff --git a/src/faces/FaceLocation.vala b/src/faces/FaceLocation.vala
index c7e2c6c2..48abd2d4 100644
--- a/src/faces/FaceLocation.vala
+++ b/src/faces/FaceLocation.vala
@@ -206,11 +206,11 @@ public class FaceLocation : Object {
     public string get_serialized_geometry() {
         return face_data.geometry;
     }
-/*    
-    private void set_serialized_geometry(string geometry) {
-        this.face_data.geometry = geometry;
+
+    public string get_serialized_vec() {
+        return face_data.vec;
     }
-*/
+
     public FaceLocationData get_face_data() {
         return face_data;
     }
diff --git a/src/faces/FaceShape.vala b/src/faces/FaceShape.vala
index 0afc5a72..349627b9 100644
--- a/src/faces/FaceShape.vala
+++ b/src/faces/FaceShape.vala
@@ -20,6 +20,7 @@ public abstract class FaceShape : Object {
     protected Gdk.CursorType current_cursor_type = Gdk.CursorType.BOTTOM_RIGHT_CORNER;
     protected EditingTools.PhotoCanvas canvas;
     protected string serialized = null;
+    protected double[] face_vec;
     
     private bool editable = true;
     private bool visible = true;
@@ -27,7 +28,7 @@ public abstract class FaceShape : Object {
     
     private weak FacesTool.FaceWidget face_widget = null;
     
-    public FaceShape(EditingTools.PhotoCanvas canvas) {
+    public FaceShape(EditingTools.PhotoCanvas canvas, double[] vec) {
         this.canvas = canvas;
         this.canvas.new_surface.connect(prepare_ctx);
         
@@ -40,6 +41,7 @@ public abstract class FaceShape : Object {
         face_window.hide();
         
         this.canvas.get_drawing_window().set_cursor(new Gdk.Cursor(current_cursor_type));
+        this.face_vec = vec;
     }
     
     ~FaceShape() {
@@ -162,7 +164,7 @@ public abstract class FaceShape : Object {
         return true;
     }
     
-    public abstract string serialize();
+    public abstract string serialize(bool geometry_only = false);
     public abstract void update_face_window_position();
     public abstract void prepare_ctx(Cairo.Context ctx, Dimensions dim);
     public abstract void on_resized_pixbuf(Dimensions old_dim, Gdk.Pixbuf scaled);
@@ -172,7 +174,7 @@ public abstract class FaceShape : Object {
     public abstract bool cursor_is_over(int x, int y);
     public abstract bool equals(FaceShape face_shape);
     public abstract double get_distance(int x, int y);
-    public abstract Gdk.Pixbuf? get_pixbuf();
+    public abstract double[] get_face_vec();
     
     protected abstract void paint();
     protected abstract void erase();
@@ -192,11 +194,10 @@ public class FaceRectangle : FaceShape {
     private Cairo.Context thin_white_ctx = null;
     private int last_grab_x = -1;
     private int last_grab_y = -1;
-    private Gdk.Pixbuf? face_pix;
     
     public FaceRectangle(EditingTools.PhotoCanvas canvas, int x, int y,
-        int half_width = NULL_SIZE, int half_height = NULL_SIZE) {
-        base(canvas);
+        int half_width = NULL_SIZE, int half_height = NULL_SIZE, double[] vec = {}) {
+        base(canvas, vec);
         
         Gdk.Rectangle scaled_pixbuf_pos = canvas.get_scaled_pixbuf_position();
         x -= scaled_pixbuf_pos.x;
@@ -217,12 +218,6 @@ public class FaceRectangle : FaceShape {
         
             box = Box(x - half_width, y - half_height, right, bottom);
         }
-        
-        Gdk.Pixbuf original = canvas.get_scaled_pixbuf();
-        message("pixbuf get %d, %d, %d, %d of %d, %d", box.left, box.top,
-                                box.get_width(), box.get_height(), original.width, original.height);
-        face_pix = new Gdk.Pixbuf.subpixbuf(original, box.left, box.top,
-                                box.get_width(), box.get_height());
     }
     
     ~FaceRectangle() {
@@ -236,7 +231,6 @@ public class FaceRectangle : FaceShape {
         
         Photo photo = canvas.get_photo();
         Dimensions raw_dim = photo.get_raw_dimensions();
-        
         int x = (int) (raw_dim.width * double.parse(args[1]));
         int y = (int) (raw_dim.height * double.parse(args[2]));
         int half_width = (int) (raw_dim.width * double.parse(args[3]));
@@ -275,9 +269,21 @@ public class FaceRectangle : FaceShape {
         
         if (half_width < FACE_MIN_SIZE || half_height < FACE_MIN_SIZE)
             throw new FaceShapeError.CANT_CREATE("FaceShape is out of cropped photo area");
-        
+
+        string[] vec_str;
+        if (args.length == 6)
+            vec_str = args[5].split(",");
+        else
+            vec_str = {};
+        double[] vec = new double[128];
+        for (int i = 0; i < 128; i++) {
+            if (vec_str.length > 0)
+                vec[i] = double.parse(vec_str[i]);
+            else
+                vec[i] = 0;
+        }
         return new FaceRectangle(canvas, box.left + half_width, box.top + half_height,
-            half_width, half_height);
+            half_width, half_height, vec);
     }
     
     public override void update_face_window_position() {
@@ -378,7 +384,7 @@ public class FaceRectangle : FaceShape {
         ctx.restore();
     }
     
-    public override string serialize() {
+    public override string serialize(bool geometry_only = false) {
         if (serialized != null)
             return serialized;
         
@@ -388,10 +394,15 @@ public class FaceRectangle : FaceShape {
         double half_height;
         
         get_geometry(out x, out y, out half_width, out half_height);
-        
-        serialized = "%s;%s;%s;%s;%s".printf(SHAPE_TYPE, x.to_string(),
+        serialized = "%s;%s;%s;%s;%s;".printf(SHAPE_TYPE, x.to_string(),
             y.to_string(), half_width.to_string(), half_height.to_string());
-        
+        if (!geometry_only) {
+            string face_vec_str = "";
+            foreach (var d in face_vec[0:-2])
+                face_vec_str += d.to_string() + ",";
+            face_vec_str += face_vec[-1].to_string();
+            serialized += face_vec_str;
+        }
         return serialized;
     }
     
@@ -435,9 +446,13 @@ public class FaceRectangle : FaceShape {
         half_width = (width_right_end - width_left_end) / 2;
         half_height = (height_bottom_end - height_top_end) / 2;
     }
+
+    public override double[] get_face_vec() {
+        return face_vec;
+    }
     
     public override bool equals(FaceShape face_shape) {
-        return serialize() == face_shape.serialize();
+        return serialize(true) == face_shape.serialize(true);
     }
     
     public override void prepare_ctx(Cairo.Context ctx, Dimensions dim) {
@@ -786,10 +801,6 @@ public class FaceRectangle : FaceShape {
         
         return Math.sqrt((center_x - x) * (center_x - x) + (center_y - y) * (center_y - y));
     }
-
-    public override Gdk.Pixbuf? get_pixbuf() {
-        return face_pix;
-    }
 }
 
 #endif
diff --git a/src/faces/FacesTool.vala b/src/faces/FacesTool.vala
index 65c532ab..85a96919 100644
--- a/src/faces/FacesTool.vala
+++ b/src/faces/FacesTool.vala
@@ -341,64 +341,23 @@ public class FacesTool : EditingTools.EditingTool {
             }
             faces = new Gee.PriorityQueue<string>();
             for (int i = 0; i < rects.length; i++) {
-                string serialized = "%s;%s".printf(
-                       FaceRectangle.SHAPE_TYPE,
-                       parse_serialized_geometry("x=%s&y=%s&width=%s&height=%s".printf(
-                            rects[i].x.to_string(), rects[i].y.to_string(), rects[i].width.to_string(), 
rects[i].height.to_string())));
+                double rect_x, rect_y, rect_w, rect_h;
+                string face_vec_str = "";
+                rect_w = rects[i].width / 2;
+                rect_h = rects[i].height / 2;
+                rect_x = rects[i].x + rect_w;
+                rect_y = rects[i].y + rect_h;
+                if (rects[i].vec != null) {
+                    foreach (var d in rects[i].vec) { face_vec_str += d.to_string() + ","; }
+                }
+                string serialized = "%s;%f;%f;%f;%f;%s".printf(FaceRectangle.SHAPE_TYPE,
+                                                                                rect_x, rect_y, rect_w, 
rect_h,
+                                                                                face_vec_str);
                 debug("saw face %s", serialized);
                 faces.add(serialized);
             }
         }
 
-        private string parse_serialized_geometry(string serialized_geometry) {
-            string[] serialized_geometry_pieces = serialized_geometry.split("&");
-            if (serialized_geometry_pieces.length != 4) {
-                critical("Wrong serialized line in face detection program output.");
-                assert_not_reached();
-            }
-
-            double x = 0;
-            double y = 0;
-            double width = 0;
-            double height = 0;
-            foreach (string piece in serialized_geometry_pieces) {
-
-                string[] name_and_value = piece.split("=");
-                if (name_and_value.length != 2) {
-                    critical("Wrong serialized line in face detection program output.");
-                    assert_not_reached();
-                }
-
-                switch (name_and_value[0]) {
-                    case "x":
-                        x = name_and_value[1].to_double();
-                        break;
-
-                    case "y":
-                        y = name_and_value[1].to_double();
-                        break;
-
-                    case "width":
-                        width = name_and_value[1].to_double();
-                        break;
-
-                    case "height":
-                        height = name_and_value[1].to_double();
-                        break;
-
-                    default:
-                        critical("Wrong serialized line in face detection program output.");
-                        assert_not_reached();
-                }
-            }
-
-            double half_width = width / 2;
-            double half_height = height / 2;
-
-            return "%s;%s;%s;%s".printf((x + half_width).to_string(), (y + half_height).to_string(),
-                half_width.to_string(), half_height.to_string());
-        }
-
         public string? get_next() {
             if (faces == null)
                 return null;
@@ -451,8 +410,10 @@ public class FacesTool : EditingTools.EditingTool {
             foreach (Gee.Map.Entry<FaceID?, FaceLocation> entry in face_locations.entries) {
                 FaceShape new_face_shape;
                 string serialized_geometry = entry.value.get_serialized_geometry();
+                string serialized_vec = entry.value.get_serialized_vec();
+                string face_shape_str = serialized_geometry + ";" + serialized_vec;
                 try {
-                    new_face_shape = FaceShape.from_serialized(canvas, serialized_geometry);
+                    new_face_shape = FaceShape.from_serialized(canvas, face_shape_str);
                 } catch (FaceShapeError e) {
                     if (e is FaceShapeError.CANT_CREATE)
                         continue;
@@ -763,13 +724,14 @@ public class FacesTool : EditingTools.EditingTool {
                 continue;
 
             Face new_face = Face.for_name(face_shape.get_name());
-            string face_vec_str = "";
-            if (face_vec != null) {
-                foreach (var d in face_vec) { face_vec_str += d.to_string() + ","; }
-            }
-            FaceLocationData face_data = {
-                face_shape.serialize(), face_vec_str
-            };
+            string[] face_string = face_shape.serialize().split(";");
+            string face_vec_str, face_geometry;
+            face_geometry = string.joinv(";", face_string[0:5]);
+            face_vec_str = face_string[5];
+            FaceLocationData face_data =
+                {
+                 face_geometry, face_vec_str
+                };
             new_faces.set(new_face, face_data);
         }
 
@@ -923,13 +885,50 @@ public class FacesTool : EditingTools.EditingTool {
                 continue;
 
             c++;
+            // Reference faces to match with
+            Face? guess = get_face_match(face_shape, 0.7);
 
-            face_shape.set_name("Unknown face #%d".printf(c));
-            face_shape.set_known(false);
+            if (guess == null) {
+                face_shape.set_name("Unknown face #%d".printf(c));
+                face_shape.set_known(false);
+            } else {
+                face_shape.set_name(guess.get_name());
+                face_shape.set_known(true);
+            }
             add_face(face_shape);
         }
     }
 
+    private Face? get_face_match(FaceShape face_shape, double threshold) {
+        Gee.List<FaceLocationRow?> face_vecs;
+        try {
+            Gee.List<FaceRow?> face_rows = FaceTable.get_instance().get_ref_rows();
+            face_vecs = FaceLocationTable.get_instance().get_face_ref_vecs(face_rows);
+        } catch(DatabaseError err) {
+            warning("Cannot get reference faces from DB");
+            return null;
+        }
+        FaceID? guess_id = null;
+        double max_product = threshold;
+        foreach (var row in face_vecs) {
+            string[] vec_str = row.vec.split(",");
+            double[] vec = {};
+            foreach (var d in vec_str) vec += double.parse(d);
+            double product = FaceDetect.dot_product(face_shape.get_face_vec(), vec[0:128]);
+            if (product > max_product) {
+                max_product = product;
+                guess_id = row.face_id;
+            }
+        }
+
+        Face? face = null;
+        if (guess_id != null) {
+            face = Face.global.fetch(guess_id);
+            assert(face != null);
+        }
+        return face;
+    }
+    
     private void on_faces_detected() {
         face_detection_cancellable.reset();
         


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]