librelist archives

« back to archive

[(broken) PATCH] switch MT to use a proper recursive mutex

[(broken) PATCH] switch MT to use a proper recursive mutex

From:
Eric Wong
Date:
2010-12-11 @ 01:54
This seems to misinteract with the GVL, proper thread-safety is easier
to do purely in C.  And recursive mutexes suck anyways...

Wait, threads suck, too :P

From c8357e676bb7bb1339bfa8e5965dd51714251355 Mon Sep 17 00:00:00 2001
From: Eric Wong <normalperson@yhbt.net>
Date: Fri, 10 Dec 2010 16:59:36 -0800
Subject: [PATCH] switch MT to use a proper recursive mutex

And also beef up thread torture test
---
 ext/tdb/rec_mutex.c    |  114 ++++++++++++++++++++++++++++++++++++++++++++++++
 ext/tdb/tdb.c          |    5 ++
 lib/tdb/mt.rb          |   14 ++----
 test/test_rec_mutex.rb |   36 +++++++++++++++
 test/test_tdb_mt.rb    |   20 +++++---
 5 files changed, 172 insertions(+), 17 deletions(-)
 create mode 100644 ext/tdb/rec_mutex.c
 create mode 100644 test/test_rec_mutex.rb

diff --git a/ext/tdb/rec_mutex.c b/ext/tdb/rec_mutex.c
new file mode 100644
index 0000000..2cf92dc
--- /dev/null
+++ b/ext/tdb/rec_mutex.c
@@ -0,0 +1,114 @@
+#include <ruby.h>
+#ifdef HAVE_RB_THREAD_BLOCKING_REGION
+#include <errno.h>
+#include <pthread.h>
+struct my_mutex {
+	pthread_mutex_t lock;
+	/* ... */
+};
+
+static void fail(const char *str, int err)
+{
+	errno = err;
+	rb_sys_fail(str);
+}
+
+static void gcfree(void *ptr)
+{
+	struct my_mutex *mutex = ptr;
+
+	(void)pthread_mutex_destroy(&mutex->lock);
+	xfree(mutex);
+}
+
+static VALUE alloc(VALUE klass)
+{
+	struct my_mutex *mutex;
+
+	return Data_Make_Struct(klass, struct my_mutex, NULL, gcfree, mutex);
+}
+
+static struct my_mutex *get(VALUE self)
+{
+	struct my_mutex *mutex;
+
+	Data_Get_Struct(self, struct my_mutex, mutex);
+	return mutex;
+}
+
+static VALUE init(VALUE self)
+{
+	pthread_mutexattr_t a;
+	int ret;
+	struct my_mutex *mutex = get(self);
+
+	if ((ret = pthread_mutexattr_init(&a)))
+		fail("pthread_mutexattr_init", ret);
+	if ((ret = pthread_mutexattr_settype(&a, PTHREAD_MUTEX_RECURSIVE)))
+		fail("pthread_mutexattr_settype", ret);
+	ret = pthread_mutex_init(&mutex->lock, &a);
+	(void)pthread_mutexattr_destroy(&a);
+	if (ret)
+		fail("pthread_mutex_init", ret);
+	return self;
+}
+
+static int my_tbr(void *pfn, void *data)
+{
+	rb_blocking_function_t *fn = (rb_blocking_function_t *)pfn;
+
+	return (int)rb_thread_blocking_region(fn, data, 0, 0);
+}
+
+static VALUE lock(VALUE self)
+{
+	struct my_mutex *mutex = get(self);
+	int ret = my_tbr((void *)pthread_mutex_lock, &mutex->lock);
+
+	if (ret)
+		fail("pthread_mutex_lock", ret);
+
+	return self;
+}
+
+static VALUE unlock(VALUE self)
+{
+	struct my_mutex *mutex = get(self);
+	int ret = my_tbr((void *)pthread_mutex_unlock, &mutex->lock);
+
+	if (ret)
+		fail("pthread_mutex_unlock", ret);
+
+	return Qnil;
+}
+
+static VALUE try_lock(VALUE self)
+{
+	struct my_mutex *mutex = get(self);
+	int ret = my_tbr((void *)pthread_mutex_trylock, &mutex->lock);
+
+	if (ret == EBUSY)
+		return Qfalse;
+	if (ret)
+		fail("pthread_mutex_trylock", ret);
+	return Qtrue;
+}
+
+static VALUE synchronize(VALUE self)
+{
+	lock(self);
+	return rb_ensure(rb_yield, Qnil, unlock, self);
+}
+
+/* TODO: split this out into a new project */
+void Init_rec_mutex(void)
+{
+	VALUE cRecMutex = rb_define_class("RecMutex", rb_cObject);
+	rb_define_alloc_func(cRecMutex, alloc);
+	rb_define_method(cRecMutex, "initialize", init, 0);
+	rb_define_method(cRecMutex, "lock", lock, 0);
+	rb_define_method(cRecMutex, "unlock", unlock, 0);
+	rb_define_method(cRecMutex, "try_lock", try_lock, 0);
+	rb_define_method(cRecMutex, "synchronize", synchronize, 0);
+}
+#endif /* HAVE_RB_THREAD_BLOCKING_REGION */
diff --git a/ext/tdb/tdb.c b/ext/tdb/tdb.c
index d1b573a..29450a0 100644
--- a/ext/tdb/tdb.c
+++ b/ext/tdb/tdb.c
@@ -641,9 +641,14 @@ static VALUE repack(VALUE self)
 }
 #endif /* HAVE_TDB_REPACK */
 
+void Init_rec_mutex(void);
+
 void Init_tdb_ext(void)
 {
 	cTDB = rb_define_class("TDB", rb_cObject);
+#ifdef HAVE_RB_THREAD_BLOCKING_REGION
+	Init_rec_mutex();
+#endif
 
 	hashes = rb_hash_new();
 
diff --git a/lib/tdb/mt.rb b/lib/tdb/mt.rb
index 7e7a09d..c1abb52 100644
--- a/lib/tdb/mt.rb
+++ b/lib/tdb/mt.rb
@@ -2,7 +2,7 @@
 module TDB::MT
   def initialize
     super
-    @lock = Mutex.new
+    @lock = RecMutex.new
   end
 
   wrap_methods = %w(
@@ -12,30 +12,24 @@ module TDB::MT
     lockall trylockall unlockall
     lockall_read trylockall_read unlockall_read
     lockall_mark lockall_unmark
-    clear
+    clear each
   )
   wrap_methods << :repack if TDB.method_defined?(:repack)
   wrap_methods.each do |meth|
     eval "def #{meth}(*args); @lock.synchronize { super }; end"
   end
 
-  def each(&block)
-    @lock.synchronize do
-      super { |k,v| @lock.exclusive_unlock { yield(k,v) } }
-    end
-  end
-
   def threadsafe?
     true
   end
 
   def self.extended(obj)
-    obj.instance_eval { @lock = Mutex.new unless defined?(@lock) }
+    obj.instance_eval { @lock = RecMutex.new unless defined?(@lock) }
   end
 
   def self.included(klass)
     ObjectSpace.each_object(klass) { |obj|
-      obj.instance_eval { @lock = Mutex.new unless defined?(@lock) }
+      obj.instance_eval { @lock = RecMutex.new unless defined?(@lock) }
     }
   end
 end
diff --git a/test/test_rec_mutex.rb b/test/test_rec_mutex.rb
new file mode 100644
index 0000000..159d3d3
--- /dev/null
+++ b/test/test_rec_mutex.rb
@@ -0,0 +1,36 @@
+# -*- encoding: binary -*-
+require 'test/unit'
+$-w = true
+# TODO: split out into a separate project
+require 'tdb'
+
+class TestRecMutex < Test::Unit::TestCase
+
+  def test_lock_unlock
+    lock = RecMutex.new
+    assert_equal lock, lock.lock
+    assert_equal lock, lock.lock
+    assert_nil lock.unlock
+    assert_nil lock.unlock
+    assert_raises(Errno::EPERM) { lock.unlock }
+  end
+
+  def test_try_lock
+    lock = RecMutex.new
+    assert_equal true, lock.try_lock
+    t = Thread.new { lock.try_lock }
+    t.join
+    assert_equal false, t.value
+    assert_nil lock.unlock
+    t = Thread.new { lock.try_lock }
+    t.join
+    assert_equal true, t.value
+  end
+
+  def test_synchronize
+    lock = RecMutex.new
+    assert_nothing_raised do
+      lock.synchronize { lock.synchronize { x = lock.inspect } }
+    end
+  end
+end
diff --git a/test/test_tdb_mt.rb b/test/test_tdb_mt.rb
index 0fbe09f..e8e7ac6 100644
--- a/test/test_tdb_mt.rb
+++ b/test/test_tdb_mt.rb
@@ -39,13 +39,15 @@ class Test_TDB_MT < Test::Unit::TestCase
   end
 
   def test_thread_safe_torture_test
-    @tdb = TDB.new(nil)
+    @tmp = Tempfile.new('tdb_test')
+    File.unlink(@tmp.path)
+    @tdb = TDB.new(@tmp.path)
     assert_nothing_raised { @tdb.threadsafe! }
-    pid = fork do
-      Thread.abort_on_exception = true
+    Thread.abort_on_exception = true
+    nr = 1000
+    blob = 'foo' * 1000
+    crazy = proc do
       threads = []
-      blob = 'foo' * 1000
-      nr = 10000
       t = Thread.new do
         while true
           Thread.pass
@@ -57,12 +59,16 @@ class Test_TDB_MT < Test::Unit::TestCase
       threads << Thread.new { nr.times { |i| @tdb[i.to_s] = blob } }
       threads << Thread.new { nr.times { |i| @tdb[i.to_s] = blob } }
       threads << t
+      sleep 1
 
       t.kill
       threads.each { |t| t.join }
     end
-    _, status = Process.waitpid2(pid)
-    assert status.success?, status.inspect
+    10.times { fork &crazy }
+    Process.waitall.each do |(pid,status)|
+      assert status.success?, status.inspect
+    end
+    nr.times { |i| assert_equal blob, @tdb[i.to_s] }
   end
 
   def test_check_methods
-- 
Eric Wong