diff options
author | jbj <devnull@localhost> | 2002-06-03 20:44:08 +0000 |
---|---|---|
committer | jbj <devnull@localhost> | 2002-06-03 20:44:08 +0000 |
commit | 9b2ac14866c7740e67e17d37392f2397cd1dcdd3 (patch) | |
tree | 1461f7e71854c806bb11bacef8d44ec91b5b883c | |
parent | b3308c5f878ec970f48c7559bb62646b182c3128 (diff) | |
download | rpm-9b2ac14866c7740e67e17d37392f2397cd1dcdd3.tar.gz rpm-9b2ac14866c7740e67e17d37392f2397cd1dcdd3.tar.bz2 rpm-9b2ac14866c7740e67e17d37392f2397cd1dcdd3.zip |
Functional unit tests after renaming bsdddb3 -> rpmdb, _db -> _rpmdb.
CVS patchset: 5459
CVS date: 2002/06/03 20:44:08
24 files changed, 3889 insertions, 22 deletions
diff --git a/python/_rpmdb.c b/python/_rpmdb.c index 8df27b9ae..b91b7fcb3 100644 --- a/python/_rpmdb.c +++ b/python/_rpmdb.c @@ -88,7 +88,7 @@ #define PY_BSDDB_VERSION "3.3.1" -static char *rcs_id = "$Id: _rpmdb.c,v 1.1 2002/06/02 20:50:49 jbj Exp $"; +static char *rcs_id = "$Id: _rpmdb.c,v 1.2 2002/06/03 20:44:08 jbj Exp $"; #ifdef WITH_THREAD @@ -1996,7 +1996,7 @@ DB_set_get_returns_none(DBObject* self, PyObject* args) /*-------------------------------------------------------------- */ /* Mapping and Dictionary-like access routines */ -int DB_length(DBObject* self) +static int DB_length(DBObject* self) { int err; long size = 0; @@ -2034,7 +2034,7 @@ int DB_length(DBObject* self) } -PyObject* DB_subscript(DBObject* self, PyObject* keyobj) +static PyObject* DB_subscript(DBObject* self, PyObject* keyobj) { int err; PyObject* retval; @@ -3927,6 +3927,8 @@ static PyMethodDef bsddb_methods[] = { +void init_rpmdb(void); /* XXX remove compiler warning */ + DL_EXPORT(void) init_rpmdb(void) { PyObject* m; @@ -3948,7 +3950,7 @@ DL_EXPORT(void) init_rpmdb(void) #endif /* Create the module and add the functions */ - m = Py_InitModule("_db", bsddb_methods); + m = Py_InitModule("_rpmdb", bsddb_methods); /* Add some symbolic constants to the module */ d = PyModule_GetDict(m); diff --git a/python/rpmdb/Makefile.am b/python/rpmdb/Makefile.am index 543a08b21..981231234 100644 --- a/python/rpmdb/Makefile.am +++ b/python/rpmdb/Makefile.am @@ -6,5 +6,4 @@ PYVER= @WITH_PYTHON_VERSION@ rpmdbdir = $(prefix)/lib/python${PYVER}/site-packages/rpmdb rpmdb_SCRIPTS = __init__.py \ - rpmdbobj.py rpmdb.py rpmdbrecio.py rpmdbshelve.py rpmdbtables.py \ - rpmdbutils.py + dbobj.py db.py dbrecio.py dbshelve.py dbtables.py dbutils.py diff --git a/python/rpmdb/__init__.py b/python/rpmdb/__init__.py index 563e18fae..6dd102dba 100644 --- a/python/rpmdb/__init__.py +++ b/python/rpmdb/__init__.py @@ -35,18 +35,18 @@ """ This package initialization module provides a compatibility interface -that should enable bsddb3 to be a near drop-in replacement for the original +that should enable rpmdb to be a near drop-in replacement for the original old bsddb module. The functions and classes provided here are all -wrappers around the new functionality provided in the bsddb3.db module. +wrappers around the new functionality provided in the rpmdb.db module. People interested in the more advanced capabilites of Berkeley DB 3.x -should use the bsddb3.db module directly. +should use the rpmdb.db module directly. """ -import _db +import _rpmdb as _db __version__ = _db.__version__ -error = _db.DBError # So bsddb3.error will mean something... +error = _db.DBError # So rpmdb.error will mean something... #---------------------------------------------------------------------- diff --git a/python/rpmdb/rpmdb.py b/python/rpmdb/db.py index b4365d06d..3bf5f8e35 100644 --- a/python/rpmdb/rpmdb.py +++ b/python/rpmdb/db.py @@ -37,8 +37,8 @@ # case we ever want to augment the stuff in _db in any way. For now # it just simply imports everything from _db. -from _db import * -from _db import __version__ +from _rpmdb import * +from _rpmdb import __version__ if version() < (3, 1, 0): raise ImportError, "BerkeleyDB 3.x symbols not found. Perhaps python was statically linked with an older version?" diff --git a/python/rpmdb/rpmdbobj.py b/python/rpmdb/dbobj.py index 9c3e90f63..9c3e90f63 100644 --- a/python/rpmdb/rpmdbobj.py +++ b/python/rpmdb/dbobj.py diff --git a/python/rpmdb/rpmdbrecio.py b/python/rpmdb/dbrecio.py index 995dad713..4ef6f6b78 100644 --- a/python/rpmdb/rpmdbrecio.py +++ b/python/rpmdb/dbrecio.py @@ -1,6 +1,6 @@ """ -File-like objects that read from or write to a bsddb3 record. +File-like objects that read from or write to a rpmdb record. This implements (nearly) all stdio methods. diff --git a/python/rpmdb/rpmdbshelve.py b/python/rpmdb/dbshelve.py index dab8caab2..e4ca9335b 100644 --- a/python/rpmdb/rpmdbshelve.py +++ b/python/rpmdb/dbshelve.py @@ -24,14 +24,14 @@ #------------------------------------------------------------------------ """ -Manage shelves of pickled objects using bsddb3 database files for the +Manage shelves of pickled objects using rpmdb database files for the storage. """ #------------------------------------------------------------------------ import cPickle -from bsddb3 import db +from rpmdb import db #------------------------------------------------------------------------ @@ -43,7 +43,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH, shleve.py module. It can be used like this, where key is a string and data is a pickleable object: - from bsddb3 import dbshelve + from rpmdb import dbshelve db = dbshelve.open(filename) db[key] = data @@ -63,7 +63,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH, elif sflag == 'n': flags = db.DB_TRUNCATE | db.DB_CREATE else: - raise error, "flags should be one of 'r', 'w', 'c' or 'n' or use the bsddb3.db.DB_* flags" + raise error, "flags should be one of 'r', 'w', 'c' or 'n' or use the rpmdb.db.DB_* flags" d = DBShelf(dbenv) d.open(filename, dbname, filetype, flags, mode) @@ -73,7 +73,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH, class DBShelf: """ - A shelf to hold pickled objects, built upon a bsddb3 DB object. It + A shelf to hold pickled objects, built upon a rpmdb DB object. It automatically pickles/unpickles data objects going to/from the DB. """ def __init__(self, dbenv=None): diff --git a/python/rpmdb/rpmdbtables.py b/python/rpmdb/dbtables.py index 8ffed9105..05bf2ef82 100644 --- a/python/rpmdb/rpmdbtables.py +++ b/python/rpmdb/dbtables.py @@ -28,7 +28,7 @@ import xdrlib import re import copy -from bsddb3.db import * +from rpmdb.db import * class TableDBError(StandardError): pass diff --git a/python/rpmdb/rpmdbutils.py b/python/rpmdb/dbutils.py index fe08407b1..d1dbd5cc6 100644 --- a/python/rpmdb/rpmdbutils.py +++ b/python/rpmdb/dbutils.py @@ -28,13 +28,13 @@ # # import the time.sleep function in a namespace safe way to allow -# "from bsddb3.db import *" +# "from rpmdb.db import *" # from time import sleep _sleep = sleep del sleep -import _db +import _rpmdb as _db _deadlock_MinSleepTime = 1.0/64 # always sleep at least N seconds between retrys _deadlock_MaxSleepTime = 1.0 # never sleep more than N seconds between retrys diff --git a/python/test/test_all.py b/python/test/test_all.py new file mode 100644 index 000000000..6fc011dd9 --- /dev/null +++ b/python/test/test_all.py @@ -0,0 +1,58 @@ +""" +Run all test cases. +""" + +import sys +import unittest + +verbose = 0 +if 'verbose' in sys.argv: + verbose = 1 + sys.argv.remove('verbose') + +if 'silent' in sys.argv: # take care of old flag, just in case + verbose = 0 + sys.argv.remove('silent') + + +# This little hack is for when this module is run as main and all the +# other modules import it so they will still be able to get the right +# verbose setting. It's confusing but it works. +import test_all +test_all.verbose = verbose + + +def suite(): + test_modules = [ 'test_compat', + 'test_basics', + 'test_misc', + 'test_dbobj', + 'test_recno', + 'test_queue', + 'test_get_none', + 'test_dbshelve', + 'test_dbtables', + 'test_thread', + 'test_lock', + 'test_associate', + ] + + alltests = unittest.TestSuite() + for name in test_modules: + module = __import__(name) + alltests.addTest(module.suite()) + return alltests + + +if __name__ == '__main__': + from rpmdb import db + print '-=' * 38 + print db.DB_VERSION_STRING + print 'rpmdb.db.version(): %s' % (db.version(), ) + print 'rpmdb.db.__version__: %s' % db.__version__ + print 'rpmdb.db.cvsid: %s' % db.cvsid + print 'python version: %s' % sys.version + print '-=' * 38 + + unittest.main( defaultTest='suite' ) + diff --git a/python/test/test_associate.py b/python/test/test_associate.py new file mode 100644 index 000000000..0db04ecd5 --- /dev/null +++ b/python/test/test_associate.py @@ -0,0 +1,323 @@ +""" +TestCases for multi-threaded access to a DB. +""" + +import sys, os, string +import tempfile +import time +from pprint import pprint + +try: + from threading import Thread, currentThread + have_threads = 1 +except ImportError: + have_threads = 0 + +import unittest +from test_all import verbose + +from rpmdb import db, dbshelve + + +#---------------------------------------------------------------------- + + +musicdata = { +1 : ("Bad English", "The Price Of Love", "Rock"), +2 : ("DNA featuring Suzanne Vega", "Tom's Diner", "Rock"), +3 : ("George Michael", "Praying For Time", "Rock"), +4 : ("Gloria Estefan", "Here We Are", "Rock"), +5 : ("Linda Ronstadt", "Don't Know Much", "Rock"), +6 : ("Michael Bolton", "How Am I Supposed To Live Without You", "Blues"), +7 : ("Paul Young", "Oh Girl", "Rock"), +8 : ("Paula Abdul", "Opposites Attract", "Rock"), +9 : ("Richard Marx", "Should've Known Better", "Rock"), +10: ("Rod Stewart", "Forever Young", "Rock"), +11: ("Roxette", "Dangerous", "Rock"), +12: ("Sheena Easton", "The Lover In Me", "Rock"), +13: ("Sinead O'Connor", "Nothing Compares 2 U", "Rock"), +14: ("Stevie B.", "Because I Love You", "Rock"), +15: ("Taylor Dayne", "Love Will Lead You Back", "Rock"), +16: ("The Bangles", "Eternal Flame", "Rock"), +17: ("Wilson Phillips", "Release Me", "Rock"), +18: ("Billy Joel", "Blonde Over Blue", "Rock"), +19: ("Billy Joel", "Famous Last Words", "Rock"), +20: ("Billy Joel", "Lullabye (Goodnight, My Angel)", "Rock"), +21: ("Billy Joel", "The River Of Dreams", "Rock"), +22: ("Billy Joel", "Two Thousand Years", "Rock"), +23: ("Janet Jackson", "Alright", "Rock"), +24: ("Janet Jackson", "Black Cat", "Rock"), +25: ("Janet Jackson", "Come Back To Me", "Rock"), +26: ("Janet Jackson", "Escapade", "Rock"), +27: ("Janet Jackson", "Love Will Never Do (Without You)", "Rock"), +28: ("Janet Jackson", "Miss You Much", "Rock"), +29: ("Janet Jackson", "Rhythm Nation", "Rock"), +30: ("Janet Jackson", "State Of The World", "Rock"), +31: ("Janet Jackson", "The Knowledge", "Rock"), +32: ("Spyro Gyra", "End of Romanticism", "Jazz"), +33: ("Spyro Gyra", "Heliopolis", "Jazz"), +34: ("Spyro Gyra", "Jubilee", "Jazz"), +35: ("Spyro Gyra", "Little Linda", "Jazz"), +36: ("Spyro Gyra", "Morning Dance", "Jazz"), +37: ("Spyro Gyra", "Song for Lorraine", "Jazz"), +38: ("Yes", "Owner Of A Lonely Heart", "Rock"), +39: ("Yes", "Rhythm Of Love", "Rock"), +40: ("Cusco", "Dream Catcher", "New Age"), +41: ("Cusco", "Geronimos Laughter", "New Age"), +42: ("Cusco", "Ghost Dance", "New Age"), +43: ("Blue Man Group", "Drumbone", "New Age"), +44: ("Blue Man Group", "Endless Column", "New Age"), +45: ("Blue Man Group", "Klein Mandelbrot", "New Age"), +46: ("Kenny G", "Silhouette", "Jazz"), +47: ("Sade", "Smooth Operator", "Jazz"), +48: ("David Arkenstone", "Papillon (On The Wings Of The Butterfly)", "New Age"), +49: ("David Arkenstone", "Stepping Stars", "New Age"), +50: ("David Arkenstone", "Carnation Lily Lily Rose", "New Age"), +51: ("David Lanz", "Behind The Waterfall", "New Age"), +52: ("David Lanz", "Cristofori's Dream", "New Age"), +53: ("David Lanz", "Heartsounds", "New Age"), +54: ("David Lanz", "Leaves on the Seine", "New Age"), +} + +#---------------------------------------------------------------------- + + +class AssociateTestCase(unittest.TestCase): + keytype = '' + + def setUp(self): + self.filename = self.__class__.__name__ + '.db' + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + self.env = db.DBEnv() + self.env.open(homeDir, db.DB_CREATE | db.DB_INIT_MPOOL | + db.DB_INIT_LOCK | db.DB_THREAD) + + def tearDown(self): + self.closeDB() + self.env.close() + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + def addDataToDB(self, d): + for key, value in musicdata.items(): + if type(self.keytype) == type(''): + key = "%02d" % key + d.put(key, string.join(value, '|')) + + + + def createDB(self): + self.primary = db.DB(self.env) + self.primary.open(self.filename, "primary", self.dbtype, + db.DB_CREATE | db.DB_THREAD) + + def closeDB(self): + self.primary.close() + + def getDB(self): + return self.primary + + + + def test01_associateWithDB(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_associateWithDB..." % self.__class__.__name__ + + self.createDB() + + secDB = db.DB(self.env) + secDB.set_flags(db.DB_DUP) + secDB.open(self.filename, "secondary", db.DB_BTREE, db.DB_CREATE | db.DB_THREAD) + self.getDB().associate(secDB, self.getGenre) + + self.addDataToDB(self.getDB()) + + self.finish_test(secDB) + + + def test02_associateAfterDB(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_associateAfterDB..." % self.__class__.__name__ + + self.createDB() + self.addDataToDB(self.getDB()) + + secDB = db.DB(self.env) + secDB.set_flags(db.DB_DUP) + secDB.open(self.filename, "secondary", db.DB_BTREE, db.DB_CREATE | db.DB_THREAD) + + # adding the DB_CREATE flag will cause it to index existing records + self.getDB().associate(secDB, self.getGenre, db.DB_CREATE) + + self.finish_test(secDB) + + + + + def finish_test(self, secDB): + if verbose: + print "Primary key traversal:" + c = self.getDB().cursor() + count = 0 + rec = c.first() + while rec is not None: + if type(self.keytype) == type(''): + assert string.atoi(rec[0]) # for primary db, key is a number + else: + assert rec[0] and type(rec[0]) == type(0) + count = count + 1 + if verbose: + print rec + rec = c.next() + assert count == len(musicdata) # all items accounted for + + + if verbose: + print "Secondary key traversal:" + c = secDB.cursor() + count = 0 + rec = c.first() + assert rec[0] == "Jazz" + while rec is not None: + count = count + 1 + if verbose: + print rec + rec = c.next() + assert count == len(musicdata)-1 # all items accounted for EXCEPT for 1 with "Blues" genre + + + + def getGenre(self, priKey, priData): + assert type(priData) == type("") + if verbose: + print 'getGenre key:', `priKey`, 'data:', `priData` + genre = string.split(priData, '|')[2] + if genre == 'Blues': + return db.DB_DONOTINDEX + else: + return genre + + +#---------------------------------------------------------------------- + + +class AssociateHashTestCase(AssociateTestCase): + dbtype = db.DB_HASH + +class AssociateBTreeTestCase(AssociateTestCase): + dbtype = db.DB_BTREE + +class AssociateRecnoTestCase(AssociateTestCase): + dbtype = db.DB_RECNO + keytype = 0 + + +#---------------------------------------------------------------------- + +class ShelveAssociateTestCase(AssociateTestCase): + + def createDB(self): + self.primary = dbshelve.open(self.filename, + dbname="primary", + dbenv=self.env, + filetype=self.dbtype) + + def addDataToDB(self, d): + for key, value in musicdata.items(): + if type(self.keytype) == type(''): + key = "%02d" % key + d.put(key, value) # save the value as is this time + + + def getGenre(self, priKey, priData): + assert type(priData) == type(()) + if verbose: + print 'getGenre key:', `priKey`, 'data:', `priData` + genre = priData[2] + if genre == 'Blues': + return db.DB_DONOTINDEX + else: + return genre + + +class ShelveAssociateHashTestCase(ShelveAssociateTestCase): + dbtype = db.DB_HASH + +class ShelveAssociateBTreeTestCase(ShelveAssociateTestCase): + dbtype = db.DB_BTREE + +class ShelveAssociateRecnoTestCase(ShelveAssociateTestCase): + dbtype = db.DB_RECNO + keytype = 0 + + +#---------------------------------------------------------------------- + +class ThreadedAssociateTestCase(AssociateTestCase): + + def addDataToDB(self, d): + t1 = Thread(target = self.writer1, + args = (d, )) + t2 = Thread(target = self.writer2, + args = (d, )) + + t1.start() + t2.start() + t1.join() + t2.join() + + def writer1(self, d): + for key, value in musicdata.items(): + if type(self.keytype) == type(''): + key = "%02d" % key + d.put(key, string.join(value, '|')) + + def writer2(self, d): + for x in range(100, 600): + key = 'z%2d' % x + value = [key] * 4 + d.put(key, string.join(value, '|')) + + +class ThreadedAssociateHashTestCase(ShelveAssociateTestCase): + dbtype = db.DB_HASH + +class ThreadedAssociateBTreeTestCase(ShelveAssociateTestCase): + dbtype = db.DB_BTREE + +class ThreadedAssociateRecnoTestCase(ShelveAssociateTestCase): + dbtype = db.DB_RECNO + keytype = 0 + + +#---------------------------------------------------------------------- + +def suite(): + theSuite = unittest.TestSuite() + + if db.version() >= (3, 3, 11): + theSuite.addTest(unittest.makeSuite(AssociateHashTestCase)) + theSuite.addTest(unittest.makeSuite(AssociateBTreeTestCase)) + theSuite.addTest(unittest.makeSuite(AssociateRecnoTestCase)) + + theSuite.addTest(unittest.makeSuite(ShelveAssociateHashTestCase)) + theSuite.addTest(unittest.makeSuite(ShelveAssociateBTreeTestCase)) + theSuite.addTest(unittest.makeSuite(ShelveAssociateRecnoTestCase)) + + if have_threads: + theSuite.addTest(unittest.makeSuite(ThreadedAssociateHashTestCase)) + theSuite.addTest(unittest.makeSuite(ThreadedAssociateBTreeTestCase)) + theSuite.addTest(unittest.makeSuite(ThreadedAssociateRecnoTestCase)) + + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/test_basics.py b/python/test/test_basics.py new file mode 100644 index 000000000..1b62bd825 --- /dev/null +++ b/python/test/test_basics.py @@ -0,0 +1,776 @@ +""" +Basic TestCases for BTree and hash DBs, with and without a DBEnv, with +various DB flags, etc. +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db + +from test_all import verbose + + +#---------------------------------------------------------------------- + +class VersionTestCase(unittest.TestCase): + def test00_version(self): + info = db.version() + if verbose: + print '\n', '-=' * 20 + print 'rpmdb.db.version(): %s' % (info, ) + print db.DB_VERSION_STRING + print '-=' * 20 + assert info == (db.DB_VERSION_MAJOR, db.DB_VERSION_MINOR, db.DB_VERSION_PATCH) + +#---------------------------------------------------------------------- + +class BasicTestCase(unittest.TestCase): + dbtype = db.DB_UNKNOWN # must be set in derived class + dbopenflags = 0 + dbsetflags = 0 + dbmode = 0660 + dbname = None + useEnv = 0 + envflags = 0 + + def setUp(self): + if self.useEnv: + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + try: os.mkdir(homeDir) + except os.error: pass + self.env = db.DBEnv() + self.env.set_lg_max(1024*1024) + self.env.open(homeDir, self.envflags | db.DB_CREATE) + tempfile.tempdir = homeDir + self.filename = os.path.split(tempfile.mktemp())[1] + tempfile.tempdir = None + self.homeDir = homeDir + else: + self.env = None + self.filename = tempfile.mktemp() + + # create and open the DB + self.d = db.DB(self.env) + self.d.set_flags(self.dbsetflags) + if self.dbname: + self.d.open(self.filename, self.dbname, self.dbtype, + self.dbopenflags|db.DB_CREATE, self.dbmode) + else: + self.d.open(self.filename, # try out keyword args + mode = self.dbmode, + dbtype = self.dbtype, flags = self.dbopenflags|db.DB_CREATE) + + self.populateDB() + + + def tearDown(self): + self.d.close() + if self.env is not None: + self.env.close() + + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + ## Make a new DBEnv to remove the env files from the home dir. + ## (It can't be done while the env is open, nor after it has been + ## closed, so we make a new one to do it.) + #e = db.DBEnv() + #e.remove(self.homeDir) + #os.remove(os.path.join(self.homeDir, self.filename)) + + else: + os.remove(self.filename) + + + + def populateDB(self): + d = self.d + for x in range(500): + key = '%04d' % (1000 - x) # insert keys in reverse order + data = self.makeData(key) + d.put(key, data) + + for x in range(500): + key = '%04d' % x # and now some in forward order + data = self.makeData(key) + d.put(key, data) + + num = len(d) + if verbose: + print "created %d records" % num + + + def makeData(self, key): + return string.join([key] * 5, '-') + + + + #---------------------------------------- + + def test01_GetsAndPuts(self): + d = self.d + + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_GetsAndPuts..." % self.__class__.__name__ + + for key in ['0001', '0100', '0400', '0700', '0999']: + data = d.get(key) + if verbose: + print data + + assert d.get('0321') == '0321-0321-0321-0321-0321' + + # By default non-existant keys return None... + assert d.get('abcd') == None + + # ...but they raise exceptions in other situations. Call + # set_get_returns_none() to change it. + try: + d.delete('abcd') + except db.DBNotFoundError, val: + assert val[0] == db.DB_NOTFOUND + if verbose: print val + else: + self.fail("expected exception") + + + d.put('abcd', 'a new record') + assert d.get('abcd') == 'a new record' + + d.put('abcd', 'same key') + if self.dbsetflags & db.DB_DUP: + assert d.get('abcd') == 'a new record' + else: + assert d.get('abcd') == 'same key' + + + try: + d.put('abcd', 'this should fail', flags=db.DB_NOOVERWRITE) + except db.DBKeyExistError, val: + assert val[0] == db.DB_KEYEXIST + if verbose: print val + else: + self.fail("expected exception") + + if self.dbsetflags & db.DB_DUP: + assert d.get('abcd') == 'a new record' + else: + assert d.get('abcd') == 'same key' + + + d.sync() + d.close() + del d + + self.d = db.DB(self.env) + if self.dbname: + self.d.open(self.filename, self.dbname) + else: + self.d.open(self.filename) + d = self.d + + assert d.get('0321') == '0321-0321-0321-0321-0321' + if self.dbsetflags & db.DB_DUP: + assert d.get('abcd') == 'a new record' + else: + assert d.get('abcd') == 'same key' + + rec = d.get_both('0555', '0555-0555-0555-0555-0555') + if verbose: + print rec + + assert d.get_both('0555', 'bad data') == None + + # test default value + data = d.get('bad key', 'bad data') + assert data == 'bad data' + + # any object can pass through + data = d.get('bad key', self) + assert data == self + + s = d.stat() + assert type(s) == type({}) + if verbose: + print 'd.stat() returned this dictionary:' + pprint(s) + + + #---------------------------------------- + + def test02_DictionaryMethods(self): + d = self.d + + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_DictionaryMethods..." % self.__class__.__name__ + + for key in ['0002', '0101', '0401', '0701', '0998']: + data = d[key] + assert data == self.makeData(key) + if verbose: + print data + + assert len(d) == 1000 + keys = d.keys() + assert len(keys) == 1000 + assert type(keys) == type([]) + + d['new record'] = 'a new record' + assert len(d) == 1001 + keys = d.keys() + assert len(keys) == 1001 + + d['new record'] = 'a replacement record' + assert len(d) == 1001 + keys = d.keys() + assert len(keys) == 1001 + + if verbose: + print "the first 10 keys are:" + pprint(keys[:10]) + + assert d['new record'] == 'a replacement record' + + assert d.has_key('0001') == 1 + assert d.has_key('spam') == 0 + + items = d.items() + assert len(items) == 1001 + assert type(items) == type([]) + assert type(items[0]) == type(()) + assert len(items[0]) == 2 + + if verbose: + print "the first 10 items are:" + pprint(items[:10]) + + values = d.values() + assert len(values) == 1001 + assert type(values) == type([]) + + if verbose: + print "the first 10 values are:" + pprint(values[:10]) + + + + #---------------------------------------- + + def test03_SimpleCursorStuff(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test03_SimpleCursorStuff..." % self.__class__.__name__ + + c = self.d.cursor() + + + rec = c.first() + count = 0 + while rec is not None: + count = count + 1 + if verbose and count % 100 == 0: + print rec + rec = c.next() + + assert count == 1000 + + + rec = c.last() + count = 0 + while rec is not None: + count = count + 1 + if verbose and count % 100 == 0: + print rec + rec = c.prev() + + assert count == 1000 + + rec = c.set('0505') + rec2 = c.current() + assert rec == rec2 + assert rec[0] == '0505' + assert rec[1] == self.makeData('0505') + + try: + c.set('bad key') + except db.DBNotFoundError, val: + assert val[0] == db.DB_NOTFOUND + if verbose: print val + else: + self.fail("expected exception") + + rec = c.get_both('0404', self.makeData('0404')) + assert rec == ('0404', self.makeData('0404')) + + try: + c.get_both('0404', 'bad data') + except db.DBNotFoundError, val: + assert val[0] == db.DB_NOTFOUND + if verbose: print val + else: + self.fail("expected exception") + + if self.d.get_type() == db.DB_BTREE: + rec = c.set_range('011') + if verbose: + print "searched for '011', found: ", rec + + c.set('0499') + c.delete() + try: + rec = c.current() + except db.DBKeyEmptyError, val: + assert val[0] == db.DB_KEYEMPTY + if verbose: print val + else: + self.fail('exception expected') + + c.next() + c2 = c.dup(db.DB_POSITION) + assert c.current() == c2.current() + + c2.put('', 'a new value', db.DB_CURRENT) + assert c.current() == c2.current() + assert c.current()[1] == 'a new value' + + c.close() + c2.close() + + # time to abuse the closed cursors and hope we don't crash + methods_to_test = { + 'current': (), + 'delete': (), + 'dup': (db.DB_POSITION,), + 'first': (), + 'get': (0,), + 'next': (), + 'prev': (), + 'last': (), + 'put':('', 'spam', db.DB_CURRENT), + 'set': ("0505",), + } + for method, args in methods_to_test.items(): + try: + if verbose: + print "attempting to use a closed cursor's %s method" % method + # a bug may cause a NULL pointer dereference... + apply(getattr(c, method), args) + except db.DBError, val: + assert val[0] == 0 + if verbose: print val + else: + self.fail("no exception raised when using a buggy cursor's %s method" % method) + + #---------------------------------------- + + def test04_PartialGetAndPut(self): + d = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test04_PartialGetAndPut..." % self.__class__.__name__ + + key = "partialTest" + data = "1" * 1000 + "2" * 1000 + d.put(key, data) + assert d.get(key) == data + assert d.get(key, dlen=20, doff=990) == ("1" * 10) + ("2" * 10) + + d.put("partialtest2", ("1" * 30000) + "robin" ) + assert d.get("partialtest2", dlen=5, doff=30000) == "robin" + + # There seems to be a bug in DB here... Commented out the test for now. + ##assert d.get("partialtest2", dlen=5, doff=30010) == "" + + if self.dbsetflags != db.DB_DUP: + # Partial put with duplicate records requires a cursor + d.put(key, "0000", dlen=2000, doff=0) + assert d.get(key) == "0000" + + d.put(key, "1111", dlen=1, doff=2) + assert d.get(key) == "0011110" + + #---------------------------------------- + + def test05_GetSize(self): + d = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test05_GetSize..." % self.__class__.__name__ + + for i in range(1, 50000, 500): + key = "size%s" % i + #print "before ", i, + d.put(key, "1" * i) + #print "after", + assert d.get_size(key) == i + #print "done" + +#---------------------------------------------------------------------- + + +class BasicBTreeTestCase(BasicTestCase): + dbtype = db.DB_BTREE + + +class BasicHashTestCase(BasicTestCase): + dbtype = db.DB_HASH + + +class BasicBTreeWithThreadFlagTestCase(BasicTestCase): + dbtype = db.DB_BTREE + dbopenflags = db.DB_THREAD + + +class BasicHashWithThreadFlagTestCase(BasicTestCase): + dbtype = db.DB_HASH + dbopenflags = db.DB_THREAD + + +class BasicBTreeWithEnvTestCase(BasicTestCase): + dbtype = db.DB_BTREE + dbopenflags = db.DB_THREAD + useEnv = 1 + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK + + +class BasicHashWithEnvTestCase(BasicTestCase): + dbtype = db.DB_HASH + dbopenflags = db.DB_THREAD + useEnv = 1 + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK + + +#---------------------------------------------------------------------- + +class BasicTransactionTestCase(BasicTestCase): + dbopenflags = db.DB_THREAD + useEnv = 1 + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK | db.DB_INIT_TXN + + + def tearDown(self): + self.txn.commit() + BasicTestCase.tearDown(self) + + + def populateDB(self): + d = self.d + txn = self.env.txn_begin() + for x in range(500): + key = '%04d' % (1000 - x) # insert keys in reverse order + data = self.makeData(key) + d.put(key, data, txn) + + for x in range(500): + key = '%04d' % x # and now some in forward order + data = self.makeData(key) + d.put(key, data, txn) + + txn.commit() + + num = len(d) + if verbose: + print "created %d records" % num + + self.txn = self.env.txn_begin() + + + + def test06_Transactions(self): + d = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test06_Transactions..." % self.__class__.__name__ + + assert d.get('new rec', txn=self.txn) == None + d.put('new rec', 'this is a new record', self.txn) + assert d.get('new rec', txn=self.txn) == 'this is a new record' + self.txn.abort() + assert d.get('new rec') == None + + self.txn = self.env.txn_begin() + + assert d.get('new rec', txn=self.txn) == None + d.put('new rec', 'this is a new record', self.txn) + assert d.get('new rec', txn=self.txn) == 'this is a new record' + self.txn.commit() + assert d.get('new rec') == 'this is a new record' + + self.txn = self.env.txn_begin() + c = d.cursor(self.txn) + rec = c.first() + count = 0 + while rec is not None: + count = count + 1 + if verbose and count % 100 == 0: + print rec + rec = c.next() + assert count == 1001 + + c.close() # Cursors *MUST* be closed before commit! + self.txn.commit() + + # flush pending updates + try: + self.env.txn_checkpoint (0, 0, 0) + except db.DBIncompleteError: + pass + + # must have at least one log file present: + logs = self.env.log_archive(db.DB_ARCH_ABS | db.DB_ARCH_LOG) + assert logs != None + for log in logs: + if verbose: + print 'log file: ' + log + + self.txn = self.env.txn_begin() + + + +class BTreeTransactionTestCase(BasicTransactionTestCase): + dbtype = db.DB_BTREE + +class HashTransactionTestCase(BasicTransactionTestCase): + dbtype = db.DB_HASH + + + +#---------------------------------------------------------------------- + +class BTreeRecnoTestCase(BasicTestCase): + dbtype = db.DB_BTREE + dbsetflags = db.DB_RECNUM + + def test07_RecnoInBTree(self): + d = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test07_RecnoInBTree..." % self.__class__.__name__ + + rec = d.get(200) + assert type(rec) == type(()) + assert len(rec) == 2 + if verbose: + print "Record #200 is ", rec + + c = d.cursor() + c.set('0200') + num = c.get_recno() + assert type(num) == type(1) + if verbose: + print "recno of d['0200'] is ", num + + rec = c.current() + assert c.set_recno(num) == rec + + c.close() + + + +class BTreeRecnoWithThreadFlagTestCase(BTreeRecnoTestCase): + dbopenflags = db.DB_THREAD + +#---------------------------------------------------------------------- + +class BasicDUPTestCase(BasicTestCase): + dbsetflags = db.DB_DUP + + def test08_DuplicateKeys(self): + d = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test08_DuplicateKeys..." % self.__class__.__name__ + + d.put("dup0", "before") + for x in string.split("The quick brown fox jumped over the lazy dog."): + d.put("dup1", x) + d.put("dup2", "after") + + data = d.get("dup1") + assert data == "The" + if verbose: + print data + + c = d.cursor() + rec = c.set("dup1") + assert rec == ('dup1', 'The') + + next = c.next() + assert next == ('dup1', 'quick') + + rec = c.set("dup1") + count = c.count() + assert count == 9 + + next_dup = c.next_dup() + assert next_dup == ('dup1', 'quick') + + rec = c.set('dup1') + while rec is not None: + if verbose: + print rec + rec = c.next_dup() + + c.set('dup1') + rec = c.next_nodup() + assert rec[0] != 'dup1' + if verbose: + print rec + + c.close() + + + +class BTreeDUPTestCase(BasicDUPTestCase): + dbtype = db.DB_BTREE + +class HashDUPTestCase(BasicDUPTestCase): + dbtype = db.DB_HASH + +class BTreeDUPWithThreadTestCase(BasicDUPTestCase): + dbtype = db.DB_BTREE + dbopenflags = db.DB_THREAD + +class HashDUPWithThreadTestCase(BasicDUPTestCase): + dbtype = db.DB_HASH + dbopenflags = db.DB_THREAD + + +#---------------------------------------------------------------------- + +class BasicMultiDBTestCase(BasicTestCase): + dbname = 'first' + + def otherType(self): + if self.dbtype == db.DB_BTREE: + return db.DB_HASH + else: + return db.DB_BTREE + + def test09_MultiDB(self): + d1 = self.d + if verbose: + print '\n', '-=' * 30 + print "Running %s.test09_MultiDB..." % self.__class__.__name__ + + d2 = db.DB(self.env) + d2.open(self.filename, "second", self.dbtype, self.dbopenflags|db.DB_CREATE) + d3 = db.DB(self.env) + d3.open(self.filename, "third", self.otherType(), self.dbopenflags|db.DB_CREATE) + + for x in string.split("The quick brown fox jumped over the lazy dog"): + d2.put(x, self.makeData(x)) + + for x in string.letters: + d3.put(x, x*70) + + d1.sync() + d2.sync() + d3.sync() + d1.close() + d2.close() + d3.close() + + self.d = d1 = d2 = d3 = None + + self.d = d1 = db.DB(self.env) + d1.open(self.filename, self.dbname, flags = self.dbopenflags) + d2 = db.DB(self.env) + d2.open(self.filename, "second", flags = self.dbopenflags) + d3 = db.DB(self.env) + d3.open(self.filename, "third", flags = self.dbopenflags) + + c1 = d1.cursor() + c2 = d2.cursor() + c3 = d3.cursor() + + count = 0 + rec = c1.first() + while rec is not None: + count = count + 1 + if verbose and (count % 50) == 0: + print rec + rec = c1.next() + assert count == 1000 + + count = 0 + rec = c2.first() + while rec is not None: + count = count + 1 + if verbose: + print rec + rec = c2.next() + assert count == 9 + + count = 0 + rec = c3.first() + while rec is not None: + count = count + 1 + if verbose: + print rec + rec = c3.next() + assert count == 52 + + + c1.close() + c2.close() + c3.close() + + d2.close() + d3.close() + + + +# Strange things happen if you try to use Multiple DBs per file without a +# DBEnv with MPOOL and LOCKing... + +class BTreeMultiDBTestCase(BasicMultiDBTestCase): + dbtype = db.DB_BTREE + dbopenflags = db.DB_THREAD + useEnv = 1 + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK + +class HashMultiDBTestCase(BasicMultiDBTestCase): + dbtype = db.DB_HASH + dbopenflags = db.DB_THREAD + useEnv = 1 + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK + + +#---------------------------------------------------------------------- +#---------------------------------------------------------------------- + +def suite(): + theSuite = unittest.TestSuite() + + theSuite.addTest(unittest.makeSuite(VersionTestCase)) + theSuite.addTest(unittest.makeSuite(BasicBTreeTestCase)) + theSuite.addTest(unittest.makeSuite(BasicHashTestCase)) + theSuite.addTest(unittest.makeSuite(BasicBTreeWithThreadFlagTestCase)) + theSuite.addTest(unittest.makeSuite(BasicHashWithThreadFlagTestCase)) + theSuite.addTest(unittest.makeSuite(BasicBTreeWithEnvTestCase)) + theSuite.addTest(unittest.makeSuite(BasicHashWithEnvTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeTransactionTestCase)) + theSuite.addTest(unittest.makeSuite(HashTransactionTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeRecnoTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeRecnoWithThreadFlagTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeDUPTestCase)) + theSuite.addTest(unittest.makeSuite(HashDUPTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeDUPWithThreadTestCase)) + theSuite.addTest(unittest.makeSuite(HashDUPWithThreadTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeMultiDBTestCase)) + theSuite.addTest(unittest.makeSuite(HashMultiDBTestCase)) + + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + diff --git a/python/test/test_compat.py b/python/test/test_compat.py new file mode 100644 index 000000000..35f281bd9 --- /dev/null +++ b/python/test/test_compat.py @@ -0,0 +1,169 @@ +""" +Test cases adapted from the test_bsddb.py module in Python's +regression test suite. +""" + +import sys, os, string +from rpmdb import hashopen, btopen, rnopen +import rpmdb +import unittest +import tempfile + +from test_all import verbose + + + +class CompatibilityTestCase(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + + def tearDown(self): + try: + os.remove(self.filename) + except os.error: + pass + + + def test01_btopen(self): + self.do_bthash_test(btopen, 'btopen') + + def test02_hashopen(self): + self.do_bthash_test(hashopen, 'hashopen') + + def test03_rnopen(self): + data = string.split("The quick brown fox jumped over the lazy dog.") + if verbose: + print "\nTesting: rnopen" + + f = rnopen(self.filename, 'c') + for x in range(len(data)): + f[x+1] = data[x] + + getTest = (f[1], f[2], f[3]) + if verbose: + print '%s %s %s' % getTest + + assert getTest[1] == 'quick', 'data mismatch!' + + f[25] = 'twenty-five' + f.close() + del f + + f = rnopen(self.filename, 'w') + f[20] = 'twenty' + + def noRec(f): + rec = f[15] + self.assertRaises(KeyError, noRec, f) + + def badKey(f): + rec = f['a string'] + self.assertRaises(TypeError, badKey, f) + + del f[3] + + rec = f.first() + while rec: + if verbose: + print rec + try: + rec = f.next() + except KeyError: + break + + f.close() + + + def test04_n_flag(self): + f = hashopen(self.filename, 'n') + f.close() + + + + def do_bthash_test(self, factory, what): + if verbose: + print '\nTesting: ', what + + f = factory(self.filename, 'c') + if verbose: + print 'creation...' + + # truth test + if f: + if verbose: print "truth test: true" + else: + if verbose: print "truth test: false" + + f['0'] = '' + f['a'] = 'Guido' + f['b'] = 'van' + f['c'] = 'Rossum' + f['d'] = 'invented' + f['f'] = 'Python' + if verbose: + print '%s %s %s' % (f['a'], f['b'], f['c']) + + if verbose: + print 'key ordering...' + f.set_location(f.first()[0]) + while 1: + try: + rec = f.next() + except KeyError: + assert rec == f.last(), 'Error, last <> last!' + f.previous() + break + if verbose: + print rec + + assert f.has_key('f'), 'Error, missing key!' + + f.sync() + f.close() + # truth test + try: + if f: + if verbose: print "truth test: true" + else: + if verbose: print "truth test: false" + except rpmdb.error: + pass + else: + self.fail("Exception expected") + + del f + + if verbose: + print 'modification...' + f = factory(self.filename, 'w') + f['d'] = 'discovered' + + if verbose: + print 'access...' + for key in f.keys(): + word = f[key] + if verbose: + print word + + def noRec(f): + rec = f['no such key'] + self.assertRaises(KeyError, noRec, f) + + def badKey(f): + rec = f[15] + self.assertRaises(TypeError, badKey, f) + + f.close() + + +#---------------------------------------------------------------------- + + +def suite(): + return unittest.makeSuite(CompatibilityTestCase) + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + + diff --git a/python/test/test_dbobj.py b/python/test/test_dbobj.py new file mode 100644 index 000000000..2bd77847a --- /dev/null +++ b/python/test/test_dbobj.py @@ -0,0 +1,72 @@ + +import sys, os, string +import unittest +import glob + +from rpmdb import db, dbobj + + +#---------------------------------------------------------------------- + +class dbobjTestCase(unittest.TestCase): + """Verify that dbobj.DB and dbobj.DBEnv work properly""" + db_home = 'db_home' + db_name = 'test-dbobj.db' + + def setUp(self): + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + + def tearDown(self): + if hasattr(self, 'db'): + del self.db + if hasattr(self, 'env'): + del self.env + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + def test01_both(self): + class TestDBEnv(dbobj.DBEnv): pass + class TestDB(dbobj.DB): + def put(self, key, *args, **kwargs): + key = string.upper(key) + # call our parent classes put method with an upper case key + return apply(dbobj.DB.put, (self, key) + args, kwargs) + self.env = TestDBEnv() + self.env.open(self.db_home, db.DB_CREATE | db.DB_INIT_MPOOL) + self.db = TestDB(self.env) + self.db.open(self.db_name, db.DB_HASH, db.DB_CREATE) + self.db.put('spam', 'eggs') + assert self.db.get('spam') == None, "overridden dbobj.DB.put() method failed [1]" + assert self.db.get('SPAM') == 'eggs', "overridden dbobj.DB.put() method failed [2]" + self.db.close() + self.env.close() + + def test02_dbobj_dict_interface(self): + self.env = dbobj.DBEnv() + self.env.open(self.db_home, db.DB_CREATE | db.DB_INIT_MPOOL) + self.db = dbobj.DB(self.env) + self.db.open(self.db_name+'02', db.DB_HASH, db.DB_CREATE) + # __setitem__ + self.db['spam'] = 'eggs' + # __len__ + assert len(self.db) == 1 + # __getitem__ + assert self.db['spam'] == 'eggs' + # __del__ + del self.db['spam'] + assert self.db.get('spam') == None, "dbobj __del__ failed" + self.db.close() + self.env.close() + +#---------------------------------------------------------------------- + +def suite(): + return unittest.makeSuite(dbobjTestCase) + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + diff --git a/python/test/test_dbshelve.py b/python/test/test_dbshelve.py new file mode 100644 index 000000000..d5b974ac1 --- /dev/null +++ b/python/test/test_dbshelve.py @@ -0,0 +1,305 @@ +""" +TestCases for checking dbShelve objects. +""" + +import sys, os, string +import tempfile, random +from pprint import pprint +from types import * +import unittest + +from rpmdb import dbshelve, db + +from test_all import verbose + + +#---------------------------------------------------------------------- + +# We want the objects to be comparable so we can test dbshelve.values +# later on. +class DataClass: + def __init__(self): + self.value = random.random() + + def __cmp__(self, other): + return cmp(self.value, other) + +class DBShelveTestCase(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + self.do_open() + + def tearDown(self): + self.do_close() + try: + os.remove(self.filename) + except os.error: + pass + + def populateDB(self, d): + for x in string.letters: + d['S' + x] = 10 * x # add a string + d['I' + x] = ord(x) # add an integer + d['L' + x] = [x] * 10 # add a list + + inst = DataClass() # add an instance + inst.S = 10 * x + inst.I = ord(x) + inst.L = [x] * 10 + d['O' + x] = inst + + + # overridable in derived classes to affect how the shelf is created/opened + def do_open(self): + self.d = dbshelve.open(self.filename) + + # and closed... + def do_close(self): + self.d.close() + + + + def test01_basics(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_basics..." % self.__class__.__name__ + + self.populateDB(self.d) + self.d.sync() + self.do_close() + self.do_open() + d = self.d + + l = len(d) + k = d.keys() + s = d.stat() + f = d.fd() + + if verbose: + print "length:", l + print "keys:", k + print "stats:", s + + assert 0 == d.has_key('bad key') + assert 1 == d.has_key('IA') + assert 1 == d.has_key('OA') + + d.delete('IA') + del d['OA'] + assert 0 == d.has_key('IA') + assert 0 == d.has_key('OA') + assert len(d) == l-2 + + values = [] + for key in d.keys(): + value = d[key] + values.append(value) + if verbose: + print "%s: %s" % (key, value) + self.checkrec(key, value) + + dbvalues = d.values() + assert len(dbvalues) == len(d.keys()) + values.sort() + dbvalues.sort() + assert values == dbvalues + + items = d.items() + assert len(items) == len(values) + + for key, value in items: + self.checkrec(key, value) + + assert d.get('bad key') == None + assert d.get('bad key', None) == None + assert d.get('bad key', 'a string') == 'a string' + assert d.get('bad key', [1, 2, 3]) == [1, 2, 3] + + d.set_get_returns_none(0) + self.assertRaises(db.DBNotFoundError, d.get, 'bad key') + d.set_get_returns_none(1) + + d.put('new key', 'new data') + assert d.get('new key') == 'new data' + assert d['new key'] == 'new data' + + + + def test02_cursors(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_cursors..." % self.__class__.__name__ + + self.populateDB(self.d) + d = self.d + + count = 0 + c = d.cursor() + rec = c.first() + while rec is not None: + count = count + 1 + if verbose: + print rec + key, value = rec + self.checkrec(key, value) + rec = c.next() + + assert count == len(d) + + count = 0 + c = d.cursor() + rec = c.last() + while rec is not None: + count = count + 1 + if verbose: + print rec + key, value = rec + self.checkrec(key, value) + rec = c.prev() + + assert count == len(d) + + c.set('SS') + key, value = c.current() + self.checkrec(key, value) + + c.close() + + + + + def checkrec(self, key, value): + x = key[1] + if key[0] == 'S': + assert type(value) == StringType + assert value == 10 * x + + elif key[0] == 'I': + assert type(value) == IntType + assert value == ord(x) + + elif key[0] == 'L': + assert type(value) == ListType + assert value == [x] * 10 + + elif key[0] == 'O': + assert type(value) == InstanceType + assert value.S == 10 * x + assert value.I == ord(x) + assert value.L == [x] * 10 + + else: + raise AssertionError, 'Unknown key type, fix the test' + +#---------------------------------------------------------------------- + +class BasicShelveTestCase(DBShelveTestCase): + def do_open(self): + self.d = dbshelve.DBShelf() + self.d.open(self.filename, self.dbtype, self.dbflags) + + def do_close(self): + self.d.close() + + + + +class BTreeShelveTestCase(BasicShelveTestCase): + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE + + +class HashShelveTestCase(BasicShelveTestCase): + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE + + +class ThreadBTreeShelveTestCase(BasicShelveTestCase): + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE | db.DB_THREAD + + +class ThreadHashShelveTestCase(BasicShelveTestCase): + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE | db.DB_THREAD + + +#---------------------------------------------------------------------- + +class BasicEnvShelveTestCase(DBShelveTestCase): + def do_open(self): + self.homeDir = homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + try: os.mkdir(homeDir) + except os.error: pass + self.env = db.DBEnv() + self.env.open(homeDir, self.envflags | db.DB_INIT_MPOOL | db.DB_CREATE) + + self.filename = os.path.split(self.filename)[1] + self.d = dbshelve.DBShelf(self.env) + self.d.open(self.filename, self.dbtype, self.dbflags) + + + def do_close(self): + self.d.close() + self.env.close() + + + def tearDown(self): + self.do_close() + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + + +class EnvBTreeShelveTestCase(BasicEnvShelveTestCase): + envflags = 0 + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE + + +class EnvHashShelveTestCase(BasicEnvShelveTestCase): + envflags = 0 + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE + + +class EnvThreadBTreeShelveTestCase(BasicEnvShelveTestCase): + envflags = db.DB_THREAD + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE | db.DB_THREAD + + +class EnvThreadHashShelveTestCase(BasicEnvShelveTestCase): + envflags = db.DB_THREAD + dbtype = db.DB_BTREE + dbflags = db.DB_CREATE | db.DB_THREAD + + +#---------------------------------------------------------------------- +# TODO: Add test cases for a DBShelf in a RECNO DB. + + +#---------------------------------------------------------------------- + +def suite(): + theSuite = unittest.TestSuite() + + theSuite.addTest(unittest.makeSuite(DBShelveTestCase)) + theSuite.addTest(unittest.makeSuite(BTreeShelveTestCase)) + theSuite.addTest(unittest.makeSuite(HashShelveTestCase)) + theSuite.addTest(unittest.makeSuite(ThreadBTreeShelveTestCase)) + theSuite.addTest(unittest.makeSuite(ThreadHashShelveTestCase)) + theSuite.addTest(unittest.makeSuite(EnvBTreeShelveTestCase)) + theSuite.addTest(unittest.makeSuite(EnvHashShelveTestCase)) + theSuite.addTest(unittest.makeSuite(EnvThreadBTreeShelveTestCase)) + theSuite.addTest(unittest.makeSuite(EnvThreadHashShelveTestCase)) + + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + + diff --git a/python/test/test_dbtables.py b/python/test/test_dbtables.py new file mode 100644 index 000000000..75fc2add7 --- /dev/null +++ b/python/test/test_dbtables.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python +# +#----------------------------------------------------------------------- +# A test suite for the table interface built on rpmdb.db +#----------------------------------------------------------------------- +# +# Copyright (C) 2000, 2001 by Autonomous Zone Industries +# +# March 20, 2000 +# +# License: This is free software. You may use this software for any +# purpose including modification/redistribution, so long as +# this header remains intact and that you do not claim any +# rights of ownership or authorship of this software. This +# software has been tested, but no warranty is expressed or +# implied. +# +# -- Gregory P. Smith <greg@electricrain.com> +# +# Id: test_dbtables.py,v 1.6 2001/05/14 20:48:18 greg Exp + +import sys, os, re +try: + import cPickle + pickle = cPickle +except ImportError: + import pickle + +import unittest +from test_all import verbose + +from rpmdb import db, dbtables + + + +#---------------------------------------------------------------------- + +class TableDBTestCase(unittest.TestCase): + db_home = 'db_home' + db_name = 'test-table.db' + + def setUp(self): + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + self.tdb = dbtables.bsdTableDB(filename='tabletest.db', dbhome='db_home', create=1) + + def tearDown(self): + self.tdb.close() + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + def test01(self): + tabname = "test01" + colname = 'cool numbers' + try: + self.tdb.Drop(tabname) + except dbtables.TableDBError: + pass + self.tdb.CreateTable(tabname, [colname]) + self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159, 1)}) + + if verbose: + self.tdb._db_print() + + values = self.tdb.Select(tabname, [colname], conditions={colname: None}) + + colval = pickle.loads(values[0][colname]) + assert(colval > 3.141 and colval < 3.142) + + + def test02(self): + tabname = "test02" + col0 = 'coolness factor' + col1 = 'but can it fly?' + col2 = 'Species' + testinfo = [ + {col0: pickle.dumps(8, 1), col1: 'no', col2: 'Penguin'}, + {col0: pickle.dumps(-1, 1), col1: 'no', col2: 'Turkey'}, + {col0: pickle.dumps(9, 1), col1: 'yes', col2: 'SR-71A Blackbird'} + ] + + try: + self.tdb.Drop(tabname) + except dbtables.TableDBError: + pass + self.tdb.CreateTable(tabname, [col0, col1, col2]) + for row in testinfo : + self.tdb.Insert(tabname, row) + + values = self.tdb.Select(tabname, [col2], + conditions={col0: lambda x: pickle.loads(x) >= 8}) + + assert len(values) == 2 + if values[0]['Species'] == 'Penguin' : + assert values[1]['Species'] == 'SR-71A Blackbird' + elif values[0]['Species'] == 'SR-71A Blackbird' : + assert values[1]['Species'] == 'Penguin' + else : + if verbose: + print "values=", `values` + raise "Wrong values returned!" + + def test03(self): + tabname = "test03" + try: + self.tdb.Drop(tabname) + except dbtables.TableDBError: + pass + if verbose: + print '...before CreateTable...' + self.tdb._db_print() + self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e']) + if verbose: + print '...after CreateTable...' + self.tdb._db_print() + self.tdb.Drop(tabname) + if verbose: + print '...after Drop...' + self.tdb._db_print() + self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e']) + + try: + self.tdb.Insert(tabname, {'a': "", 'e': pickle.dumps([{4:5, 6:7}, 'foo'], 1), 'f': "Zero"}) + assert 0 + except dbtables.TableDBError: + pass + + try: + self.tdb.Select(tabname, [], conditions={'foo': '123'}) + assert 0 + except dbtables.TableDBError: + pass + + self.tdb.Insert(tabname, {'a': '42', 'b': "bad", 'c': "meep", 'e': 'Fuzzy wuzzy was a bear'}) + self.tdb.Insert(tabname, {'a': '581750', 'b': "good", 'd': "bla", 'c': "black", 'e': 'fuzzy was here'}) + self.tdb.Insert(tabname, {'a': '800000', 'b': "good", 'd': "bla", 'c': "black", 'e': 'Fuzzy wuzzy is a bear'}) + + if verbose: + self.tdb._db_print() + + # this should return two rows + values = self.tdb.Select(tabname, ['b', 'a', 'd'], + conditions={'e': re.compile('wuzzy').search, 'a': re.compile('^[0-9]+$').match}) + assert len(values) == 2 + + # now lets delete one of them and try again + self.tdb.Delete(tabname, conditions={'b': dbtables.ExactCond('good')}) + values = self.tdb.Select(tabname, ['a', 'd', 'b'], conditions={'e': dbtables.PrefixCond('Fuzzy')}) + assert len(values) == 1 + assert values[0]['d'] == None + + values = self.tdb.Select(tabname, ['b'], + conditions={'c': lambda c: c == 'meep'}) + assert len(values) == 1 + assert values[0]['b'] == "bad" + + + def test_CreateOrExtend(self): + tabname = "test_CreateOrExtend" + + self.tdb.CreateOrExtendTable(tabname, ['name', 'taste', 'filling', 'alcohol content', 'price']) + try: + self.tdb.Insert(tabname, {'taste': 'crap', 'filling': 'no', 'is it Guinness?': 'no'}) + assert 0, "Insert should've failed due to bad column name" + except: + pass + self.tdb.CreateOrExtendTable(tabname, ['name', 'taste', 'is it Guinness?']) + + # these should both succeed as the table should contain the union of both sets of columns. + self.tdb.Insert(tabname, {'taste': 'crap', 'filling': 'no', 'is it Guinness?': 'no'}) + self.tdb.Insert(tabname, {'taste': 'great', 'filling': 'yes', 'is it Guinness?': 'yes', 'name': 'Guinness'}) + + + def test_CondObjs(self): + tabname = "test_CondObjs" + + self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e', 'p']) + + self.tdb.Insert(tabname, {'a': "the letter A", 'b': "the letter B", 'c': "is for cookie"}) + self.tdb.Insert(tabname, {'a': "is for aardvark", 'e': "the letter E", 'c': "is for cookie", 'd': "is for dog"}) + self.tdb.Insert(tabname, {'a': "the letter A", 'e': "the letter E", 'c': "is for cookie", 'p': "is for Python"}) + + values = self.tdb.Select(tabname, ['p', 'e'], conditions={'e': dbtables.PrefixCond('the l')}) + assert len(values) == 2, values + assert values[0]['e'] == values[1]['e'], values + assert values[0]['p'] != values[1]['p'], values + + values = self.tdb.Select(tabname, ['d', 'a'], conditions={'a': dbtables.LikeCond('%aardvark%')}) + assert len(values) == 1, values + assert values[0]['d'] == "is for dog", values + assert values[0]['a'] == "is for aardvark", values + + values = self.tdb.Select(tabname, None, {'b': dbtables.Cond(), 'e':dbtables.LikeCond('%letter%'), 'a':dbtables.PrefixCond('is'), 'd':dbtables.ExactCond('is for dog'), 'c':dbtables.PrefixCond('is for'), 'p':lambda s: not s}) + assert len(values) == 1, values + assert values[0]['d'] == "is for dog", values + assert values[0]['a'] == "is for aardvark", values + + def test_Delete(self): + tabname = "test_Delete" + self.tdb.CreateTable(tabname, ['x', 'y', 'z']) + + # prior to 2001-05-09 there was a bug where Delete() would + # fail if it encountered any rows that did not have values in + # every column. + # Hunted and Squashed by <Donwulff> (Jukka Santala - donwulff@nic.fi) + self.tdb.Insert(tabname, {'x': 'X1', 'y':'Y1'}) + self.tdb.Insert(tabname, {'x': 'X2', 'y':'Y2', 'z': 'Z2'}) + + self.tdb.Delete(tabname, conditions={'x': dbtables.PrefixCond('X')}) + values = self.tdb.Select(tabname, ['y'], conditions={'x': dbtables.PrefixCond('X')}) + assert len(values) == 0 + + def test_Modify(self): + tabname = "test_Modify" + self.tdb.CreateTable(tabname, ['Name', 'Type', 'Access']) + + self.tdb.Insert(tabname, {'Name': 'Index to MP3 files.doc', 'Type': 'Word', 'Access': '8'}) + self.tdb.Insert(tabname, {'Name': 'Nifty.MP3', 'Access': '1'}) + self.tdb.Insert(tabname, {'Type': 'Unknown', 'Access': '0'}) + + def set_type(type): + if type == None: + return 'MP3' + return type + + def increment_access(count): + return str(int(count)+1) + + def remove_value(value): + return None + + self.tdb.Modify(tabname, conditions={'Access': dbtables.ExactCond('0')}, mappings={'Access': remove_value}) + self.tdb.Modify(tabname, conditions={'Name': dbtables.LikeCond('%MP3%')}, mappings={'Type': set_type}) + self.tdb.Modify(tabname, conditions={'Name': dbtables.LikeCond('%')}, mappings={'Access': increment_access}) + + # Delete key in select conditions + values = self.tdb.Select(tabname, None, conditions={'Type': dbtables.ExactCond('Unknown')}) + assert len(values) == 1, values + assert values[0]['Name'] == None, values + assert values[0]['Access'] == None, values + + # Modify value by select conditions + values = self.tdb.Select(tabname, None, conditions={'Name': dbtables.ExactCond('Nifty.MP3')}) + assert len(values) == 1, values + assert values[0]['Type'] == "MP3", values + assert values[0]['Access'] == "2", values + + # Make sure change applied only to select conditions + values = self.tdb.Select(tabname, None, conditions={'Name': dbtables.LikeCond('%doc%')}) + assert len(values) == 1, values + assert values[0]['Type'] == "Word", values + assert values[0]['Access'] == "9", values + +def suite(): + theSuite = unittest.TestSuite() + theSuite.addTest(unittest.makeSuite(TableDBTestCase)) + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/test_get_none.py b/python/test/test_get_none.py new file mode 100644 index 000000000..c39cc51e4 --- /dev/null +++ b/python/test/test_get_none.py @@ -0,0 +1,98 @@ +""" +TestCases for checking set_get_returns_none. +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db + +from test_all import verbose + + +#---------------------------------------------------------------------- + +class GetReturnsNoneTestCase(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + + def tearDown(self): + try: + os.remove(self.filename) + except os.error: + pass + + + def test01_get_returns_none(self): + d = db.DB() + d.open(self.filename, db.DB_BTREE, db.DB_CREATE) + d.set_get_returns_none(1) + + for x in string.letters: + d.put(x, x * 40) + + data = d.get('bad key') + assert data == None + + data = d.get('a') + assert data == 'a'*40 + + count = 0 + c = d.cursor() + rec = c.first() + while rec: + count = count + 1 + rec = c.next() + + assert rec == None + assert count == 52 + + c.close() + d.close() + + + def test02_get_raises_exception(self): + d = db.DB() + d.open(self.filename, db.DB_BTREE, db.DB_CREATE) + d.set_get_returns_none(0) + + for x in string.letters: + d.put(x, x * 40) + + self.assertRaises(db.DBNotFoundError, d.get, 'bad key') + self.assertRaises(KeyError, d.get, 'bad key') + + data = d.get('a') + assert data == 'a'*40 + + count = 0 + exceptionHappened = 0 + c = d.cursor() + rec = c.first() + while rec: + count = count + 1 + try: + rec = c.next() + except db.DBNotFoundError: # end of the records + exceptionHappened = 1 + break + + assert rec != None + assert exceptionHappened + assert count == 52 + + c.close() + d.close() + +#---------------------------------------------------------------------- + +def suite(): + return unittest.makeSuite(GetReturnsNoneTestCase) + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + + diff --git a/python/test/test_join.py b/python/test/test_join.py new file mode 100644 index 000000000..a2f172e31 --- /dev/null +++ b/python/test/test_join.py @@ -0,0 +1,14 @@ +""" +TestCases for using the DB.join and DBCursor.join_item methods. +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db + +from test_all import verbose + + diff --git a/python/test/test_lock.py b/python/test/test_lock.py new file mode 100644 index 000000000..f78ad41a8 --- /dev/null +++ b/python/test/test_lock.py @@ -0,0 +1,124 @@ +""" +TestCases for testing the locking sub-system. +""" + +import sys, os, string +import tempfile +import time +from pprint import pprint +from whrandom import random + +try: + from threading import Thread, currentThread + have_threads = 1 +except ImportError: + have_threads = 0 + + +import unittest +from test_all import verbose + +from rpmdb import db + + +#---------------------------------------------------------------------- + +class LockingTestCase(unittest.TestCase): + + def setUp(self): + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + self.env = db.DBEnv() + self.env.open(homeDir, db.DB_THREAD | db.DB_INIT_MPOOL | + db.DB_INIT_LOCK | db.DB_CREATE) + + + def tearDown(self): + self.env.close() + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + + def test01_simple(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_simple..." % self.__class__.__name__ + + anID = self.env.lock_id() + if verbose: + print "locker ID: %s" % anID + lock = self.env.lock_get(anID, "some locked thing", db.DB_LOCK_WRITE) + if verbose: + print "Aquired lock: %s" % lock + time.sleep(1) + self.env.lock_put(lock) + if verbose: + print "Released lock: %s" % lock + + + + + def test02_threaded(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_threaded..." % self.__class__.__name__ + + threads = [] + threads.append(Thread(target = self.theThread, args=(5, db.DB_LOCK_WRITE))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE))) + threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE))) + + for t in threads: + t.start() + for t in threads: + t.join() + + + + def theThread(self, sleepTime, lockType): + name = currentThread().getName() + if lockType == db.DB_LOCK_WRITE: + lt = "write" + else: + lt = "read" + + anID = self.env.lock_id() + if verbose: + print "%s: locker ID: %s" % (name, anID) + + lock = self.env.lock_get(anID, "some locked thing", lockType) + if verbose: + print "%s: Aquired %s lock: %s" % (name, lt, lock) + + time.sleep(sleepTime) + + self.env.lock_put(lock) + if verbose: + print "%s: Released %s lock: %s" % (name, lt, lock) + + +#---------------------------------------------------------------------- + +def suite(): + theSuite = unittest.TestSuite() + + if have_threads: + theSuite.addTest(unittest.makeSuite(LockingTestCase)) + else: + theSuite.addTest(unittest.makeSuite(LockingTestCase, 'test01')) + + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/test_misc.py b/python/test/test_misc.py new file mode 100644 index 000000000..07e2e8495 --- /dev/null +++ b/python/test/test_misc.py @@ -0,0 +1,56 @@ +""" +Misc TestCases +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db +from rpmdb import dbshelve + +from test_all import verbose + +#---------------------------------------------------------------------- + +class MiscTestCase(unittest.TestCase): + def setUp(self): + self.filename = self.__class__.__name__ + '.db' + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + + def tearDown(self): + try: os.remove(self.filename) + except os.error: pass + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + + + def test01_badpointer(self): + dbs = dbshelve.open(self.filename) + dbs.close() + self.assertRaises(db.DBError, dbs.get, "foo") + + + def test02_db_home(self): + env = db.DBEnv() + # check for crash fixed when db_home is used before open() + assert env.db_home is None + env.open(self.homeDir, db.DB_CREATE) + assert self.homeDir == env.db_home + +#---------------------------------------------------------------------- + + +def suite(): + return unittest.makeSuite(MiscTestCase) + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/test_queue.py b/python/test/test_queue.py new file mode 100644 index 000000000..6a55845e5 --- /dev/null +++ b/python/test/test_queue.py @@ -0,0 +1,168 @@ +""" +TestCases for exercising a Queue DB. +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db + +from test_all import verbose + + +#---------------------------------------------------------------------- + +class SimpleQueueTestCase(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + + def tearDown(self): + try: + os.remove(self.filename) + except os.error: + pass + + + def test01_basic(self): + # Basic Queue tests using the deprecated DBCursor.consume method. + + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_basic..." % self.__class__.__name__ + + d = db.DB() + d.set_re_len(40) # Queues must be fixed length + d.open(self.filename, db.DB_QUEUE, db.DB_CREATE) + + if verbose: + print "before appends" + '-' * 30 + pprint(d.stat()) + + for x in string.letters: + d.append(x * 40) + + assert len(d) == 52 + + d.put(100, "some more data") + d.put(101, "and some more ") + d.put(75, "out of order") + d.put(1, "replacement data") + + assert len(d) == 55 + + if verbose: + print "before close" + '-' * 30 + pprint(d.stat()) + + d.close() + del d + d = db.DB() + d.open(self.filename) + + if verbose: + print "after open" + '-' * 30 + pprint(d.stat()) + + d.append("one more") + c = d.cursor() + + if verbose: + print "after append" + '-' * 30 + pprint(d.stat()) + + rec = c.consume() + while rec: + if verbose: + print rec + rec = c.consume() + c.close() + + if verbose: + print "after consume loop" + '-' * 30 + pprint(d.stat()) + + assert len(d) == 0, \ + "if you see this message then you need to rebuild BerkeleyDB 3.1.17 "\ + "with the patch in patches/qam_stat.diff" + + d.close() + + + + def test02_basicPost32(self): + # Basic Queue tests using the new DB.consume method in DB 3.2+ + # (No cursor needed) + + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_basicPost32..." % self.__class__.__name__ + + if db.version() < (3, 2, 0): + if verbose: + print "Test not run, DB not new enough..." + return + + d = db.DB() + d.set_re_len(40) # Queues must be fixed length + d.open(self.filename, db.DB_QUEUE, db.DB_CREATE) + + if verbose: + print "before appends" + '-' * 30 + pprint(d.stat()) + + for x in string.letters: + d.append(x * 40) + + assert len(d) == 52 + + d.put(100, "some more data") + d.put(101, "and some more ") + d.put(75, "out of order") + d.put(1, "replacement data") + + assert len(d) == 55 + + if verbose: + print "before close" + '-' * 30 + pprint(d.stat()) + + d.close() + del d + d = db.DB() + d.open(self.filename) + #d.set_get_returns_none(true) + + if verbose: + print "after open" + '-' * 30 + pprint(d.stat()) + + d.append("one more") + + if verbose: + print "after append" + '-' * 30 + pprint(d.stat()) + + rec = d.consume() + while rec: + if verbose: + print rec + rec = d.consume() + + if verbose: + print "after consume loop" + '-' * 30 + pprint(d.stat()) + + d.close() + + + +#---------------------------------------------------------------------- + +def suite(): + return unittest.makeSuite(SimpleQueueTestCase) + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/test_recno.py b/python/test/test_recno.py new file mode 100644 index 000000000..ccd2d753a --- /dev/null +++ b/python/test/test_recno.py @@ -0,0 +1,258 @@ +""" +TestCases for exercising a Recno DB. +""" + +import sys, os, string +import tempfile +from pprint import pprint +import unittest + +from rpmdb import db + +from test_all import verbose + +#---------------------------------------------------------------------- + +class SimpleRecnoTestCase(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + + def tearDown(self): + try: + os.remove(self.filename) + except os.error: + pass + + + + def test01_basic(self): + d = db.DB() + d.open(self.filename, db.DB_RECNO, db.DB_CREATE) + + for x in string.letters: + recno = d.append(x * 60) + assert type(recno) == type(0) + assert recno >= 1 + if verbose: + print recno, + + if verbose: print + + stat = d.stat() + if verbose: + pprint(stat) + + for recno in range(1, len(d)+1): + data = d[recno] + if verbose: + print data + + assert type(data) == type("") + assert data == d.get(recno) + + try: + data = d[0] # This should raise a KeyError!?!?! + except db.DBInvalidArgError, val: + assert val[0] == db.EINVAL + if verbose: print val + else: + self.fail("expected exception") + + try: + data = d[100] + except KeyError: + pass + else: + self.fail("expected exception") + + data = d.get(100) + assert data == None + + keys = d.keys() + if verbose: + print keys + assert type(keys) == type([]) + assert type(keys[0]) == type(123) + assert len(keys) == len(d) + + + items = d.items() + if verbose: + pprint(items) + assert type(items) == type([]) + assert type(items[0]) == type(()) + assert len(items[0]) == 2 + assert type(items[0][0]) == type(123) + assert type(items[0][1]) == type("") + assert len(items) == len(d) + + assert d.has_key(25) + + del d[25] + assert not d.has_key(25) + + d.delete(13) + assert not d.has_key(13) + + data = d.get_both(26, "z" * 60) + assert data == "z" * 60 + if verbose: + print data + + fd = d.fd() + if verbose: + print fd + + c = d.cursor() + rec = c.first() + while rec: + if verbose: + print rec + rec = c.next() + + c.set(50) + rec = c.current() + if verbose: + print rec + + c.put(-1, "a replacement record", db.DB_CURRENT) + + c.set(50) + rec = c.current() + assert rec == (50, "a replacement record") + if verbose: + print rec + + rec = c.set_range(30) + if verbose: + print rec + + c.close() + d.close() + + d = db.DB() + d.open(self.filename) + c = d.cursor() + + # put a record beyond the consecutive end of the recno's + d[100] = "way out there" + assert d[100] == "way out there" + + try: + data = d[99] + except KeyError: + pass + else: + self.fail("expected exception") + + try: + d.get(99) + except db.DBKeyEmptyError, val: + assert val[0] == db.DB_KEYEMPTY + if verbose: print val + else: + self.fail("expected exception") + + rec = c.set(40) + while rec: + if verbose: + print rec + rec = c.next() + + c.close() + d.close() + + + def test02_WithSource(self): + """ + A Recno file that is given a "backing source file" is essentially a simple ASCII + file. Normally each record is delimited by \n and so is just a line in the file, + but you can set a different record delimiter if needed. + """ + source = os.path.join(os.path.dirname(sys.argv[0]), 'db_home/test_recno.txt') + f = open(source, 'w') # create the file + f.close() + + d = db.DB() + d.set_re_delim(0x0A) # This is the default value, just checking if both int + d.set_re_delim('\n') # and char can be used... + d.set_re_source(source) + d.open(self.filename, db.DB_RECNO, db.DB_CREATE) + + data = string.split("The quick brown fox jumped over the lazy dog") + for datum in data: + d.append(datum) + d.sync() + d.close() + + # get the text from the backing source + text = open(source, 'r').read() + text = string.strip(text) + if verbose: + print text + print data + print string.split(text, '\n') + + assert string.split(text, '\n') == data + + # open as a DB again + d = db.DB() + d.set_re_source(source) + d.open(self.filename, db.DB_RECNO) + + d[3] = 'reddish-brown' + d[8] = 'comatose' + + d.sync() + d.close() + + text = open(source, 'r').read() + text = string.strip(text) + if verbose: + print text + print string.split(text, '\n') + + assert string.split(text, '\n') == string.split("The quick reddish-brown fox jumped over the comatose dog") + + + def test03_FixedLength(self): + d = db.DB() + d.set_re_len(40) # fixed length records, 40 bytes long + d.set_re_pad('-') # sets the pad character... + d.set_re_pad(45) # ...test both int and char + d.open(self.filename, db.DB_RECNO, db.DB_CREATE) + + for x in string.letters: + d.append(x * 35) # These will be padded + + d.append('.' * 40) # this one will be exact + + try: # this one will fail + d.append('bad' * 20) + except db.DBInvalidArgError, val: + assert val[0] == db.EINVAL + if verbose: print val + else: + self.fail("expected exception") + + c = d.cursor() + rec = c.first() + while rec: + if verbose: + print rec + rec = c.next() + + c.close() + d.close() + +#---------------------------------------------------------------------- + + +def suite(): + return unittest.makeSuite(SimpleRecnoTestCase) + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) + + diff --git a/python/test/test_thread.py b/python/test/test_thread.py new file mode 100644 index 000000000..f8722e27d --- /dev/null +++ b/python/test/test_thread.py @@ -0,0 +1,487 @@ +""" +TestCases for multi-threaded access to a DB. +""" + +import sys, os, string +import tempfile +import time +from pprint import pprint +from whrandom import random + +try: + from threading import Thread, currentThread + have_threads = 1 +except ImportError: + have_threads = 0 + + +import unittest +from test_all import verbose + +from rpmdb import db + + +#---------------------------------------------------------------------- + +class BaseThreadedTestCase(unittest.TestCase): + dbtype = db.DB_UNKNOWN # must be set in derived class + dbopenflags = 0 + dbsetflags = 0 + envflags = 0 + + + def setUp(self): + homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home') + self.homeDir = homeDir + try: os.mkdir(homeDir) + except os.error: pass + self.env = db.DBEnv() + self.setEnvOpts() + self.env.open(homeDir, self.envflags | db.DB_CREATE) + + self.filename = self.__class__.__name__ + '.db' + self.d = db.DB(self.env) + if self.dbsetflags: + self.d.set_flags(self.dbsetflags) + self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE) + + + def tearDown(self): + self.d.close() + self.env.close() + import glob + files = glob.glob(os.path.join(self.homeDir, '*')) + for file in files: + os.remove(file) + + + def setEnvOpts(self): + pass + + + def makeData(self, key): + return string.join([key] * 5, '-') + + +#---------------------------------------------------------------------- + + +class ConcurrentDataStoreBase(BaseThreadedTestCase): + dbopenflags = db.DB_THREAD + envflags = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL + readers = 0 # derived class should set + writers = 0 + records = 1000 + + + def test01_1WriterMultiReaders(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test01_1WriterMultiReaders..." % self.__class__.__name__ + + threads = [] + for x in range(self.writers): + wt = Thread(target = self.writerThread, + args = (self.d, self.records, x), + name = 'writer %d' % x, + )#verbose = verbose) + threads.append(wt) + + for x in range(self.readers): + rt = Thread(target = self.readerThread, + args = (self.d, x), + name = 'reader %d' % x, + )#verbose = verbose) + threads.append(rt) + + for t in threads: + t.start() + for t in threads: + t.join() + + + def writerThread(self, d, howMany, writerNum): + #time.sleep(0.01 * writerNum + 0.01) + name = currentThread().getName() + start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1 + if verbose: + print "%s: creating records %d - %d" % (name, start, stop) + + for x in range(start, stop): + key = '%04d' % x + d.put(key, self.makeData(key)) + if verbose and x % 100 == 0: + print "%s: records %d - %d finished" % (name, start, x) + + if verbose: print "%s: finished creating records" % name + +## # Each write-cursor will be exclusive, the only one that can update the DB... +## if verbose: print "%s: deleting a few records" % name +## c = d.cursor(flags = db.DB_WRITECURSOR) +## for x in range(10): +## key = int(random() * howMany) + start +## key = '%04d' % key +## if d.has_key(key): +## c.set(key) +## c.delete() + +## c.close() + if verbose: print "%s: thread finished" % name + + + def readerThread(self, d, readerNum): + time.sleep(0.01 * readerNum) + name = currentThread().getName() + + for loop in range(5): + c = d.cursor() + count = 0 + rec = c.first() + while rec: + count = count + 1 + key, data = rec + assert self.makeData(key) == data + rec = c.next() + if verbose: print "%s: found %d records" % (name, count) + c.close() + time.sleep(0.05) + + if verbose: print "%s: thread finished" % name + + + +class BTreeConcurrentDataStore(ConcurrentDataStoreBase): + dbtype = db.DB_BTREE + writers = 2 + readers = 10 + records = 1000 + + +class HashConcurrentDataStore(ConcurrentDataStoreBase): + dbtype = db.DB_HASH + writers = 2 + readers = 10 + records = 1000 + +#---------------------------------------------------------------------- + +class SimpleThreadedBase(BaseThreadedTestCase): + dbopenflags = db.DB_THREAD + envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK + readers = 5 + writers = 3 + records = 1000 + + + def setEnvOpts(self): + self.env.set_lk_detect(db.DB_LOCK_DEFAULT) + + + def test02_SimpleLocks(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test02_SimpleLocks..." % self.__class__.__name__ + + threads = [] + for x in range(self.writers): + wt = Thread(target = self.writerThread, + args = (self.d, self.records, x), + name = 'writer %d' % x, + )#verbose = verbose) + threads.append(wt) + for x in range(self.readers): + rt = Thread(target = self.readerThread, + args = (self.d, x), + name = 'reader %d' % x, + )#verbose = verbose) + threads.append(rt) + + for t in threads: + t.start() + for t in threads: + t.join() + + + + def writerThread(self, d, howMany, writerNum): + name = currentThread().getName() + start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1 + if verbose: + print "%s: creating records %d - %d" % (name, start, stop) + + # create a bunch of records + for x in xrange(start, stop): + key = '%04d' % x + d.put(key, self.makeData(key)) + + if verbose and x % 100 == 0: + print "%s: records %d - %d finished" % (name, start, x) + + # do a bit or reading too + if random() <= 0.05: + for y in xrange(start, x): + key = '%04d' % x + data = d.get(key) + assert data == self.makeData(key) + + # flush them + try: + d.sync() + except db.DBIncompleteError, val: + if verbose: + print "could not complete sync()..." + + # read them back, deleting a few + for x in xrange(start, stop): + key = '%04d' % x + data = d.get(key) + if verbose and x % 100 == 0: + print "%s: fetched record (%s, %s)" % (name, key, data) + assert data == self.makeData(key) + if random() <= 0.10: + d.delete(key) + if verbose: + print "%s: deleted record %s" % (name, key) + + if verbose: print "%s: thread finished" % name + + + def readerThread(self, d, readerNum): + time.sleep(0.01 * readerNum) + name = currentThread().getName() + + for loop in range(5): + c = d.cursor() + count = 0 + rec = c.first() + while rec: + count = count + 1 + key, data = rec + assert self.makeData(key) == data + rec = c.next() + if verbose: print "%s: found %d records" % (name, count) + c.close() + time.sleep(0.05) + + if verbose: print "%s: thread finished" % name + + + + +class BTreeSimpleThreaded(SimpleThreadedBase): + dbtype = db.DB_BTREE + + +class HashSimpleThreaded(SimpleThreadedBase): + dbtype = db.DB_BTREE + + +#---------------------------------------------------------------------- + + + +class ThreadedTransactionsBase(BaseThreadedTestCase): + dbopenflags = db.DB_THREAD + envflags = (db.DB_THREAD | + db.DB_INIT_MPOOL | + db.DB_INIT_LOCK | + db.DB_INIT_LOG | + db.DB_INIT_TXN + ) + readers = 0 + writers = 0 + records = 2000 + + txnFlag = 0 + + + def setEnvOpts(self): + #self.env.set_lk_detect(db.DB_LOCK_DEFAULT) + pass + + + def test03_ThreadedTransactions(self): + if verbose: + print '\n', '-=' * 30 + print "Running %s.test03_ThreadedTransactions..." % self.__class__.__name__ + + threads = [] + for x in range(self.writers): + wt = Thread(target = self.writerThread, + args = (self.d, self.records, x), + name = 'writer %d' % x, + )#verbose = verbose) + threads.append(wt) + + for x in range(self.readers): + rt = Thread(target = self.readerThread, + args = (self.d, x), + name = 'reader %d' % x, + )#verbose = verbose) + threads.append(rt) + + dt = Thread(target = self.deadlockThread) + dt.start() + + for t in threads: + t.start() + for t in threads: + t.join() + + self.doLockDetect = 0 + dt.join() + + + def doWrite(self, d, name, start, stop): + finished = 0 + while not finished: + try: + txn = self.env.txn_begin(None, self.txnFlag) + for x in range(start, stop): + key = '%04d' % x + d.put(key, self.makeData(key), txn) + if verbose and x % 100 == 0: + print "%s: records %d - %d finished" % (name, start, x) + txn.commit() + finished = 1 + except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: + if verbose: + print "%s: Aborting transaction (%s)" % (name, val[1]) + txn.abort() + time.sleep(0.05) + + + + def writerThread(self, d, howMany, writerNum): + name = currentThread().getName() + start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1 + if verbose: + print "%s: creating records %d - %d" % (name, start, stop) + + step = 100 + for x in range(start, stop, step): + self.doWrite(d, name, x, min(stop, x+step)) + + if verbose: print "%s: finished creating records" % name + if verbose: print "%s: deleting a few records" % name + + finished = 0 + while not finished: + try: + recs = [] + txn = self.env.txn_begin(None, self.txnFlag) + for x in range(10): + key = int(random() * howMany) + start + key = '%04d' % key + data = d.get(key, None, txn, db.DB_RMW) + if data is not None: + d.delete(key, txn) + recs.append(key) + txn.commit() + finished = 1 + if verbose: print "%s: deleted records %s" % (name, recs) + except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: + if verbose: + print "%s: Aborting transaction (%s)" % (name, val[1]) + txn.abort() + time.sleep(0.05) + + if verbose: print "%s: thread finished" % name + + + def readerThread(self, d, readerNum): + time.sleep(0.01 * readerNum + 0.05) + name = currentThread().getName() + + for loop in range(5): + finished = 0 + while not finished: + try: + txn = self.env.txn_begin(None, self.txnFlag) + c = d.cursor(txn) + count = 0 + rec = c.first() + while rec: + count = count + 1 + key, data = rec + assert self.makeData(key) == data + rec = c.next() + if verbose: print "%s: found %d records" % (name, count) + c.close() + txn.commit() + finished = 1 + except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: + if verbose: + print "%s: Aborting transaction (%s)" % (name, val[1]) + c.close() + txn.abort() + time.sleep(0.05) + + time.sleep(0.05) + + if verbose: print "%s: thread finished" % name + + + def deadlockThread(self): + self.doLockDetect = 1 + while self.doLockDetect: + time.sleep(0.5) + try: + aborted = self.env.lock_detect(db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT) + if verbose and aborted: + print "deadlock: Aborted %d deadlocked transaction(s)" % aborted + except db.DBError: + pass + + + +class BTreeThreadedTransactions(ThreadedTransactionsBase): + dbtype = db.DB_BTREE + writers = 3 + readers = 5 + records = 2000 + +class HashThreadedTransactions(ThreadedTransactionsBase): + dbtype = db.DB_HASH + writers = 1 + readers = 5 + records = 2000 + +class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase): + dbtype = db.DB_BTREE + writers = 3 + readers = 5 + records = 2000 + txnFlag = db.DB_TXN_NOWAIT + +class HashThreadedNoWaitTransactions(ThreadedTransactionsBase): + dbtype = db.DB_HASH + writers = 1 + readers = 5 + records = 2000 + txnFlag = db.DB_TXN_NOWAIT + + +#---------------------------------------------------------------------- + +def suite(): + theSuite = unittest.TestSuite() + + if have_threads: + theSuite.addTest(unittest.makeSuite(BTreeConcurrentDataStore)) + theSuite.addTest(unittest.makeSuite(HashConcurrentDataStore)) + theSuite.addTest(unittest.makeSuite(BTreeSimpleThreaded)) + theSuite.addTest(unittest.makeSuite(HashSimpleThreaded)) + theSuite.addTest(unittest.makeSuite(BTreeThreadedTransactions)) + theSuite.addTest(unittest.makeSuite(HashThreadedTransactions)) + theSuite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions)) + theSuite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions)) + + else: + print "Threads not available, skipping thread tests." + + return theSuite + + +if __name__ == '__main__': + unittest.main( defaultTest='suite' ) diff --git a/python/test/unittest.py b/python/test/unittest.py new file mode 100644 index 000000000..44b05c984 --- /dev/null +++ b/python/test/unittest.py @@ -0,0 +1,693 @@ +#!/usr/bin/env python +""" +Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's +Smalltalk testing framework. + +Further information is available in the bundled documentation, and from + + http://pyunit.sourceforge.net/ + +This module contains the core framework classes that form the basis of +specific test cases and suites (TestCase, TestSuite etc.), and also a +text-based utility class for running the tests and reporting the results +(TextTestRunner). + +Copyright (c) 1999, 2000, 2001 Steve Purcell +This module is free software, and you may redistribute it and/or modify +it under the same terms as Python itself, so long as this copyright message +and disclaimer are retained in their original form. + +IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF +THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. + +THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, +AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, +SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. +""" + +__author__ = "Steve Purcell (stephen_purcell@yahoo.com)" +__version__ = "Revision: 1.3 "[11:-2] + +import time +import sys +import traceback +import string +import os + +############################################################################## +# A platform-specific concession to help the code work for JPython users +############################################################################## + +plat = string.lower(sys.platform) +_isJPython = string.find(plat, 'java') >= 0 or string.find(plat, 'jdk') >= 0 +del plat + + +############################################################################## +# Test framework core +############################################################################## + +class TestResult: + """Holder for test result information. + + Test results are automatically managed by the TestCase and TestSuite + classes, and do not need to be explicitly manipulated by writers of tests. + + Each instance holds the total number of tests run, and collections of + failures and errors that occurred among those test runs. The collections + contain tuples of (testcase, exceptioninfo), where exceptioninfo is a + tuple of values as returned by sys.exc_info(). + """ + def __init__(self): + self.failures = [] + self.errors = [] + self.testsRun = 0 + self.shouldStop = 0 + + def startTest(self, test): + "Called when the given test is about to be run" + self.testsRun = self.testsRun + 1 + + def stopTest(self, test): + "Called when the given test has been run" + pass + + def addError(self, test, err): + "Called when an error has occurred" + self.errors.append((test, err)) + + def addFailure(self, test, err): + "Called when a failure has occurred" + self.failures.append((test, err)) + + def wasSuccessful(self): + "Tells whether or not this result was a success" + return len(self.failures) == len(self.errors) == 0 + + def stop(self): + "Indicates that the tests should be aborted" + self.shouldStop = 1 + + def __repr__(self): + return "<%s run=%i errors=%i failures=%i>" % \ + (self.__class__, self.testsRun, len(self.errors), + len(self.failures)) + + +class TestCase: + """A class whose instances are single test cases. + + Test authors should subclass TestCase for their own tests. Construction + and deconstruction of the test's environment ('fixture') can be + implemented by overriding the 'setUp' and 'tearDown' methods respectively. + + By default, the test code itself should be placed in a method named + 'runTest'. + + If the fixture may be used for many test cases, create as + many test methods as are needed. When instantiating such a TestCase + subclass, specify in the constructor arguments the name of the test method + that the instance is to execute. + + If it is necessary to override the __init__ method, the base class + __init__ method must always be called. + """ + def __init__(self, methodName='runTest'): + """Create an instance of the class that will use the named test + method when executed. Raises a ValueError if the instance does + not have a method with the specified name. + """ + try: + self.__testMethod = getattr(self,methodName) + except AttributeError: + raise ValueError, "no such test method in %s: %s" % \ + (self.__class__, methodName) + + def setUp(self): + "Hook method for setting up the test fixture before exercising it." + pass + + def tearDown(self): + "Hook method for deconstructing the test fixture after testing it." + pass + + def countTestCases(self): + return 1 + + def defaultTestResult(self): + return TestResult() + + def shortDescription(self): + """Returns a one-line description of the test, or None if no + description has been provided. + + The default implementation of this method returns the first line of + the specified test method's docstring. + """ + doc = self.__testMethod.__doc__ + return doc and string.strip(string.split(doc, "\n")[0]) or None + + def id(self): + return "%s.%s" % (self.__class__, self.__testMethod.__name__) + + def __str__(self): + return "%s (%s)" % (self.__testMethod.__name__, self.__class__) + + def __repr__(self): + return "<%s testMethod=%s>" % \ + (self.__class__, self.__testMethod.__name__) + + def run(self, result=None): + return self(result) + + def __call__(self, result=None): + if result is None: result = self.defaultTestResult() + result.startTest(self) + try: + try: + self.setUp() + except: + result.addError(self,self.__exc_info()) + return + + try: + self.__testMethod() + except AssertionError, e: + result.addFailure(self,self.__exc_info()) + except: + result.addError(self,self.__exc_info()) + + try: + self.tearDown() + except: + result.addError(self,self.__exc_info()) + finally: + result.stopTest(self) + + def debug(self): + """Run the test without collecting errors in a TestResult""" + self.setUp() + self.__testMethod() + self.tearDown() + + def assert_(self, expr, msg=None): + """Equivalent of built-in 'assert', but is not optimised out when + __debug__ is false. + """ + if not expr: + raise AssertionError, msg + + failUnless = assert_ + + def failIf(self, expr, msg=None): + "Fail the test if the expression is true." + apply(self.assert_,(not expr,msg)) + + def assertRaises(self, excClass, callableObj, *args, **kwargs): + """Assert that an exception of class excClass is thrown + by callableObj when invoked with arguments args and keyword + arguments kwargs. If a different type of exception is + thrown, it will not be caught, and the test case will be + deemed to have suffered an error, exactly as for an + unexpected exception. + """ + try: + apply(callableObj, args, kwargs) + except excClass: + return + else: + if hasattr(excClass,'__name__'): excName = excClass.__name__ + else: excName = str(excClass) + raise AssertionError, excName + + def fail(self, msg=None): + """Fail immediately, with the given message.""" + raise AssertionError, msg + + def __exc_info(self): + """Return a version of sys.exc_info() with the traceback frame + minimised; usually the top level of the traceback frame is not + needed. + """ + exctype, excvalue, tb = sys.exc_info() + newtb = tb.tb_next + if newtb is None: + return (exctype, excvalue, tb) + return (exctype, excvalue, newtb) + + +class TestSuite: + """A test suite is a composite test consisting of a number of TestCases. + + For use, create an instance of TestSuite, then add test case instances. + When all tests have been added, the suite can be passed to a test + runner, such as TextTestRunner. It will run the individual test cases + in the order in which they were added, aggregating the results. When + subclassing, do not forget to call the base class constructor. + """ + def __init__(self, tests=()): + self._tests = [] + self.addTests(tests) + + def __repr__(self): + return "<%s tests=%s>" % (self.__class__, self._tests) + + __str__ = __repr__ + + def countTestCases(self): + cases = 0 + for test in self._tests: + cases = cases + test.countTestCases() + return cases + + def addTest(self, test): + self._tests.append(test) + + def addTests(self, tests): + for test in tests: + self.addTest(test) + + def run(self, result): + return self(result) + + def __call__(self, result): + for test in self._tests: + if result.shouldStop: + break + test(result) + return result + + def debug(self): + """Run the tests without collecting errors in a TestResult""" + for test in self._tests: test.debug() + + +class FunctionTestCase(TestCase): + """A test case that wraps a test function. + + This is useful for slipping pre-existing test functions into the + PyUnit framework. Optionally, set-up and tidy-up functions can be + supplied. As with TestCase, the tidy-up ('tearDown') function will + always be called if the set-up ('setUp') function ran successfully. + """ + + def __init__(self, testFunc, setUp=None, tearDown=None, + description=None): + TestCase.__init__(self) + self.__setUpFunc = setUp + self.__tearDownFunc = tearDown + self.__testFunc = testFunc + self.__description = description + + def setUp(self): + if self.__setUpFunc is not None: + self.__setUpFunc() + + def tearDown(self): + if self.__tearDownFunc is not None: + self.__tearDownFunc() + + def runTest(self): + self.__testFunc() + + def id(self): + return self.__testFunc.__name__ + + def __str__(self): + return "%s (%s)" % (self.__class__, self.__testFunc.__name__) + + def __repr__(self): + return "<%s testFunc=%s>" % (self.__class__, self.__testFunc) + + def shortDescription(self): + if self.__description is not None: return self.__description + doc = self.__testFunc.__doc__ + return doc and string.strip(string.split(doc, "\n")[0]) or None + + + +############################################################################## +# Convenience functions +############################################################################## + +def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): + """Extracts all the names of functions in the given test case class + and its base classes that start with the given prefix. This is used + by makeSuite(). + """ + testFnNames = filter(lambda n,p=prefix: n[:len(p)] == p, + dir(testCaseClass)) + for baseclass in testCaseClass.__bases__: + testFnNames = testFnNames + \ + getTestCaseNames(baseclass, prefix, sortUsing=None) + if sortUsing: + testFnNames.sort(sortUsing) + return testFnNames + + +def makeSuite(testCaseClass, prefix='test', sortUsing=cmp): + """Returns a TestSuite instance built from all of the test functions + in the given test case class whose names begin with the given + prefix. The cases are sorted by their function names + using the supplied comparison function, which defaults to 'cmp'. + """ + cases = map(testCaseClass, + getTestCaseNames(testCaseClass, prefix, sortUsing)) + return TestSuite(cases) + + +def createTestInstance(name, module=None): + """Finds tests by their name, optionally only within the given module. + + Return the newly-constructed test, ready to run. If the name contains a ':' + then the portion of the name after the colon is used to find a specific + test case within the test case class named before the colon. + + Examples: + findTest('examples.listtests.suite') + -- returns result of calling 'suite' + findTest('examples.listtests.ListTestCase:checkAppend') + -- returns result of calling ListTestCase('checkAppend') + findTest('examples.listtests.ListTestCase:check-') + -- returns result of calling makeSuite(ListTestCase, prefix="check") + """ + + spec = string.split(name, ':') + if len(spec) > 2: raise ValueError, "illegal test name: %s" % name + if len(spec) == 1: + testName = spec[0] + caseName = None + else: + testName, caseName = spec + parts = string.split(testName, '.') + if module is None: + if len(parts) < 2: + raise ValueError, "incomplete test name: %s" % name + constructor = __import__(string.join(parts[:-1],'.')) + parts = parts[1:] + else: + constructor = module + for part in parts: + constructor = getattr(constructor, part) + if not callable(constructor): + raise ValueError, "%s is not a callable object" % constructor + if caseName: + if caseName[-1] == '-': + prefix = caseName[:-1] + if not prefix: + raise ValueError, "prefix too short: %s" % name + test = makeSuite(constructor, prefix=prefix) + else: + test = constructor(caseName) + else: + test = constructor() + if not hasattr(test,"countTestCases"): + raise TypeError, \ + "object %s found with spec %s is not a test" % (test, name) + return test + + +############################################################################## +# Text UI +############################################################################## + +class _WritelnDecorator: + """Used to decorate file-like objects with a handy 'writeln' method""" + def __init__(self,stream): + self.stream = stream + if _isJPython: + import java.lang.System + self.linesep = java.lang.System.getProperty("line.separator") + else: + self.linesep = os.linesep + + def __getattr__(self, attr): + return getattr(self.stream,attr) + + def writeln(self, *args): + if args: apply(self.write, args) + self.write(self.linesep) + + +class _JUnitTextTestResult(TestResult): + """A test result class that can print formatted text results to a stream. + + Used by JUnitTextTestRunner. + """ + def __init__(self, stream): + self.stream = stream + TestResult.__init__(self) + + def addError(self, test, error): + TestResult.addError(self,test,error) + self.stream.write('E') + self.stream.flush() + if error[0] is KeyboardInterrupt: + self.shouldStop = 1 + + def addFailure(self, test, error): + TestResult.addFailure(self,test,error) + self.stream.write('F') + self.stream.flush() + + def startTest(self, test): + TestResult.startTest(self,test) + self.stream.write('.') + self.stream.flush() + + def printNumberedErrors(self,errFlavour,errors): + if not errors: return + if len(errors) == 1: + self.stream.writeln("There was 1 %s:" % errFlavour) + else: + self.stream.writeln("There were %i %ss:" % + (len(errors), errFlavour)) + i = 1 + for test,error in errors: + errString = string.join(apply(traceback.format_exception,error),"") + self.stream.writeln("%i) %s" % (i, test)) + self.stream.writeln(errString) + i = i + 1 + + def printErrors(self): + self.printNumberedErrors("error",self.errors) + + def printFailures(self): + self.printNumberedErrors("failure",self.failures) + + def printHeader(self): + self.stream.writeln() + if self.wasSuccessful(): + self.stream.writeln("OK (%i tests)" % self.testsRun) + else: + self.stream.writeln("!!!FAILURES!!!") + self.stream.writeln("Test Results") + self.stream.writeln() + self.stream.writeln("Run: %i ; Failures: %i ; Errors: %i" % + (self.testsRun, len(self.failures), + len(self.errors))) + + def printResult(self): + self.printHeader() + self.printErrors() + self.printFailures() + + +class JUnitTextTestRunner: + """A test runner class that displays results in textual form. + + The display format approximates that of JUnit's 'textui' test runner. + This test runner may be removed in a future version of PyUnit. + """ + def __init__(self, stream=sys.stderr): + self.stream = _WritelnDecorator(stream) + + def run(self, test): + "Run the given test case or test suite." + result = _JUnitTextTestResult(self.stream) + startTime = time.time() + test(result) + stopTime = time.time() + self.stream.writeln() + self.stream.writeln("Time: %.3fs" % float(stopTime - startTime)) + result.printResult() + return result + + +############################################################################## +# Verbose text UI +############################################################################## + +class _VerboseTextTestResult(TestResult): + """A test result class that can print formatted text results to a stream. + + Used by VerboseTextTestRunner. + """ + def __init__(self, stream, descriptions): + TestResult.__init__(self) + self.stream = stream + self.lastFailure = None + self.descriptions = descriptions + + def startTest(self, test): + TestResult.startTest(self, test) + if self.descriptions: + self.stream.write(test.shortDescription() or str(test)) + else: + self.stream.write(str(test)) + self.stream.write(" ... ") + + def stopTest(self, test): + TestResult.stopTest(self, test) + if self.lastFailure is not test: + self.stream.writeln("ok") + + def addError(self, test, err): + TestResult.addError(self, test, err) + self._printError("ERROR", test, err) + self.lastFailure = test + if err[0] is KeyboardInterrupt: + self.shouldStop = 1 + + def addFailure(self, test, err): + TestResult.addFailure(self, test, err) + self._printError("FAIL", test, err) + self.lastFailure = test + + def _printError(self, flavour, test, err): + errLines = [] + separator1 = "\t" + '=' * 70 + separator2 = "\t" + '-' * 70 + if not self.lastFailure is test: + self.stream.writeln() + self.stream.writeln(separator1) + self.stream.writeln("\t%s" % flavour) + self.stream.writeln(separator2) + for line in apply(traceback.format_exception, err): + for l in string.split(line,"\n")[:-1]: + self.stream.writeln("\t%s" % l) + self.stream.writeln(separator1) + + +class VerboseTextTestRunner: + """A test runner class that displays results in textual form. + + It prints out the names of tests as they are run, errors as they + occur, and a summary of the results at the end of the test run. + """ + def __init__(self, stream=sys.stderr, descriptions=1): + self.stream = _WritelnDecorator(stream) + self.descriptions = descriptions + + def run(self, test): + "Run the given test case or test suite." + result = _VerboseTextTestResult(self.stream, self.descriptions) + startTime = time.time() + test(result) + stopTime = time.time() + timeTaken = float(stopTime - startTime) + self.stream.writeln("-" * 78) + run = result.testsRun + self.stream.writeln("Ran %d test%s in %.3fs" % + (run, run > 1 and "s" or "", timeTaken)) + self.stream.writeln() + if not result.wasSuccessful(): + self.stream.write("FAILED (") + failed, errored = map(len, (result.failures, result.errors)) + if failed: + self.stream.write("failures=%d" % failed) + if errored: + if failed: self.stream.write(", ") + self.stream.write("errors=%d" % errored) + self.stream.writeln(")") + else: + self.stream.writeln("OK") + return result + + +# Which flavour of TextTestRunner is the default? +TextTestRunner = VerboseTextTestRunner + + +############################################################################## +# Facilities for running tests from the command line +############################################################################## + +class TestProgram: + """A command-line program that runs a set of tests; this is primarily + for making test modules conveniently executable. + """ + USAGE = """\ +Usage: %(progName)s [-h|--help] [test[:(casename|prefix-)]] [...] + +Examples: + %(progName)s - run default set of tests + %(progName)s MyTestSuite - run suite 'MyTestSuite' + %(progName)s MyTestCase:checkSomething - run MyTestCase.checkSomething + %(progName)s MyTestCase:check- - run all 'check*' test methods + in MyTestCase +""" + def __init__(self, module='__main__', defaultTest=None, + argv=None, testRunner=None): + if type(module) == type(''): + self.module = __import__(module) + for part in string.split(module,'.')[1:]: + self.module = getattr(self.module, part) + else: + self.module = module + if argv is None: + argv = sys.argv + self.defaultTest = defaultTest + self.testRunner = testRunner + self.progName = os.path.basename(argv[0]) + self.parseArgs(argv) + self.createTests() + self.runTests() + + def usageExit(self, msg=None): + if msg: print msg + print self.USAGE % self.__dict__ + sys.exit(2) + + def parseArgs(self, argv): + import getopt + try: + options, args = getopt.getopt(argv[1:], 'hH', ['help']) + opts = {} + for opt, value in options: + if opt in ('-h','-H','--help'): + self.usageExit() + if len(args) == 0 and self.defaultTest is None: + raise getopt.error, "No default test is defined." + if len(args) > 0: + self.testNames = args + else: + self.testNames = (self.defaultTest,) + except getopt.error, msg: + self.usageExit(msg) + + def createTests(self): + tests = [] + for testName in self.testNames: + tests.append(createTestInstance(testName, self.module)) + self.test = TestSuite(tests) + + def runTests(self): + if self.testRunner is None: + self.testRunner = TextTestRunner() + result = self.testRunner.run(self.test) + sys.exit(not result.wasSuccessful()) + +main = TestProgram + + +############################################################################## +# Executing this module from the command line +############################################################################## + +if __name__ == "__main__": + main(module=None) |