Doge log

Abby CTO 雑賀 力王のオフィシャルサイトです

DBUnitTestCase 0.10

こんな感じかね?というか今までなかったつーのがアレだけど。
Excel使えないとめんどいというのはよくわかるのだけどなんというかやっぱdoctestで使えないとダメ?
まあちょっといじるだけでできるけど。
まあざざーっと全文公開です。
というかどっかにあげてみんなで機能増やしたりしたいかも。

dbunit/__init__.py
from dbunit.dataset import *
import unittest

DEBUG = False

class DBUnitTestCase(unittest.TestCase):

    def engine(self):pass

    def connection(self):pass
    
    def setUp(self):
        self.connection = self.connection()
        self.introspection_module = (lambda: __import__('dbunit.db.%s.introspection' % self.engine(), '', '', ['']))()
        self.cursor = self.connection.cursor()
        if DEBUG:
            print "connect databse"
    
    def tearDown(self):
        self.connection.rollback()
        self.connection.close()
        if DEBUG:
            print "rollback databse"
    
    def read_xls_write_db(self, file):
        reader = XlsReader(file)
        for table in reader.dataset :
            self._insert(table)

    def read_xls_all_replace_db(self, file):
        reader = XlsReader(file)
        for table in reader.dataset :
            self._delete(table)
            self._insert(table)

    def read_xls_replace_db(self, file):
        reader = XlsReader(file)
        for table in reader.dataset :
            self._update(table)

    def _delete(self, table):
        sql = self._create_delete_sql(table)
        self.cursor.execute(sql, [])                
        
    def _insert(self, table):
        sql = self._create_insert_sql(table)
        for data in table.data:
            row_data = []
            for cnt, v in enumerate(data):
                type, pk = table.meta[table.header[cnt]]
                row_data.append(self._get_object(type, v))
            if DEBUG :
                print row_data
            self.cursor.execute(sql, row_data)                

    def _update(self, table):
        sql = self._create_update_sql(table)
        for data in table.data:
            row_data = []
            where = []
            for cnt, v in enumerate(data):
                type, pk = table.meta[table.header[cnt]]
                if pk:
                    where.append(self._get_object(type, v))
                else :
                    row_data.append(self._get_object(type, v))
            row_data = row_data+where 
            if DEBUG :
                print row_data
            self.cursor.execute(sql, row_data)                

    def _create_delete_sql(self, table):
        table_name = table.name
        sql = 'DELETE FROM %s ' % table_name
        return str(sql)
                    
    def _create_insert_sql(self, table):
        table_name = table.name
        if not table.__dict__.has_key('meta') :
            table.meta = self._description(table_name)
        sql = 'INSERT INTO %s (%s) VALUES(%s)' % (table_name, ','.join(table.header), ','.join('%s' for x in range(len(table.header)))) 
        return str(sql)

    def _create_update_sql(self, table):
        table_name = table.name
        if not table.__dict__.has_key('meta') :
            table.meta = self._description(table_name)
        update = [];
        where = [];
        for column in table.header:
            field, pk = table.meta[column]
            if pk:
                where.append(column)
            else :
                update.append(column)            
        sql = 'UPDATE %s SET %s WHERE %s' % (table_name, ','.join([ '%s=%%s ' % x for x in update]), 'AND '.join([ '%s=%%s ' % x for x in where])) 
        return str(sql)
        
    def _description(self, table_name):
        meta = {}
        cursor = self.cursor
        if isinstance(table_name, unicode):
            table_name = table_name.encode('ms932')
        indexes = self.introspection_module.get_indexes(cursor, table_name)
        for i, row in enumerate(self.introspection_module.get_table_description(cursor, table_name)):
            name = row[0]
            type = self.introspection_module.DATA_TYPES_REVERSE[row[1]]
            isPk = False
            if name in indexes:
                if indexes[name]['primary_key']:
                    isPk = True                    
            meta[name] =  type, isPk
        return meta   

    def _get_object(self, type, value):
        if type == 'DateTimeField' :
            return value
        elif type == 'DateField' :
            return value
        elif type == 'CharField' :
            return str(value)
        elif type == 'IntegerField' :
            return int(value)
        elif type == 'FloatField' :
            return float(value)
        else :
            return value         

ふつーにunittest.TestCaseを継承。
中身は汚いです。ヘタレなので
engineとconnectionをオーバライドする必要アリ。
engineは要らないかと思ったけどやっぱキツイ。
サンプルは下

dataset.py
from pyExcelerator import *
import sys

class DataTable(object):pass
    
class XlsReader(object):
    def __init__(self, file):
        self.file = file;
        self.dataset = []
        for sheet_name, sheet_values in parse_xls(file): 
            self.create_table(sheet_name, sheet_values)
        
    def create_table(self , sheet_name, sheet_values):
        table = DataTable();
        table.name = sheet_name
        table.header = []
        table.data = []
        self.dataset.append(table)
        for row_idx, col_idx in sorted(sheet_values.keys()):
            v = sheet_values[(row_idx, col_idx)]
            if isinstance(v, unicode):
                v = v.encode('utf8')
            if row_idx == 0:
                table.header.append(v)
            else :
                if col_idx == 0:
                    data = []
                    data.append(v)
                    table.data.append(data)
                else:
                    table.data[-1].append(v)

Excel読み込み系モジュール。
DataTableは・・・そのうち拡張するかなあと思ってたんですけどいらんかも。
pyExceleratorを必須です。
w32comでも良かったんですがこちらを使えという天の声が・・・・

dbunit/db/postgres/introspection.py
def quote_name(name):
    if name.startswith('"') and name.endswith('"'):
        return name 
    return '"%s"' % name

def get_table_description(cursor, table_name):
    cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
    return cursor.description

def get_indexes(cursor, table_name):
    cursor.execute("""
        SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
        FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
            pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
        WHERE c.oid = idx.indrelid
            AND idx.indexrelid = c2.oid
            AND attr.attrelid = c.oid
            AND attr.attnum = idx.indkey[0]
            AND c.relname = %s""", [table_name])
    indexes = {}
    for row in cursor.fetchall():
        if ' ' in row[1]:
            continue
        indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
    return indexes

DATA_TYPES_REVERSE = {
    16: 'BooleanField', 
    21: 'SmallIntegerField', 
    23: 'IntegerField', 
    25: 'TextField', 
    869: 'IPAddressField', 
    1043: 'CharField', 
    1082: 'DateField', 
    1083: 'TimeField', 
    1114: 'DateTimeField', 
    1184: 'DateTimeField', 
    1266: 'TimeField', 
    1700: 'FloatField', 
}

とりあえずサンプルでpostgres。
どっかでみたなあ・・・というマニアックなツッコミ歓迎。
dbunit/db//introspection.py
を実装すれば他のDBも使えます。
実装するのは

  • テーブルのカラムなどの情報を取得するためのget_table_description
  • テーブルPK情報を取得するためのget_indexes
  • DBごとの型変換用のdict DATA_TYPES_REVERSE

です。

sampletest.py
import dbunit
import unittest

class DummyTest(dbunit.DBUnitTestCase):

    def engine(self):
        return 'postgresql'
    
    def connection(self):
        import psycopg as Database
        conn_string = "dbname=%s" % "test"
        conn_string = "user=%s %s" % ("postgres", conn_string)
        conn_string += " password='%s'" % "postgres"
        conn_string += " host=%s" % "localhost"
        conn_string += " port=%s" % "5432"
        return Database.connect(conn_string)

    def testAAA(self):
        self.read_xls_write_db('C:\sample.xls')
        #TestCode

    def testBBB(self):
        self.read_xls_replace_db('C:\sample.xls')
        #TestCode
        
if __name__=="__main__":
    unittest.main();

とまあこんな感じで書けます。
すげー適当だけどなんか動くっぽい。
あとは多分

  • DataSetのassert
  • list→DataSet変換
  • object→DataSet変換

を考えないといかんですな。結果を簡単にassertできないと。
あと実はpyExceleratorで値を取る時っつーかExcelの書き方なんだけどその辺でへんてこな値が返ってきたり。
書式重要。
とりあえず書いてみてpythonだと短く書けそうだと思ってしまいそこに考えがとられがち。あとpydevのリファクタリングが意外に使えるということがわかった。
とりあえず参考になった方はコメント下さい。
うくく。