from collections.abc import Iterable
from ppc_robot_lib.steps.abstract_step import AbstractStep
from ppc_robot_lib.tasks import TaskContextInterface, StepPerformance
from ppc_robot_lib.utils.types import JoinType
import pandas
import numpy
JoinSetSpec = str | list[str]
[docs]
class JoinOnColumnStep(AbstractStep):
"""
Performs an SQL-like join on two tables.
This step supports 4 types of joins:
* :py:attr:`ppc_robot_lib.utils.types.JoinType.INNER`
* :py:attr:`ppc_robot_lib.utils.types.JoinType.OUTER`
* :py:attr:`ppc_robot_lib.utils.types.JoinType.LEFT`
* :py:attr:`ppc_robot_lib.utils.types.JoinType.RIGHT`
You can join either by columns with same name from both tables, or you can provide list of columns in both
left and right table. The lists has to be of equal length.
When performing the join, values of columns on corresponding indexes must be equal in order to match the rows.
For details on joins in Pandas, see :ref:`pandas:merging.join`.
**Example:**
>>> from ppc_robot_lib.steps.transformations import JoinOnColumnStep, JoinType
>>> JoinOnColumnStep("keywords", "adgroups",
... left_on=["CampaignName", "AdGroupName"],
... right_on=['CampaignName', "Name"],
... join_type=JoinType.INNER,
... output_table="keywords_with_adgroup_data")
"""
def __init__(
self,
left_table: str,
right_table: str,
on: JoinSetSpec = None,
left_on: JoinSetSpec = None,
right_on: JoinSetSpec = None,
join_type: JoinType = JoinType.INNER,
output_table: str = None,
do_sort: bool = True,
ambiguous_column_suffixes=('_left', '_right'),
left_index=False,
right_index=False,
):
"""
:param left_table: Left table.
:param right_table: Right table.
:param on: Column names that must match in both tables. Exclusive with ``left_on`` and ``right_on``.
:param left_on: List of columns to match in the left table. Must be the same length with ``right_on``.
Exclusive with ``on``.
:param right_on: List of columns to match in the left table. Must be the same length with ``right_on``.
Exclusive with ``on``.
:param join_type: Type of the join. See :ref:`pandas:merging.join` for explanation of individual types.
:param output_table: Output table. If none is given, left table is used.
:param do_sort: Set to ``True`` if you would like to sort the tables by join columns.
:param ambiguous_column_suffixes: Pair of column suffixes that is used when tables with conflicting column names
are used. The first element is used for the left table, the second one for the right table.
"""
if (on is None) and ((left_on is None and left_index is False) or (right_on is None and right_index is False)):
raise ValueError('Either on or both left_on/left_index and right_on/right_index must be specified.')
elif on is not None and (
left_on is not None or left_index is True or right_on is not None or right_index is True
):
raise ValueError('When on is specified, left_on/left_index and right_on/right_index must not be given.')
self.left_table = left_table
self.right_table = right_table
self.on = on
self.left_on = left_on
self.left_index = left_index
self.right_on = right_on
self.right_index = right_index
self.join_type = join_type
self.do_sort = do_sort
self.ambiguous_column_suffixes = ambiguous_column_suffixes
if output_table is not None:
self.output_table = output_table
else:
self.output_table = left_table
def execute(self, task_ctx: TaskContextInterface) -> StepPerformance:
left_table = task_ctx.work_set.get_table(self.left_table)
right_table = task_ctx.work_set.get_table(self.right_table)
rows_in = len(left_table.index)
# Fix object dtypes on empty tables.
if len(left_table.index) == 0:
self._fix_object_dtypes(
left_table, self.on, self.left_on, self.left_index, right_table, self.right_on, self.right_index
)
if len(right_table.index) == 0:
self._fix_object_dtypes(
right_table, self.on, self.right_on, self.right_index, left_table, self.left_on, self.left_index
)
new_table = pandas.merge(
left_table,
right_table,
how=self.join_type.value,
on=self.on,
left_on=self.left_on,
right_on=self.right_on,
left_index=self.left_index,
right_index=self.right_index,
sort=self.do_sort,
suffixes=self.ambiguous_column_suffixes,
)
if self.output_table in task_ctx.work_set:
task_ctx.work_set.delete_table(self.output_table)
task_ctx.work_set.set_table(self.output_table, new_table)
return StepPerformance(new_table, rows_in=rows_in, rows_out=len(new_table.index))
def _fix_object_dtypes(
self,
table: pandas.DataFrame,
on: JoinSetSpec,
table_on: JoinSetSpec,
table_index: bool,
foreign_table: pandas.DataFrame,
foreign_on: JoinSetSpec,
foreign_index: bool,
) -> None:
if on:
source_cols, source_cols_names = self._get_columns(table, on)
foreign_cols, _ = self._get_columns(foreign_table, on)
else:
if table_on:
source_cols, source_cols_names = self._get_columns(table, table_on)
elif table_index:
source_cols = self._get_index(table)
source_cols_names = None
else:
raise ValueError('No join columns/indexes given!')
if foreign_on:
foreign_cols, _ = self._get_columns(foreign_table, foreign_on)
elif foreign_index:
foreign_cols = self._get_index(foreign_table)
else:
raise ValueError('No join columns/indexes given!')
if len(source_cols) != len(foreign_cols):
raise ValueError(
'Cannot fix columns on invalid join conditions - main table contains '
f'{len(source_cols)} columns, foreign table contains {len(foreign_cols)} columns.'
)
modified = False
for i, source_col in enumerate(source_cols):
foreign_col = foreign_cols[i]
if source_col.dtype == numpy.object_ and foreign_col.dtype != numpy.object_:
modified = True
source_cols[i] = source_col.astype(foreign_col.dtype)
if modified:
if source_cols_names is not None:
for i, col_name in enumerate(source_cols_names):
table[col_name] = source_cols[i]
elif len(source_cols) == 1:
new_index = source_cols[0]
if not isinstance(new_index, pandas.Index):
new_index = pandas.Index(new_index)
table.index = new_index
else:
new_index = pandas.MultiIndex.from_arrays(source_cols)
table.index = new_index
def _get_columns(self, table: pandas.DataFrame, columns: str | list[str]) -> tuple[list, list[str]]:
if not isinstance(columns, str) and isinstance(columns, Iterable):
return [table[column] for column in columns], [column for column in columns]
else:
return [table[columns]], [columns]
def _get_index(self, table: pandas.DataFrame):
index = table.index
return [index.get_level_values(i) for i in range(index.nlevels)]