◐ Shell
reader mode source ↗
Skip to content
Merged
Show file tree
Changes from all commits
File filter
Conversations
Jump to
Diff view
Apply and reload
Show whitespace
Diff view
Apply and reload
191 changes: 146 additions & 45 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

try:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
except ImportError as e:
Expand Down Expand Up @@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):

_dynamodb_client = None
_dynamodb_resource = None

def update(
self,
Expand Down Expand Up @@ -223,69 +226,103 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids_iter = iter(entity_ids)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
batch_result: List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = []
# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_instance.name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
)
response = response.get("Responses")
table_responses = response.get(table_instance.name)
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append(
(datetime.fromisoformat(tbl_res["event_ts"]), res)
)
entity_idx += 1

# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
result.extend(batch_result)
return result

def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_client is None:
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
Expand All @@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list) -> Any:
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered: Any = [
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
Expand Down Expand Up @@ -341,6 +384,64 @@ def _write_batch_non_duplicates(
if progress:
progress(1)


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client(
Expand Down
Loading
Toggle all file notes Toggle all file annotations