Source code for invertedai_simulate.dataset.env_dataset

import time
from typing import Iterator
import logging
from torch.utils.data import IterableDataset
from typing import Dict
import pickle
from invertedai_simulate.dataset.db import CacheDataDB
from invertedai_simulate.dataset.models import DataLoader
from invertedai_simulate.interface import IAIEnv, ServerTimeoutError
import argparse
import gym

logger = logging.getLogger(__name__)


[docs]class IAIEnvDataset(IterableDataset): def __init__(self, config, scenario_name, world_parameters=None, vehicle_physics=None, scenario_parameters=None, sensors=None, env: IAIEnv = None): """ A iterable dataset that either takes or constructs an IAIEnv and output the information from each step from the IAIEnv environment """ self.config = config self.load_cache_from = config.load_cache_from self.save_cache_to = config.save_cache_to self.cache_in_memory = config.cache_in_memory if self.load_cache_from: self.cache_loading_db = CacheDataDB(self.load_cache_from) self.starting_index_offset = self.cache_loading_db.count_all() else: self.cache_loading_db = None self.starting_index_offset = 0 if self.load_cache_from == self.save_cache_to: self.cache_saving_db = self.cache_loading_db else: if self.save_cache_to: self.cache_saving_db = CacheDataDB(self.save_cache_to) else: self.cache_saving_db = None self.save_to_disk_interval = config.save_to_disk_interval self.saved_data = [] self.last_saved_at = time.time() self._env: IAIEnv = gym.make(config.env_name, config=config) if env is None else env self.set_env_scenario(scenario_name, world_parameters, vehicle_physics, scenario_parameters, sensors) super(IAIEnvDataset).__init__()
[docs] def set_env_scenario(self, scenario_name, world_parameters=None, vehicle_physics=None, scenario_parameters=None, sensors=None): return self._env.set_scenario(scenario_name, world_parameters, vehicle_physics, scenario_parameters, sensors)
def __iter__(self) -> Iterator[Dict]: done = False if self.config.enable_rendering: self._render_env() obs = self._reset_env() action = obs['prev_action'] while not done: obs, rewards, done, info = self._step_env(action) action = self.get_next_action(info) if self.config.obs_only: data = dict(obs=obs, rewards=None, info=None) else: data = dict(obs=obs) if self.cache_in_memory: self.saved_data.append(data) if self.cache_saving_db: if (time.time() - self.last_saved_at) % 60 >= self.save_to_disk_interval: pickled_obs, pickled_reward, pickled_info = pickle.dumps(obs), pickle.dumps(rewards), pickle.dumps( info) self.cache_saving_db.insert_cached_object([(pickled_obs, pickled_reward, pickled_info)]) self.last_saved_at = time.time() yield data def __getitem__(self, index) -> Dict: if self.cache_in_memory: if self.load_cache_from: if index >= self.starting_index_offset: return self.saved_data[index - self.starting_index_offset] else: return self.build_data_from_queried_result(self.cache_loading_db.load_data_by_idx(index + 1)) else: return self.saved_data[index] else: return self.build_data_from_queried_result(self.cache_loading_db.load_data_by_idx(index + 1)) def __exit__(self): self.close()
[docs] @staticmethod def build_data_from_queried_result(result: DataLoader) -> Dict: obs, rewards, info = result.OBS, result.REWARDS, result.INFO return dict(obs=pickle.loads(obs), info=pickle.loads(obs), rewards=pickle.loads(rewards))
[docs] def close(self): self._close_env()
[docs] @staticmethod def get_next_action(info): return info['expert_action']
def _step_env(self, action): return self._env.step(action) def _reset_env(self): return self._env.reset() def _close_env(self): self._env.close() def _render_env(self): self._env.render() def _seed_env(self, seed=None): self._env.seed(seed)
[docs] def pre_load_data(self, limit=None): for i, _ in enumerate(self): if limit and i + 1 == limit: break
[docs] @staticmethod def add_config(parser: argparse.ArgumentParser) -> None: IAIEnv.add_config(parser) parser.add_argument('--enable_rendering', type=int, default=0) parser.add_argument('--obs_only', type=int, default=0) parser.add_argument('--cache_in_memory', type=int, default=1) parser.add_argument('--load_cache_from', type=str, default='') parser.add_argument('--save_cache_to', type=str, default='') parser.add_argument('--save_to_disk_interval', type=int, default=15, help='The frequency we save to disk in' 'minute') parser.add_argument('--env_name', type=str, default='iai/GenericEnv-v0', help='The gym environment name to ' 'build if env is not provided' 'in __init__')