From 38076a1b48096ef7b4f1e5005fdf9e14425d56a8 Mon Sep 17 00:00:00 2001
From: wangjiezhe <wangjiezhe@gmail.com>
Date: Sat, 23 Dec 2023 10:07:02 +0800
Subject: [PATCH 2/2] Revert "Update TF Estimator to use new TF API Generator"

This reverts commit f7653f9adf978acb3bd459e6ec779a659f1f9a2a.
---
 tensorflow_estimator/BUILD                    |  14 +-
 .../python/estimator/api/BUILD                |  48 +-
 .../python/estimator/api/api_gen.bzl          | 422 ++++--------------
 .../api/create_python_api_wrapper.py          |  30 ++
 .../python/estimator/api/extractor_wrapper.py |  21 -
 .../python/estimator/api/generator_wrapper.py |  20 -
 6 files changed, 155 insertions(+), 400 deletions(-)
 create mode 100644 tensorflow_estimator/python/estimator/api/create_python_api_wrapper.py
 delete mode 100644 tensorflow_estimator/python/estimator/api/extractor_wrapper.py
 delete mode 100644 tensorflow_estimator/python/estimator/api/generator_wrapper.py

diff --git a/tensorflow_estimator/BUILD b/tensorflow_estimator/BUILD
index 680cc4e..037c7a9 100644
--- a/tensorflow_estimator/BUILD
+++ b/tensorflow_estimator/BUILD
@@ -5,7 +5,7 @@ load(
     "//tensorflow_estimator/python/estimator/api:api_gen.bzl",
     "ESTIMATOR_API_INIT_FILES_V1",
     "ESTIMATOR_API_INIT_FILES_V2",
-    "generate_apis",
+    "gen_api_init_files",
 )
 
 licenses(["notice"])
@@ -67,17 +67,17 @@ py_library(
 genrule(
     name = "root_init_gen",
     srcs = select({
-        "api_version_2": ["_api/v2/v2.py"],
-        "//conditions:default": ["_api/v1/v1.py"],
+        "api_version_2": [":estimator_python_api_gen_compat_v2"],
+        "//conditions:default": [":estimator_python_api_gen_compat_v1"],
     }),
     outs = ["__init__.py"],
     cmd = select({
-        "api_version_2": "cp $(location :_api/v2/v2.py) $(OUTS)",
-        "//conditions:default": "cp $(location :_api/v1/v1.py) $(OUTS)",
+        "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
+        "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
     }),
 )
 
-generate_apis(
+gen_api_init_files(
     name = "estimator_python_api_gen_compat_v1",
     api_version = 1,
     output_dir = "_api/v1/",
@@ -86,7 +86,7 @@ generate_apis(
     root_file_name = "v1.py",
 )
 
-generate_apis(
+gen_api_init_files(
     name = "estimator_python_api_gen_compat_v2",
     api_version = 2,
     output_dir = "_api/v2/",
diff --git a/tensorflow_estimator/python/estimator/api/BUILD b/tensorflow_estimator/python/estimator/api/BUILD
index 01dce90..96ac567 100644
--- a/tensorflow_estimator/python/estimator/api/BUILD
+++ b/tensorflow_estimator/python/estimator/api/BUILD
@@ -1,10 +1,17 @@
-# Placeholder: load aliased py_binary
-load("//tensorflow_estimator/python/estimator/api:api_gen.bzl", "ESTIMATOR_API_INIT_FILES_V1", "ESTIMATOR_API_INIT_FILES_V2", "generate_apis")
-
 package(default_visibility = ["//tensorflow_estimator:internal"])
 
 licenses(["notice"])
 
+load("//tensorflow_estimator/python/estimator/api:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow_estimator/python/estimator/api:api_gen.bzl", "ESTIMATOR_API_INIT_FILES_V1")
+load("//tensorflow_estimator/python/estimator/api:api_gen.bzl", "ESTIMATOR_API_INIT_FILES_V2")
+
+exports_files(
+    [
+        "create_python_api_wrapper.py",
+    ],
+)
+
 # This flag specifies whether Estimator 2.0 API should be built instead
 # of 1.* API. Note that Estimator 2.0 API is currently under development.
 config_setting(
@@ -12,53 +19,36 @@ config_setting(
     define_values = {"estimator_api_version": "2"},
 )
 
-py_binary(
-    name = "extractor_wrapper",
-    srcs = ["extractor_wrapper.py"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow_estimator/python/estimator:expect_absl_installed",  # absl:app
-    ],
-)
-
-py_binary(
-    name = "generator_wrapper",
-    srcs = ["generator_wrapper.py"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow_estimator/python/estimator:expect_absl_installed",  # absl:app
-    ],
-)
-
 genrule(
     name = "estimator_python_api_gen",
     srcs = select({
-        "api_version_2": ["_v2/v2.py"],
-        "//conditions:default": ["_v1/v1.py"],
+        "api_version_2": [":estimator_python_api_gen_compat_v2"],
+        "//conditions:default": [":estimator_python_api_gen_compat_v1"],
     }),
     outs = ["__init__.py"],
     cmd = select({
-        "api_version_2": "cp $(location :_v2/v2.py) $(OUTS)",
-        "//conditions:default": "cp $(location :_v1/v1.py) $(OUTS)",
+        # Copy the right init file and replace 'from . import'
+        # with 'from ._vN import'.
+        "api_version_2": "cp $(@D)/_v2/v2.py $(OUTS) && sed -i'.original' 's/from . import/from ._v2 import/g' $(OUTS)",
+        "//conditions:default": "cp $(@D)/_v1/v1.py $(OUTS) && sed -i'.original' 's/from . import/from ._v1 import/g' $(OUTS)",
     }),
+    visibility = ["//visibility:public"],
 )
 
-generate_apis(
+gen_api_init_files(
     name = "estimator_python_api_gen_compat_v1",
     api_version = 1,
     output_dir = "_v1/",
     output_files = ESTIMATOR_API_INIT_FILES_V1,
     output_package = "tensorflow_estimator.python.estimator.api._v1",
     root_file_name = "v1.py",
-    visibility = ["//visibility:public"],
 )
 
-generate_apis(
+gen_api_init_files(
     name = "estimator_python_api_gen_compat_v2",
     api_version = 2,
     output_dir = "_v2/",
     output_files = ESTIMATOR_API_INIT_FILES_V2,
     output_package = "tensorflow_estimator.python.estimator.api._v2",
     root_file_name = "v2.py",
-    visibility = ["//visibility:public"],
 )
diff --git a/tensorflow_estimator/python/estimator/api/api_gen.bzl b/tensorflow_estimator/python/estimator/api/api_gen.bzl
index b8eaf84..87dd65b 100644
--- a/tensorflow_estimator/python/estimator/api/api_gen.bzl
+++ b/tensorflow_estimator/python/estimator/api/api_gen.bzl
@@ -1,7 +1,7 @@
-"""Targets for generating TensorFlow Estimator Python API __init__.py files.
+"""Targets for generating TensorFlow Python API __init__.py files.
 
 This bzl file is copied with slight modifications from
-tensorflow/python/tools/api/generator2/generate_api.bzl
+tensorflow/python/estimator/api/api_gen.bzl
 so that we can avoid needing to depend on TF source code in Bazel build.
 
 It should be noted that because this file is executed during the build,
@@ -10,15 +10,7 @@ is required to Bazel build Estimator.
 """
 
 load("//tensorflow_estimator:estimator.bzl", "if_indexing_source_code")
-
-_TARGET_PATTERNS = [
-    "//tensorflow_estimator:",
-    "//tensorflow_estimator/",
-]
-
-_DECORATOR = "tensorflow_estimator.python.estimator.estimator_export.estimator_export"
-
-_MODULE_PREFIX = ""
+# Placeholder: load aliased py_binary
 
 ESTIMATOR_API_INIT_FILES_V1 = [
     "__init__.py",
@@ -38,332 +30,116 @@ ESTIMATOR_API_INIT_FILES_V2 = [
     "estimator/inputs/__init__.py",
 ]
 
-def _any_match(label):
-    full_target = "//" + label.package + ":" + label.name
-    for pattern in _TARGET_PATTERNS:
-        if pattern in full_target:
-            return True
-    return False
-
-def _join(path, *others):
-    result = path
-
-    for p in others:
-        if not result or result.endswith("/"):
-            result += p
-        else:
-            result += "/" + p
-
-    return result
-
-def _api_info_init(*, transitive_api):
-    if type(transitive_api) != type(depset()):
-        fail("ApiInfo.transitive_api must be a depset")
-    return {"transitive_api": transitive_api}
-
-ApiInfo, _new_api_info = provider(
-    doc = "Provider for API symbols and docstrings extracted from Python files.",
-    fields = {
-        "transitive_api": "depset of files with extracted API.",
-    },
-    init = _api_info_init,
-)
-
-def _py_files(f):
-    if f.basename.endswith(".py") or f.basename.endswith(".py3"):
-        return f.path
-    return None
-
-def _merge_py_info(
-        deps,
-        direct_sources = None,
-        direct_imports = None,
-        has_py2_only_sources = False,
-        has_py3_only_sources = False,
-        uses_shared_libraries = False):
-    transitive_sources = []
-    transitive_imports = []
-    for dep in deps:
-        if PyInfo in dep:
-            transitive_sources.append(dep[PyInfo].transitive_sources)
-            transitive_imports.append(dep[PyInfo].imports)
-            has_py2_only_sources = has_py2_only_sources or dep[PyInfo].has_py2_only_sources
-            has_py3_only_sources = has_py3_only_sources or dep[PyInfo].has_py3_only_sources
-            uses_shared_libraries = uses_shared_libraries or dep[PyInfo].uses_shared_libraries
-
-    return PyInfo(
-        transitive_sources = depset(direct = direct_sources, transitive = transitive_sources),
-        imports = depset(direct = direct_imports, transitive = transitive_imports),
-        has_py2_only_sources = has_py2_only_sources,
-        has_py3_only_sources = has_py3_only_sources,
-        uses_shared_libraries = uses_shared_libraries,
-    )
-
-def _merge_api_info(
-        deps,
-        direct_api = None):
-    transitive_api = []
-    for dep in deps:
-        if ApiInfo in dep:
-            transitive_api.append(dep[ApiInfo].transitive_api)
-    return ApiInfo(transitive_api = depset(direct = direct_api, transitive = transitive_api))
-
-def _api_extractor_impl(target, ctx):
-    direct_api = []
-
-    # Make sure the rule has a non-empty srcs attribute.
-    if (
-        _any_match(target.label) and
-        hasattr(ctx.rule.attr, "srcs") and
-        ctx.rule.attr.srcs
-    ):
-        output = ctx.actions.declare_file("_".join([
-            target.label.name,
-            "extracted_tensorflow_estimator_api.json",
-        ]))
-
-        args = ctx.actions.args()
-        args.set_param_file_format("multiline")
-        args.use_param_file("--flagfile=%s")
-
-        args.add("--output", output)
-        args.add("--decorator", _DECORATOR)
-        args.add("--api_name", "tensorflow_estimator")
-        args.add_all(ctx.rule.files.srcs, expand_directories = True, map_each = _py_files)
-
-        ctx.actions.run(
-            mnemonic = "ExtractAPI",
-            executable = ctx.executable._extractor_bin,
-            inputs = ctx.rule.files.srcs,
-            outputs = [output],
-            arguments = [args],
-            progress_message = "Extracting tensorflow_estimator APIs for %{label} to %{output}.",
-        )
-
-        direct_api.append(output)
-
-    return [
-        _merge_api_info(ctx.rule.attr.deps if hasattr(ctx.rule.attr, "deps") else [], direct_api = direct_api),
-    ]
-
-api_extractor = aspect(
-    doc = "Extracts the exported API for the given target and its dependencies.",
-    implementation = _api_extractor_impl,
-    attr_aspects = ["deps"],
-    provides = [ApiInfo],
-    # Currently the Python rules do not correctly advertise their providers.
-    # required_providers = [PyInfo],
-    attrs = {
-        "_extractor_bin": attr.label(
-            default = Label("//tensorflow_estimator/python/estimator/api:extractor_wrapper"),
-            executable = True,
-            cfg = "exec",
-        ),
-    },
-)
-
-def _extract_api_impl(ctx):
-    return [
-        _merge_api_info(ctx.attr.deps),
-        _merge_py_info(ctx.attr.deps),
-    ]
-
-extract_api = rule(
-    doc = "Extract Python API for all targets in transitive dependencies.",
-    implementation = _extract_api_impl,
-    attrs = {
-        "deps": attr.label_list(
-            doc = "Targets to extract API from.",
-            allow_empty = False,
-            aspects = [api_extractor],
-            providers = [PyInfo],
-            mandatory = True,
-        ),
-    },
-    provides = [ApiInfo, PyInfo],
-)
-
-def _generate_api_impl(ctx):
-    args = ctx.actions.args()
-    args.set_param_file_format("multiline")
-    args.use_param_file("--flagfile=%s")
-
-    args.add_joined("--output_files", ctx.outputs.output_files, join_with = ",")
-    args.add("--output_dir", _join(ctx.bin_dir.path, ctx.label.package, ctx.attr.output_dir))
-    if ctx.file.root_init_template:
-        args.add("--root_init_template", ctx.file.root_init_template)
-    args.add("--apiversion", ctx.attr.api_version)
-    args.add_joined("--compat_api_versions", ctx.attr.compat_api_versions, join_with = ",")
-    args.add_joined("--compat_init_templates", ctx.files.compat_init_templates, join_with = ",")
-    args.add("--output_package", ctx.attr.output_package)
-    args.add_joined("--packages_to_ignore", ctx.attr.packages_to_ignore, join_with = ",")
-    if _MODULE_PREFIX:
-        args.add("--module_prefix", _MODULE_PREFIX)
-    if ctx.attr.use_lazy_loading:
-        args.add("--use_lazy_loading")
-    else:
-        args.add("--nouse_lazy_loading")
-    if ctx.attr.proxy_module_root:
-        args.add("--proxy_module_root", ctx.attr.proxy_module_root)
-    args.add_joined("--file_prefixes_to_strip", [ctx.bin_dir.path, ctx.genfiles_dir.path], join_with = ",")
-    if ctx.attr.root_file_name:
-        args.add("--root_file_name", ctx.attr.root_file_name)
-
-    inputs = depset(transitive = [
-        dep[ApiInfo].transitive_api
-        for dep in ctx.attr.deps
-    ])
-    args.add_all(
-        inputs,
-        expand_directories = True,
-    )
-
-    transitive_inputs = [inputs]
-    if ctx.attr.root_init_template:
-        transitive_inputs.append(ctx.attr.root_init_template.files)
-
-    ctx.actions.run(
-        mnemonic = "GenerateAPI",
-        executable = ctx.executable._generator_bin,
-        inputs = depset(
-            direct = ctx.files.compat_init_templates,
-            transitive = transitive_inputs,
-        ),
-        outputs = ctx.outputs.output_files,
-        arguments = [args],
-        progress_message = "Generating APIs for %{label} to %{output}.",
-    )
-
-generate_api = rule(
-    doc = "Generate Python API for all targets in transitive dependencies.",
-    implementation = _generate_api_impl,
-    attrs = {
-        "deps": attr.label_list(
-            doc = "extract_api targets to generate API from.",
-            allow_empty = True,
-            providers = [ApiInfo, PyInfo],
-            mandatory = True,
-        ),
-        "root_init_template": attr.label(
-            doc = "Template for the top level __init__.py file",
-            allow_single_file = True,
-        ),
-        "api_version": attr.int(
-            doc = "The API version to generate (1 or 2)",
-            values = [1, 2],
-        ),
-        "compat_api_versions": attr.int_list(
-            doc = "Additional versions to generate in compat/ subdirectory.",
-        ),
-        "compat_init_templates": attr.label_list(
-            doc = "Template for top-level __init__files under compat modules. This list must be " +
-                  "in the same order as the list of versions in compat_apiversions",
-            allow_files = True,
-        ),
-        "output_package": attr.string(
-            doc = "Root output package.",
-        ),
-        "output_dir": attr.string(
-            doc = "Subdirectory to output API to. If non-empty, must end with '/'.",
-        ),
-        "proxy_module_root": attr.string(
-            doc = "Module root for proxy-import format. If specified, proxy files with " +
-                  "`from proxy_module_root.proxy_module import *` will be created to enable " +
-                  "import resolution under TensorFlow.",
-        ),
-        "output_files": attr.output_list(
-            doc = "List of __init__.py files that should be generated. This list should include " +
-                  "file name for every module exported using tf_export. For e.g. if an op is " +
-                  "decorated with @tf_export('module1.module2', 'module3'). Then, output_files " +
-                  "should include module1/module2/__init__.py and module3/__init__.py.",
-        ),
-        "use_lazy_loading": attr.bool(
-            doc = "If true, lazy load imports in the generated API rather then imporing them all statically.",
-        ),
-        "packages_to_ignore": attr.string_list(
-            doc = "List of packages to ignore tf_exports from.",
-        ),
-        "root_file_name": attr.string(
-            doc = "The file name that should be generated for the top level API.",
-        ),
-        "_generator_bin": attr.label(
-            default = Label("//tensorflow_estimator/python/estimator/api:generator_wrapper"),
-            executable = True,
-            cfg = "exec",
-        ),
-    },
-)
-
-def generate_apis(
+def gen_api_init_files(
         name,
-        deps = [
+        output_files,
+        root_init_template = None,
+        srcs = [],
+        api_name = "estimator",
+        api_version = 2,
+        compat_api_versions = [],
+        compat_init_templates = [],
+        packages = ["tensorflow_estimator.python.estimator"],
+        package_deps = [
             "//tensorflow_estimator/python/estimator:estimator_py",
             # "//third_party/tensorflow/lite/python:analyzer",
             # "//third_party/tensorflow/lite/python:lite",
             # "//third_party/tensorflow/lite/python/authoring",
         ],
-        output_files = ESTIMATOR_API_INIT_FILES_V2,
-        root_init_template = None,
-        api_version = 2,
-        compat_api_versions = [],
-        compat_init_templates = [],
         output_package = "tensorflow_estimator.python.estimator.api",
         output_dir = "",
-        proxy_module_root = None,
-        packages_to_ignore = [],
-        root_file_name = "__init__.py",
-        visibility = ["//visibility:private"]):
-    """Generate TensorFlow APIs for a set of libraries.
+        root_file_name = "__init__.py"):
+    """Creates API directory structure and __init__.py files.
+
+    Creates a genrule that generates a directory structure with __init__.py
+    files that import all exported modules (i.e. modules with tf_export
+    decorators).
 
     Args:
-        name: name of generate_api target.
-        deps: python_library targets to serve as roots for extracting APIs.
-        output_files: The list of files that the API generator is exected to create.
-        root_init_template: The template for the top level __init__.py file generated.
-            "#API IMPORTS PLACEHOLDER" comment will be replaced with imports.
-        api_version: THhe API version to generate. (1 or 2)
-        compat_api_versions: Additional versions to generate in compat/ subdirectory.
-        compat_init_templates: Template for top level __init__.py files under the compat modules.
-            The list must be in the same order as the list of versions in 'compat_api_versions'
-        output_package: Root output package.
-        output_dir: Directory where the generated output files are placed. This should be a prefix
-            of every directory in 'output_files'
-        proxy_module_root: Module root for proxy-import format. If specified, proxy files with
-            `from proxy_module_root.proxy_module import *` will be created to enable import
-            resolution under TensorFlow.
-        packages_to_ignore: List of packages to ignore tf_exports from.
-        root_file_name: The file name that should be generated for the top level API.
-        visibility: Visibility of the target containing the generated files.
+      name: name of genrule to create.
+      output_files: List of __init__.py files that should be generated.
+        This list should include file name for every module exported using
+        tf_export. For e.g. if an op is decorated with
+        @tf_export('module1.module2', 'module3'). Then, output_files should
+        include module1/module2/__init__.py and module3/__init__.py.
+      root_init_template: Python init file that should be used as template for
+        root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
+        template will be replaced with root imports collected by this genrule.
+      srcs: genrule sources. If passing root_init_template, the template file
+        must be included in sources.
+      api_name: Name of the project that you want to generate API files for
+        (e.g. "tensorflow" or "estimator").
+      api_version: TensorFlow API version to generate. Must be either 1 or 2.
+      compat_api_versions: Older TensorFlow API versions to generate under
+        compat/ directory.
+      compat_init_templates: Python init file that should be used as template
+        for top level __init__.py files under compat/vN directories.
+        "# API IMPORTS PLACEHOLDER" comment inside this
+        template will be replaced with root imports collected by this genrule.
+      packages: Python packages containing the @tf_export decorators you want to
+        process
+      package_deps: Python library target containing your packages.
+      output_package: Package where generated API will be added to.
+      output_dir: Subdirectory to output API to.
+        If non-empty, must end with '/'.
+      root_file_name: Name of the root file with all the root imports.
     """
-    extract_name = name + ".extract-tensorflow-estimator"
-    extract_api(
-        name = extract_name,
-        deps = deps,
-        visibility = ["//visibility:private"],
+    root_init_template_flag = ""
+    if root_init_template:
+        root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+
+    primary_package = packages[0]
+    api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
+    native.py_binary(
+        name = api_gen_binary_target,
+        srcs = ["//tensorflow_estimator/python/estimator/api:create_python_api_wrapper.py"],
+        main = "//tensorflow_estimator/python/estimator/api:create_python_api_wrapper.py",
+        python_version = "PY3",
+        srcs_version = "PY3",
+        visibility = ["//visibility:public"],
+        deps = package_deps,
     )
 
-    if proxy_module_root != None:
-        # Avoid conflicts between the __init__.py file of TensorFlow and proxy module.
-        output_files = [f for f in output_files if f != "__init__.py"]
-
-    if root_file_name != None:
-        output_files = [f if f != "__init__.py" else root_file_name for f in output_files]
+    # Replace name of root file with root_file_name.
+    output_files = [
+        root_file_name if f == "__init__.py" else f
+        for f in output_files
+    ]
+    all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
+    compat_api_version_flags = ""
+    for compat_api_version in compat_api_versions:
+        compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
+
+    compat_init_template_flags = ""
+    for compat_init_template in compat_init_templates:
+        compat_init_template_flags += (
+            " --compat_init_template=$(location %s)" % compat_init_template
+        )
 
-    all_output_files = [_join(output_dir, f) for f in output_files]
+    flags = [
+        root_init_template_flag,
+        "--apidir=$(@D)" + output_dir,
+        "--apiname=" + api_name,
+        "--apiversion=" + str(api_version),
+        compat_api_version_flags,
+        compat_init_template_flags,
+        "--packages=" + ",".join(packages),
+        "--output_package=" + output_package,
+    ]
 
-    generate_api(
+    native.genrule(
         name = name,
-        deps = [":" + extract_name],
-        output_files = all_output_files,
-        output_dir = output_dir,
-        root_init_template = root_init_template,
-        compat_api_versions = compat_api_versions,
-        compat_init_templates = compat_init_templates,
-        api_version = api_version,
-        proxy_module_root = proxy_module_root,
-        visibility = visibility,
-        packages_to_ignore = packages_to_ignore,
-        use_lazy_loading = False,
-        output_package = output_package,
-        root_file_name = root_file_name,
+        outs = all_output_files,
+        cmd = if_indexing_source_code(
+            _make_cmd(api_gen_binary_target, flags, loading = "static"),
+            _make_cmd(api_gen_binary_target, flags, loading = "default"),
+        ),
+        srcs = srcs,
+        tools = [":" + api_gen_binary_target],
+        visibility = ["//visibility:public"],
     )
+
+def _make_cmd(api_gen_binary_target, flags, loading = "default"):
+    binary = "$(location :" + api_gen_binary_target + ")"
+    flags.append("--loading=" + loading)
+    return " ".join([binary] + flags + ["$(OUTS)"])
diff --git a/tensorflow_estimator/python/estimator/api/create_python_api_wrapper.py b/tensorflow_estimator/python/estimator/api/create_python_api_wrapper.py
new file mode 100644
index 0000000..9d52a02
--- /dev/null
+++ b/tensorflow_estimator/python/estimator/api/create_python_api_wrapper.py
@@ -0,0 +1,30 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Thin wrapper to call TensorFlow's API generation script.
+
+This file exists to provide a main function for the py_binary in the API
+generation genrule. It just calls the main function for the actual API
+generation script in TensorFlow.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow_estimator.python.estimator import estimator_lib  # pylint: disable=unused-import
+from tensorflow.python.tools.api.generator import create_python_api
+
+if __name__ == '__main__':
+  create_python_api.main()
diff --git a/tensorflow_estimator/python/estimator/api/extractor_wrapper.py b/tensorflow_estimator/python/estimator/api/extractor_wrapper.py
deleted file mode 100644
index 884fcba..0000000
--- a/tensorflow_estimator/python/estimator/api/extractor_wrapper.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Thin wrapper to call TensorFlow's API extractor script."""
-from absl import app
-
-from tensorflow.python.tools.api.generator2.extractor import extractor
-
-if __name__ == "__main__":
-  app.run(extractor.main)
diff --git a/tensorflow_estimator/python/estimator/api/generator_wrapper.py b/tensorflow_estimator/python/estimator/api/generator_wrapper.py
deleted file mode 100644
index ffcd49a..0000000
--- a/tensorflow_estimator/python/estimator/api/generator_wrapper.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Thin wrapper to call TensorFlow's API generator script."""
-from absl import app
-from tensorflow.python.tools.api.generator2.generator import generator
-
-if __name__ == "__main__":
-  app.run(generator.main)
-- 
2.41.0

