From 42ff82cb704eb5162eb15408fde017c126a3b468 Mon Sep 17 00:00:00 2001 From: Xunmo Yang Date: Thu, 2 Apr 2026 11:23:22 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 893607810 --- README.md | 30 +++++++++++++++ metrics.py | 1 + operations.py | 96 +++++++++++++++++++++++++++++++++++++--------- operations_test.py | 59 ++++++++++++++++++++++++---- pyproject.toml | 2 +- sql.py | 14 ++++++- 6 files changed, 175 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 4ca7d77..62d4758 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,36 @@ Built-in standard errors include: Example Usage: `... | Jackknife('CookieBucket', confidence=.95)` +#### Transformations + +Transformations are functions that can be applied to metrics to perform +element-wise operations on their results. + +Currently supported transformations include: + ++ `ExponentialTransform(metric)`: Applies an exponential transformation to + the metric result. ++ `LogTransform(metric, base='ln')`: Applies a logarithmic transformation to + the metric result. The `base` can be `ln` or `log10`. ++ `ExponentialPercentTransform(metric, base='ln')`: Computes + `100 * (base^metric - 1)`. If `base='log10'`, it computes + `100 * (10^metric - 1)`. + It's useful for converting log-transformed + metrics back to a percent scale. For example, when + `PercentChange(.., Sum(x))` is skewed, applying the transformation sequence: + + ``` + (Sum(x) + | LogTransform() + | AbsoluteChange(...) + | Jackknife(...) + | ExponentialPercentTransform()) + ``` + + computes the same percent change, while the confidence interval is + calculated in the log-transformed space, which often results in less + skewness. + #### Distributions A **distribution** operation produces the distribution of the metric over diff --git a/metrics.py b/metrics.py index fa37283..d2c0313 100644 --- a/metrics.py +++ b/metrics.py @@ -131,6 +131,7 @@ def to_sql(table, split_by=None): 'Bootstrap', 'PoissonBootstrap', 'LogTransform', + 'ExponentialTransform', 'ExponentialPercentTransform', 'LogTransformedPercentChangeWithCI', # Diversity Operations diff --git a/operations.py b/operations.py index 05674d5..10b08ad 100644 --- a/operations.py +++ b/operations.py @@ -4364,19 +4364,45 @@ class MetricFunction(Operation): Attributes: func: The function to apply to the result of child Metric. + sql_func: The function to apply to the result of child Metric in SQL. children: A tuple containing the child Metric. name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child, func, name_tmpl, **kwargs): + def __init__(self, child, func, sql_func, name_tmpl, **kwargs): super().__init__(child, name_tmpl, **kwargs) self.func = func + self.sql_func = sql_func def compute_on_children(self, children, split_by): new_df = self.func(children) new_df = copy_meterstick_metadata(children, new_df) return new_df + def get_sql_and_with_clause( + self, table, split_by, global_filter, indexes, local_filter, with_data + ): + if not self.sql_func: + raise NotImplementedError( + f'SQL generation not supported for {type(self)}.' + ) + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + child_sql, with_data = self.children[0].get_sql_and_with_clause( + table, split_by, global_filter, indexes, local_filter, with_data) + columns = sql.Columns() + for c in child_sql.all_columns: + if c.alias in indexes.aliases: + columns.add(c) + else: + col = sql.Column(c.expression, self.sql_func, + alias=self.name_tmpl.format(c.alias_raw)) + columns.add(col) + child_sql = copy.deepcopy(child_sql) + child_sql.columns = columns + return child_sql, with_data + def manipulate( self, res, melted=False, return_dataframe=True, apply_name_tmpl=None ): @@ -4396,24 +4422,43 @@ class LogTransform(MetricFunction): Attributes: base: The logarithm base, 'ln' or 'log10'. func: The log function (np.log or np.log10). + sql_func: The log function in SQL (sql.LN_FN or sql.LOG10_FN). children: A tuple containing the child Metric. name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child=None, base: str = 'ln', **kwargs): + def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): if base not in ('ln', 'log10'): raise ValueError("base must be 'ln' or 'log10'") self.base = base func = np.log if base == 'ln' else np.log10 + sql_func = sql.LN_FN if base == 'ln' else sql.LOG10_FN + if name_tmpl is None: + name_tmpl = 'Ln({})' if base == 'ln' else 'Log10({})' super().__init__( child, func, - 'Log({})' if base == 'ln' else 'Log10({})', + sql_func, + name_tmpl, additional_fingerprint_attrs=['base'], **kwargs ) +class ExponentialTransform(MetricFunction): + """Base class for applying exponential transformations to Metric.""" + + def __init__(self, child=None, name_tmpl='Exp({})', **kwargs): + sql_func = 'EXP({})' + super().__init__( + child, + np.exp, + sql_func, + name_tmpl, + **kwargs + ) + + class ExponentialPercentTransform(MetricFunction): """Applies exponential and percent transformations to Metric results. @@ -4440,16 +4485,19 @@ class ExponentialPercentTransform(MetricFunction): base: The logarithm base, 'ln' or 'log10', used in inverse transformation. func: The inverse function: 100*(exp(x)-1) for 'ln', 100*(10^x-1) for 'log10'. + sql_func: The inverse function in SQL: 100*(EXP(x)-1) for base='ln', + 100*(10^x-1) for base='log10'. children: A tuple containing the child Metric. name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child=None, base: str = 'ln', **kwargs): + def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): """Initializes an ExponentialPercentTransform. Args: child: The child Metric to apply exp transform to. base: The logarithm base, 'ln' or 'log10'. Default is 'ln'. + name_tmpl: The template to generate the name from child Metric's name. **kwargs: other keyword arguments passed to MetricFunction.__init__. """ if base not in ('ln', 'log10'): @@ -4457,13 +4505,16 @@ def __init__(self, child=None, base: str = 'ln', **kwargs): self.base = base if base == 'ln': func = lambda x: 100 * (np.exp(x) - 1) + sql_func = '100 * (EXP({}) - 1)' name_tmpl = '100 * Exp({}) - 1' else: func = lambda x: 100 * (10**x - 1) + sql_func = '100 * (POWER(10, {}) - 1)' name_tmpl = '100 * 10^({}) - 1' super().__init__( child, func, + sql_func, name_tmpl, additional_fingerprint_attrs=['base'], **kwargs @@ -4513,10 +4564,11 @@ def _check_and_update_for_log_transformed_abs_change(self): ab = ci_method.children[0] log_transform = ab.children[0] - log_transform.name_tmpl = '{}' - ab.children = tuple([log_transform]) self.name_tmpl = '{}' - self.children = tuple([ci_method]) + log_transform = LogTransform( + log_transform.children[0], log_transform.base, name_tmpl='{}' + ) + self.children = tuple([ci_method(ab(log_transform))]) return True def __call__(self, *args, **kwargs): @@ -4609,17 +4661,25 @@ def _get_equiv(self, child): | ExponentialPercentTransform() ) - def compute_on( - self, - df, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - cache=None, - ): + def compute_slices(self, df, split_by=None): """Computes CI on log-scale and transform back to percent change.""" equiv = self._get_equiv(self.children[0]) - return equiv.compute_on( - df, split_by, melted, return_dataframe, cache_key, cache + return self.compute_util_metric_on(equiv, df, split_by) + + def compute_through_sql(self, table, split_by, execute, mode): + equiv = self._get_equiv(self.children[0]) + return self.compute_util_metric_on_sql( + equiv, table, split_by, execute, mode ) + + def manipulate( + self, res, melted=False, return_dataframe=True, apply_name_tmpl=None + ): + new_res = super().manipulate(res, melted, return_dataframe, apply_name_tmpl) + new_res = copy_meterstick_metadata(res, new_res) + return new_res + + def final_compute(self, res, melted, return_dataframe, split_by, df): + new_res = super().final_compute(res, melted, return_dataframe, split_by, df) + new_res = copy_meterstick_metadata(res, new_res) + return new_res diff --git a/operations_test.py b/operations_test.py index eaf9a07..6e4acfa 100644 --- a/operations_test.py +++ b/operations_test.py @@ -282,7 +282,7 @@ def test_log_transform_ln(self): metric = operations.LogTransform(metrics.Sum('x'), base='ln') output = metric.compute_on(self.df) expected = pd.DataFrame( - {'Log(sum(x))': [np.log(8)]} + {'Ln(sum(x))': [np.log(8)]} ) testing.assert_frame_equal(output, expected) @@ -294,6 +294,14 @@ def test_log_transform_log10(self): ) testing.assert_frame_equal(output, expected) + def test_exponential_transform(self): + metric = operations.ExponentialTransform(metrics.Sum('x')) + output = metric.compute_on(self.df) + expected = pd.DataFrame( + {'Exp(sum(x))': [np.exp(8)]} + ) + testing.assert_frame_equal(output, expected) + def test_exponential_percent_transform_ln(self): metric = operations.ExponentialPercentTransform(metrics.Sum('x'), base='ln') output = metric.compute_on(self.df) @@ -1605,7 +1613,7 @@ def test_display_raises_for_duplicate_metric_names(self): ) @parameterized.product(base=['ln', 'log10'], melted=[True, False]) - def test_dispplay_log_transformed_percent_change(self, base, melted): + def test_display_log_transformed_percent_change(self, base, melted): df = pd.DataFrame({ 'x': list(range(8, 13)) + list(range(98, 103)), 'grp': list('A' * 5 + 'B' * 5), @@ -1633,7 +1641,7 @@ def test_dispplay_log_transformed_percent_change(self, base, melted): testing.assert_frame_equal(actual, expected) @parameterized.product(base=['ln', 'log10'], melted=[True, False]) - def test_dispplay_log_transformed_percent_change_split_by(self, base, melted): + def test_display_log_transformed_percent_change_split_by(self, base, melted): df = pd.DataFrame({ 'x': np.random.random(10).round(5), 'grp': list('A' * 5 + 'B' * 5), @@ -1665,6 +1673,43 @@ def test_dispplay_log_transformed_percent_change_split_by(self, base, melted): testing.assert_frame_equal(actual, expected) + def test_display_log_transformed_percent_change_metric_list(self): + df = pd.DataFrame({ + 'x': list(range(8, 13)) + list(range(98, 103)), + 'grp': list('A' * 5 + 'B' * 5), + 'unit': list(range(5)) * 2, + }) + metric = metrics.MetricList((metrics.Mean('x'), metrics.Sum('x'))) + m = ( + metric + | operations.LogTransform() + | operations.AbsoluteChange('grp', 'A') + | operations.Jackknife('unit', confidence=0.9) + | operations.ExponentialPercentTransform() + ) + + actual = m.compute_on(df) + actual = actual.display(return_formatted_df=True) + log_pct = operations.LogTransformedPercentChangeWithCI( + 'grp', 'A', 'unit', 0.9 + ) + expected1 = ( + log_pct(metrics.Mean('x')) + .compute_on(df) + .display(return_formatted_df=True) + ) + expected2 = ( + log_pct(metrics.Sum('x')) + .compute_on(df) + .display(return_formatted_df=True) + ) + expected = ( + expected1.set_index('Dimensions') + .join(expected2.set_index('Dimensions')) + .reset_index() + ) + testing.assert_frame_equal(actual, expected) + def test_display_log_transformed_percent_change_with_ci(self): df = pd.DataFrame({ 'x': np.random.random(10).round(3), @@ -2809,10 +2854,10 @@ def test_different_metrics_have_different_fingerprints(self): operations.Bootstrap('x', n_replicates=10), operations.Bootstrap('x', confidence=0.9), operations.Bootstrap('x', confidence=0.95), - operations.LogTransform('x'), - operations.LogTransform('x', base='log10'), - operations.ExponentialPercentTransform('x'), - operations.ExponentialPercentTransform('x', base='log10'), + operations.LogTransform(), + operations.LogTransform(base='log10'), + operations.ExponentialPercentTransform(), + operations.ExponentialPercentTransform(base='log10'), diversity.HHI('x'), diversity.HHI('y'), diversity.Entropy('x'), diff --git a/pyproject.toml b/pyproject.toml index 8393b65..f24e5be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "meterstick" -version = "1.5.8" +version = "1.6.0" authors = [ { name="Xunmo Yang", email="xunmo@google.com" }, { name="Dennis Sun", email="dlsun@google.com" }, diff --git a/sql.py b/sql.py index 42ec689..80a155d 100644 --- a/sql.py +++ b/sql.py @@ -40,6 +40,8 @@ RAND_FN = None CEIL_FN = None SAFE_DIVIDE_FN = None +LN_FN = None +LOG10_FN = None QUANTILE_FN = None ARRAY_AGG_FN = None ARRAY_INDEX_FN = None @@ -448,6 +450,14 @@ def covar_samp_not_implemented(): 'Default': safe_divide_fn_default, 'GoogleSQL': 'SAFE_DIVIDE({numer}, {denom})'.format, } +LN_OPTIONS = { + 'Default': 'LN({})', + 'SQL Server': 'LOG({})', +} +LOG10_OPTIONS = { + 'Default': 'LOG10({})', + 'Oracle': 'LOG(10, {})', +} # When make changes, manually evaluate the run_only_once_in_with_clause and # update the VOLATILE_RAND_IN_WITH_CLAUSE_OPTIONS. RAND_OPTIONS = { @@ -602,7 +612,7 @@ def set_dialect(dialect: Optional[str]): """Sets the dialect of the SQL query.""" # You can manually override the options below. You can manually test it in # https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing. - global DIALECT, VOLATILE_RAND_IN_WITH_CLAUSE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, SUPPORT_JOIN_WITH_USING, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN, STDDEV_POP_FN, STDDEV_SAMP_FN, VARIANCE_POP_FN, VARIANCE_SAMP_FN, CORR_FN, COVAR_POP_FN, COVAR_SAMP_FN + global DIALECT, VOLATILE_RAND_IN_WITH_CLAUSE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, SUPPORT_JOIN_WITH_USING, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, LN_FN, LOG10_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN, STDDEV_POP_FN, STDDEV_SAMP_FN, VARIANCE_POP_FN, VARIANCE_SAMP_FN, CORR_FN, COVAR_POP_FN, COVAR_SAMP_FN if not dialect: return if dialect not in BUILTIN_DIALECTS: @@ -624,6 +634,8 @@ def set_dialect(dialect: Optional[str]): GROUP_BY_FN = _get_dialect_option(GROUP_BY_OPTIONS) RAND_FN = _get_dialect_option(RAND_OPTIONS) CEIL_FN = _get_dialect_option(CEIL_OPTIONS) + LN_FN = _get_dialect_option(LN_OPTIONS) + LOG10_FN = _get_dialect_option(LOG10_OPTIONS) SAFE_DIVIDE_FN = _get_dialect_option(SAFE_DIVIDE_OPTIONS) QUANTILE_FN = _get_dialect_option(QUANTILE_OPTIONS) ARRAY_AGG_FN = _get_dialect_option(ARRAY_AGG_OPTIONS)