import json
import logging
import os
import sqlite3
from functools import wraps
from biothings.utils.common import find_value_in_doc, json_serial
from biothings.utils.dataload import update_dict_recur
from biothings.utils.dotfield import parse_dot_fields
from biothings.utils.hub_db import IDatabase
from biothings.utils.serializer import json_loads
config = None
logger = logging.getLogger(__name__)
[docs]
def requires_config(func):
@wraps(func)
def func_wrapper(*args, **kwargs):
global config
if not config:
try:
from biothings import config as config_mod
config = config_mod
except ImportError:
raise Exception("call biothings.config_for_app() first")
return func(*args, **kwargs)
return func_wrapper
[docs]
@requires_config
def get_hub_db_conn():
conn = DatabaseClient()
return conn
[docs]
@requires_config
def get_src_conn():
conn = get_hub_db_conn()
return conn
[docs]
@requires_config
def get_src_db():
conn = get_src_conn()
return conn[config.DATA_SRC_DATABASE]
[docs]
@requires_config
def get_src_dump():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][getattr(config, "DATA_SRC_DUMP_COLLECTION", "src_dump")]
[docs]
@requires_config
def get_src_master():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.DATA_SRC_MASTER_COLLECTION]
[docs]
def get_src_build():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.DATA_SRC_BUILD_COLLECTION]
[docs]
def get_src_build_config():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.DATA_SRC_BUILD_COLLECTION + "_config"]
[docs]
def get_data_plugin():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.DATA_PLUGIN_COLLECTION]
[docs]
def get_api():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.API_COLLECTION]
[docs]
def get_cmd():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][config.CMD_COLLECTION]
[docs]
def get_event():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][getattr(config, "EVENT_COLLECTION", "event")]
[docs]
def get_hub_config():
conn = get_hub_db_conn()
return conn[config.DATA_HUB_DB_DATABASE][getattr(config, "HUB_CONFIG_COLLECTION", "hub_config")]
[docs]
def get_last_command():
try:
db = get_cmd()
res = db.get_conn().execute("SELECT MAX(_id) FROM cmd").fetchall()
assert res[0][0], "No command ID found, bootstrap ?"
return {"_id": res[0][0]}
except Exception:
return {"_id": 1}
[docs]
def get_source_fullname(col_name):
"""
Assuming col_name is a collection created from an upload process,
find the main source & sub_source associated.
"""
src_dump = get_src_dump()
info = None
for doc in src_dump.find():
if col_name in doc.get("upload", {}).get("jobs", {}).keys():
info = doc
if info:
name = info["_id"]
if name != col_name:
# col_name was a sub-source name
return "%s.%s" % (name, col_name)
else:
return name
[docs]
class Database(IDatabase):
def __init__(self, db_folder, name=None):
super(Database, self).__init__()
if not name:
self.name = config.DATA_HUB_DB_DATABASE
else:
self.name = name
self.dbfile = os.path.join(db_folder, self.name)
self.cols = {}
@property
def address(self):
return self.dbfile
[docs]
def get_conn(self):
return sqlite3.connect(self.dbfile)
[docs]
def collection_names(self):
tables = self.get_conn().execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
return [name[0] for name in tables]
[docs]
def create_collection(self, colname):
return self[colname]
[docs]
def create_if_needed(self, table):
existings = [
tname[0]
for tname in self.get_conn()
.execute("SELECT name FROM sqlite_master WHERE type='table' and " + "name = ?", (table,))
.fetchall()
]
if table not in existings:
# TODO: injection...
self.get_conn().execute("CREATE TABLE %s (_id TEXT PRIMARY KEY, document TEXT)" % table).fetchone()
def __getitem__(self, colname):
if colname not in self.cols:
self.create_if_needed(colname)
self.cols[colname] = Collection(colname, self)
return self.cols[colname]
[docs]
class DatabaseClient(IDatabase):
def __init__(self):
super().__init__()
self.sqlite_db_folder = config.HUB_DB_BACKEND["sqlite_db_folder"]
if not os.path.exists(self.sqlite_db_folder):
os.makedirs(self.sqlite_db_folder)
self.name = None
self.dbfile = None
self.cols = {}
def __getitem__(self, name):
return Database(self.sqlite_db_folder, name)
[docs]
class Collection(object):
def __init__(self, colname, db):
self.colname = colname
self.db = db
[docs]
def get_conn(self):
return sqlite3.connect(self.db.dbfile)
@property
def name(self):
return self.colname
@property
def database(self):
return self.db
[docs]
def find_one(self, *args, **kwargs):
if args and len(args) == 1 and isinstance(args[0], dict):
if len(args[0]) == 1 and "_id" in args[0]:
strdoc = (
self.get_conn()
.execute("SELECT document FROM %s WHERE _id = ?" % self.colname, (args[0]["_id"],))
.fetchone()
)
if strdoc:
return json_loads(strdoc[0])
else:
return None
else:
return self.find(*args, find_one=True)
elif args or kwargs:
raise NotImplementedError("find(): %s %s" % (repr(args), repr(kwargs)))
else:
return self.find(find_one=True)
[docs]
def findv2(self, *args, **kwargs):
"""This is a new version of find() that uses json feature of sqlite3, will replace find in the future"""
start = kwargs.get("start", 0)
limit = kwargs.get("limit", 10)
return_total = kwargs.get("return_total", False) # return (results, total) tuple if True, default False
return_list = kwargs.get("return_list", False) # return list instead of generator if True, default False
conn = self.get_conn()
tbl_name = self.colname
if args and len(args) == 1 and isinstance(args[0], dict) and len(args[0]) > 0:
# it's key/value search, args[0] like {"a.b": "test", "a.b.c", "value"}
sub_queries = []
for k, v in args[0].items():
if "*" in v or "?" in v:
_v = v.replace("*", "%").replace("?", "_")
_v = f"LIKE '{_v}'"
else:
_v = f"= '{v}'"
if k == "_id":
where = f"(_id {_v})"
sub_query = f"SELECT _id FROM {tbl_name} WHERE {where}"
elif "." in k:
# nested field name like a.b.c, we will use json_tree.fullkey to match
# Here is an example for the query {"object.symbol": "BRD1"}:
# SELECT document FROM TISSUES, json_tree(TISSUES.document)
# WHERE (json_tree.fullkey LIKE '$.%object%.%symbol%' AND json_tree.value = 'BRD1')
k = k.replace(".", "%.%")
k = f"$.%{k}%"
where = f"(json_tree.fullkey LIKE '{k}' AND json_tree.value {_v})"
sub_query = f"SELECT _id FROM {tbl_name}, json_tree({tbl_name}.document) WHERE {where}"
else:
# just a top level field, we will use json_each.key to match
# _v matches the value directly using LIKE or = (for a scalar field);
# _v2 matches the double-quoted value using LIKE (for an array field)
# Here is an example for the query {"ancestors": "CHEBI:75771"}:
# "SELECT _id FROM chebi, json_each(chebi.document)
# WHERE (json_each.key = 'ancestors' AND
# (json_each.value = 'CHEBI:75771' OR json_each.value LIKE '%"CHEBI:75771"%')
# )
_v2 = _v.replace("LIKE '", "LIKE '%\"").replace("= '", "LIKE '%\"")
_v2 = _v2[:-1] + "\"%'"
where = f"(json_each.key = '{k}' AND (json_each.value {_v} OR json_each.value {_v2}))"
sub_query = f"SELECT _id FROM {tbl_name}, json_each({tbl_name}.document) WHERE {where}"
sub_queries.append(sub_query)
if sub_queries:
if len(sub_queries) == 1:
query = sub_queries[0].replace("SELECT _id FROM", "SELECT document FROM")
else:
# JOIN multiple sub queries:
# Here is an example for the query: q=object.symbol:BRD1%20AND%20subject.id:BTO:0000017
# SELECT document FROM TISSUES WHERE _id IN
# (SELECT _id FROM
# (SELECT _id FROM TISSUES, json_tree(TISSUES.document)
# WHERE (json_tree.fullkey LIKE '$.%object%.%symbol%' AND json_tree.value = 'BRD1')
# ) AS subq0
# INNER JOIN
# (SELECT _id FROM TISSUES, json_tree(TISSUES.document)
# WHERE (json_tree.fullkey LIKE '$.%subject%.%id%' AND json_tree.value = 'BTO:0000017')
# ) AS subq1
# USING (_id)
# )
query = f"SELECT _id FROM ({sub_queries[0]}) AS subq0"
for i, sub_query in enumerate(sub_queries[1:]):
query += f" INNER JOIN ({sub_query}) AS subq{i+1} USING (_id)"
query = f"SELECT document FROM {tbl_name} WHERE _id IN ({query})"
elif not args or len(args) == 1 and len(args[0]) == 0:
# nothing or empty dict
query = f"SELECT document FROM {tbl_name}"
else:
raise NotImplementedError("find: args=%s kwargs=%s" % (repr(args), repr(kwargs)))
# include limit and offset
_query = query + f" LIMIT {limit} OFFSET {start}"
logger.debug('SQLite query: "%s"', _query)
results = (json_loads(doc[0]) for doc in conn.execute(_query)) # results is a generator
if return_list:
results = list(results)
if return_total:
# get total count without limit and offset
total = conn.execute(query.replace("SELECT document FROM", "SELECT COUNT(*) FROM")).fetchone()[0]
return results, total
else:
return results
[docs]
def find(self, *args, **kwargs):
results = []
if args and len(args) == 1 and isinstance(args[0], dict) and len(args[0]) > 0:
# it's key/value search, let's iterate
for doc in self.get_conn().execute("SELECT document FROM %s" % self.colname).fetchall():
found = []
doc = json_loads(doc[0])
for k, v in args[0].items():
_found = find_value_in_doc(k, v, doc)
found.append(_found)
if all(found):
if "find_one" in kwargs:
return doc
else:
results.append(doc)
if "limit" in kwargs:
start = kwargs.get("start", 0)
end = start + kwargs.get("limit", 0)
return results[start:end]
return results
elif not args or len(args) == 1 and len(args[0]) == 0:
# nothing or empty dict
results = [
json_loads(doc[0])
for doc in self.get_conn().execute("SELECT document FROM %s" % self.colname).fetchall()
]
if "limit" in kwargs:
start = kwargs.get("start", 0)
end = start + kwargs.get("limit", 0)
return results[start:end]
return results
else:
raise NotImplementedError("find: args=%s kwargs=%s" % (repr(args), repr(kwargs)))
[docs]
def insert_one(self, doc):
assert "_id" in doc
with self.get_conn() as conn:
conn.execute(
"INSERT INTO %s (_id,document) VALUES (?,?)" % self.colname,
(doc["_id"], json.dumps(doc, default=json_serial)),
).fetchone()
conn.commit()
[docs]
def insert(self, docs, *args, **kwargs):
with self.get_conn() as conn:
for doc in docs:
conn.execute(
"INSERT INTO %s (_id,document) VALUES (?,?)" % self.colname,
(doc["_id"], json.dumps(doc, default=json_serial)),
).fetchone()
conn.commit()
[docs]
def bulk_write(self, docs, *args, **kwargs):
doc_objs = [item._doc for item in docs]
self.insert(doc_objs, *args, **kwargs)
return Cursor(len(doc_objs))
[docs]
def update_one(self, query, what, upsert=False):
assert len(what) == 1 and (
"$set" in what or "$unset" in what or "$push" in what
), "$set/$unset/$push operators not found"
doc = self.find_one(query)
if doc:
if "$set" in what:
# parse_dot_fields uses json.dumps internally, we can to make
# sure everything is serializable first
what = json.loads(json.dumps(what, default=json_serial))
what = parse_dot_fields(what["$set"])
doc = update_dict_recur(doc, what)
elif "$unset" in what:
for keytounset in what["$unset"].keys():
doc.pop(keytounset, None)
elif "$push" in what:
for listkey, elem in what["$push"].items():
assert "." not in listkey, "$push not supported for nested keys: %s" % listkey
doc.setdefault(listkey, []).append(elem)
self.save(doc)
elif upsert:
assert "$set" in what
query.update(what["$set"])
self.save(query)
[docs]
def update(self, query, what, upsert=False):
docs = self.find(query)
for doc in docs:
self.update_one({"_id": doc["_id"]}, what, upsert)
[docs]
def save(self, doc):
if self.find_one({"_id": doc["_id"]}):
with self.get_conn() as conn:
conn.execute(
"UPDATE %s SET document = ? WHERE _id = ?" % self.colname,
(json.dumps(doc, default=json_serial), doc["_id"]),
)
conn.commit()
else:
self.insert_one(doc)
[docs]
def replace_one(self, query, doc, upsert=False):
assert "_id" in query
orig = self.find_one(query)
if orig:
orig["_id"] = query["_id"]
with self.get_conn() as conn:
conn.execute(
"UPDATE %s SET document = ? WHERE _id = ?" % self.colname,
(json.dumps(doc, default=json_serial), orig["_id"]),
)
conn.commit()
elif upsert:
doc["_id"] = query["_id"]
self.save(doc)
[docs]
def remove(self, query):
docs = self.find(query)
with self.get_conn() as conn:
for doc in docs:
conn.execute("DELETE FROM %s WHERE _id = ?" % self.colname, (doc["_id"],)).fetchone()
conn.commit()
[docs]
def rename(self, new_name, dropTarget=False):
with self.get_conn() as conn:
if dropTarget:
conn.execute(f"DROP TABLE IF EXISTS {new_name}")
conn.execute(f"ALTER TABLE {self.colname} RENAME TO {new_name}").fetchall()
[docs]
def count(self):
return self.get_conn().execute("SELECT count(_id) FROM %s" % self.colname).fetchone()[0]
[docs]
def drop(self):
self.get_conn().execute("DROP TABLE %s" % self.colname).fetchall()
def __getitem__(self, _id):
return self.find_one({"_id": _id})
def __getstate__(self):
self.__dict__.pop("db", None)
return self.__dict__
[docs]
class Cursor(object):
def __init__(self, inserted_count):
self.inserted_count = inserted_count