◐ Shell
reader mode source ↗
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
File filter
Conversations
Jump to
Diff view
Apply and reload
Show whitespace
Diff view
Apply and reload
64 changes: 19 additions & 45 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,11 @@ def to_df(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
features_df = self._to_df_internal(timeout=timeout)

if self.on_demand_feature_views:
# TODO(adchia): Fix requirement to specify dependent feature views in feature_refs
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
"Dataset validation is an experimental feature. "
"This API is unstable and it could and most probably will be changed in the future. "
"We do not guarantee that future changes will maintain backward compatibility.",
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return features_df

def to_arrow(
self,
Expand All @@ -122,23 +97,20 @@ def to_arrow(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
if not self.on_demand_feature_views and not validation_reference:
return self._to_arrow_internal(timeout=timeout)

features_df = self._to_df_internal(timeout=timeout)
if self.on_demand_feature_views:
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
Expand All @@ -148,11 +120,13 @@ def to_arrow(
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return pyarrow.Table.from_pandas(features_df)

def to_sql(self) -> str:
"""
Expand Down
14 changes: 14 additions & 0 deletions sdk/python/feast/transformation/pandas_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dill
import pandas as pd

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -26,6 +27,19 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform(self, input_df: pd.DataFrame) -> pd.DataFrame:
if not isinstance(input_df, pd.DataFrame):
raise TypeError(
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/transformation/python_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

import dill

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -24,6 +25,11 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform(self, input_dict: Dict) -> Dict:
if not isinstance(input_dict, Dict):
raise TypeError(
Expand Down
9 changes: 9 additions & 0 deletions sdk/python/feast/transformation/substrait_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def table_provider(names, schema: pyarrow.Schema):
).read_all()
return table.to_pandas()

def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
df = pd.DataFrame.from_dict(random_input)
output_df: pd.DataFrame = self.transform(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_to_sql():

@pytest.mark.parametrize("timeout", (None, 30))
def test_to_df_timeout(retrieval_job, timeout: Optional[int]):
with patch.object(retrieval_job, "_to_df_internal") as mock_to_df_internal:
retrieval_job.to_df(timeout=timeout)
mock_to_df_internal.assert_called_once_with(timeout=timeout)

Expand Down
Toggle all file notes Toggle all file annotations