summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjbj <devnull@localhost>2003-05-05 21:42:55 +0000
committerjbj <devnull@localhost>2003-05-05 21:42:55 +0000
commit3bb3246247fea289d3a2638fde2f3d0b191774fd (patch)
treea6045517c91b0b080d72716269ea64ac79e6ca0c
parent704ce887bfe69799fff28882443025d59417f733 (diff)
downloadlibrpm-tizen-3bb3246247fea289d3a2638fde2f3d0b191774fd.tar.gz
librpm-tizen-3bb3246247fea289d3a2638fde2f3d0b191774fd.tar.bz2
librpm-tizen-3bb3246247fea289d3a2638fde2f3d0b191774fd.zip
Add unit test {rpmdb,mpw}/test subdirs.
CVS patchset: 6817 CVS date: 2003/05/05 21:42:55
-rw-r--r--python/Makefile.am2
-rw-r--r--python/mpw/Makefile.am5
-rw-r--r--python/mpw/test/Makefile.am13
-rw-r--r--python/mpw/test/test_all.py59
-rw-r--r--python/mpw/test/test_methods.py101
-rw-r--r--python/mpw/test/unittest.py759
-rw-r--r--python/rpmdb/Makefile.am4
-rw-r--r--python/rpmdb/test/.cvsignore6
-rw-r--r--python/rpmdb/test/Makefile.am15
-rw-r--r--python/rpmdb/test/test_all.py76
-rw-r--r--python/rpmdb/test/test_associate.py321
-rw-r--r--python/rpmdb/test/test_basics.py882
-rw-r--r--python/rpmdb/test/test_compat.py166
-rw-r--r--python/rpmdb/test/test_dbobj.py73
-rw-r--r--python/rpmdb/test/test_dbshelve.py301
-rw-r--r--python/rpmdb/test/test_dbtables.py368
-rw-r--r--python/rpmdb/test/test_env_close.py102
-rw-r--r--python/rpmdb/test/test_get_none.py96
-rw-r--r--python/rpmdb/test/test_join.py9
-rw-r--r--python/rpmdb/test/test_lock.py139
-rw-r--r--python/rpmdb/test/test_misc.py53
-rw-r--r--python/rpmdb/test/test_queue.py168
-rw-r--r--python/rpmdb/test/test_recno.py260
-rw-r--r--python/rpmdb/test/test_thread.py495
-rw-r--r--python/rpmdb/test/unittest.py759
-rw-r--r--python/rpmmpw-py.c4
26 files changed, 5232 insertions, 4 deletions
diff --git a/python/Makefile.am b/python/Makefile.am
index 5698b1ce7..dd159b3db 100644
--- a/python/Makefile.am
+++ b/python/Makefile.am
@@ -9,7 +9,7 @@ PYVER= @WITH_PYTHON_VERSION@
pylibdir = $(shell python -c 'import sys; print sys.path[1]')
pyincdir = $(prefix)/include/python${PYVER}
-SUBDIRS = rpmdb test
+SUBDIRS = rpmdb mpw
EXTRA_DIST = rpmdebug-py.c
diff --git a/python/mpw/Makefile.am b/python/mpw/Makefile.am
new file mode 100644
index 000000000..358906224
--- /dev/null
+++ b/python/mpw/Makefile.am
@@ -0,0 +1,5 @@
+# Makefile for rpm library.
+
+AUTOMAKE_OPTIONS = 1.4 foreign
+
+SUBDIRS = test
diff --git a/python/mpw/test/Makefile.am b/python/mpw/test/Makefile.am
new file mode 100644
index 000000000..1e9ba892f
--- /dev/null
+++ b/python/mpw/test/Makefile.am
@@ -0,0 +1,13 @@
+# Makefile for rpm library.
+
+AUTOMAKE_OPTIONS = 1.4 foreign
+
+PYVER= @WITH_PYTHON_VERSION@
+
+pylibdir = $(shell python -c 'import sys; print sys.path[1]')
+
+EXTRA_DIST = \
+ test_all.py test_methods.py \
+ unittest.py
+
+all:
diff --git a/python/mpw/test/test_all.py b/python/mpw/test/test_all.py
new file mode 100644
index 000000000..7690d4efd
--- /dev/null
+++ b/python/mpw/test/test_all.py
@@ -0,0 +1,59 @@
+"""Run all test cases.
+"""
+
+import sys
+import os
+import unittest
+
+verbose = 0
+if 'verbose' in sys.argv:
+ verbose = 1
+ sys.argv.remove('verbose')
+
+if 'silent' in sys.argv: # take care of old flag, just in case
+ verbose = 0
+ sys.argv.remove('silent')
+
+
+def print_versions():
+ from rpm import mpw
+ print
+ print '-=' * 38
+ print 'python version: %s' % sys.version
+ print 'My pid: %s' % os.getpid()
+ print '-=' * 38
+
+
+class PrintInfoFakeTest(unittest.TestCase):
+ def testPrintVersions(self):
+ print_versions()
+
+
+# This little hack is for when this module is run as main and all the
+# other modules import it so they will still be able to get the right
+# verbose setting. It's confusing but it works.
+import test_all
+test_all.verbose = verbose
+
+
+def suite():
+ test_modules = [
+ 'test_methods',
+ ]
+
+ alltests = unittest.TestSuite()
+ for name in test_modules:
+ module = __import__(name)
+ alltests.addTest(module.test_suite())
+ return alltests
+
+
+def test_suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(PrintInfoFakeTest))
+ return suite
+
+
+if __name__ == '__main__':
+ print_versions()
+ unittest.main(defaultTest='suite')
diff --git a/python/mpw/test/test_methods.py b/python/mpw/test/test_methods.py
new file mode 100644
index 000000000..6c851eb63
--- /dev/null
+++ b/python/mpw/test/test_methods.py
@@ -0,0 +1,101 @@
+"""
+Basic TestCases for BTree and hash DBs, with and without a DBEnv, with
+various DB flags, etc.
+"""
+
+import os
+import sys
+import errno
+import shutil
+import string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpm import mpw
+import mpz
+
+from test_all import verbose
+
+DASH = '-'
+
+
+#----------------------------------------------------------------------
+
+class BasicTestCase(unittest.TestCase):
+
+ def setUp(self):
+ mpw().Debug(0)
+ pass
+
+ def tearDown(self):
+ mpw().Debug(0)
+ pass
+
+ #----------------------------------------
+
+ def test01_SimpleMethods(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_GetsAndPuts..." % \
+ self.__class__.__name__
+
+ wa = mpw("0000000987654321")
+ wb = mpw("0000000000000010")
+ wc = mpw("0fedcba000000000")
+ za = mpz.mpz(0x0000000987654321)
+ zb = mpz.mpz(0x0000000000000010)
+ zc = mpz.mpz(0x0fedcba000000000)
+
+ print hex(mpw.__add__(wa, wb)), hex(mpz.MPZType.__add__(za, zb))
+ print hex(mpw.__sub__(wa, wb)), hex(mpz.MPZType.__sub__(za, zb))
+ print hex(mpw.__mul__(wa, wb)), hex(mpz.MPZType.__mul__(za, zb))
+ print hex(mpw.__div__(wa, wb)), hex(mpz.MPZType.__div__(za, zb))
+ print hex(mpw.__mod__(wa, wb)), hex(mpz.MPZType.__mod__(za, zb))
+
+# print mpw.__divmod__(a, b)
+# print mpw.__pow__(a, b)
+
+# print mpw.__lshift__(a, b)
+# print mpw.__rshift__(a, b)
+
+# print mpw.__and__(a, c)
+# print mpw.__xor__(a, a)
+# print mpw.__or__(a, c)
+
+# print mpw.__neg__(a)
+# print mpw.__pos__(a)
+# print mpw.__abs__(a)
+# print mpw.__invert__(a)
+
+# print mpw.__int__(b)
+# print mpw.__long__(b)
+# print mpw.__float__(a)
+# print mpw.__complex__(b)
+# print mpw.__oct__(a*b)
+# print mpw.__hex__(a*b)
+# print mpw.__coerce__(b, i)
+
+ del wa
+ del wb
+ del wc
+ del za
+ del zb
+ del zc
+ pass
+
+ #----------------------------------------
+
+#----------------------------------------------------------------------
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ suite.addTest(unittest.makeSuite(BasicTestCase))
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/mpw/test/unittest.py b/python/mpw/test/unittest.py
new file mode 100644
index 000000000..d31e251d4
--- /dev/null
+++ b/python/mpw/test/unittest.py
@@ -0,0 +1,759 @@
+#!/usr/bin/env python
+'''
+Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's
+Smalltalk testing framework.
+
+This module contains the core framework classes that form the basis of
+specific test cases and suites (TestCase, TestSuite etc.), and also a
+text-based utility class for running the tests and reporting the results
+ (TextTestRunner).
+
+Simple usage:
+
+ import unittest
+
+ class IntegerArithmenticTestCase(unittest.TestCase):
+ def testAdd(self): ## test method names begin 'test*'
+ self.assertEquals((1 + 2), 3)
+ self.assertEquals(0 + 1, 1)
+ def testMultiply(self):
+ self.assertEquals((0 * 10), 0)
+ self.assertEquals((5 * 8), 40)
+
+ if __name__ == '__main__':
+ unittest.main()
+
+Further information is available in the bundled documentation, and from
+
+ http://pyunit.sourceforge.net/
+
+Copyright (c) 1999, 2000, 2001 Steve Purcell
+This module is free software, and you may redistribute it and/or modify
+it under the same terms as Python itself, so long as this copyright message
+and disclaimer are retained in their original form.
+
+IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
+THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
+
+THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
+AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
+SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
+'''
+
+__author__ = "Steve Purcell"
+__email__ = "stephen_purcell at yahoo dot com"
+__version__ = "#Revision: 1.46 $"[11:-2]
+
+import time
+import sys
+import traceback
+import string
+import os
+import types
+
+##############################################################################
+# Test framework core
+##############################################################################
+
+# All classes defined herein are 'new-style' classes, allowing use of 'super()'
+__metaclass__ = type
+
+def _strclass(cls):
+ return "%s.%s" % (cls.__module__, cls.__name__)
+
+class TestResult:
+ """Holder for test result information.
+
+ Test results are automatically managed by the TestCase and TestSuite
+ classes, and do not need to be explicitly manipulated by writers of tests.
+
+ Each instance holds the total number of tests run, and collections of
+ failures and errors that occurred among those test runs. The collections
+ contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
+ formatted traceback of the error that occurred.
+ """
+ def __init__(self):
+ self.failures = []
+ self.errors = []
+ self.testsRun = 0
+ self.shouldStop = 0
+
+ def startTest(self, test):
+ "Called when the given test is about to be run"
+ self.testsRun = self.testsRun + 1
+
+ def stopTest(self, test):
+ "Called when the given test has been run"
+ pass
+
+ def addError(self, test, err):
+ """Called when an error has occurred. 'err' is a tuple of values as
+ returned by sys.exc_info().
+ """
+ self.errors.append((test, self._exc_info_to_string(err)))
+
+ def addFailure(self, test, err):
+ """Called when an error has occurred. 'err' is a tuple of values as
+ returned by sys.exc_info()."""
+ self.failures.append((test, self._exc_info_to_string(err)))
+
+ def addSuccess(self, test):
+ "Called when a test has completed successfully"
+ pass
+
+ def wasSuccessful(self):
+ "Tells whether or not this result was a success"
+ return len(self.failures) == len(self.errors) == 0
+
+ def stop(self):
+ "Indicates that the tests should be aborted"
+ self.shouldStop = 1
+
+ def _exc_info_to_string(self, err):
+ """Converts a sys.exc_info()-style tuple of values into a string."""
+ return string.join(traceback.format_exception(*err), '')
+
+ def __repr__(self):
+ return "<%s run=%i errors=%i failures=%i>" % \
+ (_strclass(self.__class__), self.testsRun, len(self.errors),
+ len(self.failures))
+
+
+class TestCase:
+ """A class whose instances are single test cases.
+
+ By default, the test code itself should be placed in a method named
+ 'runTest'.
+
+ If the fixture may be used for many test cases, create as
+ many test methods as are needed. When instantiating such a TestCase
+ subclass, specify in the constructor arguments the name of the test method
+ that the instance is to execute.
+
+ Test authors should subclass TestCase for their own tests. Construction
+ and deconstruction of the test's environment ('fixture') can be
+ implemented by overriding the 'setUp' and 'tearDown' methods respectively.
+
+ If it is necessary to override the __init__ method, the base class
+ __init__ method must always be called. It is important that subclasses
+ should not change the signature of their __init__ method, since instances
+ of the classes are instantiated automatically by parts of the framework
+ in order to be run.
+ """
+
+ # This attribute determines which exception will be raised when
+ # the instance's assertion methods fail; test methods raising this
+ # exception will be deemed to have 'failed' rather than 'errored'
+
+ failureException = AssertionError
+
+ def __init__(self, methodName='runTest'):
+ """Create an instance of the class that will use the named test
+ method when executed. Raises a ValueError if the instance does
+ not have a method with the specified name.
+ """
+ try:
+ self.__testMethodName = methodName
+ testMethod = getattr(self, methodName)
+ self.__testMethodDoc = testMethod.__doc__
+ except AttributeError:
+ raise ValueError, "no such test method in %s: %s" % \
+ (self.__class__, methodName)
+
+ def setUp(self):
+ "Hook method for setting up the test fixture before exercising it."
+ pass
+
+ def tearDown(self):
+ "Hook method for deconstructing the test fixture after testing it."
+ pass
+
+ def countTestCases(self):
+ return 1
+
+ def defaultTestResult(self):
+ return TestResult()
+
+ def shortDescription(self):
+ """Returns a one-line description of the test, or None if no
+ description has been provided.
+
+ The default implementation of this method returns the first line of
+ the specified test method's docstring.
+ """
+ doc = self.__testMethodDoc
+ return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+ def id(self):
+ return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
+
+ def __str__(self):
+ return "%s (%s)" % (self.__testMethodName, _strclass(self.__class__))
+
+ def __repr__(self):
+ return "<%s testMethod=%s>" % \
+ (_strclass(self.__class__), self.__testMethodName)
+
+ def run(self, result=None):
+ return self(result)
+
+ def __call__(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ result.startTest(self)
+ testMethod = getattr(self, self.__testMethodName)
+ try:
+ try:
+ self.setUp()
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+ return
+
+ ok = 0
+ try:
+ testMethod()
+ ok = 1
+ except self.failureException, e:
+ result.addFailure(self, self.__exc_info())
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+
+ try:
+ self.tearDown()
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+ ok = 0
+ if ok: result.addSuccess(self)
+ finally:
+ result.stopTest(self)
+
+ def debug(self):
+ """Run the test without collecting errors in a TestResult"""
+ self.setUp()
+ getattr(self, self.__testMethodName)()
+ self.tearDown()
+
+ def __exc_info(self):
+ """Return a version of sys.exc_info() with the traceback frame
+ minimised; usually the top level of the traceback frame is not
+ needed.
+ """
+ exctype, excvalue, tb = sys.exc_info()
+ if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+ return (exctype, excvalue, tb)
+ newtb = tb.tb_next
+ if newtb is None:
+ return (exctype, excvalue, tb)
+ return (exctype, excvalue, newtb)
+
+ def fail(self, msg=None):
+ """Fail immediately, with the given message."""
+ raise self.failureException, msg
+
+ def failIf(self, expr, msg=None):
+ "Fail the test if the expression is true."
+ if expr: raise self.failureException, msg
+
+ def failUnless(self, expr, msg=None):
+ """Fail the test unless the expression is true."""
+ if not expr: raise self.failureException, msg
+
+ def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
+ """Fail unless an exception of class excClass is thrown
+ by callableObj when invoked with arguments args and keyword
+ arguments kwargs. If a different type of exception is
+ thrown, it will not be caught, and the test case will be
+ deemed to have suffered an error, exactly as for an
+ unexpected exception.
+ """
+ try:
+ callableObj(*args, **kwargs)
+ except excClass:
+ return
+ else:
+ if hasattr(excClass,'__name__'): excName = excClass.__name__
+ else: excName = str(excClass)
+ raise self.failureException, excName
+
+ def failUnlessEqual(self, first, second, msg=None):
+ """Fail if the two objects are unequal as determined by the '=='
+ operator.
+ """
+ if not first == second:
+ raise self.failureException, \
+ (msg or '%s != %s' % (`first`, `second`))
+
+ def failIfEqual(self, first, second, msg=None):
+ """Fail if the two objects are equal as determined by the '=='
+ operator.
+ """
+ if first == second:
+ raise self.failureException, \
+ (msg or '%s == %s' % (`first`, `second`))
+
+ def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
+ """Fail if the two objects are unequal as determined by their
+ difference rounded to the given number of decimal places
+ (default 7) and comparing to zero.
+
+ Note that decimal places (from zero) is usually not the same
+ as significant digits (measured from the most signficant digit).
+ """
+ if round(second-first, places) != 0:
+ raise self.failureException, \
+ (msg or '%s != %s within %s places' % (`first`, `second`, `places` ))
+
+ def failIfAlmostEqual(self, first, second, places=7, msg=None):
+ """Fail if the two objects are equal as determined by their
+ difference rounded to the given number of decimal places
+ (default 7) and comparing to zero.
+
+ Note that decimal places (from zero) is usually not the same
+ as significant digits (measured from the most signficant digit).
+ """
+ if round(second-first, places) == 0:
+ raise self.failureException, \
+ (msg or '%s == %s within %s places' % (`first`, `second`, `places`))
+
+ assertEqual = assertEquals = failUnlessEqual
+
+ assertNotEqual = assertNotEquals = failIfEqual
+
+ assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
+
+ assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
+
+ assertRaises = failUnlessRaises
+
+ assert_ = failUnless
+
+
+
+class TestSuite:
+ """A test suite is a composite test consisting of a number of TestCases.
+
+ For use, create an instance of TestSuite, then add test case instances.
+ When all tests have been added, the suite can be passed to a test
+ runner, such as TextTestRunner. It will run the individual test cases
+ in the order in which they were added, aggregating the results. When
+ subclassing, do not forget to call the base class constructor.
+ """
+ def __init__(self, tests=()):
+ self._tests = []
+ self.addTests(tests)
+
+ def __repr__(self):
+ return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
+
+ __str__ = __repr__
+
+ def countTestCases(self):
+ cases = 0
+ for test in self._tests:
+ cases = cases + test.countTestCases()
+ return cases
+
+ def addTest(self, test):
+ self._tests.append(test)
+
+ def addTests(self, tests):
+ for test in tests:
+ self.addTest(test)
+
+ def run(self, result):
+ return self(result)
+
+ def __call__(self, result):
+ for test in self._tests:
+ if result.shouldStop:
+ break
+ test(result)
+ return result
+
+ def debug(self):
+ """Run the tests without collecting errors in a TestResult"""
+ for test in self._tests: test.debug()
+
+
+class FunctionTestCase(TestCase):
+ """A test case that wraps a test function.
+
+ This is useful for slipping pre-existing test functions into the
+ PyUnit framework. Optionally, set-up and tidy-up functions can be
+ supplied. As with TestCase, the tidy-up ('tearDown') function will
+ always be called if the set-up ('setUp') function ran successfully.
+ """
+
+ def __init__(self, testFunc, setUp=None, tearDown=None,
+ description=None):
+ TestCase.__init__(self)
+ self.__setUpFunc = setUp
+ self.__tearDownFunc = tearDown
+ self.__testFunc = testFunc
+ self.__description = description
+
+ def setUp(self):
+ if self.__setUpFunc is not None:
+ self.__setUpFunc()
+
+ def tearDown(self):
+ if self.__tearDownFunc is not None:
+ self.__tearDownFunc()
+
+ def runTest(self):
+ self.__testFunc()
+
+ def id(self):
+ return self.__testFunc.__name__
+
+ def __str__(self):
+ return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
+
+ def __repr__(self):
+ return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
+
+ def shortDescription(self):
+ if self.__description is not None: return self.__description
+ doc = self.__testFunc.__doc__
+ return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+
+
+##############################################################################
+# Locating and loading tests
+##############################################################################
+
+class TestLoader:
+ """This class is responsible for loading tests according to various
+ criteria and returning them wrapped in a Test
+ """
+ testMethodPrefix = 'test'
+ sortTestMethodsUsing = cmp
+ suiteClass = TestSuite
+
+ def loadTestsFromTestCase(self, testCaseClass):
+ """Return a suite of all tests cases contained in testCaseClass"""
+ return self.suiteClass(map(testCaseClass,
+ self.getTestCaseNames(testCaseClass)))
+
+ def loadTestsFromModule(self, module):
+ """Return a suite of all tests cases contained in the given module"""
+ tests = []
+ for name in dir(module):
+ obj = getattr(module, name)
+ if (isinstance(obj, (type, types.ClassType)) and
+ issubclass(obj, TestCase)):
+ tests.append(self.loadTestsFromTestCase(obj))
+ return self.suiteClass(tests)
+
+ def loadTestsFromName(self, name, module=None):
+ """Return a suite of all tests cases given a string specifier.
+
+ The name may resolve either to a module, a test case class, a
+ test method within a test case class, or a callable object which
+ returns a TestCase or TestSuite instance.
+
+ The method optionally resolves the names relative to a given module.
+ """
+ parts = string.split(name, '.')
+ if module is None:
+ if not parts:
+ raise ValueError, "incomplete test name: %s" % name
+ else:
+ parts_copy = parts[:]
+ while parts_copy:
+ try:
+ module = __import__(string.join(parts_copy,'.'))
+ break
+ except ImportError:
+ del parts_copy[-1]
+ if not parts_copy: raise
+ parts = parts[1:]
+ obj = module
+ for part in parts:
+ obj = getattr(obj, part)
+
+ import unittest
+ if type(obj) == types.ModuleType:
+ return self.loadTestsFromModule(obj)
+ elif (isinstance(obj, (type, types.ClassType)) and
+ issubclass(obj, unittest.TestCase)):
+ return self.loadTestsFromTestCase(obj)
+ elif type(obj) == types.UnboundMethodType:
+ return obj.im_class(obj.__name__)
+ elif callable(obj):
+ test = obj()
+ if not isinstance(test, unittest.TestCase) and \
+ not isinstance(test, unittest.TestSuite):
+ raise ValueError, \
+ "calling %s returned %s, not a test" % (obj,test)
+ return test
+ else:
+ raise ValueError, "don't know how to make test from: %s" % obj
+
+ def loadTestsFromNames(self, names, module=None):
+ """Return a suite of all tests cases found using the given sequence
+ of string specifiers. See 'loadTestsFromName()'.
+ """
+ suites = []
+ for name in names:
+ suites.append(self.loadTestsFromName(name, module))
+ return self.suiteClass(suites)
+
+ def getTestCaseNames(self, testCaseClass):
+ """Return a sorted sequence of method names found within testCaseClass
+ """
+ testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p,
+ dir(testCaseClass))
+ for baseclass in testCaseClass.__bases__:
+ for testFnName in self.getTestCaseNames(baseclass):
+ if testFnName not in testFnNames: # handle overridden methods
+ testFnNames.append(testFnName)
+ if self.sortTestMethodsUsing:
+ testFnNames.sort(self.sortTestMethodsUsing)
+ return testFnNames
+
+
+
+defaultTestLoader = TestLoader()
+
+
+##############################################################################
+# Patches for old functions: these functions should be considered obsolete
+##############################################################################
+
+def _makeLoader(prefix, sortUsing, suiteClass=None):
+ loader = TestLoader()
+ loader.sortTestMethodsUsing = sortUsing
+ loader.testMethodPrefix = prefix
+ if suiteClass: loader.suiteClass = suiteClass
+ return loader
+
+def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
+ return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
+
+def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
+ return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
+
+def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
+ return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
+
+
+##############################################################################
+# Text UI
+##############################################################################
+
+class _WritelnDecorator:
+ """Used to decorate file-like objects with a handy 'writeln' method"""
+ def __init__(self,stream):
+ self.stream = stream
+
+ def __getattr__(self, attr):
+ return getattr(self.stream,attr)
+
+ def writeln(self, *args):
+ if args: self.write(*args)
+ self.write('\n') # text-mode streams translate to \r\n if needed
+
+
+class _TextTestResult(TestResult):
+ """A test result class that can print formatted text results to a stream.
+
+ Used by TextTestRunner.
+ """
+ separator1 = '=' * 70
+ separator2 = '-' * 70
+
+ def __init__(self, stream, descriptions, verbosity):
+ TestResult.__init__(self)
+ self.stream = stream
+ self.showAll = verbosity > 1
+ self.dots = verbosity == 1
+ self.descriptions = descriptions
+
+ def getDescription(self, test):
+ if self.descriptions:
+ return test.shortDescription() or str(test)
+ else:
+ return str(test)
+
+ def startTest(self, test):
+ TestResult.startTest(self, test)
+ if self.showAll:
+ self.stream.write(self.getDescription(test))
+ self.stream.write(" ... ")
+
+ def addSuccess(self, test):
+ TestResult.addSuccess(self, test)
+ if self.showAll:
+ self.stream.writeln("ok")
+ elif self.dots:
+ self.stream.write('.')
+
+ def addError(self, test, err):
+ TestResult.addError(self, test, err)
+ if self.showAll:
+ self.stream.writeln("ERROR")
+ elif self.dots:
+ self.stream.write('E')
+
+ def addFailure(self, test, err):
+ TestResult.addFailure(self, test, err)
+ if self.showAll:
+ self.stream.writeln("FAIL")
+ elif self.dots:
+ self.stream.write('F')
+
+ def printErrors(self):
+ if self.dots or self.showAll:
+ self.stream.writeln()
+ self.printErrorList('ERROR', self.errors)
+ self.printErrorList('FAIL', self.failures)
+
+ def printErrorList(self, flavour, errors):
+ for test, err in errors:
+ self.stream.writeln(self.separator1)
+ self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
+ self.stream.writeln(self.separator2)
+ self.stream.writeln("%s" % err)
+
+
+class TextTestRunner:
+ """A test runner class that displays results in textual form.
+
+ It prints out the names of tests as they are run, errors as they
+ occur, and a summary of the results at the end of the test run.
+ """
+ def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
+ self.stream = _WritelnDecorator(stream)
+ self.descriptions = descriptions
+ self.verbosity = verbosity
+
+ def _makeResult(self):
+ return _TextTestResult(self.stream, self.descriptions, self.verbosity)
+
+ def run(self, test):
+ "Run the given test case or test suite."
+ result = self._makeResult()
+ startTime = time.time()
+ test(result)
+ stopTime = time.time()
+ timeTaken = float(stopTime - startTime)
+ result.printErrors()
+ self.stream.writeln(result.separator2)
+ run = result.testsRun
+ self.stream.writeln("Ran %d test%s in %.3fs" %
+ (run, run != 1 and "s" or "", timeTaken))
+ self.stream.writeln()
+ if not result.wasSuccessful():
+ self.stream.write("FAILED (")
+ failed, errored = map(len, (result.failures, result.errors))
+ if failed:
+ self.stream.write("failures=%d" % failed)
+ if errored:
+ if failed: self.stream.write(", ")
+ self.stream.write("errors=%d" % errored)
+ self.stream.writeln(")")
+ else:
+ self.stream.writeln("OK")
+ return result
+
+
+
+##############################################################################
+# Facilities for running tests from the command line
+##############################################################################
+
+class TestProgram:
+ """A command-line program that runs a set of tests; this is primarily
+ for making test modules conveniently executable.
+ """
+ USAGE = """\
+Usage: %(progName)s [options] [test] [...]
+
+Options:
+ -h, --help Show this message
+ -v, --verbose Verbose output
+ -q, --quiet Minimal output
+
+Examples:
+ %(progName)s - run default set of tests
+ %(progName)s MyTestSuite - run suite 'MyTestSuite'
+ %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
+ %(progName)s MyTestCase - run all 'test*' test methods
+ in MyTestCase
+"""
+ def __init__(self, module='__main__', defaultTest=None,
+ argv=None, testRunner=None, testLoader=defaultTestLoader):
+ if type(module) == type(''):
+ self.module = __import__(module)
+ for part in string.split(module,'.')[1:]:
+ self.module = getattr(self.module, part)
+ else:
+ self.module = module
+ if argv is None:
+ argv = sys.argv
+ self.verbosity = 1
+ self.defaultTest = defaultTest
+ self.testRunner = testRunner
+ self.testLoader = testLoader
+ self.progName = os.path.basename(argv[0])
+ self.parseArgs(argv)
+ self.runTests()
+
+ def usageExit(self, msg=None):
+ if msg: print msg
+ print self.USAGE % self.__dict__
+ sys.exit(2)
+
+ def parseArgs(self, argv):
+ import getopt
+ try:
+ options, args = getopt.getopt(argv[1:], 'hHvq',
+ ['help','verbose','quiet'])
+ for opt, value in options:
+ if opt in ('-h','-H','--help'):
+ self.usageExit()
+ if opt in ('-q','--quiet'):
+ self.verbosity = 0
+ if opt in ('-v','--verbose'):
+ self.verbosity = 2
+ if len(args) == 0 and self.defaultTest is None:
+ self.test = self.testLoader.loadTestsFromModule(self.module)
+ return
+ if len(args) > 0:
+ self.testNames = args
+ else:
+ self.testNames = (self.defaultTest,)
+ self.createTests()
+ except getopt.error, msg:
+ self.usageExit(msg)
+
+ def createTests(self):
+ self.test = self.testLoader.loadTestsFromNames(self.testNames,
+ self.module)
+
+ def runTests(self):
+ if self.testRunner is None:
+ self.testRunner = TextTestRunner(verbosity=self.verbosity)
+ result = self.testRunner.run(self.test)
+ sys.exit(not result.wasSuccessful())
+
+main = TestProgram
+
+
+##############################################################################
+# Executing this module from the command line
+##############################################################################
+
+if __name__ == "__main__":
+ main(module=None)
diff --git a/python/rpmdb/Makefile.am b/python/rpmdb/Makefile.am
index 11082eb2e..b162bd81f 100644
--- a/python/rpmdb/Makefile.am
+++ b/python/rpmdb/Makefile.am
@@ -4,11 +4,13 @@ AUTOMAKE_OPTIONS = 1.4 foreign
PYVER= @WITH_PYTHON_VERSION@
-pylibdir = $(shell python -c 'import sys; print sys.path[1]')
+SUBDIRS = test
EXTRA_DIST = \
__init__.py dbobj.py db.py dbrecio.py dbshelve.py dbtables.py dbutils.py
+pylibdir = $(shell python -c 'import sys; print sys.path[1]')
+
rpmdbdir = $(pylibdir)/site-packages/rpmdb
rpmdb_SCRIPTS = \
__init__.py dbobj.py db.py dbrecio.py dbshelve.py dbtables.py dbutils.py
diff --git a/python/rpmdb/test/.cvsignore b/python/rpmdb/test/.cvsignore
new file mode 100644
index 000000000..18b1bb4ed
--- /dev/null
+++ b/python/rpmdb/test/.cvsignore
@@ -0,0 +1,6 @@
+Makefile
+Makefile.in
+db_home
+*.la
+*.lo
+*.pyc
diff --git a/python/rpmdb/test/Makefile.am b/python/rpmdb/test/Makefile.am
new file mode 100644
index 000000000..667ad87c8
--- /dev/null
+++ b/python/rpmdb/test/Makefile.am
@@ -0,0 +1,15 @@
+# Makefile for rpm library.
+
+AUTOMAKE_OPTIONS = 1.4 foreign
+
+PYVER= @WITH_PYTHON_VERSION@
+
+pylibdir = $(shell python -c 'import sys; print sys.path[1]')
+
+EXTRA_DIST = \
+ test_all.py test_associate.py test_basics.py test_compat.py \
+ test_dbobj.py test_dbshelve.py test_dbtables.py test_get_none.py \
+ test_join.py test_lock.py test_misc.py test_queue.py test_recno.py \
+ test_thread.py unittest.py
+
+all:
diff --git a/python/rpmdb/test/test_all.py b/python/rpmdb/test/test_all.py
new file mode 100644
index 000000000..ee5300f29
--- /dev/null
+++ b/python/rpmdb/test/test_all.py
@@ -0,0 +1,76 @@
+"""Run all test cases.
+"""
+
+import sys
+import os
+import unittest
+
+verbose = 0
+if 'verbose' in sys.argv:
+ verbose = 1
+ sys.argv.remove('verbose')
+
+if 'silent' in sys.argv: # take care of old flag, just in case
+ verbose = 0
+ sys.argv.remove('silent')
+
+
+def print_versions():
+ from rpmdb import db
+ print
+ print '-=' * 38
+ print db.DB_VERSION_STRING
+ print 'rpmdb.db.version(): %s' % (db.version(), )
+ print 'rpmdb.db.__version__: %s' % db.__version__
+ print 'rpmdb.db.cvsid: %s' % db.cvsid
+ print 'python version: %s' % sys.version
+ print 'My pid: %s' % os.getpid()
+ print '-=' * 38
+
+
+class PrintInfoFakeTest(unittest.TestCase):
+ def testPrintVersions(self):
+ print_versions()
+
+
+# This little hack is for when this module is run as main and all the
+# other modules import it so they will still be able to get the right
+# verbose setting. It's confusing but it works.
+import test_all
+test_all.verbose = verbose
+
+
+def suite():
+ test_modules = [
+ 'test_associate',
+ 'test_basics',
+ 'test_compat',
+ 'test_dbobj',
+ 'test_dbshelve',
+ 'test_dbtables',
+ 'test_env_close',
+ 'test_get_none',
+ 'test_join',
+ 'test_lock',
+ 'test_misc',
+ 'test_queue',
+ 'test_recno',
+ 'test_thread',
+ ]
+
+ alltests = unittest.TestSuite()
+ for name in test_modules:
+ module = __import__(name)
+ alltests.addTest(module.test_suite())
+ return alltests
+
+
+def test_suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(PrintInfoFakeTest))
+ return suite
+
+
+if __name__ == '__main__':
+ print_versions()
+ unittest.main(defaultTest='suite')
diff --git a/python/rpmdb/test/test_associate.py b/python/rpmdb/test/test_associate.py
new file mode 100644
index 000000000..be6ef610d
--- /dev/null
+++ b/python/rpmdb/test/test_associate.py
@@ -0,0 +1,321 @@
+"""
+TestCases for multi-threaded access to a DB.
+"""
+
+import sys, os, string
+import tempfile
+import time
+from pprint import pprint
+
+try:
+ from threading import Thread, currentThread
+ have_threads = 1
+except ImportError:
+ have_threads = 0
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db, dbshelve
+
+
+#----------------------------------------------------------------------
+
+
+musicdata = {
+1 : ("Bad English", "The Price Of Love", "Rock"),
+2 : ("DNA featuring Suzanne Vega", "Tom's Diner", "Rock"),
+3 : ("George Michael", "Praying For Time", "Rock"),
+4 : ("Gloria Estefan", "Here We Are", "Rock"),
+5 : ("Linda Ronstadt", "Don't Know Much", "Rock"),
+6 : ("Michael Bolton", "How Am I Supposed To Live Without You", "Blues"),
+7 : ("Paul Young", "Oh Girl", "Rock"),
+8 : ("Paula Abdul", "Opposites Attract", "Rock"),
+9 : ("Richard Marx", "Should've Known Better", "Rock"),
+10: ("Rod Stewart", "Forever Young", "Rock"),
+11: ("Roxette", "Dangerous", "Rock"),
+12: ("Sheena Easton", "The Lover In Me", "Rock"),
+13: ("Sinead O'Connor", "Nothing Compares 2 U", "Rock"),
+14: ("Stevie B.", "Because I Love You", "Rock"),
+15: ("Taylor Dayne", "Love Will Lead You Back", "Rock"),
+16: ("The Bangles", "Eternal Flame", "Rock"),
+17: ("Wilson Phillips", "Release Me", "Rock"),
+18: ("Billy Joel", "Blonde Over Blue", "Rock"),
+19: ("Billy Joel", "Famous Last Words", "Rock"),
+20: ("Billy Joel", "Lullabye (Goodnight, My Angel)", "Rock"),
+21: ("Billy Joel", "The River Of Dreams", "Rock"),
+22: ("Billy Joel", "Two Thousand Years", "Rock"),
+23: ("Janet Jackson", "Alright", "Rock"),
+24: ("Janet Jackson", "Black Cat", "Rock"),
+25: ("Janet Jackson", "Come Back To Me", "Rock"),
+26: ("Janet Jackson", "Escapade", "Rock"),
+27: ("Janet Jackson", "Love Will Never Do (Without You)", "Rock"),
+28: ("Janet Jackson", "Miss You Much", "Rock"),
+29: ("Janet Jackson", "Rhythm Nation", "Rock"),
+30: ("Janet Jackson", "State Of The World", "Rock"),
+31: ("Janet Jackson", "The Knowledge", "Rock"),
+32: ("Spyro Gyra", "End of Romanticism", "Jazz"),
+33: ("Spyro Gyra", "Heliopolis", "Jazz"),
+34: ("Spyro Gyra", "Jubilee", "Jazz"),
+35: ("Spyro Gyra", "Little Linda", "Jazz"),
+36: ("Spyro Gyra", "Morning Dance", "Jazz"),
+37: ("Spyro Gyra", "Song for Lorraine", "Jazz"),
+38: ("Yes", "Owner Of A Lonely Heart", "Rock"),
+39: ("Yes", "Rhythm Of Love", "Rock"),
+40: ("Cusco", "Dream Catcher", "New Age"),
+41: ("Cusco", "Geronimos Laughter", "New Age"),
+42: ("Cusco", "Ghost Dance", "New Age"),
+43: ("Blue Man Group", "Drumbone", "New Age"),
+44: ("Blue Man Group", "Endless Column", "New Age"),
+45: ("Blue Man Group", "Klein Mandelbrot", "New Age"),
+46: ("Kenny G", "Silhouette", "Jazz"),
+47: ("Sade", "Smooth Operator", "Jazz"),
+48: ("David Arkenstone", "Papillon (On The Wings Of The Butterfly)",
+ "New Age"),
+49: ("David Arkenstone", "Stepping Stars", "New Age"),
+50: ("David Arkenstone", "Carnation Lily Lily Rose", "New Age"),
+51: ("David Lanz", "Behind The Waterfall", "New Age"),
+52: ("David Lanz", "Cristofori's Dream", "New Age"),
+53: ("David Lanz", "Heartsounds", "New Age"),
+54: ("David Lanz", "Leaves on the Seine", "New Age"),
+}
+
+#----------------------------------------------------------------------
+
+
+class AssociateTestCase(unittest.TestCase):
+ keytype = ''
+
+ def setUp(self):
+ self.filename = self.__class__.__name__ + '.db'
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try: os.mkdir(homeDir)
+ except os.error: pass
+ self.env = db.DBEnv()
+ self.env.open(homeDir, db.DB_CREATE | db.DB_INIT_MPOOL |
+ db.DB_INIT_LOCK | db.DB_THREAD)
+
+ def tearDown(self):
+ self.closeDB()
+ self.env.close()
+ import glob
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+ def addDataToDB(self, d):
+ for key, value in musicdata.items():
+ if type(self.keytype) == type(''):
+ key = "%02d" % key
+ d.put(key, string.join(value, '|'))
+
+ def createDB(self):
+ self.primary = db.DB(self.env)
+ self.primary.open(self.filename, "primary", self.dbtype,
+ db.DB_CREATE | db.DB_THREAD)
+
+ def closeDB(self):
+ self.primary.close()
+
+ def getDB(self):
+ return self.primary
+
+ def test01_associateWithDB(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_associateWithDB..." % \
+ self.__class__.__name__
+
+ self.createDB()
+
+ secDB = db.DB(self.env)
+ secDB.set_flags(db.DB_DUP)
+ secDB.open(self.filename, "secondary", db.DB_BTREE,
+ db.DB_CREATE | db.DB_THREAD)
+ self.getDB().associate(secDB, self.getGenre)
+
+ self.addDataToDB(self.getDB())
+
+ self.finish_test(secDB)
+
+
+ def test02_associateAfterDB(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_associateAfterDB..." % \
+ self.__class__.__name__
+
+ self.createDB()
+ self.addDataToDB(self.getDB())
+
+ secDB = db.DB(self.env)
+ secDB.set_flags(db.DB_DUP)
+ secDB.open(self.filename, "secondary", db.DB_BTREE,
+ db.DB_CREATE | db.DB_THREAD)
+
+ # adding the DB_CREATE flag will cause it to index existing records
+ self.getDB().associate(secDB, self.getGenre, db.DB_CREATE)
+
+ self.finish_test(secDB)
+
+
+ def finish_test(self, secDB):
+ if verbose:
+ print "Primary key traversal:"
+ c = self.getDB().cursor()
+ count = 0
+ rec = c.first()
+ while rec is not None:
+ if type(self.keytype) == type(''):
+ assert string.atoi(rec[0]) # for primary db, key is a number
+ else:
+ assert rec[0] and type(rec[0]) == type(0)
+ count = count + 1
+ if verbose:
+ print rec
+ rec = c.next()
+ assert count == len(musicdata) # all items accounted for
+
+
+ if verbose:
+ print "Secondary key traversal:"
+ c = secDB.cursor()
+ count = 0
+ rec = c.first()
+ assert rec[0] == "Jazz"
+ while rec is not None:
+ count = count + 1
+ if verbose:
+ print rec
+ rec = c.next()
+ # all items accounted for EXCEPT for 1 with "Blues" genre
+ assert count == len(musicdata)-1
+
+ def getGenre(self, priKey, priData):
+ assert type(priData) == type("")
+ if verbose:
+ print 'getGenre key:', `priKey`, 'data:', `priData`
+ genre = string.split(priData, '|')[2]
+ if genre == 'Blues':
+ return db.DB_DONOTINDEX
+ else:
+ return genre
+
+
+#----------------------------------------------------------------------
+
+
+class AssociateHashTestCase(AssociateTestCase):
+ dbtype = db.DB_HASH
+
+class AssociateBTreeTestCase(AssociateTestCase):
+ dbtype = db.DB_BTREE
+
+class AssociateRecnoTestCase(AssociateTestCase):
+ dbtype = db.DB_RECNO
+ keytype = 0
+
+
+#----------------------------------------------------------------------
+
+class ShelveAssociateTestCase(AssociateTestCase):
+
+ def createDB(self):
+ self.primary = dbshelve.open(self.filename,
+ dbname="primary",
+ dbenv=self.env,
+ filetype=self.dbtype)
+
+ def addDataToDB(self, d):
+ for key, value in musicdata.items():
+ if type(self.keytype) == type(''):
+ key = "%02d" % key
+ d.put(key, value) # save the value as is this time
+
+
+ def getGenre(self, priKey, priData):
+ assert type(priData) == type(())
+ if verbose:
+ print 'getGenre key:', `priKey`, 'data:', `priData`
+ genre = priData[2]
+ if genre == 'Blues':
+ return db.DB_DONOTINDEX
+ else:
+ return genre
+
+
+class ShelveAssociateHashTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_HASH
+
+class ShelveAssociateBTreeTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_BTREE
+
+class ShelveAssociateRecnoTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_RECNO
+ keytype = 0
+
+
+#----------------------------------------------------------------------
+
+class ThreadedAssociateTestCase(AssociateTestCase):
+
+ def addDataToDB(self, d):
+ t1 = Thread(target = self.writer1,
+ args = (d, ))
+ t2 = Thread(target = self.writer2,
+ args = (d, ))
+
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+
+ def writer1(self, d):
+ for key, value in musicdata.items():
+ if type(self.keytype) == type(''):
+ key = "%02d" % key
+ d.put(key, string.join(value, '|'))
+
+ def writer2(self, d):
+ for x in range(100, 600):
+ key = 'z%2d' % x
+ value = [key] * 4
+ d.put(key, string.join(value, '|'))
+
+
+class ThreadedAssociateHashTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_HASH
+
+class ThreadedAssociateBTreeTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_BTREE
+
+class ThreadedAssociateRecnoTestCase(ShelveAssociateTestCase):
+ dbtype = db.DB_RECNO
+ keytype = 0
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ if db.version() >= (3, 3, 11):
+ suite.addTest(unittest.makeSuite(AssociateHashTestCase))
+ suite.addTest(unittest.makeSuite(AssociateBTreeTestCase))
+ suite.addTest(unittest.makeSuite(AssociateRecnoTestCase))
+
+ suite.addTest(unittest.makeSuite(ShelveAssociateHashTestCase))
+ suite.addTest(unittest.makeSuite(ShelveAssociateBTreeTestCase))
+ suite.addTest(unittest.makeSuite(ShelveAssociateRecnoTestCase))
+
+ if have_threads:
+ suite.addTest(unittest.makeSuite(ThreadedAssociateHashTestCase))
+ suite.addTest(unittest.makeSuite(ThreadedAssociateBTreeTestCase))
+ suite.addTest(unittest.makeSuite(ThreadedAssociateRecnoTestCase))
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_basics.py b/python/rpmdb/test/test_basics.py
new file mode 100644
index 000000000..e35002b3f
--- /dev/null
+++ b/python/rpmdb/test/test_basics.py
@@ -0,0 +1,882 @@
+"""
+Basic TestCases for BTree and hash DBs, with and without a DBEnv, with
+various DB flags, etc.
+"""
+
+import os
+import sys
+import errno
+import shutil
+import string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+DASH = '-'
+
+
+#----------------------------------------------------------------------
+
+class VersionTestCase(unittest.TestCase):
+ def test00_version(self):
+ info = db.version()
+ if verbose:
+ print '\n', '-=' * 20
+ print 'rpmdb.db.version(): %s' % (info, )
+ print db.DB_VERSION_STRING
+ print '-=' * 20
+ assert info == (db.DB_VERSION_MAJOR, db.DB_VERSION_MINOR,
+ db.DB_VERSION_PATCH)
+
+#----------------------------------------------------------------------
+
+class BasicTestCase(unittest.TestCase):
+ dbtype = db.DB_UNKNOWN # must be set in derived class
+ dbopenflags = 0
+ dbsetflags = 0
+ dbmode = 0660
+ dbname = None
+ useEnv = 0
+ envflags = 0
+ envsetflags = 0
+
+ def setUp(self):
+ if self.useEnv:
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try:
+ shutil.rmtree(homeDir)
+ except OSError, e:
+ # unix returns ENOENT, windows returns ESRCH
+ if e.errno not in (errno.ENOENT, errno.ESRCH): raise
+ os.mkdir(homeDir)
+ try:
+ self.env = db.DBEnv()
+ self.env.set_lg_max(1024*1024)
+ self.env.set_flags(self.envsetflags, 1)
+ self.env.open(homeDir, self.envflags | db.DB_CREATE)
+ tempfile.tempdir = homeDir
+ self.filename = os.path.split(tempfile.mktemp())[1]
+ tempfile.tempdir = None
+ # Yes, a bare except is intended, since we're re-raising the exc.
+ except:
+ shutil.rmtree(homeDir)
+ raise
+ else:
+ self.env = None
+ self.filename = tempfile.mktemp()
+
+ # create and open the DB
+ self.d = db.DB(self.env)
+ self.d.set_flags(self.dbsetflags)
+ if self.dbname:
+ self.d.open(self.filename, self.dbname, self.dbtype,
+ self.dbopenflags|db.DB_CREATE, self.dbmode)
+ else:
+ self.d.open(self.filename, # try out keyword args
+ mode = self.dbmode,
+ dbtype = self.dbtype,
+ flags = self.dbopenflags|db.DB_CREATE)
+
+ self.populateDB()
+
+
+ def tearDown(self):
+ self.d.close()
+ if self.env is not None:
+ self.env.close()
+ shutil.rmtree(self.homeDir)
+ ## Make a new DBEnv to remove the env files from the home dir.
+ ## (It can't be done while the env is open, nor after it has been
+ ## closed, so we make a new one to do it.)
+ #e = db.DBEnv()
+ #e.remove(self.homeDir)
+ #os.remove(os.path.join(self.homeDir, self.filename))
+ else:
+ os.remove(self.filename)
+
+
+
+ def populateDB(self):
+ d = self.d
+ for x in range(500):
+ key = '%04d' % (1000 - x) # insert keys in reverse order
+ data = self.makeData(key)
+ d.put(key, data)
+
+ for x in range(500):
+ key = '%04d' % x # and now some in forward order
+ data = self.makeData(key)
+ d.put(key, data)
+
+ num = len(d)
+ if verbose:
+ print "created %d records" % num
+
+
+ def makeData(self, key):
+ return DASH.join([key] * 5)
+
+
+
+ #----------------------------------------
+
+ def test01_GetsAndPuts(self):
+ d = self.d
+
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_GetsAndPuts..." % self.__class__.__name__
+
+ for key in ['0001', '0100', '0400', '0700', '0999']:
+ data = d.get(key)
+ if verbose:
+ print data
+
+ assert d.get('0321') == '0321-0321-0321-0321-0321'
+
+ # By default non-existant keys return None...
+ assert d.get('abcd') == None
+
+ # ...but they raise exceptions in other situations. Call
+ # set_get_returns_none() to change it.
+ try:
+ d.delete('abcd')
+ except db.DBNotFoundError, val:
+ assert val[0] == db.DB_NOTFOUND
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+
+ d.put('abcd', 'a new record')
+ assert d.get('abcd') == 'a new record'
+
+ d.put('abcd', 'same key')
+ if self.dbsetflags & db.DB_DUP:
+ assert d.get('abcd') == 'a new record'
+ else:
+ assert d.get('abcd') == 'same key'
+
+
+ try:
+ d.put('abcd', 'this should fail', flags=db.DB_NOOVERWRITE)
+ except db.DBKeyExistError, val:
+ assert val[0] == db.DB_KEYEXIST
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ if self.dbsetflags & db.DB_DUP:
+ assert d.get('abcd') == 'a new record'
+ else:
+ assert d.get('abcd') == 'same key'
+
+
+ d.sync()
+ d.close()
+ del d
+
+ self.d = db.DB(self.env)
+ if self.dbname:
+ self.d.open(self.filename, self.dbname)
+ else:
+ self.d.open(self.filename)
+ d = self.d
+
+ assert d.get('0321') == '0321-0321-0321-0321-0321'
+ if self.dbsetflags & db.DB_DUP:
+ assert d.get('abcd') == 'a new record'
+ else:
+ assert d.get('abcd') == 'same key'
+
+ rec = d.get_both('0555', '0555-0555-0555-0555-0555')
+ if verbose:
+ print rec
+
+ assert d.get_both('0555', 'bad data') == None
+
+ # test default value
+ data = d.get('bad key', 'bad data')
+ assert data == 'bad data'
+
+ # any object can pass through
+ data = d.get('bad key', self)
+ assert data == self
+
+ s = d.stat()
+ assert type(s) == type({})
+ if verbose:
+ print 'd.stat() returned this dictionary:'
+ pprint(s)
+
+
+ #----------------------------------------
+
+ def test02_DictionaryMethods(self):
+ d = self.d
+
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_DictionaryMethods..." % \
+ self.__class__.__name__
+
+ for key in ['0002', '0101', '0401', '0701', '0998']:
+ data = d[key]
+ assert data == self.makeData(key)
+ if verbose:
+ print data
+
+ assert len(d) == 1000
+ keys = d.keys()
+ assert len(keys) == 1000
+ assert type(keys) == type([])
+
+ d['new record'] = 'a new record'
+ assert len(d) == 1001
+ keys = d.keys()
+ assert len(keys) == 1001
+
+ d['new record'] = 'a replacement record'
+ assert len(d) == 1001
+ keys = d.keys()
+ assert len(keys) == 1001
+
+ if verbose:
+ print "the first 10 keys are:"
+ pprint(keys[:10])
+
+ assert d['new record'] == 'a replacement record'
+
+ assert d.has_key('0001') == 1
+ assert d.has_key('spam') == 0
+
+ items = d.items()
+ assert len(items) == 1001
+ assert type(items) == type([])
+ assert type(items[0]) == type(())
+ assert len(items[0]) == 2
+
+ if verbose:
+ print "the first 10 items are:"
+ pprint(items[:10])
+
+ values = d.values()
+ assert len(values) == 1001
+ assert type(values) == type([])
+
+ if verbose:
+ print "the first 10 values are:"
+ pprint(values[:10])
+
+
+
+ #----------------------------------------
+
+ def test03_SimpleCursorStuff(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test03_SimpleCursorStuff..." % \
+ self.__class__.__name__
+
+ if self.env and self.dbopenflags & db.DB_AUTO_COMMIT:
+ txn = self.env.txn_begin()
+ else:
+ txn = None
+ c = self.d.cursor(txn=txn)
+
+ rec = c.first()
+ count = 0
+ while rec is not None:
+ count = count + 1
+ if verbose and count % 100 == 0:
+ print rec
+ rec = c.next()
+
+ assert count == 1000
+
+
+ rec = c.last()
+ count = 0
+ while rec is not None:
+ count = count + 1
+ if verbose and count % 100 == 0:
+ print rec
+ rec = c.prev()
+
+ assert count == 1000
+
+ rec = c.set('0505')
+ rec2 = c.current()
+ assert rec == rec2
+ assert rec[0] == '0505'
+ assert rec[1] == self.makeData('0505')
+
+ try:
+ c.set('bad key')
+ except db.DBNotFoundError, val:
+ assert val[0] == db.DB_NOTFOUND
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ rec = c.get_both('0404', self.makeData('0404'))
+ assert rec == ('0404', self.makeData('0404'))
+
+ try:
+ c.get_both('0404', 'bad data')
+ except db.DBNotFoundError, val:
+ assert val[0] == db.DB_NOTFOUND
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ if self.d.get_type() == db.DB_BTREE:
+ rec = c.set_range('011')
+ if verbose:
+ print "searched for '011', found: ", rec
+
+ rec = c.set_range('011',dlen=0,doff=0)
+ if verbose:
+ print "searched (partial) for '011', found: ", rec
+ if rec[1] != '': set.fail('expected empty data portion')
+
+ c.set('0499')
+ c.delete()
+ try:
+ rec = c.current()
+ except db.DBKeyEmptyError, val:
+ assert val[0] == db.DB_KEYEMPTY
+ if verbose: print val
+ else:
+ self.fail('exception expected')
+
+ c.next()
+ c2 = c.dup(db.DB_POSITION)
+ assert c.current() == c2.current()
+
+ c2.put('', 'a new value', db.DB_CURRENT)
+ assert c.current() == c2.current()
+ assert c.current()[1] == 'a new value'
+
+ c2.put('', 'er', db.DB_CURRENT, dlen=0, doff=5)
+ assert c2.current()[1] == 'a newer value'
+
+ c.close()
+ c2.close()
+ if txn:
+ txn.commit()
+
+ # time to abuse the closed cursors and hope we don't crash
+ methods_to_test = {
+ 'current': (),
+ 'delete': (),
+ 'dup': (db.DB_POSITION,),
+ 'first': (),
+ 'get': (0,),
+ 'next': (),
+ 'prev': (),
+ 'last': (),
+ 'put':('', 'spam', db.DB_CURRENT),
+ 'set': ("0505",),
+ }
+ for method, args in methods_to_test.items():
+ try:
+ if verbose:
+ print "attempting to use a closed cursor's %s method" % \
+ method
+ # a bug may cause a NULL pointer dereference...
+ apply(getattr(c, method), args)
+ except db.DBError, val:
+ assert val[0] == 0
+ if verbose: print val
+ else:
+ self.fail("no exception raised when using a buggy cursor's"
+ "%s method" % method)
+
+ #
+ # free cursor referencing a closed database, it should not barf:
+ #
+ oldcursor = self.d.cursor(txn=txn)
+ self.d.close()
+
+ # this would originally cause a segfault when the cursor for a
+ # closed database was cleaned up. it should not anymore.
+ # SF pybsddb bug id 667343
+ del oldcursor
+
+
+ #----------------------------------------
+
+ def test04_PartialGetAndPut(self):
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test04_PartialGetAndPut..." % \
+ self.__class__.__name__
+
+ key = "partialTest"
+ data = "1" * 1000 + "2" * 1000
+ d.put(key, data)
+ assert d.get(key) == data
+ assert d.get(key, dlen=20, doff=990) == ("1" * 10) + ("2" * 10)
+
+ d.put("partialtest2", ("1" * 30000) + "robin" )
+ assert d.get("partialtest2", dlen=5, doff=30000) == "robin"
+
+ # There seems to be a bug in DB here... Commented out the test for
+ # now.
+ ##assert d.get("partialtest2", dlen=5, doff=30010) == ""
+
+ if self.dbsetflags != db.DB_DUP:
+ # Partial put with duplicate records requires a cursor
+ d.put(key, "0000", dlen=2000, doff=0)
+ assert d.get(key) == "0000"
+
+ d.put(key, "1111", dlen=1, doff=2)
+ assert d.get(key) == "0011110"
+
+ #----------------------------------------
+
+ def test05_GetSize(self):
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test05_GetSize..." % self.__class__.__name__
+
+ for i in range(1, 50000, 500):
+ key = "size%s" % i
+ #print "before ", i,
+ d.put(key, "1" * i)
+ #print "after",
+ assert d.get_size(key) == i
+ #print "done"
+
+ #----------------------------------------
+
+ def test06_Truncate(self):
+ if db.version() < (3,3):
+ # truncate is a feature of BerkeleyDB 3.3 and above
+ return
+
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test99_Truncate..." % self.__class__.__name__
+
+ d.put("abcde", "ABCDE");
+ num = d.truncate()
+ assert num >= 1, "truncate returned <= 0 on non-empty database"
+ num = d.truncate()
+ assert num == 0, "truncate on empty DB returned nonzero (%s)" % `num`
+
+#----------------------------------------------------------------------
+
+
+class BasicBTreeTestCase(BasicTestCase):
+ dbtype = db.DB_BTREE
+
+
+class BasicHashTestCase(BasicTestCase):
+ dbtype = db.DB_HASH
+
+
+class BasicBTreeWithThreadFlagTestCase(BasicTestCase):
+ dbtype = db.DB_BTREE
+ dbopenflags = db.DB_THREAD
+
+
+class BasicHashWithThreadFlagTestCase(BasicTestCase):
+ dbtype = db.DB_HASH
+ dbopenflags = db.DB_THREAD
+
+
+class BasicBTreeWithEnvTestCase(BasicTestCase):
+ dbtype = db.DB_BTREE
+ dbopenflags = db.DB_THREAD
+ useEnv = 1
+ envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+class BasicHashWithEnvTestCase(BasicTestCase):
+ dbtype = db.DB_HASH
+ dbopenflags = db.DB_THREAD
+ useEnv = 1
+ envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+#----------------------------------------------------------------------
+
+class BasicTransactionTestCase(BasicTestCase):
+ dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT
+ useEnv = 1
+ envflags = (db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK |
+ db.DB_INIT_TXN)
+ envsetflags = db.DB_AUTO_COMMIT
+
+
+ def tearDown(self):
+ self.txn.commit()
+ BasicTestCase.tearDown(self)
+
+
+ def populateDB(self):
+ d = self.d
+ txn = self.env.txn_begin()
+ for x in range(500):
+ key = '%04d' % (1000 - x) # insert keys in reverse order
+ data = self.makeData(key)
+ d.put(key, data, txn)
+
+ for x in range(500):
+ key = '%04d' % x # and now some in forward order
+ data = self.makeData(key)
+ d.put(key, data, txn)
+
+ txn.commit()
+
+ num = len(d)
+ if verbose:
+ print "created %d records" % num
+
+ self.txn = self.env.txn_begin()
+
+
+
+ def test06_Transactions(self):
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test06_Transactions..." % self.__class__.__name__
+
+ assert d.get('new rec', txn=self.txn) == None
+ d.put('new rec', 'this is a new record', self.txn)
+ assert d.get('new rec', txn=self.txn) == 'this is a new record'
+ self.txn.abort()
+ assert d.get('new rec') == None
+
+ self.txn = self.env.txn_begin()
+
+ assert d.get('new rec', txn=self.txn) == None
+ d.put('new rec', 'this is a new record', self.txn)
+ assert d.get('new rec', txn=self.txn) == 'this is a new record'
+ self.txn.commit()
+ assert d.get('new rec') == 'this is a new record'
+
+ self.txn = self.env.txn_begin()
+ c = d.cursor(self.txn)
+ rec = c.first()
+ count = 0
+ while rec is not None:
+ count = count + 1
+ if verbose and count % 100 == 0:
+ print rec
+ rec = c.next()
+ assert count == 1001
+
+ c.close() # Cursors *MUST* be closed before commit!
+ self.txn.commit()
+
+ # flush pending updates
+ try:
+ self.env.txn_checkpoint (0, 0, 0)
+ except db.DBIncompleteError:
+ pass
+
+ # must have at least one log file present:
+ logs = self.env.log_archive(db.DB_ARCH_ABS | db.DB_ARCH_LOG)
+ assert logs != None
+ for log in logs:
+ if verbose:
+ print 'log file: ' + log
+
+ self.txn = self.env.txn_begin()
+
+ #----------------------------------------
+
+ def test07_TxnTruncate(self):
+ if db.version() < (3,3):
+ # truncate is a feature of BerkeleyDB 3.3 and above
+ return
+
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test07_TxnTruncate..." % self.__class__.__name__
+
+ d.put("abcde", "ABCDE");
+ txn = self.env.txn_begin()
+ num = d.truncate(txn)
+ assert num >= 1, "truncate returned <= 0 on non-empty database"
+ num = d.truncate(txn)
+ assert num == 0, "truncate on empty DB returned nonzero (%s)" % `num`
+ txn.commit()
+
+ #----------------------------------------
+
+ def test08_TxnLateUse(self):
+ txn = self.env.txn_begin()
+ txn.abort()
+ try:
+ txn.abort()
+ except db.DBError, e:
+ pass
+ else:
+ raise RuntimeError, "DBTxn.abort() called after DB_TXN no longer valid w/o an exception"
+
+ txn = self.env.txn_begin()
+ txn.commit()
+ try:
+ txn.commit()
+ except db.DBError, e:
+ pass
+ else:
+ raise RuntimeError, "DBTxn.commit() called after DB_TXN no longer valid w/o an exception"
+
+
+class BTreeTransactionTestCase(BasicTransactionTestCase):
+ dbtype = db.DB_BTREE
+
+class HashTransactionTestCase(BasicTransactionTestCase):
+ dbtype = db.DB_HASH
+
+
+
+#----------------------------------------------------------------------
+
+class BTreeRecnoTestCase(BasicTestCase):
+ dbtype = db.DB_BTREE
+ dbsetflags = db.DB_RECNUM
+
+ def test07_RecnoInBTree(self):
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test07_RecnoInBTree..." % self.__class__.__name__
+
+ rec = d.get(200)
+ assert type(rec) == type(())
+ assert len(rec) == 2
+ if verbose:
+ print "Record #200 is ", rec
+
+ c = d.cursor()
+ c.set('0200')
+ num = c.get_recno()
+ assert type(num) == type(1)
+ if verbose:
+ print "recno of d['0200'] is ", num
+
+ rec = c.current()
+ assert c.set_recno(num) == rec
+
+ c.close()
+
+
+
+class BTreeRecnoWithThreadFlagTestCase(BTreeRecnoTestCase):
+ dbopenflags = db.DB_THREAD
+
+#----------------------------------------------------------------------
+
+class BasicDUPTestCase(BasicTestCase):
+ dbsetflags = db.DB_DUP
+
+ def test08_DuplicateKeys(self):
+ d = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test08_DuplicateKeys..." % \
+ self.__class__.__name__
+
+ d.put("dup0", "before")
+ for x in "The quick brown fox jumped over the lazy dog.".split():
+ d.put("dup1", x)
+ d.put("dup2", "after")
+
+ data = d.get("dup1")
+ assert data == "The"
+ if verbose:
+ print data
+
+ c = d.cursor()
+ rec = c.set("dup1")
+ assert rec == ('dup1', 'The')
+
+ next = c.next()
+ assert next == ('dup1', 'quick')
+
+ rec = c.set("dup1")
+ count = c.count()
+ assert count == 9
+
+ next_dup = c.next_dup()
+ assert next_dup == ('dup1', 'quick')
+
+ rec = c.set('dup1')
+ while rec is not None:
+ if verbose:
+ print rec
+ rec = c.next_dup()
+
+ c.set('dup1')
+ rec = c.next_nodup()
+ assert rec[0] != 'dup1'
+ if verbose:
+ print rec
+
+ c.close()
+
+
+
+class BTreeDUPTestCase(BasicDUPTestCase):
+ dbtype = db.DB_BTREE
+
+class HashDUPTestCase(BasicDUPTestCase):
+ dbtype = db.DB_HASH
+
+class BTreeDUPWithThreadTestCase(BasicDUPTestCase):
+ dbtype = db.DB_BTREE
+ dbopenflags = db.DB_THREAD
+
+class HashDUPWithThreadTestCase(BasicDUPTestCase):
+ dbtype = db.DB_HASH
+ dbopenflags = db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+
+class BasicMultiDBTestCase(BasicTestCase):
+ dbname = 'first'
+
+ def otherType(self):
+ if self.dbtype == db.DB_BTREE:
+ return db.DB_HASH
+ else:
+ return db.DB_BTREE
+
+ def test09_MultiDB(self):
+ d1 = self.d
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test09_MultiDB..." % self.__class__.__name__
+
+ d2 = db.DB(self.env)
+ d2.open(self.filename, "second", self.dbtype,
+ self.dbopenflags|db.DB_CREATE)
+ d3 = db.DB(self.env)
+ d3.open(self.filename, "third", self.otherType(),
+ self.dbopenflags|db.DB_CREATE)
+
+ for x in "The quick brown fox jumped over the lazy dog".split():
+ d2.put(x, self.makeData(x))
+
+ for x in string.letters:
+ d3.put(x, x*70)
+
+ d1.sync()
+ d2.sync()
+ d3.sync()
+ d1.close()
+ d2.close()
+ d3.close()
+
+ self.d = d1 = d2 = d3 = None
+
+ self.d = d1 = db.DB(self.env)
+ d1.open(self.filename, self.dbname, flags = self.dbopenflags)
+ d2 = db.DB(self.env)
+ d2.open(self.filename, "second", flags = self.dbopenflags)
+ d3 = db.DB(self.env)
+ d3.open(self.filename, "third", flags = self.dbopenflags)
+
+ c1 = d1.cursor()
+ c2 = d2.cursor()
+ c3 = d3.cursor()
+
+ count = 0
+ rec = c1.first()
+ while rec is not None:
+ count = count + 1
+ if verbose and (count % 50) == 0:
+ print rec
+ rec = c1.next()
+ assert count == 1000
+
+ count = 0
+ rec = c2.first()
+ while rec is not None:
+ count = count + 1
+ if verbose:
+ print rec
+ rec = c2.next()
+ assert count == 9
+
+ count = 0
+ rec = c3.first()
+ while rec is not None:
+ count = count + 1
+ if verbose:
+ print rec
+ rec = c3.next()
+ assert count == 52
+
+
+ c1.close()
+ c2.close()
+ c3.close()
+
+ d2.close()
+ d3.close()
+
+
+
+# Strange things happen if you try to use Multiple DBs per file without a
+# DBEnv with MPOOL and LOCKing...
+
+class BTreeMultiDBTestCase(BasicMultiDBTestCase):
+ dbtype = db.DB_BTREE
+ dbopenflags = db.DB_THREAD
+ useEnv = 1
+ envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+class HashMultiDBTestCase(BasicMultiDBTestCase):
+ dbtype = db.DB_HASH
+ dbopenflags = db.DB_THREAD
+ useEnv = 1
+ envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+#----------------------------------------------------------------------
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ suite.addTest(unittest.makeSuite(VersionTestCase))
+ suite.addTest(unittest.makeSuite(BasicBTreeTestCase))
+ suite.addTest(unittest.makeSuite(BasicHashTestCase))
+ suite.addTest(unittest.makeSuite(BasicBTreeWithThreadFlagTestCase))
+ suite.addTest(unittest.makeSuite(BasicHashWithThreadFlagTestCase))
+ suite.addTest(unittest.makeSuite(BasicBTreeWithEnvTestCase))
+ suite.addTest(unittest.makeSuite(BasicHashWithEnvTestCase))
+ suite.addTest(unittest.makeSuite(BTreeTransactionTestCase))
+ suite.addTest(unittest.makeSuite(HashTransactionTestCase))
+ suite.addTest(unittest.makeSuite(BTreeRecnoTestCase))
+ suite.addTest(unittest.makeSuite(BTreeRecnoWithThreadFlagTestCase))
+ suite.addTest(unittest.makeSuite(BTreeDUPTestCase))
+ suite.addTest(unittest.makeSuite(HashDUPTestCase))
+ suite.addTest(unittest.makeSuite(BTreeDUPWithThreadTestCase))
+ suite.addTest(unittest.makeSuite(HashDUPWithThreadTestCase))
+ suite.addTest(unittest.makeSuite(BTreeMultiDBTestCase))
+ suite.addTest(unittest.makeSuite(HashMultiDBTestCase))
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_compat.py b/python/rpmdb/test/test_compat.py
new file mode 100644
index 000000000..041081468
--- /dev/null
+++ b/python/rpmdb/test/test_compat.py
@@ -0,0 +1,166 @@
+"""
+Test cases adapted from the test_bsddb.py module in Python's
+regression test suite.
+"""
+
+import sys, os, string
+import rpmdb
+import unittest
+import tempfile
+
+from test_all import verbose
+
+from rpmdb import db, hashopen, btopen, rnopen
+
+class CompatibilityTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+
+ def tearDown(self):
+ try:
+ os.remove(self.filename)
+ except os.error:
+ pass
+
+
+ def test01_btopen(self):
+ self.do_bthash_test(btopen, 'btopen')
+
+ def test02_hashopen(self):
+ self.do_bthash_test(hashopen, 'hashopen')
+
+ def test03_rnopen(self):
+ data = string.split("The quick brown fox jumped over the lazy dog.")
+ if verbose:
+ print "\nTesting: rnopen"
+
+ f = rnopen(self.filename, 'c')
+ for x in range(len(data)):
+ f[x+1] = data[x]
+
+ getTest = (f[1], f[2], f[3])
+ if verbose:
+ print '%s %s %s' % getTest
+
+ assert getTest[1] == 'quick', 'data mismatch!'
+
+ f[25] = 'twenty-five'
+ f.close()
+ del f
+
+ f = rnopen(self.filename, 'w')
+ f[20] = 'twenty'
+
+ def noRec(f):
+ rec = f[15]
+ self.assertRaises(KeyError, noRec, f)
+
+ def badKey(f):
+ rec = f['a string']
+ self.assertRaises(TypeError, badKey, f)
+
+ del f[3]
+
+ rec = f.first()
+ while rec:
+ if verbose:
+ print rec
+ try:
+ rec = f.next()
+ except KeyError:
+ break
+
+ f.close()
+
+
+ def test04_n_flag(self):
+ f = hashopen(self.filename, 'n')
+ f.close()
+
+
+
+ def do_bthash_test(self, factory, what):
+ if verbose:
+ print '\nTesting: ', what
+
+ f = factory(self.filename, 'c')
+ if verbose:
+ print 'creation...'
+
+ # truth test
+ if f:
+ if verbose: print "truth test: true"
+ else:
+ if verbose: print "truth test: false"
+
+ f['0'] = ''
+ f['a'] = 'Guido'
+ f['b'] = 'van'
+ f['c'] = 'Rossum'
+ f['d'] = 'invented'
+ f['f'] = 'Python'
+ if verbose:
+ print '%s %s %s' % (f['a'], f['b'], f['c'])
+
+ if verbose:
+ print 'key ordering...'
+ f.set_location(f.first()[0])
+ while 1:
+ try:
+ rec = f.next()
+ except KeyError:
+ assert rec == f.last(), 'Error, last <> last!'
+ f.previous()
+ break
+ if verbose:
+ print rec
+
+ assert f.has_key('f'), 'Error, missing key!'
+
+ f.sync()
+ f.close()
+ # truth test
+ try:
+ if f:
+ if verbose: print "truth test: true"
+ else:
+ if verbose: print "truth test: false"
+ except db.DBError:
+ pass
+ else:
+ self.fail("Exception expected")
+
+ del f
+
+ if verbose:
+ print 'modification...'
+ f = factory(self.filename, 'w')
+ f['d'] = 'discovered'
+
+ if verbose:
+ print 'access...'
+ for key in f.keys():
+ word = f[key]
+ if verbose:
+ print word
+
+ def noRec(f):
+ rec = f['no such key']
+ self.assertRaises(KeyError, noRec, f)
+
+ def badKey(f):
+ rec = f[15]
+ self.assertRaises(TypeError, badKey, f)
+
+ f.close()
+
+
+#----------------------------------------------------------------------
+
+
+def test_suite():
+ return unittest.makeSuite(CompatibilityTestCase)
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_dbobj.py b/python/rpmdb/test/test_dbobj.py
new file mode 100644
index 000000000..60b74d52e
--- /dev/null
+++ b/python/rpmdb/test/test_dbobj.py
@@ -0,0 +1,73 @@
+
+import sys, os, string
+import unittest
+import glob
+
+from rpmdb import db, dbobj
+
+
+#----------------------------------------------------------------------
+
+class dbobjTestCase(unittest.TestCase):
+ """Verify that dbobj.DB and dbobj.DBEnv work properly"""
+ db_home = 'db_home'
+ db_name = 'test-dbobj.db'
+
+ def setUp(self):
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try: os.mkdir(homeDir)
+ except os.error: pass
+
+ def tearDown(self):
+ if hasattr(self, 'db'):
+ del self.db
+ if hasattr(self, 'env'):
+ del self.env
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+ def test01_both(self):
+ class TestDBEnv(dbobj.DBEnv): pass
+ class TestDB(dbobj.DB):
+ def put(self, key, *args, **kwargs):
+ key = string.upper(key)
+ # call our parent classes put method with an upper case key
+ return apply(dbobj.DB.put, (self, key) + args, kwargs)
+ self.env = TestDBEnv()
+ self.env.open(self.homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)
+ self.db = TestDB(self.env)
+ self.db.open(self.db_name, db.DB_HASH, db.DB_CREATE)
+ self.db.put('spam', 'eggs')
+ assert self.db.get('spam') == None, \
+ "overridden dbobj.DB.put() method failed [1]"
+ assert self.db.get('SPAM') == 'eggs', \
+ "overridden dbobj.DB.put() method failed [2]"
+ self.db.close()
+ self.env.close()
+
+ def test02_dbobj_dict_interface(self):
+ self.env = dbobj.DBEnv()
+ self.env.open(self.homeDir, db.DB_CREATE | db.DB_INIT_MPOOL)
+ self.db = dbobj.DB(self.env)
+ self.db.open(self.db_name+'02', db.DB_HASH, db.DB_CREATE)
+ # __setitem__
+ self.db['spam'] = 'eggs'
+ # __len__
+ assert len(self.db) == 1
+ # __getitem__
+ assert self.db['spam'] == 'eggs'
+ # __del__
+ del self.db['spam']
+ assert self.db.get('spam') == None, "dbobj __del__ failed"
+ self.db.close()
+ self.env.close()
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ return unittest.makeSuite(dbobjTestCase)
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_dbshelve.py b/python/rpmdb/test/test_dbshelve.py
new file mode 100644
index 000000000..97f9f436f
--- /dev/null
+++ b/python/rpmdb/test/test_dbshelve.py
@@ -0,0 +1,301 @@
+"""
+TestCases for checking dbShelve objects.
+"""
+
+import sys, os, string
+import tempfile, random
+from pprint import pprint
+from types import *
+import unittest
+
+from rpmdb import db, dbshelve
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+# We want the objects to be comparable so we can test dbshelve.values
+# later on.
+class DataClass:
+ def __init__(self):
+ self.value = random.random()
+
+ def __cmp__(self, other):
+ return cmp(self.value, other)
+
+class DBShelveTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+ self.do_open()
+
+ def tearDown(self):
+ self.do_close()
+ try:
+ os.remove(self.filename)
+ except os.error:
+ pass
+
+ def populateDB(self, d):
+ for x in string.letters:
+ d['S' + x] = 10 * x # add a string
+ d['I' + x] = ord(x) # add an integer
+ d['L' + x] = [x] * 10 # add a list
+
+ inst = DataClass() # add an instance
+ inst.S = 10 * x
+ inst.I = ord(x)
+ inst.L = [x] * 10
+ d['O' + x] = inst
+
+
+ # overridable in derived classes to affect how the shelf is created/opened
+ def do_open(self):
+ self.d = dbshelve.open(self.filename)
+
+ # and closed...
+ def do_close(self):
+ self.d.close()
+
+
+
+ def test01_basics(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_basics..." % self.__class__.__name__
+
+ self.populateDB(self.d)
+ self.d.sync()
+ self.do_close()
+ self.do_open()
+ d = self.d
+
+ l = len(d)
+ k = d.keys()
+ s = d.stat()
+ f = d.fd()
+
+ if verbose:
+ print "length:", l
+ print "keys:", k
+ print "stats:", s
+
+ assert 0 == d.has_key('bad key')
+ assert 1 == d.has_key('IA')
+ assert 1 == d.has_key('OA')
+
+ d.delete('IA')
+ del d['OA']
+ assert 0 == d.has_key('IA')
+ assert 0 == d.has_key('OA')
+ assert len(d) == l-2
+
+ values = []
+ for key in d.keys():
+ value = d[key]
+ values.append(value)
+ if verbose:
+ print "%s: %s" % (key, value)
+ self.checkrec(key, value)
+
+ dbvalues = d.values()
+ assert len(dbvalues) == len(d.keys())
+ values.sort()
+ dbvalues.sort()
+ assert values == dbvalues
+
+ items = d.items()
+ assert len(items) == len(values)
+
+ for key, value in items:
+ self.checkrec(key, value)
+
+ assert d.get('bad key') == None
+ assert d.get('bad key', None) == None
+ assert d.get('bad key', 'a string') == 'a string'
+ assert d.get('bad key', [1, 2, 3]) == [1, 2, 3]
+
+ d.set_get_returns_none(0)
+ self.assertRaises(db.DBNotFoundError, d.get, 'bad key')
+ d.set_get_returns_none(1)
+
+ d.put('new key', 'new data')
+ assert d.get('new key') == 'new data'
+ assert d['new key'] == 'new data'
+
+
+
+ def test02_cursors(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_cursors..." % self.__class__.__name__
+
+ self.populateDB(self.d)
+ d = self.d
+
+ count = 0
+ c = d.cursor()
+ rec = c.first()
+ while rec is not None:
+ count = count + 1
+ if verbose:
+ print rec
+ key, value = rec
+ self.checkrec(key, value)
+ rec = c.next()
+ del c
+
+ assert count == len(d)
+
+ count = 0
+ c = d.cursor()
+ rec = c.last()
+ while rec is not None:
+ count = count + 1
+ if verbose:
+ print rec
+ key, value = rec
+ self.checkrec(key, value)
+ rec = c.prev()
+
+ assert count == len(d)
+
+ c.set('SS')
+ key, value = c.current()
+ self.checkrec(key, value)
+ del c
+
+
+
+ def checkrec(self, key, value):
+ x = key[1]
+ if key[0] == 'S':
+ assert type(value) == StringType
+ assert value == 10 * x
+
+ elif key[0] == 'I':
+ assert type(value) == IntType
+ assert value == ord(x)
+
+ elif key[0] == 'L':
+ assert type(value) == ListType
+ assert value == [x] * 10
+
+ elif key[0] == 'O':
+ assert type(value) == InstanceType
+ assert value.S == 10 * x
+ assert value.I == ord(x)
+ assert value.L == [x] * 10
+
+ else:
+ raise AssertionError, 'Unknown key type, fix the test'
+
+#----------------------------------------------------------------------
+
+class BasicShelveTestCase(DBShelveTestCase):
+ def do_open(self):
+ self.d = dbshelve.DBShelf()
+ self.d.open(self.filename, self.dbtype, self.dbflags)
+
+ def do_close(self):
+ self.d.close()
+
+
+class BTreeShelveTestCase(BasicShelveTestCase):
+ dbtype = db.DB_BTREE
+ dbflags = db.DB_CREATE
+
+
+class HashShelveTestCase(BasicShelveTestCase):
+ dbtype = db.DB_HASH
+ dbflags = db.DB_CREATE
+
+
+class ThreadBTreeShelveTestCase(BasicShelveTestCase):
+ dbtype = db.DB_BTREE
+ dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+class ThreadHashShelveTestCase(BasicShelveTestCase):
+ dbtype = db.DB_HASH
+ dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+
+class BasicEnvShelveTestCase(DBShelveTestCase):
+ def do_open(self):
+ self.homeDir = homeDir = os.path.join(
+ os.path.dirname(sys.argv[0]), 'db_home')
+ try: os.mkdir(homeDir)
+ except os.error: pass
+ self.env = db.DBEnv()
+ self.env.open(homeDir, self.envflags | db.DB_INIT_MPOOL | db.DB_CREATE)
+
+ self.filename = os.path.split(self.filename)[1]
+ self.d = dbshelve.DBShelf(self.env)
+ self.d.open(self.filename, self.dbtype, self.dbflags)
+
+
+ def do_close(self):
+ self.d.close()
+ self.env.close()
+
+
+ def tearDown(self):
+ self.do_close()
+ import glob
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+
+
+class EnvBTreeShelveTestCase(BasicEnvShelveTestCase):
+ envflags = 0
+ dbtype = db.DB_BTREE
+ dbflags = db.DB_CREATE
+
+
+class EnvHashShelveTestCase(BasicEnvShelveTestCase):
+ envflags = 0
+ dbtype = db.DB_HASH
+ dbflags = db.DB_CREATE
+
+
+class EnvThreadBTreeShelveTestCase(BasicEnvShelveTestCase):
+ envflags = db.DB_THREAD
+ dbtype = db.DB_BTREE
+ dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+class EnvThreadHashShelveTestCase(BasicEnvShelveTestCase):
+ envflags = db.DB_THREAD
+ dbtype = db.DB_HASH
+ dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+# TODO: Add test cases for a DBShelf in a RECNO DB.
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ suite.addTest(unittest.makeSuite(DBShelveTestCase))
+ suite.addTest(unittest.makeSuite(BTreeShelveTestCase))
+ suite.addTest(unittest.makeSuite(HashShelveTestCase))
+ suite.addTest(unittest.makeSuite(ThreadBTreeShelveTestCase))
+ suite.addTest(unittest.makeSuite(ThreadHashShelveTestCase))
+ suite.addTest(unittest.makeSuite(EnvBTreeShelveTestCase))
+ suite.addTest(unittest.makeSuite(EnvHashShelveTestCase))
+ suite.addTest(unittest.makeSuite(EnvThreadBTreeShelveTestCase))
+ suite.addTest(unittest.makeSuite(EnvThreadHashShelveTestCase))
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_dbtables.py b/python/rpmdb/test/test_dbtables.py
new file mode 100644
index 000000000..b90906515
--- /dev/null
+++ b/python/rpmdb/test/test_dbtables.py
@@ -0,0 +1,368 @@
+#!/usr/bin/env python
+#
+#-----------------------------------------------------------------------
+# A test suite for the table interface built on rpmdb.db
+#-----------------------------------------------------------------------
+#
+# Copyright (C) 2000, 2001 by Autonomous Zone Industries
+# Copyright (C) 2002 Gregory P. Smith
+#
+# March 20, 2000
+#
+# License: This is free software. You may use this software for any
+# purpose including modification/redistribution, so long as
+# this header remains intact and that you do not claim any
+# rights of ownership or authorship of this software. This
+# software has been tested, but no warranty is expressed or
+# implied.
+#
+# -- Gregory P. Smith <greg@electricrain.com>
+#
+# $Id: test_dbtables.py,v 1.1 2003/05/05 21:42:55 jbj Exp $
+
+import sys, os, re
+try:
+ import cPickle
+ pickle = cPickle
+except ImportError:
+ import pickle
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db, dbtables
+
+
+
+#----------------------------------------------------------------------
+
+class TableDBTestCase(unittest.TestCase):
+ db_home = 'db_home'
+ db_name = 'test-table.db'
+
+ def setUp(self):
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try: os.mkdir(homeDir)
+ except os.error: pass
+ self.tdb = dbtables.bsdTableDB(
+ filename='tabletest.db', dbhome=homeDir, create=1)
+
+ def tearDown(self):
+ self.tdb.close()
+ import glob
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+ def test01(self):
+ tabname = "test01"
+ colname = 'cool numbers'
+ try:
+ self.tdb.Drop(tabname)
+ except dbtables.TableDBError:
+ pass
+ self.tdb.CreateTable(tabname, [colname])
+ self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159, 1)})
+
+ if verbose:
+ self.tdb._db_print()
+
+ values = self.tdb.Select(
+ tabname, [colname], conditions={colname: None})
+
+ colval = pickle.loads(values[0][colname])
+ assert(colval > 3.141 and colval < 3.142)
+
+
+ def test02(self):
+ tabname = "test02"
+ col0 = 'coolness factor'
+ col1 = 'but can it fly?'
+ col2 = 'Species'
+ testinfo = [
+ {col0: pickle.dumps(8, 1), col1: 'no', col2: 'Penguin'},
+ {col0: pickle.dumps(-1, 1), col1: 'no', col2: 'Turkey'},
+ {col0: pickle.dumps(9, 1), col1: 'yes', col2: 'SR-71A Blackbird'}
+ ]
+
+ try:
+ self.tdb.Drop(tabname)
+ except dbtables.TableDBError:
+ pass
+ self.tdb.CreateTable(tabname, [col0, col1, col2])
+ for row in testinfo :
+ self.tdb.Insert(tabname, row)
+
+ values = self.tdb.Select(tabname, [col2],
+ conditions={col0: lambda x: pickle.loads(x) >= 8})
+
+ assert len(values) == 2
+ if values[0]['Species'] == 'Penguin' :
+ assert values[1]['Species'] == 'SR-71A Blackbird'
+ elif values[0]['Species'] == 'SR-71A Blackbird' :
+ assert values[1]['Species'] == 'Penguin'
+ else :
+ if verbose:
+ print "values=", `values`
+ raise "Wrong values returned!"
+
+ def test03(self):
+ tabname = "test03"
+ try:
+ self.tdb.Drop(tabname)
+ except dbtables.TableDBError:
+ pass
+ if verbose:
+ print '...before CreateTable...'
+ self.tdb._db_print()
+ self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e'])
+ if verbose:
+ print '...after CreateTable...'
+ self.tdb._db_print()
+ self.tdb.Drop(tabname)
+ if verbose:
+ print '...after Drop...'
+ self.tdb._db_print()
+ self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e'])
+
+ try:
+ self.tdb.Insert(tabname,
+ {'a': "",
+ 'e': pickle.dumps([{4:5, 6:7}, 'foo'], 1),
+ 'f': "Zero"})
+ assert 0
+ except dbtables.TableDBError:
+ pass
+
+ try:
+ self.tdb.Select(tabname, [], conditions={'foo': '123'})
+ assert 0
+ except dbtables.TableDBError:
+ pass
+
+ self.tdb.Insert(tabname,
+ {'a': '42',
+ 'b': "bad",
+ 'c': "meep",
+ 'e': 'Fuzzy wuzzy was a bear'})
+ self.tdb.Insert(tabname,
+ {'a': '581750',
+ 'b': "good",
+ 'd': "bla",
+ 'c': "black",
+ 'e': 'fuzzy was here'})
+ self.tdb.Insert(tabname,
+ {'a': '800000',
+ 'b': "good",
+ 'd': "bla",
+ 'c': "black",
+ 'e': 'Fuzzy wuzzy is a bear'})
+
+ if verbose:
+ self.tdb._db_print()
+
+ # this should return two rows
+ values = self.tdb.Select(tabname, ['b', 'a', 'd'],
+ conditions={'e': re.compile('wuzzy').search,
+ 'a': re.compile('^[0-9]+$').match})
+ assert len(values) == 2
+
+ # now lets delete one of them and try again
+ self.tdb.Delete(tabname, conditions={'b': dbtables.ExactCond('good')})
+ values = self.tdb.Select(
+ tabname, ['a', 'd', 'b'],
+ conditions={'e': dbtables.PrefixCond('Fuzzy')})
+ assert len(values) == 1
+ assert values[0]['d'] == None
+
+ values = self.tdb.Select(tabname, ['b'],
+ conditions={'c': lambda c: c == 'meep'})
+ assert len(values) == 1
+ assert values[0]['b'] == "bad"
+
+
+ def test04_MultiCondSelect(self):
+ tabname = "test04_MultiCondSelect"
+ try:
+ self.tdb.Drop(tabname)
+ except dbtables.TableDBError:
+ pass
+ self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e'])
+
+ try:
+ self.tdb.Insert(tabname,
+ {'a': "",
+ 'e': pickle.dumps([{4:5, 6:7}, 'foo'], 1),
+ 'f': "Zero"})
+ assert 0
+ except dbtables.TableDBError:
+ pass
+
+ self.tdb.Insert(tabname, {'a': "A", 'b': "B", 'c': "C", 'd': "D",
+ 'e': "E"})
+ self.tdb.Insert(tabname, {'a': "-A", 'b': "-B", 'c': "-C", 'd': "-D",
+ 'e': "-E"})
+ self.tdb.Insert(tabname, {'a': "A-", 'b': "B-", 'c': "C-", 'd': "D-",
+ 'e': "E-"})
+
+ if verbose:
+ self.tdb._db_print()
+
+ # This select should return 0 rows. it is designed to test
+ # the bug identified and fixed in sourceforge bug # 590449
+ # (Big Thanks to "Rob Tillotson (n9mtb)" for tracking this down
+ # and supplying a fix!! This one caused many headaches to say
+ # the least...)
+ values = self.tdb.Select(tabname, ['b', 'a', 'd'],
+ conditions={'e': dbtables.ExactCond('E'),
+ 'a': dbtables.ExactCond('A'),
+ 'd': dbtables.PrefixCond('-')
+ } )
+ assert len(values) == 0, values
+
+
+ def test_CreateOrExtend(self):
+ tabname = "test_CreateOrExtend"
+
+ self.tdb.CreateOrExtendTable(
+ tabname, ['name', 'taste', 'filling', 'alcohol content', 'price'])
+ try:
+ self.tdb.Insert(tabname,
+ {'taste': 'crap',
+ 'filling': 'no',
+ 'is it Guinness?': 'no'})
+ assert 0, "Insert should've failed due to bad column name"
+ except:
+ pass
+ self.tdb.CreateOrExtendTable(tabname,
+ ['name', 'taste', 'is it Guinness?'])
+
+ # these should both succeed as the table should contain the union of both sets of columns.
+ self.tdb.Insert(tabname, {'taste': 'crap', 'filling': 'no',
+ 'is it Guinness?': 'no'})
+ self.tdb.Insert(tabname, {'taste': 'great', 'filling': 'yes',
+ 'is it Guinness?': 'yes',
+ 'name': 'Guinness'})
+
+
+ def test_CondObjs(self):
+ tabname = "test_CondObjs"
+
+ self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e', 'p'])
+
+ self.tdb.Insert(tabname, {'a': "the letter A",
+ 'b': "the letter B",
+ 'c': "is for cookie"})
+ self.tdb.Insert(tabname, {'a': "is for aardvark",
+ 'e': "the letter E",
+ 'c': "is for cookie",
+ 'd': "is for dog"})
+ self.tdb.Insert(tabname, {'a': "the letter A",
+ 'e': "the letter E",
+ 'c': "is for cookie",
+ 'p': "is for Python"})
+
+ values = self.tdb.Select(
+ tabname, ['p', 'e'],
+ conditions={'e': dbtables.PrefixCond('the l')})
+ assert len(values) == 2, values
+ assert values[0]['e'] == values[1]['e'], values
+ assert values[0]['p'] != values[1]['p'], values
+
+ values = self.tdb.Select(
+ tabname, ['d', 'a'],
+ conditions={'a': dbtables.LikeCond('%aardvark%')})
+ assert len(values) == 1, values
+ assert values[0]['d'] == "is for dog", values
+ assert values[0]['a'] == "is for aardvark", values
+
+ values = self.tdb.Select(tabname, None,
+ {'b': dbtables.Cond(),
+ 'e':dbtables.LikeCond('%letter%'),
+ 'a':dbtables.PrefixCond('is'),
+ 'd':dbtables.ExactCond('is for dog'),
+ 'c':dbtables.PrefixCond('is for'),
+ 'p':lambda s: not s})
+ assert len(values) == 1, values
+ assert values[0]['d'] == "is for dog", values
+ assert values[0]['a'] == "is for aardvark", values
+
+ def test_Delete(self):
+ tabname = "test_Delete"
+ self.tdb.CreateTable(tabname, ['x', 'y', 'z'])
+
+ # prior to 2001-05-09 there was a bug where Delete() would
+ # fail if it encountered any rows that did not have values in
+ # every column.
+ # Hunted and Squashed by <Donwulff> (Jukka Santala - donwulff@nic.fi)
+ self.tdb.Insert(tabname, {'x': 'X1', 'y':'Y1'})
+ self.tdb.Insert(tabname, {'x': 'X2', 'y':'Y2', 'z': 'Z2'})
+
+ self.tdb.Delete(tabname, conditions={'x': dbtables.PrefixCond('X')})
+ values = self.tdb.Select(tabname, ['y'],
+ conditions={'x': dbtables.PrefixCond('X')})
+ assert len(values) == 0
+
+ def test_Modify(self):
+ tabname = "test_Modify"
+ self.tdb.CreateTable(tabname, ['Name', 'Type', 'Access'])
+
+ self.tdb.Insert(tabname, {'Name': 'Index to MP3 files.doc',
+ 'Type': 'Word', 'Access': '8'})
+ self.tdb.Insert(tabname, {'Name': 'Nifty.MP3', 'Access': '1'})
+ self.tdb.Insert(tabname, {'Type': 'Unknown', 'Access': '0'})
+
+ def set_type(type):
+ if type == None:
+ return 'MP3'
+ return type
+
+ def increment_access(count):
+ return str(int(count)+1)
+
+ def remove_value(value):
+ return None
+
+ self.tdb.Modify(tabname,
+ conditions={'Access': dbtables.ExactCond('0')},
+ mappings={'Access': remove_value})
+ self.tdb.Modify(tabname,
+ conditions={'Name': dbtables.LikeCond('%MP3%')},
+ mappings={'Type': set_type})
+ self.tdb.Modify(tabname,
+ conditions={'Name': dbtables.LikeCond('%')},
+ mappings={'Access': increment_access})
+
+ # Delete key in select conditions
+ values = self.tdb.Select(
+ tabname, None,
+ conditions={'Type': dbtables.ExactCond('Unknown')})
+ assert len(values) == 1, values
+ assert values[0]['Name'] == None, values
+ assert values[0]['Access'] == None, values
+
+ # Modify value by select conditions
+ values = self.tdb.Select(
+ tabname, None,
+ conditions={'Name': dbtables.ExactCond('Nifty.MP3')})
+ assert len(values) == 1, values
+ assert values[0]['Type'] == "MP3", values
+ assert values[0]['Access'] == "2", values
+
+ # Make sure change applied only to select conditions
+ values = self.tdb.Select(
+ tabname, None, conditions={'Name': dbtables.LikeCond('%doc%')})
+ assert len(values) == 1, values
+ assert values[0]['Type'] == "Word", values
+ assert values[0]['Access'] == "9", values
+
+
+def test_suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(TableDBTestCase))
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_env_close.py b/python/rpmdb/test/test_env_close.py
new file mode 100644
index 000000000..347e12131
--- /dev/null
+++ b/python/rpmdb/test/test_env_close.py
@@ -0,0 +1,102 @@
+"""TestCases for checking that it does not segfault when a DBEnv object
+is closed before its DB objects.
+"""
+
+import os
+import sys
+import tempfile
+import glob
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+# We're going to get warnings in this module about trying to close the db when
+# its env is already closed. Let's just ignore those.
+try:
+ import warnings
+except ImportError:
+ pass
+else:
+ warnings.filterwarnings('ignore',
+ message='DB could not be closed in',
+ category=RuntimeWarning)
+
+
+#----------------------------------------------------------------------
+
+class DBEnvClosedEarlyCrash(unittest.TestCase):
+ def setUp(self):
+ self.homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ try: os.mkdir(self.homeDir)
+ except os.error: pass
+ tempfile.tempdir = self.homeDir
+ self.filename = os.path.split(tempfile.mktemp())[1]
+ tempfile.tempdir = None
+
+ def tearDown(self):
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+
+ def test01_close_dbenv_before_db(self):
+ dbenv = db.DBEnv()
+ dbenv.open(self.homeDir,
+ db.DB_INIT_CDB| db.DB_CREATE |db.DB_THREAD|db.DB_INIT_MPOOL,
+ 0666)
+
+ d = db.DB(dbenv)
+ d.open(self.filename, db.DB_BTREE, db.DB_CREATE | db.DB_THREAD, 0666)
+
+ try:
+ dbenv.close()
+ except db.DBError:
+ try:
+ d.close()
+ except db.DBError:
+ return
+ assert 0, \
+ "DB close did not raise an exception about its "\
+ "DBEnv being trashed"
+
+ # XXX This may fail when using older versions of BerkeleyDB.
+ # E.g. 3.2.9 never raised the exception.
+ assert 0, "dbenv did not raise an exception about its DB being open"
+
+
+ def test02_close_dbenv_delete_db_success(self):
+ dbenv = db.DBEnv()
+ dbenv.open(self.homeDir,
+ db.DB_INIT_CDB| db.DB_CREATE |db.DB_THREAD|db.DB_INIT_MPOOL,
+ 0666)
+
+ d = db.DB(dbenv)
+ d.open(self.filename, db.DB_BTREE, db.DB_CREATE | db.DB_THREAD, 0666)
+
+ try:
+ dbenv.close()
+ except db.DBError:
+ pass # good, it should raise an exception
+
+ del d
+ try:
+ import gc
+ except ImportError:
+ gc = None
+ if gc:
+ # force d.__del__ [DB_dealloc] to be called
+ gc.collect()
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(DBEnvClosedEarlyCrash))
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_get_none.py b/python/rpmdb/test/test_get_none.py
new file mode 100644
index 000000000..e684e5623
--- /dev/null
+++ b/python/rpmdb/test/test_get_none.py
@@ -0,0 +1,96 @@
+"""
+TestCases for checking set_get_returns_none.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+class GetReturnsNoneTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+
+ def tearDown(self):
+ try:
+ os.remove(self.filename)
+ except os.error:
+ pass
+
+
+ def test01_get_returns_none(self):
+ d = db.DB()
+ d.open(self.filename, db.DB_BTREE, db.DB_CREATE)
+ d.set_get_returns_none(1)
+
+ for x in string.letters:
+ d.put(x, x * 40)
+
+ data = d.get('bad key')
+ assert data == None
+
+ data = d.get('a')
+ assert data == 'a'*40
+
+ count = 0
+ c = d.cursor()
+ rec = c.first()
+ while rec:
+ count = count + 1
+ rec = c.next()
+
+ assert rec == None
+ assert count == 52
+
+ c.close()
+ d.close()
+
+
+ def test02_get_raises_exception(self):
+ d = db.DB()
+ d.open(self.filename, db.DB_BTREE, db.DB_CREATE)
+ d.set_get_returns_none(0)
+
+ for x in string.letters:
+ d.put(x, x * 40)
+
+ self.assertRaises(db.DBNotFoundError, d.get, 'bad key')
+ self.assertRaises(KeyError, d.get, 'bad key')
+
+ data = d.get('a')
+ assert data == 'a'*40
+
+ count = 0
+ exceptionHappened = 0
+ c = d.cursor()
+ rec = c.first()
+ while rec:
+ count = count + 1
+ try:
+ rec = c.next()
+ except db.DBNotFoundError: # end of the records
+ exceptionHappened = 1
+ break
+
+ assert rec != None
+ assert exceptionHappened
+ assert count == 52
+
+ c.close()
+ d.close()
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ return unittest.makeSuite(GetReturnsNoneTestCase)
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_join.py b/python/rpmdb/test/test_join.py
new file mode 100644
index 000000000..ab75ba196
--- /dev/null
+++ b/python/rpmdb/test/test_join.py
@@ -0,0 +1,9 @@
+"""TestCases for using the DB.join and DBCursor.join_item methods.
+"""
+
+import unittest
+
+
+def test_suite():
+ suite = unittest.TestSuite()
+ return suite
diff --git a/python/rpmdb/test/test_lock.py b/python/rpmdb/test/test_lock.py
new file mode 100644
index 000000000..22bb629e9
--- /dev/null
+++ b/python/rpmdb/test/test_lock.py
@@ -0,0 +1,139 @@
+"""
+TestCases for testing the locking sub-system.
+"""
+
+import sys, os, string
+import tempfile
+import time
+from pprint import pprint
+from whrandom import random
+
+try:
+ from threading import Thread, currentThread
+ have_threads = 1
+except ImportError:
+ have_threads = 0
+
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db
+
+
+#----------------------------------------------------------------------
+
+class LockingTestCase(unittest.TestCase):
+
+ def setUp(self):
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try: os.mkdir(homeDir)
+ except os.error: pass
+ self.env = db.DBEnv()
+ self.env.open(homeDir, db.DB_THREAD | db.DB_INIT_MPOOL |
+ db.DB_INIT_LOCK | db.DB_CREATE)
+
+
+ def tearDown(self):
+ self.env.close()
+ import glob
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+
+ def test01_simple(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_simple..." % self.__class__.__name__
+
+ anID = self.env.lock_id()
+ if verbose:
+ print "locker ID: %s" % anID
+ lock = self.env.lock_get(anID, "some locked thing", db.DB_LOCK_WRITE)
+ if verbose:
+ print "Aquired lock: %s" % lock
+ time.sleep(1)
+ self.env.lock_put(lock)
+ if verbose:
+ print "Released lock: %s" % lock
+
+
+
+
+ def test02_threaded(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_threaded..." % self.__class__.__name__
+
+ threads = []
+ threads.append(Thread(target = self.theThread,
+ args=(5, db.DB_LOCK_WRITE)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_READ)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_READ)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_WRITE)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_READ)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_READ)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_WRITE)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_WRITE)))
+ threads.append(Thread(target = self.theThread,
+ args=(1, db.DB_LOCK_WRITE)))
+
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ def test03_set_timeout(self):
+ # test that the set_timeout call works
+ if hasattr(self.env, 'set_timeout'):
+ self.env.set_timeout(0, db.DB_SET_LOCK_TIMEOUT)
+ self.env.set_timeout(0, db.DB_SET_TXN_TIMEOUT)
+ self.env.set_timeout(123456, db.DB_SET_LOCK_TIMEOUT)
+ self.env.set_timeout(7890123, db.DB_SET_TXN_TIMEOUT)
+
+ def theThread(self, sleepTime, lockType):
+ name = currentThread().getName()
+ if lockType == db.DB_LOCK_WRITE:
+ lt = "write"
+ else:
+ lt = "read"
+
+ anID = self.env.lock_id()
+ if verbose:
+ print "%s: locker ID: %s" % (name, anID)
+
+ lock = self.env.lock_get(anID, "some locked thing", lockType)
+ if verbose:
+ print "%s: Aquired %s lock: %s" % (name, lt, lock)
+
+ time.sleep(sleepTime)
+
+ self.env.lock_put(lock)
+ if verbose:
+ print "%s: Released %s lock: %s" % (name, lt, lock)
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ if have_threads:
+ suite.addTest(unittest.makeSuite(LockingTestCase))
+ else:
+ suite.addTest(unittest.makeSuite(LockingTestCase, 'test01'))
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_misc.py b/python/rpmdb/test/test_misc.py
new file mode 100644
index 000000000..7c2cfcbea
--- /dev/null
+++ b/python/rpmdb/test/test_misc.py
@@ -0,0 +1,53 @@
+"""Miscellaneous rpmdb module test cases
+"""
+
+import os
+import sys
+import unittest
+
+from rpmdb import db, dbshelve
+
+#----------------------------------------------------------------------
+
+class MiscTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = self.__class__.__name__ + '.db'
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try:
+ os.mkdir(homeDir)
+ except OSError:
+ pass
+
+ def tearDown(self):
+ try:
+ os.remove(self.filename)
+ except OSError:
+ pass
+ import glob
+ files = glob.glob(os.path.join(self.homeDir, '*'))
+ for file in files:
+ os.remove(file)
+
+ def test01_badpointer(self):
+ dbs = dbshelve.open(self.filename)
+ dbs.close()
+ self.assertRaises(db.DBError, dbs.get, "foo")
+
+ def test02_db_home(self):
+ env = db.DBEnv()
+ # check for crash fixed when db_home is used before open()
+ assert env.db_home is None
+ env.open(self.homeDir, db.DB_CREATE)
+ assert self.homeDir == env.db_home
+
+
+#----------------------------------------------------------------------
+
+
+def test_suite():
+ return unittest.makeSuite(MiscTestCase)
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_queue.py b/python/rpmdb/test/test_queue.py
new file mode 100644
index 000000000..5a4bca0d1
--- /dev/null
+++ b/python/rpmdb/test/test_queue.py
@@ -0,0 +1,168 @@
+"""
+TestCases for exercising a Queue DB.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+class SimpleQueueTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+
+ def tearDown(self):
+ try:
+ os.remove(self.filename)
+ except os.error:
+ pass
+
+
+ def test01_basic(self):
+ # Basic Queue tests using the deprecated DBCursor.consume method.
+
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_basic..." % self.__class__.__name__
+
+ d = db.DB()
+ d.set_re_len(40) # Queues must be fixed length
+ d.open(self.filename, db.DB_QUEUE, db.DB_CREATE)
+
+ if verbose:
+ print "before appends" + '-' * 30
+ pprint(d.stat())
+
+ for x in string.letters:
+ d.append(x * 40)
+
+ assert len(d) == 52
+
+ d.put(100, "some more data")
+ d.put(101, "and some more ")
+ d.put(75, "out of order")
+ d.put(1, "replacement data")
+
+ assert len(d) == 55
+
+ if verbose:
+ print "before close" + '-' * 30
+ pprint(d.stat())
+
+ d.close()
+ del d
+ d = db.DB()
+ d.open(self.filename)
+
+ if verbose:
+ print "after open" + '-' * 30
+ pprint(d.stat())
+
+ d.append("one more")
+ c = d.cursor()
+
+ if verbose:
+ print "after append" + '-' * 30
+ pprint(d.stat())
+
+ rec = c.consume()
+ while rec:
+ if verbose:
+ print rec
+ rec = c.consume()
+ c.close()
+
+ if verbose:
+ print "after consume loop" + '-' * 30
+ pprint(d.stat())
+
+ assert len(d) == 0, \
+ "if you see this message then you need to rebuild " \
+ "BerkeleyDB 3.1.17 with the patch in patches/qam_stat.diff"
+
+ d.close()
+
+
+
+ def test02_basicPost32(self):
+ # Basic Queue tests using the new DB.consume method in DB 3.2+
+ # (No cursor needed)
+
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_basicPost32..." % self.__class__.__name__
+
+ if db.version() < (3, 2, 0):
+ if verbose:
+ print "Test not run, DB not new enough..."
+ return
+
+ d = db.DB()
+ d.set_re_len(40) # Queues must be fixed length
+ d.open(self.filename, db.DB_QUEUE, db.DB_CREATE)
+
+ if verbose:
+ print "before appends" + '-' * 30
+ pprint(d.stat())
+
+ for x in string.letters:
+ d.append(x * 40)
+
+ assert len(d) == 52
+
+ d.put(100, "some more data")
+ d.put(101, "and some more ")
+ d.put(75, "out of order")
+ d.put(1, "replacement data")
+
+ assert len(d) == 55
+
+ if verbose:
+ print "before close" + '-' * 30
+ pprint(d.stat())
+
+ d.close()
+ del d
+ d = db.DB()
+ d.open(self.filename)
+ #d.set_get_returns_none(true)
+
+ if verbose:
+ print "after open" + '-' * 30
+ pprint(d.stat())
+
+ d.append("one more")
+
+ if verbose:
+ print "after append" + '-' * 30
+ pprint(d.stat())
+
+ rec = d.consume()
+ while rec:
+ if verbose:
+ print rec
+ rec = d.consume()
+
+ if verbose:
+ print "after consume loop" + '-' * 30
+ pprint(d.stat())
+
+ d.close()
+
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ return unittest.makeSuite(SimpleQueueTestCase)
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_recno.py b/python/rpmdb/test/test_recno.py
new file mode 100644
index 000000000..bd62cfad5
--- /dev/null
+++ b/python/rpmdb/test/test_recno.py
@@ -0,0 +1,260 @@
+"""TestCases for exercising a Recno DB.
+"""
+
+import os
+import sys
+import errno
+import tempfile
+from pprint import pprint
+import unittest
+
+from test_all import verbose
+
+from rpmdb import db
+
+letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+
+
+#----------------------------------------------------------------------
+
+class SimpleRecnoTestCase(unittest.TestCase):
+ def setUp(self):
+ self.filename = tempfile.mktemp()
+
+ def tearDown(self):
+ try:
+ os.remove(self.filename)
+ except OSError, e:
+ if e.errno <> errno.EEXIST: raise
+
+ def test01_basic(self):
+ d = db.DB()
+ d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+ for x in letters:
+ recno = d.append(x * 60)
+ assert type(recno) == type(0)
+ assert recno >= 1
+ if verbose:
+ print recno,
+
+ if verbose: print
+
+ stat = d.stat()
+ if verbose:
+ pprint(stat)
+
+ for recno in range(1, len(d)+1):
+ data = d[recno]
+ if verbose:
+ print data
+
+ assert type(data) == type("")
+ assert data == d.get(recno)
+
+ try:
+ data = d[0] # This should raise a KeyError!?!?!
+ except db.DBInvalidArgError, val:
+ assert val[0] == db.EINVAL
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ try:
+ data = d[100]
+ except KeyError:
+ pass
+ else:
+ self.fail("expected exception")
+
+ data = d.get(100)
+ assert data == None
+
+ keys = d.keys()
+ if verbose:
+ print keys
+ assert type(keys) == type([])
+ assert type(keys[0]) == type(123)
+ assert len(keys) == len(d)
+
+ items = d.items()
+ if verbose:
+ pprint(items)
+ assert type(items) == type([])
+ assert type(items[0]) == type(())
+ assert len(items[0]) == 2
+ assert type(items[0][0]) == type(123)
+ assert type(items[0][1]) == type("")
+ assert len(items) == len(d)
+
+ assert d.has_key(25)
+
+ del d[25]
+ assert not d.has_key(25)
+
+ d.delete(13)
+ assert not d.has_key(13)
+
+ data = d.get_both(26, "z" * 60)
+ assert data == "z" * 60
+ if verbose:
+ print data
+
+ fd = d.fd()
+ if verbose:
+ print fd
+
+ c = d.cursor()
+ rec = c.first()
+ while rec:
+ if verbose:
+ print rec
+ rec = c.next()
+
+ c.set(50)
+ rec = c.current()
+ if verbose:
+ print rec
+
+ c.put(-1, "a replacement record", db.DB_CURRENT)
+
+ c.set(50)
+ rec = c.current()
+ assert rec == (50, "a replacement record")
+ if verbose:
+ print rec
+
+ rec = c.set_range(30)
+ if verbose:
+ print rec
+
+ c.close()
+ d.close()
+
+ d = db.DB()
+ d.open(self.filename)
+ c = d.cursor()
+
+ # put a record beyond the consecutive end of the recno's
+ d[100] = "way out there"
+ assert d[100] == "way out there"
+
+ try:
+ data = d[99]
+ except KeyError:
+ pass
+ else:
+ self.fail("expected exception")
+
+ try:
+ d.get(99)
+ except db.DBKeyEmptyError, val:
+ assert val[0] == db.DB_KEYEMPTY
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ rec = c.set(40)
+ while rec:
+ if verbose:
+ print rec
+ rec = c.next()
+
+ c.close()
+ d.close()
+
+ def test02_WithSource(self):
+ """
+ A Recno file that is given a "backing source file" is essentially a
+ simple ASCII file. Normally each record is delimited by \n and so is
+ just a line in the file, but you can set a different record delimiter
+ if needed.
+ """
+ source = os.path.join(os.path.dirname(sys.argv[0]),
+ 'db_home/test_recno.txt')
+ f = open(source, 'w') # create the file
+ f.close()
+
+ d = db.DB()
+ # This is the default value, just checking if both int
+ d.set_re_delim(0x0A)
+ d.set_re_delim('\n') # and char can be used...
+ d.set_re_source(source)
+ d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+ data = "The quick brown fox jumped over the lazy dog".split()
+ for datum in data:
+ d.append(datum)
+ d.sync()
+ d.close()
+
+ # get the text from the backing source
+ text = open(source, 'r').read()
+ text = text.strip()
+ if verbose:
+ print text
+ print data
+ print text.split('\n')
+
+ assert text.split('\n') == data
+
+ # open as a DB again
+ d = db.DB()
+ d.set_re_source(source)
+ d.open(self.filename, db.DB_RECNO)
+
+ d[3] = 'reddish-brown'
+ d[8] = 'comatose'
+
+ d.sync()
+ d.close()
+
+ text = open(source, 'r').read()
+ text = text.strip()
+ if verbose:
+ print text
+ print text.split('\n')
+
+ assert text.split('\n') == \
+ "The quick reddish-brown fox jumped over the comatose dog".split()
+
+ def test03_FixedLength(self):
+ d = db.DB()
+ d.set_re_len(40) # fixed length records, 40 bytes long
+ d.set_re_pad('-') # sets the pad character...
+ d.set_re_pad(45) # ...test both int and char
+ d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+ for x in letters:
+ d.append(x * 35) # These will be padded
+
+ d.append('.' * 40) # this one will be exact
+
+ try: # this one will fail
+ d.append('bad' * 20)
+ except db.DBInvalidArgError, val:
+ assert val[0] == db.EINVAL
+ if verbose: print val
+ else:
+ self.fail("expected exception")
+
+ c = d.cursor()
+ rec = c.first()
+ while rec:
+ if verbose:
+ print rec
+ rec = c.next()
+
+ c.close()
+ d.close()
+
+
+#----------------------------------------------------------------------
+
+
+def test_suite():
+ return unittest.makeSuite(SimpleRecnoTestCase)
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/test_thread.py b/python/rpmdb/test/test_thread.py
new file mode 100644
index 000000000..3041557cc
--- /dev/null
+++ b/python/rpmdb/test/test_thread.py
@@ -0,0 +1,495 @@
+"""TestCases for multi-threaded access to a DB.
+"""
+
+import os
+import sys
+import time
+import errno
+import shutil
+import tempfile
+from pprint import pprint
+from whrandom import random
+
+try:
+ True, False
+except NameError:
+ True = 1
+ False = 0
+
+DASH = '-'
+
+try:
+ from threading import Thread, currentThread
+ have_threads = True
+except ImportError:
+ have_threads = False
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db, dbutils
+
+
+#----------------------------------------------------------------------
+
+class BaseThreadedTestCase(unittest.TestCase):
+ dbtype = db.DB_UNKNOWN # must be set in derived class
+ dbopenflags = 0
+ dbsetflags = 0
+ envflags = 0
+
+ def setUp(self):
+ if verbose:
+ dbutils._deadlock_VerboseFile = sys.stdout
+
+ homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+ self.homeDir = homeDir
+ try:
+ os.mkdir(homeDir)
+ except OSError, e:
+ if e.errno <> errno.EEXIST: raise
+ self.env = db.DBEnv()
+ self.setEnvOpts()
+ self.env.open(homeDir, self.envflags | db.DB_CREATE)
+
+ self.filename = self.__class__.__name__ + '.db'
+ self.d = db.DB(self.env)
+ if self.dbsetflags:
+ self.d.set_flags(self.dbsetflags)
+ self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE)
+
+ def tearDown(self):
+ self.d.close()
+ self.env.close()
+ shutil.rmtree(self.homeDir)
+
+ def setEnvOpts(self):
+ pass
+
+ def makeData(self, key):
+ return DASH.join([key] * 5)
+
+
+#----------------------------------------------------------------------
+
+
+class ConcurrentDataStoreBase(BaseThreadedTestCase):
+ dbopenflags = db.DB_THREAD
+ envflags = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL
+ readers = 0 # derived class should set
+ writers = 0
+ records = 1000
+
+ def test01_1WriterMultiReaders(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test01_1WriterMultiReaders..." % \
+ self.__class__.__name__
+
+ threads = []
+ for x in range(self.writers):
+ wt = Thread(target = self.writerThread,
+ args = (self.d, self.records, x),
+ name = 'writer %d' % x,
+ )#verbose = verbose)
+ threads.append(wt)
+
+ for x in range(self.readers):
+ rt = Thread(target = self.readerThread,
+ args = (self.d, x),
+ name = 'reader %d' % x,
+ )#verbose = verbose)
+ threads.append(rt)
+
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ def writerThread(self, d, howMany, writerNum):
+ #time.sleep(0.01 * writerNum + 0.01)
+ name = currentThread().getName()
+ start = howMany * writerNum
+ stop = howMany * (writerNum + 1) - 1
+ if verbose:
+ print "%s: creating records %d - %d" % (name, start, stop)
+
+ for x in range(start, stop):
+ key = '%04d' % x
+ dbutils.DeadlockWrap(d.put, key, self.makeData(key),
+ max_retries=12)
+ if verbose and x % 100 == 0:
+ print "%s: records %d - %d finished" % (name, start, x)
+
+ if verbose:
+ print "%s: finished creating records" % name
+
+## # Each write-cursor will be exclusive, the only one that can update the DB...
+## if verbose: print "%s: deleting a few records" % name
+## c = d.cursor(flags = db.DB_WRITECURSOR)
+## for x in range(10):
+## key = int(random() * howMany) + start
+## key = '%04d' % key
+## if d.has_key(key):
+## c.set(key)
+## c.delete()
+
+## c.close()
+ if verbose:
+ print "%s: thread finished" % name
+
+ def readerThread(self, d, readerNum):
+ time.sleep(0.01 * readerNum)
+ name = currentThread().getName()
+
+ for loop in range(5):
+ c = d.cursor()
+ count = 0
+ rec = c.first()
+ while rec:
+ count += 1
+ key, data = rec
+ self.assertEqual(self.makeData(key), data)
+ rec = c.next()
+ if verbose:
+ print "%s: found %d records" % (name, count)
+ c.close()
+ time.sleep(0.05)
+
+ if verbose:
+ print "%s: thread finished" % name
+
+
+class BTreeConcurrentDataStore(ConcurrentDataStoreBase):
+ dbtype = db.DB_BTREE
+ writers = 2
+ readers = 10
+ records = 1000
+
+
+class HashConcurrentDataStore(ConcurrentDataStoreBase):
+ dbtype = db.DB_HASH
+ writers = 2
+ readers = 10
+ records = 1000
+
+
+#----------------------------------------------------------------------
+
+class SimpleThreadedBase(BaseThreadedTestCase):
+ dbopenflags = db.DB_THREAD
+ envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+ readers = 5
+ writers = 3
+ records = 1000
+
+ def setEnvOpts(self):
+ self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
+
+ def test02_SimpleLocks(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test02_SimpleLocks..." % self.__class__.__name__
+
+ threads = []
+ for x in range(self.writers):
+ wt = Thread(target = self.writerThread,
+ args = (self.d, self.records, x),
+ name = 'writer %d' % x,
+ )#verbose = verbose)
+ threads.append(wt)
+ for x in range(self.readers):
+ rt = Thread(target = self.readerThread,
+ args = (self.d, x),
+ name = 'reader %d' % x,
+ )#verbose = verbose)
+ threads.append(rt)
+
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ def writerThread(self, d, howMany, writerNum):
+ name = currentThread().getName()
+ start = howMany * writerNum
+ stop = howMany * (writerNum + 1) - 1
+ if verbose:
+ print "%s: creating records %d - %d" % (name, start, stop)
+
+ # create a bunch of records
+ for x in xrange(start, stop):
+ key = '%04d' % x
+ dbutils.DeadlockWrap(d.put, key, self.makeData(key),
+ max_retries=12)
+
+ if verbose and x % 100 == 0:
+ print "%s: records %d - %d finished" % (name, start, x)
+
+ # do a bit or reading too
+ if random() <= 0.05:
+ for y in xrange(start, x):
+ key = '%04d' % x
+ data = dbutils.DeadlockWrap(d.get, key, max_retries=12)
+ self.assertEqual(data, self.makeData(key))
+
+ # flush them
+ try:
+ dbutils.DeadlockWrap(d.sync, max_retries=12)
+ except db.DBIncompleteError, val:
+ if verbose:
+ print "could not complete sync()..."
+
+ # read them back, deleting a few
+ for x in xrange(start, stop):
+ key = '%04d' % x
+ data = dbutils.DeadlockWrap(d.get, key, max_retries=12)
+ if verbose and x % 100 == 0:
+ print "%s: fetched record (%s, %s)" % (name, key, data)
+ self.assertEqual(data, self.makeData(key))
+ if random() <= 0.10:
+ dbutils.DeadlockWrap(d.delete, key, max_retries=12)
+ if verbose:
+ print "%s: deleted record %s" % (name, key)
+
+ if verbose:
+ print "%s: thread finished" % name
+
+ def readerThread(self, d, readerNum):
+ time.sleep(0.01 * readerNum)
+ name = currentThread().getName()
+
+ for loop in range(5):
+ c = d.cursor()
+ count = 0
+ rec = c.first()
+ while rec:
+ count += 1
+ key, data = rec
+ self.assertEqual(self.makeData(key), data)
+ rec = c.next()
+ if verbose:
+ print "%s: found %d records" % (name, count)
+ c.close()
+ time.sleep(0.05)
+
+ if verbose:
+ print "%s: thread finished" % name
+
+
+class BTreeSimpleThreaded(SimpleThreadedBase):
+ dbtype = db.DB_BTREE
+
+
+class HashSimpleThreaded(SimpleThreadedBase):
+ dbtype = db.DB_HASH
+
+
+#----------------------------------------------------------------------
+
+
+class ThreadedTransactionsBase(BaseThreadedTestCase):
+ dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT
+ envflags = (db.DB_THREAD |
+ db.DB_INIT_MPOOL |
+ db.DB_INIT_LOCK |
+ db.DB_INIT_LOG |
+ db.DB_INIT_TXN
+ )
+ readers = 0
+ writers = 0
+ records = 2000
+ txnFlag = 0
+
+ def setEnvOpts(self):
+ #self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
+ pass
+
+ def test03_ThreadedTransactions(self):
+ if verbose:
+ print '\n', '-=' * 30
+ print "Running %s.test03_ThreadedTransactions..." % \
+ self.__class__.__name__
+
+ threads = []
+ for x in range(self.writers):
+ wt = Thread(target = self.writerThread,
+ args = (self.d, self.records, x),
+ name = 'writer %d' % x,
+ )#verbose = verbose)
+ threads.append(wt)
+
+ for x in range(self.readers):
+ rt = Thread(target = self.readerThread,
+ args = (self.d, x),
+ name = 'reader %d' % x,
+ )#verbose = verbose)
+ threads.append(rt)
+
+ dt = Thread(target = self.deadlockThread)
+ dt.start()
+
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ self.doLockDetect = False
+ dt.join()
+
+ def doWrite(self, d, name, start, stop):
+ finished = False
+ while not finished:
+ try:
+ txn = self.env.txn_begin(None, self.txnFlag)
+ for x in range(start, stop):
+ key = '%04d' % x
+ d.put(key, self.makeData(key), txn)
+ if verbose and x % 100 == 0:
+ print "%s: records %d - %d finished" % (name, start, x)
+ txn.commit()
+ finished = True
+ except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+ if verbose:
+ print "%s: Aborting transaction (%s)" % (name, val[1])
+ txn.abort()
+ time.sleep(0.05)
+
+ def writerThread(self, d, howMany, writerNum):
+ name = currentThread().getName()
+ start = howMany * writerNum
+ stop = howMany * (writerNum + 1) - 1
+ if verbose:
+ print "%s: creating records %d - %d" % (name, start, stop)
+
+ step = 100
+ for x in range(start, stop, step):
+ self.doWrite(d, name, x, min(stop, x+step))
+
+ if verbose:
+ print "%s: finished creating records" % name
+ if verbose:
+ print "%s: deleting a few records" % name
+
+ finished = False
+ while not finished:
+ try:
+ recs = []
+ txn = self.env.txn_begin(None, self.txnFlag)
+ for x in range(10):
+ key = int(random() * howMany) + start
+ key = '%04d' % key
+ data = d.get(key, None, txn, db.DB_RMW)
+ if data is not None:
+ d.delete(key, txn)
+ recs.append(key)
+ txn.commit()
+ finished = True
+ if verbose:
+ print "%s: deleted records %s" % (name, recs)
+ except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+ if verbose:
+ print "%s: Aborting transaction (%s)" % (name, val[1])
+ txn.abort()
+ time.sleep(0.05)
+
+ if verbose:
+ print "%s: thread finished" % name
+
+ def readerThread(self, d, readerNum):
+ time.sleep(0.01 * readerNum + 0.05)
+ name = currentThread().getName()
+
+ for loop in range(5):
+ finished = False
+ while not finished:
+ try:
+ txn = self.env.txn_begin(None, self.txnFlag)
+ c = d.cursor(txn)
+ count = 0
+ rec = c.first()
+ while rec:
+ count += 1
+ key, data = rec
+ self.assertEqual(self.makeData(key), data)
+ rec = c.next()
+ if verbose: print "%s: found %d records" % (name, count)
+ c.close()
+ txn.commit()
+ finished = True
+ except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+ if verbose:
+ print "%s: Aborting transaction (%s)" % (name, val[1])
+ c.close()
+ txn.abort()
+ time.sleep(0.05)
+
+ time.sleep(0.05)
+
+ if verbose:
+ print "%s: thread finished" % name
+
+ def deadlockThread(self):
+ self.doLockDetect = True
+ while self.doLockDetect:
+ time.sleep(0.5)
+ try:
+ aborted = self.env.lock_detect(
+ db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT)
+ if verbose and aborted:
+ print "deadlock: Aborted %d deadlocked transaction(s)" \
+ % aborted
+ except db.DBError:
+ pass
+
+
+class BTreeThreadedTransactions(ThreadedTransactionsBase):
+ dbtype = db.DB_BTREE
+ writers = 3
+ readers = 5
+ records = 2000
+
+class HashThreadedTransactions(ThreadedTransactionsBase):
+ dbtype = db.DB_HASH
+ writers = 1
+ readers = 5
+ records = 2000
+
+class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase):
+ dbtype = db.DB_BTREE
+ writers = 3
+ readers = 5
+ records = 2000
+ txnFlag = db.DB_TXN_NOWAIT
+
+class HashThreadedNoWaitTransactions(ThreadedTransactionsBase):
+ dbtype = db.DB_HASH
+ writers = 1
+ readers = 5
+ records = 2000
+ txnFlag = db.DB_TXN_NOWAIT
+
+
+#----------------------------------------------------------------------
+
+def test_suite():
+ suite = unittest.TestSuite()
+
+ if have_threads:
+ suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore))
+ suite.addTest(unittest.makeSuite(HashConcurrentDataStore))
+ suite.addTest(unittest.makeSuite(BTreeSimpleThreaded))
+ suite.addTest(unittest.makeSuite(HashSimpleThreaded))
+ suite.addTest(unittest.makeSuite(BTreeThreadedTransactions))
+ suite.addTest(unittest.makeSuite(HashThreadedTransactions))
+ suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions))
+ suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions))
+
+ else:
+ print "Threads not available, skipping thread tests."
+
+ return suite
+
+
+if __name__ == '__main__':
+ unittest.main(defaultTest='test_suite')
diff --git a/python/rpmdb/test/unittest.py b/python/rpmdb/test/unittest.py
new file mode 100644
index 000000000..d31e251d4
--- /dev/null
+++ b/python/rpmdb/test/unittest.py
@@ -0,0 +1,759 @@
+#!/usr/bin/env python
+'''
+Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's
+Smalltalk testing framework.
+
+This module contains the core framework classes that form the basis of
+specific test cases and suites (TestCase, TestSuite etc.), and also a
+text-based utility class for running the tests and reporting the results
+ (TextTestRunner).
+
+Simple usage:
+
+ import unittest
+
+ class IntegerArithmenticTestCase(unittest.TestCase):
+ def testAdd(self): ## test method names begin 'test*'
+ self.assertEquals((1 + 2), 3)
+ self.assertEquals(0 + 1, 1)
+ def testMultiply(self):
+ self.assertEquals((0 * 10), 0)
+ self.assertEquals((5 * 8), 40)
+
+ if __name__ == '__main__':
+ unittest.main()
+
+Further information is available in the bundled documentation, and from
+
+ http://pyunit.sourceforge.net/
+
+Copyright (c) 1999, 2000, 2001 Steve Purcell
+This module is free software, and you may redistribute it and/or modify
+it under the same terms as Python itself, so long as this copyright message
+and disclaimer are retained in their original form.
+
+IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
+THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
+
+THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
+AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
+SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
+'''
+
+__author__ = "Steve Purcell"
+__email__ = "stephen_purcell at yahoo dot com"
+__version__ = "#Revision: 1.46 $"[11:-2]
+
+import time
+import sys
+import traceback
+import string
+import os
+import types
+
+##############################################################################
+# Test framework core
+##############################################################################
+
+# All classes defined herein are 'new-style' classes, allowing use of 'super()'
+__metaclass__ = type
+
+def _strclass(cls):
+ return "%s.%s" % (cls.__module__, cls.__name__)
+
+class TestResult:
+ """Holder for test result information.
+
+ Test results are automatically managed by the TestCase and TestSuite
+ classes, and do not need to be explicitly manipulated by writers of tests.
+
+ Each instance holds the total number of tests run, and collections of
+ failures and errors that occurred among those test runs. The collections
+ contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
+ formatted traceback of the error that occurred.
+ """
+ def __init__(self):
+ self.failures = []
+ self.errors = []
+ self.testsRun = 0
+ self.shouldStop = 0
+
+ def startTest(self, test):
+ "Called when the given test is about to be run"
+ self.testsRun = self.testsRun + 1
+
+ def stopTest(self, test):
+ "Called when the given test has been run"
+ pass
+
+ def addError(self, test, err):
+ """Called when an error has occurred. 'err' is a tuple of values as
+ returned by sys.exc_info().
+ """
+ self.errors.append((test, self._exc_info_to_string(err)))
+
+ def addFailure(self, test, err):
+ """Called when an error has occurred. 'err' is a tuple of values as
+ returned by sys.exc_info()."""
+ self.failures.append((test, self._exc_info_to_string(err)))
+
+ def addSuccess(self, test):
+ "Called when a test has completed successfully"
+ pass
+
+ def wasSuccessful(self):
+ "Tells whether or not this result was a success"
+ return len(self.failures) == len(self.errors) == 0
+
+ def stop(self):
+ "Indicates that the tests should be aborted"
+ self.shouldStop = 1
+
+ def _exc_info_to_string(self, err):
+ """Converts a sys.exc_info()-style tuple of values into a string."""
+ return string.join(traceback.format_exception(*err), '')
+
+ def __repr__(self):
+ return "<%s run=%i errors=%i failures=%i>" % \
+ (_strclass(self.__class__), self.testsRun, len(self.errors),
+ len(self.failures))
+
+
+class TestCase:
+ """A class whose instances are single test cases.
+
+ By default, the test code itself should be placed in a method named
+ 'runTest'.
+
+ If the fixture may be used for many test cases, create as
+ many test methods as are needed. When instantiating such a TestCase
+ subclass, specify in the constructor arguments the name of the test method
+ that the instance is to execute.
+
+ Test authors should subclass TestCase for their own tests. Construction
+ and deconstruction of the test's environment ('fixture') can be
+ implemented by overriding the 'setUp' and 'tearDown' methods respectively.
+
+ If it is necessary to override the __init__ method, the base class
+ __init__ method must always be called. It is important that subclasses
+ should not change the signature of their __init__ method, since instances
+ of the classes are instantiated automatically by parts of the framework
+ in order to be run.
+ """
+
+ # This attribute determines which exception will be raised when
+ # the instance's assertion methods fail; test methods raising this
+ # exception will be deemed to have 'failed' rather than 'errored'
+
+ failureException = AssertionError
+
+ def __init__(self, methodName='runTest'):
+ """Create an instance of the class that will use the named test
+ method when executed. Raises a ValueError if the instance does
+ not have a method with the specified name.
+ """
+ try:
+ self.__testMethodName = methodName
+ testMethod = getattr(self, methodName)
+ self.__testMethodDoc = testMethod.__doc__
+ except AttributeError:
+ raise ValueError, "no such test method in %s: %s" % \
+ (self.__class__, methodName)
+
+ def setUp(self):
+ "Hook method for setting up the test fixture before exercising it."
+ pass
+
+ def tearDown(self):
+ "Hook method for deconstructing the test fixture after testing it."
+ pass
+
+ def countTestCases(self):
+ return 1
+
+ def defaultTestResult(self):
+ return TestResult()
+
+ def shortDescription(self):
+ """Returns a one-line description of the test, or None if no
+ description has been provided.
+
+ The default implementation of this method returns the first line of
+ the specified test method's docstring.
+ """
+ doc = self.__testMethodDoc
+ return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+ def id(self):
+ return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
+
+ def __str__(self):
+ return "%s (%s)" % (self.__testMethodName, _strclass(self.__class__))
+
+ def __repr__(self):
+ return "<%s testMethod=%s>" % \
+ (_strclass(self.__class__), self.__testMethodName)
+
+ def run(self, result=None):
+ return self(result)
+
+ def __call__(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ result.startTest(self)
+ testMethod = getattr(self, self.__testMethodName)
+ try:
+ try:
+ self.setUp()
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+ return
+
+ ok = 0
+ try:
+ testMethod()
+ ok = 1
+ except self.failureException, e:
+ result.addFailure(self, self.__exc_info())
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+
+ try:
+ self.tearDown()
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self.__exc_info())
+ ok = 0
+ if ok: result.addSuccess(self)
+ finally:
+ result.stopTest(self)
+
+ def debug(self):
+ """Run the test without collecting errors in a TestResult"""
+ self.setUp()
+ getattr(self, self.__testMethodName)()
+ self.tearDown()
+
+ def __exc_info(self):
+ """Return a version of sys.exc_info() with the traceback frame
+ minimised; usually the top level of the traceback frame is not
+ needed.
+ """
+ exctype, excvalue, tb = sys.exc_info()
+ if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+ return (exctype, excvalue, tb)
+ newtb = tb.tb_next
+ if newtb is None:
+ return (exctype, excvalue, tb)
+ return (exctype, excvalue, newtb)
+
+ def fail(self, msg=None):
+ """Fail immediately, with the given message."""
+ raise self.failureException, msg
+
+ def failIf(self, expr, msg=None):
+ "Fail the test if the expression is true."
+ if expr: raise self.failureException, msg
+
+ def failUnless(self, expr, msg=None):
+ """Fail the test unless the expression is true."""
+ if not expr: raise self.failureException, msg
+
+ def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
+ """Fail unless an exception of class excClass is thrown
+ by callableObj when invoked with arguments args and keyword
+ arguments kwargs. If a different type of exception is
+ thrown, it will not be caught, and the test case will be
+ deemed to have suffered an error, exactly as for an
+ unexpected exception.
+ """
+ try:
+ callableObj(*args, **kwargs)
+ except excClass:
+ return
+ else:
+ if hasattr(excClass,'__name__'): excName = excClass.__name__
+ else: excName = str(excClass)
+ raise self.failureException, excName
+
+ def failUnlessEqual(self, first, second, msg=None):
+ """Fail if the two objects are unequal as determined by the '=='
+ operator.
+ """
+ if not first == second:
+ raise self.failureException, \
+ (msg or '%s != %s' % (`first`, `second`))
+
+ def failIfEqual(self, first, second, msg=None):
+ """Fail if the two objects are equal as determined by the '=='
+ operator.
+ """
+ if first == second:
+ raise self.failureException, \
+ (msg or '%s == %s' % (`first`, `second`))
+
+ def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
+ """Fail if the two objects are unequal as determined by their
+ difference rounded to the given number of decimal places
+ (default 7) and comparing to zero.
+
+ Note that decimal places (from zero) is usually not the same
+ as significant digits (measured from the most signficant digit).
+ """
+ if round(second-first, places) != 0:
+ raise self.failureException, \
+ (msg or '%s != %s within %s places' % (`first`, `second`, `places` ))
+
+ def failIfAlmostEqual(self, first, second, places=7, msg=None):
+ """Fail if the two objects are equal as determined by their
+ difference rounded to the given number of decimal places
+ (default 7) and comparing to zero.
+
+ Note that decimal places (from zero) is usually not the same
+ as significant digits (measured from the most signficant digit).
+ """
+ if round(second-first, places) == 0:
+ raise self.failureException, \
+ (msg or '%s == %s within %s places' % (`first`, `second`, `places`))
+
+ assertEqual = assertEquals = failUnlessEqual
+
+ assertNotEqual = assertNotEquals = failIfEqual
+
+ assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
+
+ assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
+
+ assertRaises = failUnlessRaises
+
+ assert_ = failUnless
+
+
+
+class TestSuite:
+ """A test suite is a composite test consisting of a number of TestCases.
+
+ For use, create an instance of TestSuite, then add test case instances.
+ When all tests have been added, the suite can be passed to a test
+ runner, such as TextTestRunner. It will run the individual test cases
+ in the order in which they were added, aggregating the results. When
+ subclassing, do not forget to call the base class constructor.
+ """
+ def __init__(self, tests=()):
+ self._tests = []
+ self.addTests(tests)
+
+ def __repr__(self):
+ return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
+
+ __str__ = __repr__
+
+ def countTestCases(self):
+ cases = 0
+ for test in self._tests:
+ cases = cases + test.countTestCases()
+ return cases
+
+ def addTest(self, test):
+ self._tests.append(test)
+
+ def addTests(self, tests):
+ for test in tests:
+ self.addTest(test)
+
+ def run(self, result):
+ return self(result)
+
+ def __call__(self, result):
+ for test in self._tests:
+ if result.shouldStop:
+ break
+ test(result)
+ return result
+
+ def debug(self):
+ """Run the tests without collecting errors in a TestResult"""
+ for test in self._tests: test.debug()
+
+
+class FunctionTestCase(TestCase):
+ """A test case that wraps a test function.
+
+ This is useful for slipping pre-existing test functions into the
+ PyUnit framework. Optionally, set-up and tidy-up functions can be
+ supplied. As with TestCase, the tidy-up ('tearDown') function will
+ always be called if the set-up ('setUp') function ran successfully.
+ """
+
+ def __init__(self, testFunc, setUp=None, tearDown=None,
+ description=None):
+ TestCase.__init__(self)
+ self.__setUpFunc = setUp
+ self.__tearDownFunc = tearDown
+ self.__testFunc = testFunc
+ self.__description = description
+
+ def setUp(self):
+ if self.__setUpFunc is not None:
+ self.__setUpFunc()
+
+ def tearDown(self):
+ if self.__tearDownFunc is not None:
+ self.__tearDownFunc()
+
+ def runTest(self):
+ self.__testFunc()
+
+ def id(self):
+ return self.__testFunc.__name__
+
+ def __str__(self):
+ return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
+
+ def __repr__(self):
+ return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
+
+ def shortDescription(self):
+ if self.__description is not None: return self.__description
+ doc = self.__testFunc.__doc__
+ return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+
+
+##############################################################################
+# Locating and loading tests
+##############################################################################
+
+class TestLoader:
+ """This class is responsible for loading tests according to various
+ criteria and returning them wrapped in a Test
+ """
+ testMethodPrefix = 'test'
+ sortTestMethodsUsing = cmp
+ suiteClass = TestSuite
+
+ def loadTestsFromTestCase(self, testCaseClass):
+ """Return a suite of all tests cases contained in testCaseClass"""
+ return self.suiteClass(map(testCaseClass,
+ self.getTestCaseNames(testCaseClass)))
+
+ def loadTestsFromModule(self, module):
+ """Return a suite of all tests cases contained in the given module"""
+ tests = []
+ for name in dir(module):
+ obj = getattr(module, name)
+ if (isinstance(obj, (type, types.ClassType)) and
+ issubclass(obj, TestCase)):
+ tests.append(self.loadTestsFromTestCase(obj))
+ return self.suiteClass(tests)
+
+ def loadTestsFromName(self, name, module=None):
+ """Return a suite of all tests cases given a string specifier.
+
+ The name may resolve either to a module, a test case class, a
+ test method within a test case class, or a callable object which
+ returns a TestCase or TestSuite instance.
+
+ The method optionally resolves the names relative to a given module.
+ """
+ parts = string.split(name, '.')
+ if module is None:
+ if not parts:
+ raise ValueError, "incomplete test name: %s" % name
+ else:
+ parts_copy = parts[:]
+ while parts_copy:
+ try:
+ module = __import__(string.join(parts_copy,'.'))
+ break
+ except ImportError:
+ del parts_copy[-1]
+ if not parts_copy: raise
+ parts = parts[1:]
+ obj = module
+ for part in parts:
+ obj = getattr(obj, part)
+
+ import unittest
+ if type(obj) == types.ModuleType:
+ return self.loadTestsFromModule(obj)
+ elif (isinstance(obj, (type, types.ClassType)) and
+ issubclass(obj, unittest.TestCase)):
+ return self.loadTestsFromTestCase(obj)
+ elif type(obj) == types.UnboundMethodType:
+ return obj.im_class(obj.__name__)
+ elif callable(obj):
+ test = obj()
+ if not isinstance(test, unittest.TestCase) and \
+ not isinstance(test, unittest.TestSuite):
+ raise ValueError, \
+ "calling %s returned %s, not a test" % (obj,test)
+ return test
+ else:
+ raise ValueError, "don't know how to make test from: %s" % obj
+
+ def loadTestsFromNames(self, names, module=None):
+ """Return a suite of all tests cases found using the given sequence
+ of string specifiers. See 'loadTestsFromName()'.
+ """
+ suites = []
+ for name in names:
+ suites.append(self.loadTestsFromName(name, module))
+ return self.suiteClass(suites)
+
+ def getTestCaseNames(self, testCaseClass):
+ """Return a sorted sequence of method names found within testCaseClass
+ """
+ testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p,
+ dir(testCaseClass))
+ for baseclass in testCaseClass.__bases__:
+ for testFnName in self.getTestCaseNames(baseclass):
+ if testFnName not in testFnNames: # handle overridden methods
+ testFnNames.append(testFnName)
+ if self.sortTestMethodsUsing:
+ testFnNames.sort(self.sortTestMethodsUsing)
+ return testFnNames
+
+
+
+defaultTestLoader = TestLoader()
+
+
+##############################################################################
+# Patches for old functions: these functions should be considered obsolete
+##############################################################################
+
+def _makeLoader(prefix, sortUsing, suiteClass=None):
+ loader = TestLoader()
+ loader.sortTestMethodsUsing = sortUsing
+ loader.testMethodPrefix = prefix
+ if suiteClass: loader.suiteClass = suiteClass
+ return loader
+
+def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
+ return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
+
+def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
+ return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
+
+def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
+ return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
+
+
+##############################################################################
+# Text UI
+##############################################################################
+
+class _WritelnDecorator:
+ """Used to decorate file-like objects with a handy 'writeln' method"""
+ def __init__(self,stream):
+ self.stream = stream
+
+ def __getattr__(self, attr):
+ return getattr(self.stream,attr)
+
+ def writeln(self, *args):
+ if args: self.write(*args)
+ self.write('\n') # text-mode streams translate to \r\n if needed
+
+
+class _TextTestResult(TestResult):
+ """A test result class that can print formatted text results to a stream.
+
+ Used by TextTestRunner.
+ """
+ separator1 = '=' * 70
+ separator2 = '-' * 70
+
+ def __init__(self, stream, descriptions, verbosity):
+ TestResult.__init__(self)
+ self.stream = stream
+ self.showAll = verbosity > 1
+ self.dots = verbosity == 1
+ self.descriptions = descriptions
+
+ def getDescription(self, test):
+ if self.descriptions:
+ return test.shortDescription() or str(test)
+ else:
+ return str(test)
+
+ def startTest(self, test):
+ TestResult.startTest(self, test)
+ if self.showAll:
+ self.stream.write(self.getDescription(test))
+ self.stream.write(" ... ")
+
+ def addSuccess(self, test):
+ TestResult.addSuccess(self, test)
+ if self.showAll:
+ self.stream.writeln("ok")
+ elif self.dots:
+ self.stream.write('.')
+
+ def addError(self, test, err):
+ TestResult.addError(self, test, err)
+ if self.showAll:
+ self.stream.writeln("ERROR")
+ elif self.dots:
+ self.stream.write('E')
+
+ def addFailure(self, test, err):
+ TestResult.addFailure(self, test, err)
+ if self.showAll:
+ self.stream.writeln("FAIL")
+ elif self.dots:
+ self.stream.write('F')
+
+ def printErrors(self):
+ if self.dots or self.showAll:
+ self.stream.writeln()
+ self.printErrorList('ERROR', self.errors)
+ self.printErrorList('FAIL', self.failures)
+
+ def printErrorList(self, flavour, errors):
+ for test, err in errors:
+ self.stream.writeln(self.separator1)
+ self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
+ self.stream.writeln(self.separator2)
+ self.stream.writeln("%s" % err)
+
+
+class TextTestRunner:
+ """A test runner class that displays results in textual form.
+
+ It prints out the names of tests as they are run, errors as they
+ occur, and a summary of the results at the end of the test run.
+ """
+ def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
+ self.stream = _WritelnDecorator(stream)
+ self.descriptions = descriptions
+ self.verbosity = verbosity
+
+ def _makeResult(self):
+ return _TextTestResult(self.stream, self.descriptions, self.verbosity)
+
+ def run(self, test):
+ "Run the given test case or test suite."
+ result = self._makeResult()
+ startTime = time.time()
+ test(result)
+ stopTime = time.time()
+ timeTaken = float(stopTime - startTime)
+ result.printErrors()
+ self.stream.writeln(result.separator2)
+ run = result.testsRun
+ self.stream.writeln("Ran %d test%s in %.3fs" %
+ (run, run != 1 and "s" or "", timeTaken))
+ self.stream.writeln()
+ if not result.wasSuccessful():
+ self.stream.write("FAILED (")
+ failed, errored = map(len, (result.failures, result.errors))
+ if failed:
+ self.stream.write("failures=%d" % failed)
+ if errored:
+ if failed: self.stream.write(", ")
+ self.stream.write("errors=%d" % errored)
+ self.stream.writeln(")")
+ else:
+ self.stream.writeln("OK")
+ return result
+
+
+
+##############################################################################
+# Facilities for running tests from the command line
+##############################################################################
+
+class TestProgram:
+ """A command-line program that runs a set of tests; this is primarily
+ for making test modules conveniently executable.
+ """
+ USAGE = """\
+Usage: %(progName)s [options] [test] [...]
+
+Options:
+ -h, --help Show this message
+ -v, --verbose Verbose output
+ -q, --quiet Minimal output
+
+Examples:
+ %(progName)s - run default set of tests
+ %(progName)s MyTestSuite - run suite 'MyTestSuite'
+ %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
+ %(progName)s MyTestCase - run all 'test*' test methods
+ in MyTestCase
+"""
+ def __init__(self, module='__main__', defaultTest=None,
+ argv=None, testRunner=None, testLoader=defaultTestLoader):
+ if type(module) == type(''):
+ self.module = __import__(module)
+ for part in string.split(module,'.')[1:]:
+ self.module = getattr(self.module, part)
+ else:
+ self.module = module
+ if argv is None:
+ argv = sys.argv
+ self.verbosity = 1
+ self.defaultTest = defaultTest
+ self.testRunner = testRunner
+ self.testLoader = testLoader
+ self.progName = os.path.basename(argv[0])
+ self.parseArgs(argv)
+ self.runTests()
+
+ def usageExit(self, msg=None):
+ if msg: print msg
+ print self.USAGE % self.__dict__
+ sys.exit(2)
+
+ def parseArgs(self, argv):
+ import getopt
+ try:
+ options, args = getopt.getopt(argv[1:], 'hHvq',
+ ['help','verbose','quiet'])
+ for opt, value in options:
+ if opt in ('-h','-H','--help'):
+ self.usageExit()
+ if opt in ('-q','--quiet'):
+ self.verbosity = 0
+ if opt in ('-v','--verbose'):
+ self.verbosity = 2
+ if len(args) == 0 and self.defaultTest is None:
+ self.test = self.testLoader.loadTestsFromModule(self.module)
+ return
+ if len(args) > 0:
+ self.testNames = args
+ else:
+ self.testNames = (self.defaultTest,)
+ self.createTests()
+ except getopt.error, msg:
+ self.usageExit(msg)
+
+ def createTests(self):
+ self.test = self.testLoader.loadTestsFromNames(self.testNames,
+ self.module)
+
+ def runTests(self):
+ if self.testRunner is None:
+ self.testRunner = TextTestRunner(verbosity=self.verbosity)
+ result = self.testRunner.run(self.test)
+ sys.exit(not result.wasSuccessful())
+
+main = TestProgram
+
+
+##############################################################################
+# Executing this module from the command line
+##############################################################################
+
+if __name__ == "__main__":
+ main(module=None)
diff --git a/python/rpmmpw-py.c b/python/rpmmpw-py.c
index 45a022389..17dbbee4d 100644
--- a/python/rpmmpw-py.c
+++ b/python/rpmmpw-py.c
@@ -21,7 +21,7 @@
static int _mpw_debug = 0;
/*@unchecked@*/ /*@observer@*/
-static const char initialiser_name[] = "rpm.mpw";
+static const char initialiser_name[] = "mpz";
/*@unchecked@*/ /*@observer@*/
static const struct {
@@ -2149,7 +2149,7 @@ PyTypeObject mpw_Type = {
(getattrofunc) mpw_getattro, /* tp_getattro */
(setattrofunc) mpw_setattro, /* tp_setattro */
0, /* tp_as_buffer */
- Py_TPFLAGS_DEFAULT, /* tp_flags */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
mpw_doc, /* tp_doc */
#if Py_TPFLAGS_HAVE_ITER
(traverseproc)0, /* tp_traverse */