Source code for invertedai_simulate.dataset.db
from typing import List, Tuple
import os
from invertedai_simulate.dataset.models import DataLoader
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
[docs]class CacheDataDB:
"""
A class holds a sqlite engine session that is able to read from and write to the passed sql_file
"""
def __init__(self, sql_file):
self.sql_file = sql_file
if os.path.dirname(sql_file) and not os.path.exists(os.path.dirname(sql_file)):
os.makedirs(os.path.dirname(sql_file))
self.engine = create_engine(f'sqlite:///{sql_file}')
self.session = sessionmaker(self.engine)
self._init_database()
def _init_database(self):
DataLoader.__table__.create(self.engine, checkfirst=True)
[docs] def insert_cached_object(self, items: List[Tuple[bytes, bytes, bytes]]):
with self.session() as session:
for i, (obs, rewards, info) in enumerate(items):
data = DataLoader(OBS=obs, REWARDS=rewards, INFO=info)
session.add(data)
if i % 1000 == 0:
session.commit()
session.commit()
[docs] def load_data_by_idx(self, idx: int) -> DataLoader:
with self.session() as session:
result = session.query(DataLoader).filter(DataLoader.IDX == idx).first()
return result
[docs] def select_all(self, limit=1000, offset=0) -> List[DataLoader]:
with self.session() as session:
results = session.query(DataLoader).limit(limit).offset(offset).all()
return results
[docs] def count_all(self) -> int:
with self.session() as session:
results = session.query(DataLoader).count()
return results