◐ Shell
clean mode source ↗

feat: Add online_read_async for dynamodb by robhowley · Pull Request #4244 · feast-dev/feast

Expand Up @@ -33,6 +33,8 @@
try: import boto3 from aiobotocore import session from boto3.dynamodb.types import TypeDeserializer from botocore.config import Config from botocore.exceptions import ClientError except ImportError as e: Expand Down Expand Up @@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):
_dynamodb_client = None _dynamodb_resource = None _aioboto_session = None
def update( self, Expand Down Expand Up @@ -223,69 +226,103 @@ def online_read( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource( online_config.region, online_config.endpoint_url ) table_instance = dynamodb_resource.Table( _get_table_name(online_config, config, table) )
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] entity_ids = [ compute_entity_id( entity_key, entity_key_serialization_version=config.entity_key_serialization_version, ) for entity_key in entity_keys ] batch_size = online_config.batch_size entity_ids = self._to_entity_ids(config, entity_keys) entity_ids_iter = iter(entity_ids) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
while True: batch = list(itertools.islice(entity_ids_iter, batch_size)) batch_result: List[ Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] ] = []
# No more items to insert if len(batch) == 0: break batch_entity_ids = { table_instance.name: { "Keys": [{"entity_id": entity_id} for entity_id in batch], "ConsistentRead": online_config.consistent_reads, } } batch_entity_ids = self._to_resource_batch_get_payload( online_config, table_instance.name, batch ) response = dynamodb_resource.batch_get_item( RequestItems=batch_entity_ids, ) response = response.get("Responses") table_responses = response.get(table_instance.name) if table_responses: table_responses = self._sort_dynamodb_response( table_responses, entity_ids ) entity_idx = 0 for tbl_res in table_responses: entity_id = tbl_res["entity_id"] while entity_id != batch[entity_idx]: batch_result.append((None, None)) entity_idx += 1 res = {} for feature_name, value_bin in tbl_res["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val batch_result.append( (datetime.fromisoformat(tbl_res["event_ts"]), res) ) entity_idx += 1
# Not all entities in a batch may have responses # Pad with remaining values in batch that were not found batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) batch_result.extend(batch_size_nones) batch_result = self._process_batch_get_response( table_instance.name, response, entity_ids, batch ) result.extend(batch_result) return result
async def online_read_async( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Reads features values for the given entity keys asynchronously.
Args: config: The config for the current feature store. table: The feature view whose feature values should be read. entity_keys: The list of entity keys for which feature values should be read. requested_features: The list of features that should be read.
Returns: A list of the same length as entity_keys. Each item in the list is a tuple where the first item is the event timestamp for the row, and the second item is a dict mapping feature names to values, which are returned in proto format. """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig)
batch_size = online_config.batch_size entity_ids = self._to_entity_ids(config, entity_keys) entity_ids_iter = iter(entity_ids) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] table_name = _get_table_name(online_config, config, table)
deserialize = TypeDeserializer().deserialize
def to_tbl_resp(raw_client_response): return { "entity_id": deserialize(raw_client_response["entity_id"]), "event_ts": deserialize(raw_client_response["event_ts"]), "values": deserialize(raw_client_response["values"]), }
async with self._get_aiodynamodb_client(online_config.region) as client: while True: batch = list(itertools.islice(entity_ids_iter, batch_size))
# No more items to insert if len(batch) == 0: break batch_entity_ids = self._to_client_batch_get_payload( online_config, table_name, batch ) response = await client.batch_get_item( RequestItems=batch_entity_ids, ) batch_result = self._process_batch_get_response( table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp ) result.extend(batch_result) return result
def _get_aioboto_session(self): if self._aioboto_session is None: self._aioboto_session = session.get_session() return self._aioboto_session
def _get_aiodynamodb_client(self, region: str): return self._get_aioboto_session().create_client("dynamodb", region_name=region)
def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): if self._dynamodb_client is None: self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url) Expand All @@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None ) return self._dynamodb_resource
def _sort_dynamodb_response(self, responses: list, order: list) -> Any: def _sort_dynamodb_response( self, responses: list, order: list, to_tbl_response: Callable = lambda raw_dict: raw_dict, ) -> Any: """DynamoDB Batch Get Item doesn't return items in a particular order.""" # Assign an index to order order_with_index = {value: idx for idx, value in enumerate(order)} # Sort table responses by index table_responses_ordered: Any = [ (order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses (order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in map(to_tbl_response, responses) ] table_responses_ordered = sorted( table_responses_ordered, key=lambda tup: tup[0] Expand Down Expand Up @@ -341,6 +384,64 @@ def _write_batch_non_duplicates( if progress: progress(1)
def _process_batch_get_response( self, table_name, response, entity_ids, batch, **sort_kwargs ): response = response.get("Responses") table_responses = response.get(table_name)
batch_result = [] if table_responses: table_responses = self._sort_dynamodb_response( table_responses, entity_ids, **sort_kwargs ) entity_idx = 0 for tbl_res in table_responses: entity_id = tbl_res["entity_id"] while entity_id != batch[entity_idx]: batch_result.append((None, None)) entity_idx += 1 res = {} for feature_name, value_bin in tbl_res["values"].items(): val = ValueProto() val.ParseFromString(value_bin.value) res[feature_name] = val batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) entity_idx += 1 # Not all entities in a batch may have responses # Pad with remaining values in batch that were not found batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) batch_result.extend(batch_size_nones) return batch_result
@staticmethod def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]): return [ compute_entity_id( entity_key, entity_key_serialization_version=config.entity_key_serialization_version, ) for entity_key in entity_keys ]
@staticmethod def _to_resource_batch_get_payload(online_config, table_name, batch): return { table_name: { "Keys": [{"entity_id": entity_id} for entity_id in batch], "ConsistentRead": online_config.consistent_reads, } }
@staticmethod def _to_client_batch_get_payload(online_config, table_name, batch): return { table_name: { "Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch], "ConsistentRead": online_config.consistent_reads, } }

def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None): return boto3.client( Expand Down