# Description:
#   Contains the Keras engine API (internal TensorFlow version).

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    # TODO(scottzhu): Remove non-keras deps from TF.
    default_visibility = [
        "//tensorflow/python:__pkg__",
        "//tensorflow/python/feature_column:__pkg__",
        "//tensorflow/python/keras:__subpackages__",
    ],
    licenses = ["notice"],
)

filegroup(
    name = "all_py_srcs",
    srcs = glob(["*.py"]),
    visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
)

py_library(
    name = "engine",
    srcs = [
        "__init__.py",
        "compile_utils.py",
        "functional.py",
        "input_layer.py",
        "partial_batch_padding_handler.py",
        "saving.py",
        "sequential.py",
        "training.py",
        "training_arrays_v1.py",
        "training_distributed_v1.py",
        "training_eager_v1.py",
        "training_generator_v1.py",
        "training_utils.py",
        "training_utils_v1.py",
        "training_v1.py",
    ],
    srcs_version = "PY3",
    deps = [
        ":base_layer",
        ":base_preprocessing_layer",
        ":data_adapter",
        ":input_spec",
        ":keras_tensor",
        ":node",
        "//tensorflow/python/data",
        "//tensorflow/python/distribute:distribute_coordinator",
        "//tensorflow/python/distribute:distribute_lib",
        "//tensorflow/python/distribute:input_lib",
        "//tensorflow/python/distribute:parameter_server_strategy",
        "//tensorflow/python/distribute:parameter_server_strategy_v2",
        "//tensorflow/python/distribute:reduce_util",
        "//tensorflow/python/distribute/coordinator:cluster_coordinator",
        "//tensorflow/python/eager:monitoring",
        "//tensorflow/python/framework:ops",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/framework:tensor_conversion",
        "//tensorflow/python/keras:activations",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/keras:callbacks",
        "//tensorflow/python/keras:callbacks_v1",
        "//tensorflow/python/keras:constraints",
        "//tensorflow/python/keras:losses",
        "//tensorflow/python/keras:metrics",
        "//tensorflow/python/keras:optimizers",
        "//tensorflow/python/keras:regularizers",
        "//tensorflow/python/keras/distribute",
        "//tensorflow/python/keras/distribute:distribute_coordinator_utils",
        "//tensorflow/python/keras/initializers",
        "//tensorflow/python/keras/mixed_precision:autocast_variable",
        "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer",
        "//tensorflow/python/keras/mixed_precision:policy",
        "//tensorflow/python/keras/saving",
        "//tensorflow/python/keras/utils:engine_utils",
        "//tensorflow/python/keras/utils:metrics_utils",
        "//tensorflow/python/keras/utils:mode_keys",
        "//tensorflow/python/keras/utils:tf_utils",
        "//tensorflow/python/keras/utils:version_utils",
        "//tensorflow/python/module",
        "//tensorflow/python/ops/ragged:ragged_tensor",
        "//tensorflow/python/ops/ragged:ragged_util",
        "//tensorflow/python/platform:tf_logging",
        "//tensorflow/python/profiler:trace",
        "//tensorflow/python/saved_model:constants",
        "//tensorflow/python/saved_model:loader",
        "//tensorflow/python/tpu:tpu_lib",
        "//tensorflow/python/trackable:data_structures",
        "//tensorflow/python/training:py_checkpoint_reader",
        "//tensorflow/python/types:data",
        "//tensorflow/python/util:nest",
        "//tensorflow/python/util:tf_decorator",
        "//tensorflow/python/util:tf_export",
        "//tensorflow/tools/docs:doc_controls",
        "@pypi_h5py//:pkg",
    ],
)

py_library(
    name = "base_layer_utils",
    srcs = ["base_layer_utils.py"],
    srcs_version = "PY3",
    deps = [
        "//tensorflow/python:tf2",
        "//tensorflow/python/distribute:distribute_lib",
        "//tensorflow/python/eager:context",
        "//tensorflow/python/framework:auto_control_deps",
        "//tensorflow/python/framework:dtypes",
        "//tensorflow/python/framework:ops",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/keras/utils:tf_inspect",
        "//tensorflow/python/keras/utils:tf_utils",
        "//tensorflow/python/ops:array_ops",
        "//tensorflow/python/ops:control_flow_v2_func_graphs",
        "//tensorflow/python/ops:variable_v1",
        "//tensorflow/python/ops:variables",
        "//tensorflow/python/util:nest",
        "//tensorflow/python/util:tf_export",
    ],
)

py_library(
    name = "base_layer",
    srcs = [
        "base_layer.py",
        "base_layer_v1.py",
    ],
    srcs_version = "PY3",
    deps = [
        ":base_layer_utils",
        ":input_spec",
        ":node",
        "//third_party/py/numpy",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/framework:tensor_conversion",
        "//tensorflow/python/platform:tf_logging",
        "//tensorflow/python/util:compat",
        "//tensorflow/python/util:nest",
        "//tensorflow/python/util:tf_export",
        "//tensorflow/python/ops:array_ops",
        "//tensorflow/python/framework:auto_control_deps",
        "//tensorflow/python/framework:constant_op",
        "//tensorflow/python/framework:dtypes",
        "//tensorflow/python/framework:errors",
        "//tensorflow/python/framework:ops",
        "//tensorflow/python/framework:func_graph",
        "//tensorflow/python/ops:math_ops",
        "//tensorflow/python/ops:resource_variable_ops",
        "//tensorflow/python/framework:sparse_tensor",
        "//tensorflow/python/framework:tensor_spec",
        "//tensorflow/python/framework:tensor_util",
        "//tensorflow/python:tf2",
        "//tensorflow/python/ops:variables",
        "//tensorflow/python/autograph/core:ag_ctx",
        "//tensorflow/python/autograph/impl:api",
        "//tensorflow/python/distribute:distribute_lib",
        "//tensorflow/python/distribute:sharded_variable",
        "//tensorflow/python/eager:context",
        "//tensorflow/python/eager:execute",
        "//tensorflow/python/eager:function",
        "//tensorflow/python/eager:monitoring",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/keras:constraints",
        "//tensorflow/python/keras/initializers",
        # TODO(keras-team): Fix the cyclar deps between layer and metrics.
        # "//tensorflow/python/keras:metrics",
        "//tensorflow/python/keras:regularizers",
        "//tensorflow/python/keras/mixed_precision:autocast_variable",
        "//tensorflow/python/keras/mixed_precision:loss_scale_optimizer",
        "//tensorflow/python/keras/mixed_precision:policy",
        "//tensorflow/python/keras/saving",
        "//tensorflow/python/keras/utils:generic_utils",
        "//tensorflow/python/keras/utils:layer_utils",
        "//tensorflow/python/keras/utils:object_identity",
        "//tensorflow/python/keras/utils:tf_utils",
        "//tensorflow/python/keras/utils:version_utils",
        "//tensorflow/python/module",
        "//tensorflow/python/ops/ragged:ragged_tensor",
        "//tensorflow/python/trackable:autotrackable",
        "//tensorflow/python/trackable:base",
        "//tensorflow/python/trackable:data_structures",
        "//tensorflow/python/trackable:layer_utils",
        "//tensorflow/tools/docs:doc_controls",
    ],
)

py_library(
    name = "data_adapter",
    srcs = ["data_adapter.py"],
    srcs_version = "PY3",
    deps = [
        "//tensorflow/python/data/ops:dataset_ops",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/framework:tensor_conversion",
        "//tensorflow/python/keras/utils:dataset_creator",
        "//tensorflow/python/keras/utils:engine_utils",
        "//tensorflow/python/keras/utils:tf_utils",
        "//tensorflow/python/platform:tf_logging",
        "//tensorflow/python/types:data",
        "//tensorflow/python/util:nest",
        "//tensorflow/python/util:tf_export",
    ],
)

py_library(
    name = "input_spec",
    srcs = ["input_spec.py"],
    srcs_version = "PY3",
    deps = [
        "//tensorflow/python/framework:dtypes",
        "//tensorflow/python/framework:tensor_shape",
        "//tensorflow/python/framework:tensor_spec",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/lib/io:lib",
        "//tensorflow/python/util:nest",
        "//tensorflow/python/util:tf_export",
    ],
)

py_library(
    name = "keras_tensor",
    srcs = ["keras_tensor.py"],
    srcs_version = "PY3",
    deps = [
        "//tensorflow/python/framework:dtypes",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/keras/utils:object_identity",
        "//tensorflow/python/lib/io:lib",
        "//tensorflow/python/util:nest",
    ],
)

py_library(
    name = "base_preprocessing_layer",
    srcs = [
        "base_preprocessing_layer.py",
    ],
    srcs_version = "PY3",
    deps = [
        ":base_layer",
        "//tensorflow/python/data",
        "//tensorflow/python/eager:monitoring",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/module",
    ],
)

py_library(
    name = "node",
    srcs = ["node.py"],
    srcs_version = "PY3",
    deps = [
        ":base_layer_utils",
        "//tensorflow/python/framework:ops",
        "//tensorflow/python/framework:tensor",
        "//tensorflow/python/framework:tensor_util",
        "//tensorflow/python/keras:backend",
        "//tensorflow/python/keras/utils:tf_utils",
        "//tensorflow/python/util:nest",
        "//third_party/py/numpy",
    ],
)
