Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,36 @@ Built-in standard errors include:

Example Usage: `... | Jackknife('CookieBucket', confidence=.95)`

#### Transformations

Transformations are functions that can be applied to metrics to perform
element-wise operations on their results.

Currently supported transformations include:

+ `ExponentialTransform(metric)`: Applies an exponential transformation to
the metric result.
+ `LogTransform(metric, base='ln')`: Applies a logarithmic transformation to
the metric result. The `base` can be `ln` or `log10`.
+ `ExponentialPercentTransform(metric, base='ln')`: Computes
`100 * (base^metric - 1)`. If `base='log10'`, it computes
`100 * (10^metric - 1)`.
It's useful for converting log-transformed
metrics back to a percent scale. For example, when
`PercentChange(.., Sum(x))` is skewed, applying the transformation sequence:

```
(Sum(x)
| LogTransform()
| AbsoluteChange(...)
| Jackknife(...)
| ExponentialPercentTransform())
```

computes the same percent change, while the confidence interval is
calculated in the log-transformed space, which often results in less
skewness.

#### Distributions

A **distribution** operation produces the distribution of the metric over
Expand Down
1 change: 1 addition & 0 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def to_sql(table, split_by=None):
'Bootstrap',
'PoissonBootstrap',
'LogTransform',
'ExponentialTransform',
'ExponentialPercentTransform',
'LogTransformedPercentChangeWithCI',
# Diversity Operations
Expand Down
96 changes: 78 additions & 18 deletions operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4364,19 +4364,45 @@ class MetricFunction(Operation):

Attributes:
func: The function to apply to the result of child Metric.
sql_func: The function to apply to the result of child Metric in SQL.
children: A tuple containing the child Metric.
name_tmpl: The template to generate the name from child Metric's name.
"""

def __init__(self, child, func, name_tmpl, **kwargs):
def __init__(self, child, func, sql_func, name_tmpl, **kwargs):
super().__init__(child, name_tmpl, **kwargs)
self.func = func
self.sql_func = sql_func

def compute_on_children(self, children, split_by):
new_df = self.func(children)
new_df = copy_meterstick_metadata(children, new_df)
return new_df

def get_sql_and_with_clause(
self, table, split_by, global_filter, indexes, local_filter, with_data
):
if not self.sql_func:
raise NotImplementedError(
f'SQL generation not supported for {type(self)}.'
)
local_filter = (
sql.Filters(self.where_).add(local_filter).remove(global_filter)
)
child_sql, with_data = self.children[0].get_sql_and_with_clause(
table, split_by, global_filter, indexes, local_filter, with_data)
columns = sql.Columns()
for c in child_sql.all_columns:
if c.alias in indexes.aliases:
columns.add(c)
else:
col = sql.Column(c.expression, self.sql_func,
alias=self.name_tmpl.format(c.alias_raw))
columns.add(col)
child_sql = copy.deepcopy(child_sql)
child_sql.columns = columns
return child_sql, with_data

def manipulate(
self, res, melted=False, return_dataframe=True, apply_name_tmpl=None
):
Expand All @@ -4396,24 +4422,43 @@ class LogTransform(MetricFunction):
Attributes:
base: The logarithm base, 'ln' or 'log10'.
func: The log function (np.log or np.log10).
sql_func: The log function in SQL (sql.LN_FN or sql.LOG10_FN).
children: A tuple containing the child Metric.
name_tmpl: The template to generate the name from child Metric's name.
"""

def __init__(self, child=None, base: str = 'ln', **kwargs):
def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs):
if base not in ('ln', 'log10'):
raise ValueError("base must be 'ln' or 'log10'")
self.base = base
func = np.log if base == 'ln' else np.log10
sql_func = sql.LN_FN if base == 'ln' else sql.LOG10_FN
if name_tmpl is None:
name_tmpl = 'Ln({})' if base == 'ln' else 'Log10({})'
super().__init__(
child,
func,
'Log({})' if base == 'ln' else 'Log10({})',
sql_func,
name_tmpl,
additional_fingerprint_attrs=['base'],
**kwargs
)


class ExponentialTransform(MetricFunction):
"""Base class for applying exponential transformations to Metric."""

def __init__(self, child=None, name_tmpl='Exp({})', **kwargs):
sql_func = 'EXP({})'
super().__init__(
child,
np.exp,
sql_func,
name_tmpl,
**kwargs
)


class ExponentialPercentTransform(MetricFunction):
"""Applies exponential and percent transformations to Metric results.

Expand All @@ -4440,30 +4485,36 @@ class ExponentialPercentTransform(MetricFunction):
base: The logarithm base, 'ln' or 'log10', used in inverse transformation.
func: The inverse function: 100*(exp(x)-1) for 'ln', 100*(10^x-1) for
'log10'.
sql_func: The inverse function in SQL: 100*(EXP(x)-1) for base='ln',
100*(10^x-1) for base='log10'.
children: A tuple containing the child Metric.
name_tmpl: The template to generate the name from child Metric's name.
"""

def __init__(self, child=None, base: str = 'ln', **kwargs):
def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs):
"""Initializes an ExponentialPercentTransform.

Args:
child: The child Metric to apply exp transform to.
base: The logarithm base, 'ln' or 'log10'. Default is 'ln'.
name_tmpl: The template to generate the name from child Metric's name.
**kwargs: other keyword arguments passed to MetricFunction.__init__.
"""
if base not in ('ln', 'log10'):
raise ValueError("base must be 'ln' or 'log10'")
self.base = base
if base == 'ln':
func = lambda x: 100 * (np.exp(x) - 1)
sql_func = '100 * (EXP({}) - 1)'
name_tmpl = '100 * Exp({}) - 1'
else:
func = lambda x: 100 * (10**x - 1)
sql_func = '100 * (POWER(10, {}) - 1)'
name_tmpl = '100 * 10^({}) - 1'
super().__init__(
child,
func,
sql_func,
name_tmpl,
additional_fingerprint_attrs=['base'],
**kwargs
Expand Down Expand Up @@ -4513,10 +4564,11 @@ def _check_and_update_for_log_transformed_abs_change(self):
ab = ci_method.children[0]
log_transform = ab.children[0]

log_transform.name_tmpl = '{}'
ab.children = tuple([log_transform])
self.name_tmpl = '{}'
self.children = tuple([ci_method])
log_transform = LogTransform(
log_transform.children[0], log_transform.base, name_tmpl='{}'
)
self.children = tuple([ci_method(ab(log_transform))])
return True

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -4609,17 +4661,25 @@ def _get_equiv(self, child):
| ExponentialPercentTransform()
)

def compute_on(
self,
df,
split_by=None,
melted=False,
return_dataframe=True,
cache_key=None,
cache=None,
):
def compute_slices(self, df, split_by=None):
"""Computes CI on log-scale and transform back to percent change."""
equiv = self._get_equiv(self.children[0])
return equiv.compute_on(
df, split_by, melted, return_dataframe, cache_key, cache
return self.compute_util_metric_on(equiv, df, split_by)

def compute_through_sql(self, table, split_by, execute, mode):
equiv = self._get_equiv(self.children[0])
return self.compute_util_metric_on_sql(
equiv, table, split_by, execute, mode
)

def manipulate(
self, res, melted=False, return_dataframe=True, apply_name_tmpl=None
):
new_res = super().manipulate(res, melted, return_dataframe, apply_name_tmpl)
new_res = copy_meterstick_metadata(res, new_res)
return new_res

def final_compute(self, res, melted, return_dataframe, split_by, df):
new_res = super().final_compute(res, melted, return_dataframe, split_by, df)
new_res = copy_meterstick_metadata(res, new_res)
return new_res
59 changes: 52 additions & 7 deletions operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test_log_transform_ln(self):
metric = operations.LogTransform(metrics.Sum('x'), base='ln')
output = metric.compute_on(self.df)
expected = pd.DataFrame(
{'Log(sum(x))': [np.log(8)]}
{'Ln(sum(x))': [np.log(8)]}
)
testing.assert_frame_equal(output, expected)

Expand All @@ -294,6 +294,14 @@ def test_log_transform_log10(self):
)
testing.assert_frame_equal(output, expected)

def test_exponential_transform(self):
metric = operations.ExponentialTransform(metrics.Sum('x'))
output = metric.compute_on(self.df)
expected = pd.DataFrame(
{'Exp(sum(x))': [np.exp(8)]}
)
testing.assert_frame_equal(output, expected)

def test_exponential_percent_transform_ln(self):
metric = operations.ExponentialPercentTransform(metrics.Sum('x'), base='ln')
output = metric.compute_on(self.df)
Expand Down Expand Up @@ -1605,7 +1613,7 @@ def test_display_raises_for_duplicate_metric_names(self):
)

@parameterized.product(base=['ln', 'log10'], melted=[True, False])
def test_dispplay_log_transformed_percent_change(self, base, melted):
def test_display_log_transformed_percent_change(self, base, melted):
df = pd.DataFrame({
'x': list(range(8, 13)) + list(range(98, 103)),
'grp': list('A' * 5 + 'B' * 5),
Expand Down Expand Up @@ -1633,7 +1641,7 @@ def test_dispplay_log_transformed_percent_change(self, base, melted):
testing.assert_frame_equal(actual, expected)

@parameterized.product(base=['ln', 'log10'], melted=[True, False])
def test_dispplay_log_transformed_percent_change_split_by(self, base, melted):
def test_display_log_transformed_percent_change_split_by(self, base, melted):
df = pd.DataFrame({
'x': np.random.random(10).round(5),
'grp': list('A' * 5 + 'B' * 5),
Expand Down Expand Up @@ -1665,6 +1673,43 @@ def test_dispplay_log_transformed_percent_change_split_by(self, base, melted):

testing.assert_frame_equal(actual, expected)

def test_display_log_transformed_percent_change_metric_list(self):
df = pd.DataFrame({
'x': list(range(8, 13)) + list(range(98, 103)),
'grp': list('A' * 5 + 'B' * 5),
'unit': list(range(5)) * 2,
})
metric = metrics.MetricList((metrics.Mean('x'), metrics.Sum('x')))
m = (
metric
| operations.LogTransform()
| operations.AbsoluteChange('grp', 'A')
| operations.Jackknife('unit', confidence=0.9)
| operations.ExponentialPercentTransform()
)

actual = m.compute_on(df)
actual = actual.display(return_formatted_df=True)
log_pct = operations.LogTransformedPercentChangeWithCI(
'grp', 'A', 'unit', 0.9
)
expected1 = (
log_pct(metrics.Mean('x'))
.compute_on(df)
.display(return_formatted_df=True)
)
expected2 = (
log_pct(metrics.Sum('x'))
.compute_on(df)
.display(return_formatted_df=True)
)
expected = (
expected1.set_index('Dimensions')
.join(expected2.set_index('Dimensions'))
.reset_index()
)
testing.assert_frame_equal(actual, expected)

def test_display_log_transformed_percent_change_with_ci(self):
df = pd.DataFrame({
'x': np.random.random(10).round(3),
Expand Down Expand Up @@ -2809,10 +2854,10 @@ def test_different_metrics_have_different_fingerprints(self):
operations.Bootstrap('x', n_replicates=10),
operations.Bootstrap('x', confidence=0.9),
operations.Bootstrap('x', confidence=0.95),
operations.LogTransform('x'),
operations.LogTransform('x', base='log10'),
operations.ExponentialPercentTransform('x'),
operations.ExponentialPercentTransform('x', base='log10'),
operations.LogTransform(),
operations.LogTransform(base='log10'),
operations.ExponentialPercentTransform(),
operations.ExponentialPercentTransform(base='log10'),
diversity.HHI('x'),
diversity.HHI('y'),
diversity.Entropy('x'),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "meterstick"
version = "1.5.8"
version = "1.6.0"
authors = [
{ name="Xunmo Yang", email="xunmo@google.com" },
{ name="Dennis Sun", email="dlsun@google.com" },
Expand Down
14 changes: 13 additions & 1 deletion sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
RAND_FN = None
CEIL_FN = None
SAFE_DIVIDE_FN = None
LN_FN = None
LOG10_FN = None
QUANTILE_FN = None
ARRAY_AGG_FN = None
ARRAY_INDEX_FN = None
Expand Down Expand Up @@ -448,6 +450,14 @@ def covar_samp_not_implemented():
'Default': safe_divide_fn_default,
'GoogleSQL': 'SAFE_DIVIDE({numer}, {denom})'.format,
}
LN_OPTIONS = {
'Default': 'LN({})',
'SQL Server': 'LOG({})',
}
LOG10_OPTIONS = {
'Default': 'LOG10({})',
'Oracle': 'LOG(10, {})',
}
# When make changes, manually evaluate the run_only_once_in_with_clause and
# update the VOLATILE_RAND_IN_WITH_CLAUSE_OPTIONS.
RAND_OPTIONS = {
Expand Down Expand Up @@ -602,7 +612,7 @@ def set_dialect(dialect: Optional[str]):
"""Sets the dialect of the SQL query."""
# You can manually override the options below. You can manually test it in
# https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing.
global DIALECT, VOLATILE_RAND_IN_WITH_CLAUSE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, SUPPORT_JOIN_WITH_USING, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN, STDDEV_POP_FN, STDDEV_SAMP_FN, VARIANCE_POP_FN, VARIANCE_SAMP_FN, CORR_FN, COVAR_POP_FN, COVAR_SAMP_FN
global DIALECT, VOLATILE_RAND_IN_WITH_CLAUSE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, SUPPORT_JOIN_WITH_USING, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, LN_FN, LOG10_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN, STDDEV_POP_FN, STDDEV_SAMP_FN, VARIANCE_POP_FN, VARIANCE_SAMP_FN, CORR_FN, COVAR_POP_FN, COVAR_SAMP_FN
if not dialect:
return
if dialect not in BUILTIN_DIALECTS:
Expand All @@ -624,6 +634,8 @@ def set_dialect(dialect: Optional[str]):
GROUP_BY_FN = _get_dialect_option(GROUP_BY_OPTIONS)
RAND_FN = _get_dialect_option(RAND_OPTIONS)
CEIL_FN = _get_dialect_option(CEIL_OPTIONS)
LN_FN = _get_dialect_option(LN_OPTIONS)
LOG10_FN = _get_dialect_option(LOG10_OPTIONS)
SAFE_DIVIDE_FN = _get_dialect_option(SAFE_DIVIDE_OPTIONS)
QUANTILE_FN = _get_dialect_option(QUANTILE_OPTIONS)
ARRAY_AGG_FN = _get_dialect_option(ARRAY_AGG_OPTIONS)
Expand Down
Loading