Source code for pandagg.response

from __future__ import annotations

import copy
import dataclasses
from typing_extensions import Literal, TypedDict
from typing import (
    Iterator,
    Optional,
    List,
    TYPE_CHECKING,
    Dict,
    Tuple,
    Union,
    overload,
    Any,
)

from elasticsearch import Elasticsearch
from lighttree.node import NodeId

from pandagg.query import Query
from pandagg.aggs import Aggs, Composite
from pandagg.node.aggs.abstract import UniqueBucketAgg, MetricAgg, Root, AggClause
from pandagg.node.aggs.bucket import Nested, ReverseNested
from pandagg.types import (
    HitDict,
    HitsDict,
    DocSource,
    TotalDict,
    SearchResponseDict,
    ShardsDict,
    AggregationsResponseDict,
    AggName,
    ProfileDict,
    BucketDict,
    BucketKey,
    CompositeBucketKey,
    BucketKeyAtom,
)

if TYPE_CHECKING:
    import pandas as pd
    from pandagg.search import Search
    from pandagg import DocumentMeta
    from pandagg.document import DocumentSource


GroupingKeysDict = Dict[AggName, BucketKeyAtom]
GroupingKeysTuple = Tuple[BucketKeyAtom, ...]

RowValues = Dict[str, Any]
# dictionary containing both grouping keys and values
Row = Dict[str, Any]


[docs]class NormalizedBucketDict(TypedDict, total=False): level: AggName key: BucketKey value: Any # children are themselves NormalizedBucketDict, but mypy doesn't support recursive types children: List[Any]
[docs]@dataclasses.dataclass class Hit: data: HitDict _document_class: Optional[DocumentMeta] @property def _source(self) -> Optional[Union[DocSource, DocumentSource]]: source = self.data.get("_source") if self._document_class is not None: return self._document_class._from_dict_(source) # type: ignore return source @property def _score(self) -> Optional[float]: return self.data.get("_score") @property def _id(self) -> Optional[str]: return self.data.get("_id") @property def _index(self) -> Optional[str]: return self.data.get("_index") def __repr__(self) -> str: if self._score is None: return "<Hit %s>" % self._id return "<Hit %s> score=%.2f" % (self._id, self._score)
[docs]@dataclasses.dataclass class Hits: data: Optional[HitsDict] _document_class: Optional[DocumentMeta] @property def total(self) -> Optional[TotalDict]: return self.data.get("total") if self.data else None @property def hits(self) -> List[Hit]: return ( [ Hit(hit, _document_class=self._document_class) for hit in self.data.get("hits", []) ] if self.data else [] ) @property def max_score(self) -> Optional[float]: return self.data.get("max_score") if self.data else None def __len__(self) -> int: return len(self.hits) def __iter__(self) -> Iterator[Hit]: return iter(self.hits) def _total_repr(self) -> str: if self.total is None: return 'Unknown total (probably filtered by "filter_path")' if self.total.get("relation") == "eq": return str(self.total["value"]) if self.total.get("relation") == "gte": return ">=%d" % self.total["value"] raise ValueError("Invalid total %s" % self.total)
[docs] def to_dataframe( self, expand_source: bool = True, source_only: bool = True ) -> pd.DataFrame: """ Return hits as pandas dataframe. Requires pandas dependency. :param expand_source: if True, `_source` sub-fields are expanded as columns :param source_only: if True, doesn't include hit metadata (except id which is used as dataframe index) """ try: import pandas as pd except ImportError: raise ImportError( 'Using dataframe output format requires to install pandas. Please install "pandas" or ' "use another output format." ) hits: List[HitDict] = self.data.get("hits", []) if self.data else [] if not hits: return pd.DataFrame() if not expand_source: return pd.DataFrame(hits).set_index("_id") flattened_hits: List[DocSource] = [] hit: HitDict for hit in hits: hit_metadata: HitDict = hit.copy() hit_source: DocSource = hit_metadata.pop("_source") if source_only: hit_source["_id"] = hit_metadata["_id"] else: hit_source.update(hit_metadata) flattened_hits.append(hit_source) return pd.DataFrame(flattened_hits).set_index("_id")
def __repr__(self) -> str: if not isinstance(self.total, dict): total_repr = str(self.total) elif self.total.get("relation") == "eq": total_repr = str(self.total["value"]) elif self.total.get("relation") == "gte": total_repr = ">%d" % self.total["value"] else: raise ValueError("Invalid total %s" % self.total) return "<Hits> total: %s, contains %d hits" % (total_repr, len(self.hits))
[docs]@dataclasses.dataclass class SearchResponse: data: SearchResponseDict _search: Search @property def took(self) -> Optional[int]: return self.data.get("took") @property def timed_out(self) -> Optional[int]: return self.data.get("timed_out") @property def _shards(self) -> Optional[ShardsDict]: return self.data.get("_shards") @property def hits(self) -> Hits: return Hits( data=self.data.get("hits"), _document_class=self._search._document_class ) @property def aggregations(self) -> Aggregations: return Aggregations(self.data.get("aggregations", {}), _search=self._search) @property def profile(self) -> Optional[ProfileDict]: return self.data.get("profile") def __iter__(self) -> Iterator[Hit]: return iter(self.hits) @property def success(self) -> bool: if ( self._shards is None or self._shards.get("total") is None or self._shards.get("successful") is None ): # if total result is filtered by 'filter_path', ignore return False return ( self._shards["total"] == self._shards["successful"] and not self.timed_out ) def __len__(self) -> int: return len(self.hits) def __repr__(self) -> str: return ( "<Response> took %sms, success: %s, total result %s, contains %s hits" % (self.took, self.success, self.hits._total_repr(), len(self.hits)) )
[docs]@dataclasses.dataclass class Aggregations: data: AggregationsResponseDict _search: Search @property def _aggs(self) -> Aggs: return self._search._aggs @property def _query(self) -> Query: return self._search._query @property def _client(self) -> Optional[Elasticsearch]: return self._search._using @property def _index(self) -> Optional[List[str]]: return self._search._index
[docs] def keys(self) -> List[AggName]: return list(self.data.keys())
@overload def parse_group_by( self, *, response: AggregationsResponseDict, until: Optional[AggName], with_single_bucket_groups: bool = False, row_as_tuple: Literal[True], ) -> Tuple[List[AggName], List[Tuple[GroupingKeysTuple, BucketDict]]]: ... @overload def parse_group_by( self, *, response: AggregationsResponseDict, until: Optional[AggName], with_single_bucket_groups: bool = False, row_as_tuple: Literal[False], ) -> Tuple[List[AggName], List[Tuple[GroupingKeysDict, BucketDict]]]: ...
[docs] def parse_group_by( self, *, response: AggregationsResponseDict, until: Optional[AggName], with_single_bucket_groups: bool = False, row_as_tuple: bool = False, ) -> Tuple[ List[AggName], Union[ List[Tuple[GroupingKeysTuple, BucketDict]], List[Tuple[GroupingKeysDict, BucketDict]], ], ]: if not until: index_names_: List[AggName] = [] if row_as_tuple: r_: List[Tuple[GroupingKeysTuple, BucketDict]] = [(tuple(), response)] return index_names_, r_ r__: GroupingKeysDict = {} return index_names_, [(r__, response)] # initialization: cache ancestors once for faster computation until_id: NodeId = self._aggs.id_from_key(until) # remove root (not an aggregation clause), ignore type warning about key None (since only root # can have a None key) ancestors: List[Tuple[AggName, AggClause]] = [ (k, n) # type: ignore for k, n in self._aggs.ancestors( until_id, include_current=True, from_root=True )[1:] ] if not ancestors: index_names__: List[AggName] = [] if row_as_tuple: r___: List[Tuple[GroupingKeysTuple, BucketDict]] = [(tuple(), response)] return index_names__, r___ r____: GroupingKeysDict = {} return index_names__, [(r____, response)] # from root aggregation to deepest aggregation clause index_names: List[AggName] = [] for name, a in ancestors: if isinstance(a, UniqueBucketAgg) and not with_single_bucket_groups: continue if isinstance(a, Composite): # a composite aggregation can generate multiple grouping columns index_names.extend(a.source_names) else: index_names.append(name) first_agg_name: AggName first_agg_name, _ = ancestors[0] index_values: List[Tuple[GroupingKeysDict, BucketDict]] = list( self._parse_group_by( response=response, until=until, agg_clauses_per_name={k: a for k, a in ancestors}, agg_name=first_agg_name, row={}, with_single_bucket_groups=with_single_bucket_groups, ) ) if not row_as_tuple: return index_names, index_values values_: List[Tuple[GroupingKeysTuple, BucketDict]] = [ (tuple(grouping_row[index_name] for index_name in index_names), raw_bucket) for grouping_row, raw_bucket in index_values ] return index_names, values_
def _parse_group_by( self, response: AggregationsResponseDict, until: AggName, row: GroupingKeysDict, agg_name: AggName, agg_clauses_per_name: Dict[AggName, AggClause], with_single_bucket_groups: bool, ) -> Iterator[Tuple[GroupingKeysDict, BucketDict]]: """ Recursive parsing of succession of grouping aggregation clauses. Yields each row for which last bucket aggregation generated buckets. """ if agg_name not in response: return None agg_node = agg_clauses_per_name[agg_name] key: BucketKey raw_bucket: BucketDict for key, raw_bucket in agg_node.extract_buckets(response[agg_name]): sub_row: GroupingKeysDict = copy.copy(row) if not isinstance(agg_node, UniqueBucketAgg) or with_single_bucket_groups: if isinstance(agg_node, Composite): key_: CompositeBucketKey = key # type: ignore for source_name in agg_node.source_names: sub_row[source_name] = key_[source_name] else: key__: BucketKeyAtom = key # type: ignore sub_row[agg_name] = key__ if agg_name == until: # end real yield yield sub_row, raw_bucket elif agg_name in agg_clauses_per_name.keys(): # yield children child_name: AggName for child_name, _ in self._aggs.children( # type: ignore agg_node.identifier ): for nrow, nraw_bucket in self._parse_group_by( row=sub_row, response=raw_bucket, agg_name=child_name, until=until, agg_clauses_per_name=agg_clauses_per_name, with_single_bucket_groups=with_single_bucket_groups, ): yield nrow, nraw_bucket def _normalize_buckets( self, agg_response: AggregationsResponseDict, agg_name: AggName ) -> Iterator[NormalizedBucketDict]: """ Recursive function to parse aggregation response as a normalized entities. Each response bucket is represented as a dict with keys (key, level, value, children):: { "level": "owner.id", "key": 35, "value": 235, "children": [ ] } """ id_: NodeId = self._aggs.id_from_key(agg_name) _, agg_node = self._aggs.get(id_) agg_children = self._aggs.children(id_) key: BucketKey raw_bucket: BucketDict for key, raw_bucket in agg_node.extract_buckets(agg_response[agg_name]): result: NormalizedBucketDict = { "level": agg_name, "key": key, "value": agg_node.extract_bucket_value(raw_bucket), } child_key: AggName normalized_children: List[NormalizedBucketDict] = [ normalized_child for child_key, child in agg_children for normalized_child in self._normalize_buckets( # ignore warning about child_key not being necessarily a AggName (str), it is agg_name=child_key, # type: ignore agg_response=raw_bucket, ) ] if normalized_children: result["children"] = normalized_children yield result def _grouping_agg( self, name: Optional[AggName] = None ) -> Tuple[Optional[AggName], AggClause]: """ Return aggregation node that used as grouping node. Note: in case there is only a nested aggregation below that node, group-by nested clause. """ key: str if name is not None: # override existing groupby_ptr id_ = self._aggs.id_from_key(name) if not self._aggs._is_eligible_grouping_node(id_): raise ValueError( "Cannot group by <%s>, not a valid grouping aggregation" % name ) key, node = self._aggs.get(id_) # type: ignore else: key, node = self._aggs.get(self._aggs._groupby_ptr) # type: ignore # if parent of single nested clause and nested_autocorrect if self._aggs.nested_autocorrect: children = self._aggs.children(node.identifier) if len(children) == 1: child_key: str child_key, child_node = children[0] # type: ignore if isinstance(child_node, (Nested, ReverseNested)): return child_key, child_node return key, node @overload def to_tabular( self, *, index_orient: Literal[True] = True, grouped_by: Optional[AggName] = None, expand_columns: bool = True, expand_sep: str = "|", normalize: bool = True, with_single_bucket_groups: bool = False, ) -> Tuple[List[AggName], Dict[GroupingKeysTuple, RowValues]]: ... @overload def to_tabular( self, *, index_orient: Literal[False], grouped_by: Optional[AggName] = None, expand_columns: bool = True, expand_sep: str = "|", normalize: bool = True, with_single_bucket_groups: bool = False, ) -> Tuple[List[AggName], List[Row]]: ...
[docs] def to_tabular( self, *, index_orient: bool = True, grouped_by: Optional[AggName] = None, expand_columns: bool = True, expand_sep: str = "|", normalize: bool = True, with_single_bucket_groups: bool = False, ) -> Tuple[List[AggName], Union[Dict[GroupingKeysTuple, RowValues], List[Row]]]: """ Build tabular view of ES response grouping levels (rows) until 'grouped_by' aggregation node included is reached, and using children aggregations of grouping level as values for each of generated groups (columns). Suppose an aggregation of this shape (A & B bucket aggregations):: A──> B──> C1 ├──> C2 └──> C3 With grouped_by='B', breakdown ElasticSearch response (tree structure), into a tabular structure of this shape:: C1 C2 C3 A B wood blue 10 4 0 red 7 5 2 steel blue 1 9 0 red 23 4 2 :param index_orient: if True, level-key samples are returned as tuples, else in a dictionary :param grouped_by: name of the aggregation node used as last grouping level :param normalize: if True, normalize columns buckets :return: index_names, values """ grouping_agg_name, grouping_agg = self._grouping_agg(grouped_by) index_names: List[AggName] if index_orient: index_values: List[Tuple[GroupingKeysTuple, BucketDict]] index_names, index_values = self.parse_group_by( response=self.data, until=grouping_agg_name, with_single_bucket_groups=with_single_bucket_groups, row_as_tuple=True, ) rows: Dict[GroupingKeysTuple, Row] = { row_index: self._serialize_columns( row_raw_data=row_raw_data, normalize=normalize, total_agg=grouping_agg, expand_columns=expand_columns, expand_sep=expand_sep, ) for row_index, row_raw_data in index_values } return index_names, rows index_values_: List[Tuple[GroupingKeysDict, BucketDict]] index_names, index_values_ = self.parse_group_by( response=self.data, until=grouping_agg_name, with_single_bucket_groups=with_single_bucket_groups, row_as_tuple=False, ) rows_ = [ dict( row_index, **self._serialize_columns( row_raw_data=row_raw_data, normalize=normalize, total_agg=grouping_agg, expand_columns=expand_columns, expand_sep=expand_sep, ), ) for row_index, row_raw_data in index_values_ ] return index_names, rows_
def _serialize_columns( self, row_raw_data: BucketDict, normalize: bool, expand_columns: bool, expand_sep: str, total_agg: AggClause, ) -> RowValues: # extract value (usually 'doc_count') of grouping agg node result: RowValues = {} if not isinstance(total_agg, Root): result[total_agg.VALUE_ATTRS[0]] = total_agg.extract_bucket_value( row_raw_data ) # extract values of children, one columns per child child_key: AggName child: AggClause for child_key, child in self._aggs.children( # type: ignore total_agg.identifier ): if isinstance(child, (UniqueBucketAgg, MetricAgg)): result[child_key] = child.extract_bucket_value(row_raw_data[child_key]) elif expand_columns: for key, bucket in child.extract_buckets(row_raw_data[child_key]): result[ "%s%s%s" % (child_key, expand_sep, key) ] = child.extract_bucket_value(bucket) elif normalize: result[child_key] = next( self._normalize_buckets(row_raw_data, child_key), None ) else: result[child_key] = row_raw_data[child_key] return result
[docs] def to_dataframe( self, grouped_by: Optional[str] = None, normalize_children: bool = True, with_single_bucket_groups: bool = False, ) -> pd.DataFrame: try: import pandas as pd except ImportError: raise ImportError( 'Using dataframe output format requires to install pandas. Please install "pandas" or ' "use another output format." ) index_names: List[AggName] rows: Dict[GroupingKeysTuple, RowValues] index_names, rows = self.to_tabular( index_orient=True, grouped_by=grouped_by, normalize=normalize_children, with_single_bucket_groups=with_single_bucket_groups, ) if not rows: return pd.DataFrame() index: Tuple[GroupingKeysTuple, ...] values: Tuple[RowValues, ...] index, values = zip(*rows.items()) # empty index if len(index[0]) == 0: return pd.DataFrame(index=(None,) * len(values), data=list(values)) # single or multi-index return pd.DataFrame( index=pd.MultiIndex.from_tuples(index, names=index_names), data=list(values) ).sort_index()
[docs] def to_normalized(self) -> NormalizedBucketDict: children: List[NormalizedBucketDict] = [] for k in self.data.keys(): for child in self._normalize_buckets(self.data, k): children.append(child) return {"level": "root", "key": None, "value": None, "children": children}
def __repr__(self) -> str: if not self.keys(): return "<Aggregations> empty" return "<Aggregations> %s" % list(map(str, self.keys()))