diff --git a/examples/sagemaker-tensorflow/container/Dockerfile b/examples/sagemaker-tensorflow/container/Dockerfile index ddb3c9995..4be1fbdc3 100644 --- a/examples/sagemaker-tensorflow/container/Dockerfile +++ b/examples/sagemaker-tensorflow/container/Dockerfile @@ -1,4 +1,5 @@ - -FROM nvcr.io/nvidia/merlin/merlin-tensorflow:22.10 +FROM nvcr.io/nvidia/merlin/merlin-tensorflow:23.08 RUN pip3 install sagemaker-training + +COPY --chown=1000:1000 serve /usr/bin/serve diff --git a/examples/sagemaker-tensorflow/container/serve b/examples/sagemaker-tensorflow/container/serve new file mode 100755 index 000000000..887962904 --- /dev/null +++ b/examples/sagemaker-tensorflow/container/serve @@ -0,0 +1,136 @@ +#!/bin/bash +# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/ + +# Use 'ready' for ping check in single-model endpoint mode, and use 'live' for ping check in multi-model endpoint model +# https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/rest_predict_v2.yaml#L10-L26 +if [ -n "$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ]; then + SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE} +else + SAGEMAKER_TRITON_PING_MODE="ready" +fi + +# Note: in Triton on SageMaker, each model url is registered as a separate repository +# e.g., /opt/ml/models//model. Specifying MME model repo path as /opt/ml/models causes Triton +# to treat it as an additional empty repository and changes +# the state of all models to be UNAVAILABLE in the model repository +# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922 +# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton +SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker + +SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO} +is_mme_mode=false + +if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then + if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then + mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO} + SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO} + if [ -n "$SAGEMAKER_TRITON_OVERRIDE_PING_MODE" ]; then + SAGEMAKER_TRITON_PING_MODE=${SAGEMAKER_TRITON_OVERRIDE_PING_MODE} + else + SAGEMAKER_TRITON_PING_MODE="live" + fi + is_mme_mode=true + echo -e "Triton is running in SageMaker MME mode. Using Triton ping mode: \"${SAGEMAKER_TRITON_PING_MODE}\"" + fi +fi + +SAGEMAKER_ARGS="--model-repository=${SAGEMAKER_MODEL_REPO}" +#Set model namespacing to true, but allow disabling if required +if [ -n "$SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=${SAGEMAKER_TRITON_DISABLE_MODEL_NAMESPACING}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-namespacing=true" +fi +if [ -n "$SAGEMAKER_BIND_TO_PORT" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-port=${SAGEMAKER_BIND_TO_PORT}" +fi +if [ -n "$SAGEMAKER_SAFE_PORT_RANGE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-safe-port-range=${SAGEMAKER_SAFE_PORT_RANGE}" +fi +if [ -n "$SAGEMAKER_TRITON_ALLOW_GRPC" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=${SAGEMAKER_TRITON_ALLOW_GRPC}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-grpc=false" +fi +if [ -n "$SAGEMAKER_TRITON_ALLOW_METRICS" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=${SAGEMAKER_TRITON_ALLOW_METRICS}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --allow-metrics=false" +fi +if [ -n "$SAGEMAKER_TRITON_METRICS_PORT" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --metrics-port=${SAGEMAKER_TRITON_METRICS_PORT}" +fi +if [ -n "$SAGEMAKER_TRITON_GRPC_PORT" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --grpc-port=${SAGEMAKER_TRITON_GRPC_PORT}" +fi +if [ -n "$SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --buffer-manager-thread-count=${SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT}" +fi +if [ -n "$SAGEMAKER_TRITON_THREAD_COUNT" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-thread-count=${SAGEMAKER_TRITON_THREAD_COUNT}" +fi +# Enable verbose logging by default. If env variable is specified, use value from env variable +if [ -n "$SAGEMAKER_TRITON_LOG_VERBOSE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=${SAGEMAKER_TRITON_LOG_VERBOSE}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-verbose=true" +fi +if [ -n "$SAGEMAKER_TRITON_LOG_INFO" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-info=${SAGEMAKER_TRITON_LOG_INFO}" +fi +if [ -n "$SAGEMAKER_TRITON_LOG_WARNING" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-warning=${SAGEMAKER_TRITON_LOG_WARNING}" +fi +if [ -n "$SAGEMAKER_TRITON_LOG_ERROR" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-error=${SAGEMAKER_TRITON_LOG_ERROR}" +fi +if [ -n "$SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=${SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=16777216" #16MB +fi +if [ -n "$SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=${SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE}" +else + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=1048576" #1MB +fi +if [ -n "$SAGEMAKER_TRITON_TENSORFLOW_VERSION" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=tensorflow,version=${SAGEMAKER_TRITON_TENSORFLOW_VERSION}" +fi +if [ -n "$SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT" ]; then + num_gpus=$(nvidia-smi -L | wc -l) + for ((i=0; i<${num_gpus}; i++)); do + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --model-load-gpu-limit ${i}:${SAGEMAKER_TRITON_MODEL_LOAD_GPU_LIMIT}" + done +fi +if [ -n "$SAGEMAKER_TRITON_ADDITIONAL_ARGS" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} ${SAGEMAKER_TRITON_ADDITIONAL_ARGS}" +fi + +tritonserver --allow-sagemaker=true --allow-http=false $SAGEMAKER_ARGS diff --git a/examples/sagemaker-tensorflow/sagemaker-merlin-tensorflow.ipynb b/examples/sagemaker-tensorflow/sagemaker-merlin-tensorflow.ipynb index 4118c5cfd..c912183b4 100644 --- a/examples/sagemaker-tensorflow/sagemaker-merlin-tensorflow.ipynb +++ b/examples/sagemaker-tensorflow/sagemaker-merlin-tensorflow.ipynb @@ -7,7 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Copyright (c) 2022, NVIDIA CORPORATION.\n", + "# Copyright (c) 2023, NVIDIA CORPORATION.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -19,8 +19,7 @@ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", - "# limitations under the License.", - "# ======================================================================\n", + "# limitations under the License.# ======================================================================\n", "\n", "# Each user is responsible for checking the content of datasets and the\n", "# applicable licenses and determining if suitable for the intended use." @@ -53,7 +52,7 @@ "in this repository or example notebooks in\n", "[Merlin Models](https://github.com/NVIDIA-Merlin/models/tree/stable/examples).\n", "\n", - "To run this notebook, you need to have [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) installed." + "To run this notebook, you need to have [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) installed. If you are *not* running this notebook in the [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container (e.g., in a Sagemaker notebook instance or on Sagemaker Studio), you will also need to install the merlin packages by uncommenting below. You do not need to install them again if you run this notebook in [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container." ] }, { @@ -66,68 +65,103 @@ "name": "stdout", "output_type": "stream", "text": [ - "Collecting sagemaker\n", - " Downloading sagemaker-2.116.0.tar.gz (592 kB)\n", - "\u001b[K |████████████████████████████████| 592 kB 4.4 MB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: attrs<23,>=20.3.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (22.1.0)\n", - "Requirement already satisfied: boto3<2.0,>=1.20.21 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.25.2)\n", - "Requirement already satisfied: google-pasta in /usr/local/lib/python3.8/dist-packages (from sagemaker) (0.2.0)\n", - "Collecting importlib-metadata<5.0,>=1.4.0\n", - " Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)\n", - "Requirement already satisfied: numpy<2.0,>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.22.4)\n", - "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (21.3)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from sagemaker) (1.3.5)\n", - "Collecting pathos\n", - " Downloading pathos-0.3.0-py3-none-any.whl (79 kB)\n", - "\u001b[K |████████████████████████████████| 79 kB 10.5 MB/s eta 0:00:01\n", - "\u001b[?25hCollecting protobuf3-to-dict<1.0,>=0.1.5\n", - " Downloading protobuf3-to-dict-0.1.5.tar.gz (3.5 kB)\n", - "Requirement already satisfied: protobuf<4.0,>=3.1 in /usr/local/lib/python3.8/dist-packages (from sagemaker) (3.19.6)\n", - "Collecting schema\n", - " Downloading schema-0.7.5-py2.py3-none-any.whl (17 kB)\n", - "Collecting smdebug_rulesconfig==1.0.1\n", - " Downloading smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl (20 kB)\n", - "Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (0.6.0)\n", - "Requirement already satisfied: botocore<1.29.0,>=1.28.2 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (1.28.2)\n", - "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /usr/local/lib/python3.8/dist-packages (from boto3<2.0,>=1.20.21->sagemaker) (1.0.1)\n", - "Requirement already satisfied: six in /usr/lib/python3/dist-packages (from google-pasta->sagemaker) (1.14.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata<5.0,>=1.4.0->sagemaker) (3.10.0)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->sagemaker) (3.0.9)\n", - "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->sagemaker) (2022.5)\n", - "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->sagemaker) (2.8.2)\n", - "Collecting pox>=0.3.2\n", - " Downloading pox-0.3.2-py3-none-any.whl (29 kB)\n", - "Collecting dill>=0.3.6\n", - " Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n", - "\u001b[K |████████████████████████████████| 110 kB 17.3 MB/s eta 0:00:01\n", - "\u001b[?25hCollecting multiprocess>=0.70.14\n", - " Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)\n", - "\u001b[K |████████████████████████████████| 132 kB 17.8 MB/s eta 0:00:01\n", - "\u001b[?25hCollecting ppft>=1.7.6.6\n", - " Downloading ppft-1.7.6.6-py3-none-any.whl (52 kB)\n", - "\u001b[K |████████████████████████████████| 52 kB 2.9 MB/s eta 0:00:01\n", - "\u001b[?25hCollecting contextlib2>=0.5.5\n", - " Downloading contextlib2-21.6.0-py2.py3-none-any.whl (13 kB)\n", - "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /usr/local/lib/python3.8/dist-packages (from botocore<1.29.0,>=1.28.2->boto3<2.0,>=1.20.21->sagemaker) (1.26.12)\n", - "Building wheels for collected packages: sagemaker, protobuf3-to-dict\n", - " Building wheel for sagemaker (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for sagemaker: filename=sagemaker-2.116.0-py2.py3-none-any.whl size=809052 sha256=f446dd6eed6d268b7f3f2709f8f11c1ba153e382fbea9b2caedd517c1fb71215\n", - " Stored in directory: /root/.cache/pip/wheels/3e/cb/b1/5b13ff7b150aa151e4a11030a6c41b1e457c31a52ea1ef11b0\n", - " Building wheel for protobuf3-to-dict (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for protobuf3-to-dict: filename=protobuf3_to_dict-0.1.5-py3-none-any.whl size=4029 sha256=8f99baaa875ba544d54f624f95dfbf4fd52ca96d52ce8af6d05c1ff2bb8435b2\n", - " Stored in directory: /root/.cache/pip/wheels/fc/10/27/2d1e23d8b9a9013a83fbb418a0b17b1e6f81c8db8f53b53934\n", - "Successfully built sagemaker protobuf3-to-dict\n", - "Installing collected packages: importlib-metadata, pox, dill, multiprocess, ppft, pathos, protobuf3-to-dict, contextlib2, schema, smdebug-rulesconfig, sagemaker\n", - " Attempting uninstall: importlib-metadata\n", - " Found existing installation: importlib-metadata 5.0.0\n", - " Uninstalling importlib-metadata-5.0.0:\n", - " Successfully uninstalled importlib-metadata-5.0.0\n", - "Successfully installed contextlib2-21.6.0 dill-0.3.6 importlib-metadata-4.13.0 multiprocess-0.70.14 pathos-0.3.0 pox-0.3.2 ppft-1.7.6.6 protobuf3-to-dict-0.1.5 sagemaker-2.116.0 schema-0.7.5 smdebug-rulesconfig-1.0.1\n" + "Requirement already satisfied: sagemaker in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (2.188.0)\n", + "Requirement already satisfied: attrs<24,>=23.1.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (23.1.0)\n", + "Requirement already satisfied: boto3<2.0,>=1.26.131 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (1.28.57)\n", + "Requirement already satisfied: cloudpickle==2.2.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (2.2.1)\n", + "Requirement already satisfied: google-pasta in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (0.2.0)\n", + "Requirement already satisfied: numpy<2.0,>=1.9.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (1.22.3)\n", + "Requirement already satisfied: protobuf<5.0,>=3.12 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (3.20.3)\n", + "Requirement already satisfied: smdebug-rulesconfig==1.0.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (1.0.1)\n", + "Requirement already satisfied: importlib-metadata<7.0,>=1.4.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (6.8.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (21.3)\n", + "Requirement already satisfied: pandas in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (1.5.3)\n", + "Requirement already satisfied: pathos in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (0.3.1)\n", + "Requirement already satisfied: schema in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (0.7.5)\n", + "Requirement already satisfied: PyYAML~=6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (6.0)\n", + "Requirement already satisfied: jsonschema in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (4.18.4)\n", + "Requirement already satisfied: platformdirs in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (3.9.1)\n", + "Requirement already satisfied: tblib==1.7.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from sagemaker) (1.7.0)\n", + "Requirement already satisfied: botocore<1.32.0,>=1.31.57 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from boto3<2.0,>=1.26.131->sagemaker) (1.31.57)\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from boto3<2.0,>=1.26.131->sagemaker) (1.0.1)\n", + "Requirement already satisfied: s3transfer<0.8.0,>=0.7.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from boto3<2.0,>=1.26.131->sagemaker) (0.7.0)\n", + "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from importlib-metadata<7.0,>=1.4.0->sagemaker) (3.16.2)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from packaging>=20.0->sagemaker) (3.0.9)\n", + "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from google-pasta->sagemaker) (1.16.0)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from jsonschema->sagemaker) (2023.7.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from jsonschema->sagemaker) (0.30.0)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from jsonschema->sagemaker) (0.9.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pandas->sagemaker) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pandas->sagemaker) (2023.3)\n", + "Requirement already satisfied: ppft>=1.7.6.7 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pathos->sagemaker) (1.7.6.7)\n", + "Requirement already satisfied: dill>=0.3.7 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pathos->sagemaker) (0.3.7)\n", + "Requirement already satisfied: pox>=0.3.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pathos->sagemaker) (0.3.3)\n", + "Requirement already satisfied: multiprocess>=0.70.15 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pathos->sagemaker) (0.70.15)\n", + "Requirement already satisfied: contextlib2>=0.5.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from schema->sagemaker) (21.6.0)\n", + "Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from botocore<1.32.0,>=1.31.57->boto3<2.0,>=1.26.131->sagemaker) (1.26.14)\n", + "Requirement already satisfied: merlin-core==23.08 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (23.8.0)\n", + "Requirement already satisfied: merlin-dataloader==23.08 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (23.8.0)\n", + "Requirement already satisfied: nvtabular==23.08 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (23.8.0)\n", + "Requirement already satisfied: merlin-models==23.08 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (23.8.0)\n", + "Requirement already satisfied: merlin-systems==23.08 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (23.8.0)\n", + "Requirement already satisfied: dask>=2022.11.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (2023.9.2)\n", + "Requirement already satisfied: dask-cuda>=22.12.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (23.10.0)\n", + "Requirement already satisfied: distributed>=2022.11.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (2023.9.2)\n", + "Requirement already satisfied: fsspec>=2022.7.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (2023.6.0)\n", + "Requirement already satisfied: numpy>=1.22.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (1.22.3)\n", + "Requirement already satisfied: pandas<1.6.0dev0,>=1.2.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (1.5.3)\n", + "Requirement already satisfied: numba>=0.54 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (0.57.1)\n", + "Requirement already satisfied: pyarrow>=5.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (12.0.1)\n", + "Requirement already satisfied: protobuf>=3.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (3.20.3)\n", + "Requirement already satisfied: tqdm>=4.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (4.65.0)\n", + "Requirement already satisfied: tensorflow-metadata>=1.2.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (1.14.0)\n", + "Requirement already satisfied: betterproto<2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (1.2.5)\n", + "Requirement already satisfied: packaging in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (21.3)\n", + "Requirement already satisfied: npy-append-array in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (0.9.16)\n", + "Requirement already satisfied: pynvml<11.5,>=11.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-core==23.08) (11.4.1)\n", + "Requirement already satisfied: scipy in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from nvtabular==23.08) (1.11.1)\n", + "Requirement already satisfied: requests<3,>=2.10 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-systems==23.08) (2.31.0)\n", + "Requirement already satisfied: treelite==2.4.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-systems==23.08) (2.4.0)\n", + "Requirement already satisfied: treelite-runtime==2.4.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from merlin-systems==23.08) (2.4.0)\n", + "Requirement already satisfied: grpclib in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from betterproto<2.0.0->merlin-core==23.08) (0.4.6)\n", + "Requirement already satisfied: stringcase in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from betterproto<2.0.0->merlin-core==23.08) (1.2.0)\n", + "Requirement already satisfied: click>=8.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (8.1.6)\n", + "Requirement already satisfied: cloudpickle>=1.5.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (2.2.1)\n", + "Requirement already satisfied: partd>=1.2.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (1.4.0)\n", + "Requirement already satisfied: pyyaml>=5.3.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (6.0)\n", + "Requirement already satisfied: toolz>=0.10.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (0.12.0)\n", + "Requirement already satisfied: importlib-metadata>=4.13.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask>=2022.11.1->merlin-core==23.08) (6.8.0)\n", + "Requirement already satisfied: zict>=2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from dask-cuda>=22.12.0->merlin-core==23.08) (3.0.0)\n", + "Requirement already satisfied: jinja2>=2.10.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (3.1.2)\n", + "Requirement already satisfied: locket>=1.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (1.0.0)\n", + "Requirement already satisfied: msgpack>=1.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (1.0.5)\n", + "Requirement already satisfied: psutil>=5.7.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (5.9.5)\n", + "Requirement already satisfied: sortedcontainers>=2.0.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (2.4.0)\n", + "Requirement already satisfied: tblib>=1.6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (1.7.0)\n", + "Requirement already satisfied: tornado>=6.0.4 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (6.3.2)\n", + "Requirement already satisfied: urllib3>=1.24.3 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from distributed>=2022.11.1->merlin-core==23.08) (1.26.14)\n", + "Requirement already satisfied: llvmlite<0.41,>=0.40.0dev0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from numba>=0.54->merlin-core==23.08) (0.40.1)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from packaging->merlin-core==23.08) (3.0.9)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pandas<1.6.0dev0,>=1.2.0->merlin-core==23.08) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from pandas<1.6.0dev0,>=1.2.0->merlin-core==23.08) (2023.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from requests<3,>=2.10->merlin-systems==23.08) (3.2.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from requests<3,>=2.10->merlin-systems==23.08) (3.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from requests<3,>=2.10->merlin-systems==23.08) (2023.5.7)\n", + "Requirement already satisfied: absl-py<2.0.0,>=0.9 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from tensorflow-metadata>=1.2.0->merlin-core==23.08) (1.4.0)\n", + "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from tensorflow-metadata>=1.2.0->merlin-core==23.08) (1.61.0)\n", + "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from importlib-metadata>=4.13.0->dask>=2022.11.1->merlin-core==23.08) (3.16.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from jinja2>=2.10.3->distributed>=2022.11.1->merlin-core==23.08) (2.1.3)\n", + "Requirement already satisfied: six>=1.5 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas<1.6.0dev0,>=1.2.0->merlin-core==23.08) (1.16.0)\n", + "Requirement already satisfied: h2<5,>=3.1.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from grpclib->betterproto<2.0.0->merlin-core==23.08) (4.1.0)\n", + "Requirement already satisfied: multidict in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from grpclib->betterproto<2.0.0->merlin-core==23.08) (6.0.4)\n", + "Requirement already satisfied: hyperframe<7,>=6.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from h2<5,>=3.1.0->grpclib->betterproto<2.0.0->merlin-core==23.08) (6.0.1)\n", + "Requirement already satisfied: hpack<5,>=4.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages (from h2<5,>=3.1.0->grpclib->betterproto<2.0.0->merlin-core==23.08) (4.0.0)\n" ] } ], "source": [ - "! python -m pip install sagemaker" + "! python -m pip install sagemaker\n", + "#! python -m pip install merlin-core==23.08 merlin-dataloader==23.08 nvtabular==23.08 merlin-models==23.08 merlin-systems==23.08" ] }, { @@ -160,9 +194,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", + "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'\n", + " warn(f\"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/merlin/io/dataset.py:267: UserWarning: Initializing an NVTabular Dataset in CPU mode.This is an experimental feature with extremely limited support!\n", " warnings.warn(\n" ] } @@ -172,7 +208,7 @@ "\n", "from merlin.datasets.synthetic import generate_data\n", "\n", - "DATA_FOLDER = os.environ.get(\"DATA_FOLDER\", \"/workspace/data/\")\n", + "DATA_FOLDER = os.environ.get(\"DATA_FOLDER\", \"./data/\")\n", "NUM_ROWS = os.environ.get(\"NUM_ROWS\", 1_000_000)\n", "SYNTHETIC_DATA = eval(os.environ.get(\"SYNTHETIC_DATA\", \"True\"))\n", "BATCH_SIZE = int(os.environ.get(\"BATCH_SIZE\", 512))\n", @@ -215,7 +251,7 @@ "source": [ "%%writefile train.py\n", "#\n", - "# Copyright (c) 2022, NVIDIA CORPORATION.\n", + "# Copyright (c) 2023 NVIDIA CORPORATION.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -284,14 +320,15 @@ "\n", "\n", "def create_nvtabular_workflow(train_path, valid_path):\n", - " user_id = [\"user_id\"] >> Categorify() >> TagAsUserID()\n", - " item_id = [\"item_id\"] >> Categorify() >> TagAsItemID()\n", - " targets = [\"click\"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, \"target\"])\n", + "\n", + " user_id_raw = [\"user_id\"] >> Rename(postfix='_raw') >> LambdaOp(lambda col: col.astype(\"int32\")) >> TagAsUserFeatures()\n", + " item_id_raw = [\"item_id\"] >> Rename(postfix='_raw') >> LambdaOp(lambda col: col.astype(\"int32\")) >> TagAsItemFeatures()\n", + "\n", + " user_id = [\"user_id\"] >> Categorify(dtype=\"int32\") >> TagAsUserID()\n", + " item_id = [\"item_id\"] >> Categorify(dtype=\"int32\") >> TagAsItemID()\n", "\n", " item_features = (\n", - " [\"item_category\", \"item_shop\", \"item_brand\"]\n", - " >> Categorify()\n", - " >> TagAsItemFeatures()\n", + " [\"item_category\", \"item_shop\", \"item_brand\"] >> Categorify(dtype=\"int32\") >> TagAsItemFeatures()\n", " )\n", "\n", " user_features = (\n", @@ -307,12 +344,15 @@ " \"user_intentions\",\n", " \"user_brands\",\n", " \"user_categories\",\n", - " ]\n", - " >> Categorify()\n", - " >> TagAsUserFeatures()\n", + " ] >> Categorify(dtype=\"int32\") >> TagAsUserFeatures()\n", " )\n", "\n", - " outputs = user_id + item_id + item_features + user_features + targets\n", + " targets = [\"click\"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, \"target\"])\n", + "\n", + " outputs = user_id + item_id + item_features + user_features + user_id_raw + item_id_raw + targets\n", + "\n", + " # add dropna op to filter rows with nulls\n", + " outputs = outputs >> Dropna()\n", "\n", " workflow = nvt.Workflow(outputs)\n", "\n", @@ -440,10 +480,11 @@ ], "source": [ "%%writefile container/Dockerfile\n", + "FROM nvcr.io/nvidia/merlin/merlin-tensorflow:23.08\n", "\n", - "FROM nvcr.io/nvidia/merlin/merlin-tensorflow:22.10\n", + "RUN pip3 install sagemaker-training\n", "\n", - "RUN pip3 install sagemaker-training" + "COPY --chown=1000:1000 serve /usr/bin/serve" ] }, { @@ -549,17 +590,16 @@ "id": "2b62f39e-af41-4aec-857f-0541038d9c5c", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Couldn't call 'get_role' to get Role ARN from role name AWSOS-AD-Engineer to get Role path.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n", "arn:aws:iam::843263297212:role/AWSOS-AD-Engineer\n" ] } @@ -673,26 +713,38 @@ "name": "stdout", "output_type": "stream", "text": [ - "2022-11-09 10:18:31 Starting - Starting the training job...\n", - "2022-11-09 10:18:54 Starting - Preparing the instances for trainingProfilerReport-1667989110: InProgress\n", - "......\n", - "2022-11-09 10:19:54 Downloading - Downloading input data...\n", - "2022-11-09 10:20:34 Training - Downloading the training image..................................\u001b[34m==================================\u001b[0m\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n", + "Using provided s3_resource\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:sagemaker:Creating training-job with name: sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-10-26 00:23:35 Starting - Starting the training job...\n", + "2023-10-26 00:23:51 Starting - Preparing the instances for training......\n", + "2023-10-26 00:25:02 Downloading - Downloading input data...\n", + "2023-10-26 00:25:27 Training - Downloading the training image.......................\u001b[34m==================================\u001b[0m\n", "\u001b[34m== Triton Inference Server Base ==\u001b[0m\n", "\u001b[34m==================================\u001b[0m\n", - "\u001b[34mNVIDIA Release 22.08 (build 42766143)\u001b[0m\n", - "\u001b[34mCopyright (c) 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\u001b[0m\n", + "\u001b[34mNVIDIA Release 23.06 (build 62878575)\u001b[0m\n", + "\u001b[34mCopyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\u001b[0m\n", "\u001b[34mVarious files include modifications (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved.\u001b[0m\n", "\u001b[34mThis container image and its contents are governed by the NVIDIA Deep Learning Container License.\u001b[0m\n", "\u001b[34mBy pulling and using the container, you accept the terms and conditions of this license:\u001b[0m\n", "\u001b[34mhttps://developer.nvidia.com/ngc/nvidia-deep-learning-container-license\u001b[0m\n", - "\u001b[34mNOTE: CUDA Forward Compatibility mode ENABLED.\n", - " Using CUDA 11.7 driver version 515.65.01 with kernel driver version 510.47.03.\n", - " See https://docs.nvidia.com/deploy/cuda-compatibility/ for details.\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:03,405 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:03,438 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:03,473 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:03,485 sagemaker-training-toolkit INFO Invoking user script\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:21,913 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:21,947 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:21,979 sagemaker-training-toolkit INFO No Neurons detected (normal if no neurons installed)\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:21,992 sagemaker-training-toolkit INFO Invoking user script\u001b[0m\n", "\u001b[34mTraining Env:\u001b[0m\n", "\u001b[34m{\n", " \"additional_framework_parameters\": {},\n", @@ -746,11 +798,11 @@ " \"is_master\": true,\n", " \"is_modelparallel_enabled\": null,\n", " \"is_smddpmprun_installed\": false,\n", - " \"job_name\": \"sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376\",\n", + " \"job_name\": \"sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295\",\n", " \"log_level\": 20,\n", " \"master_hostname\": \"algo-1\",\n", " \"model_dir\": \"/opt/ml/model\",\n", - " \"module_dir\": \"s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz\",\n", + " \"module_dir\": \"s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295/source/sourcedir.tar.gz\",\n", " \"module_name\": \"train\",\n", " \"network_interface_name\": \"eth0\",\n", " \"num_cpus\": 4,\n", @@ -807,79 +859,68 @@ "\u001b[34mSM_NUM_GPUS=1\u001b[0m\n", "\u001b[34mSM_NUM_NEURONS=0\u001b[0m\n", "\u001b[34mSM_MODEL_DIR=/opt/ml/model\u001b[0m\n", - "\u001b[34mSM_MODULE_DIR=s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz\u001b[0m\n", - "\u001b[34mSM_TRAINING_ENV={\"additional_framework_parameters\":{},\"channel_input_dirs\":{\"train\":\"/opt/ml/input/data/train\",\"valid\":\"/opt/ml/input/data/valid\"},\"current_host\":\"algo-1\",\"current_instance_group\":\"homogeneousCluster\",\"current_instance_group_hosts\":[\"algo-1\"],\"current_instance_type\":\"ml.g4dn.xlarge\",\"distribution_hosts\":[],\"distribution_instance_groups\":[],\"framework_module\":null,\"hosts\":[\"algo-1\"],\"hyperparameters\":{\"batch_size\":1024,\"epoch\":10},\"input_config_dir\":\"/opt/ml/input/config\",\"input_data_config\":{\"train\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"},\"valid\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"}},\"input_dir\":\"/opt/ml/input\",\"instance_groups\":[\"homogeneousCluster\"],\"instance_groups_dict\":{\"homogeneousCluster\":{\"hosts\":[\"algo-1\"],\"instance_group_name\":\"homogeneousCluster\",\"instance_type\":\"ml.g4dn.xlarge\"}},\"is_hetero\":false,\"is_master\":true,\"is_modelparallel_enabled\":null,\"is_smddpmprun_installed\":false,\"job_name\":\"sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376\",\"log_level\":20,\"master_hostname\":\"algo-1\",\"model_dir\":\"/opt/ml/model\",\"module_dir\":\"s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/source/sourcedir.tar.gz\",\"module_name\":\"train\",\"network_interface_name\":\"eth0\",\"num_cpus\":4,\"num_gpus\":1,\"num_neurons\":0,\"output_data_dir\":\"/opt/ml/output/data\",\"output_dir\":\"/opt/ml/output\",\"output_intermediate_dir\":\"/opt/ml/output/intermediate\",\"resource_config\":{\"current_group_name\":\"homogeneousCluster\",\"current_host\":\"algo-1\",\"current_instance_type\":\"ml.g4dn.xlarge\",\"hosts\":[\"algo-1\"],\"instance_groups\":[{\"hosts\":[\"algo-1\"],\"instance_group_name\":\"homogeneousCluster\",\"instance_type\":\"ml.g4dn.xlarge\"}],\"network_interface_name\":\"eth0\"},\"user_entry_point\":\"train.py\"}\u001b[0m\n", + "\u001b[34mSM_MODULE_DIR=s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295/source/sourcedir.tar.gz\u001b[0m\n", + "\u001b[34mSM_TRAINING_ENV={\"additional_framework_parameters\":{},\"channel_input_dirs\":{\"train\":\"/opt/ml/input/data/train\",\"valid\":\"/opt/ml/input/data/valid\"},\"current_host\":\"algo-1\",\"current_instance_group\":\"homogeneousCluster\",\"current_instance_group_hosts\":[\"algo-1\"],\"current_instance_type\":\"ml.g4dn.xlarge\",\"distribution_hosts\":[],\"distribution_instance_groups\":[],\"framework_module\":null,\"hosts\":[\"algo-1\"],\"hyperparameters\":{\"batch_size\":1024,\"epoch\":10},\"input_config_dir\":\"/opt/ml/input/config\",\"input_data_config\":{\"train\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"},\"valid\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"}},\"input_dir\":\"/opt/ml/input\",\"instance_groups\":[\"homogeneousCluster\"],\"instance_groups_dict\":{\"homogeneousCluster\":{\"hosts\":[\"algo-1\"],\"instance_group_name\":\"homogeneousCluster\",\"instance_type\":\"ml.g4dn.xlarge\"}},\"is_hetero\":false,\"is_master\":true,\"is_modelparallel_enabled\":null,\"is_smddpmprun_installed\":false,\"job_name\":\"sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295\",\"log_level\":20,\"master_hostname\":\"algo-1\",\"model_dir\":\"/opt/ml/model\",\"module_dir\":\"s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295/source/sourcedir.tar.gz\",\"module_name\":\"train\",\"network_interface_name\":\"eth0\",\"num_cpus\":4,\"num_gpus\":1,\"num_neurons\":0,\"output_data_dir\":\"/opt/ml/output/data\",\"output_dir\":\"/opt/ml/output\",\"output_intermediate_dir\":\"/opt/ml/output/intermediate\",\"resource_config\":{\"current_group_name\":\"homogeneousCluster\",\"current_host\":\"algo-1\",\"current_instance_type\":\"ml.g4dn.xlarge\",\"hosts\":[\"algo-1\"],\"instance_groups\":[{\"hosts\":[\"algo-1\"],\"instance_group_name\":\"homogeneousCluster\",\"instance_type\":\"ml.g4dn.xlarge\"}],\"network_interface_name\":\"eth0\"},\"user_entry_point\":\"train.py\"}\u001b[0m\n", "\u001b[34mSM_USER_ARGS=[\"--batch_size\",\"1024\",\"--epoch\",\"10\"]\u001b[0m\n", "\u001b[34mSM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate\u001b[0m\n", "\u001b[34mSM_CHANNEL_TRAIN=/opt/ml/input/data/train\u001b[0m\n", "\u001b[34mSM_CHANNEL_VALID=/opt/ml/input/data/valid\u001b[0m\n", "\u001b[34mSM_HP_BATCH_SIZE=1024\u001b[0m\n", "\u001b[34mSM_HP_EPOCH=10\u001b[0m\n", - "\u001b[34mPYTHONPATH=/opt/ml/code:/usr/local/bin:/opt/tritonserver:/usr/local/lib/python3.8/dist-packages:/usr/lib/python38.zip:/usr/lib/python3.8:/usr/lib/python3.8/lib-dynload:/usr/local/lib/python3.8/dist-packages/faiss-1.7.2-py3.8.egg:/usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg:/usr/local/lib/python3.8/dist-packages/merlin_hps-1.0.0-py3.8-linux-x86_64.egg:/usr/lib/python3/dist-packages\u001b[0m\n", + "\u001b[34mPYTHONPATH=/opt/ml/code:/usr/local/bin:/opt/tritonserver:/usr/local/lib/python3.10/dist-packages:/usr/lib/python310.zip:/usr/lib/python3.10:/usr/lib/python3.10/lib-dynload:/usr/local/lib/python3.10/dist-packages/faiss-1.7.2-py3.10.egg:/ptx:/usr/local/lib/python3.10/dist-packages/merlin_sok-1.2.0-py3.10-linux-x86_64.egg:/usr/local/lib/python3.10/dist-packages/merlin_hps-1.0.0-py3.10-linux-x86_64.egg:/usr/lib/python3/dist-packages:/usr/lib/python3.10/dist-packages\u001b[0m\n", "\u001b[34mInvoking script with the following command:\u001b[0m\n", "\u001b[34m/usr/bin/python3 train.py --batch_size 1024 --epoch 10\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:03,486 sagemaker-training-toolkit INFO Exceptions not imported for SageMaker Debugger as it is not installed.\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:21,993 sagemaker-training-toolkit INFO Exceptions not imported for SageMaker Debugger as it is not installed.\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:22.172037: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:22.226772: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\u001b[0m\n", + "\u001b[34mTo enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\u001b[0m\n", "\n", - "2022-11-09 10:27:16 Training - Training image download completed. Training in progress.\u001b[34m2022-11-09 10:27:08.761711: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.818302: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.819693: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.819906: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.894084: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\u001b[0m\n", - "\u001b[34mTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.895367: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.895631: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:12.895807: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:16.651703: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:16.651981: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:16.652183: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\u001b[0m\n", - "\u001b[34m2022-11-09 10:27:16.653025: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10752 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5\u001b[0m\n", - "\u001b[34mWorkflow saved to /tmp/tmp5fpdavsc/workflow.\u001b[0m\n", + "2023-10-26 00:29:13 Training - Training image download completed. Training in progress.\u001b[34m/usr/local/lib/python3.10/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'\n", + " warn(f\"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}\")\u001b[0m\n", + "\u001b[34mWARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.\u001b[0m\n", + "\u001b[34mWARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.\u001b[0m\n", + "\u001b[34m2023-10-26 00:29:30.531929: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\u001b[0m\n", + "\u001b[34m[INFO]: sparse_operation_kit is imported\u001b[0m\n", + "\u001b[34m[SOK INFO] Import /usr/local/lib/python3.10/dist-packages/merlin_sok-1.2.0-py3.10-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\u001b[0m\n", + "\u001b[34m[SOK INFO] Import /usr/local/lib/python3.10/dist-packages/merlin_sok-1.2.0-py3.10-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\u001b[0m\n", + "\u001b[34m[SOK INFO] Initialize finished, communication tool: horovod\u001b[0m\n", + "\u001b[34mWorkflow saved to /tmp/tmpo0dgd1_j/workflow.\u001b[0m\n", "\u001b[34mbatch_size = 1024, epochs = 10\u001b[0m\n", "\u001b[34mEpoch 1/10\u001b[0m\n", - "\u001b[34m684/684 - 14s - loss: 0.6932 - auc: 0.4998 - regularization_loss: 0.0000e+00 - val_loss: 0.6931 - val_auc: 0.5000 - val_regularization_loss: 0.0000e+00 - 14s/epoch - 20ms/step\u001b[0m\n", + "\u001b[34m684/684 - 15s - loss: 0.6932 - auc: 0.5000 - regularization_loss: 0.0000e+00 - loss_batch: 0.6931 - val_loss: 0.6931 - val_auc: 0.5000 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6932 - 15s/epoch - 22ms/step\u001b[0m\n", "\u001b[34mEpoch 2/10\u001b[0m\n", - "\u001b[34m684/684 - 8s - loss: 0.6931 - auc: 0.5026 - regularization_loss: 0.0000e+00 - val_loss: 0.6932 - val_auc: 0.4990 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6932 - auc: 0.4992 - regularization_loss: 0.0000e+00 - loss_batch: 0.6932 - val_loss: 0.6932 - val_auc: 0.5007 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6930 - 8s/epoch - 12ms/step\u001b[0m\n", "\u001b[34mEpoch 3/10\u001b[0m\n", - "\u001b[34m684/684 - 7s - loss: 0.6922 - auc: 0.5222 - regularization_loss: 0.0000e+00 - val_loss: 0.6941 - val_auc: 0.4989 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6931 - auc: 0.5043 - regularization_loss: 0.0000e+00 - loss_batch: 0.6930 - val_loss: 0.6932 - val_auc: 0.4992 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6928 - 8s/epoch - 12ms/step\u001b[0m\n", "\u001b[34mEpoch 4/10\u001b[0m\n", - "\u001b[34m684/684 - 7s - loss: 0.6858 - auc: 0.5509 - regularization_loss: 0.0000e+00 - val_loss: 0.6991 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6916 - auc: 0.5279 - regularization_loss: 0.0000e+00 - loss_batch: 0.6920 - val_loss: 0.6945 - val_auc: 0.4992 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6923 - 8s/epoch - 11ms/step\u001b[0m\n", "\u001b[34mEpoch 5/10\u001b[0m\n", - "\u001b[34m684/684 - 7s - loss: 0.6790 - auc: 0.5660 - regularization_loss: 0.0000e+00 - val_loss: 0.7052 - val_auc: 0.4993 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6843 - auc: 0.5551 - regularization_loss: 0.0000e+00 - loss_batch: 0.6859 - val_loss: 0.7006 - val_auc: 0.4997 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6933 - 8s/epoch - 12ms/step\u001b[0m\n", "\u001b[34mEpoch 6/10\u001b[0m\n", - "\u001b[34m684/684 - 8s - loss: 0.6751 - auc: 0.5722 - regularization_loss: 0.0000e+00 - val_loss: 0.7096 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6780 - auc: 0.5685 - regularization_loss: 0.0000e+00 - loss_batch: 0.6825 - val_loss: 0.7065 - val_auc: 0.4995 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.6988 - 8s/epoch - 12ms/step\u001b[0m\n", "\u001b[34mEpoch 7/10\u001b[0m\n", - "\u001b[34m684/684 - 7s - loss: 0.6722 - auc: 0.5755 - regularization_loss: 0.0000e+00 - val_loss: 0.7184 - val_auc: 0.4991 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6748 - auc: 0.5739 - regularization_loss: 0.0000e+00 - loss_batch: 0.6718 - val_loss: 0.7130 - val_auc: 0.4990 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.7103 - 8s/epoch - 11ms/step\u001b[0m\n", "\u001b[34mEpoch 8/10\u001b[0m\n", - "\u001b[34m684/684 - 7s - loss: 0.6700 - auc: 0.5777 - regularization_loss: 0.0000e+00 - val_loss: 0.7289 - val_auc: 0.4990 - val_regularization_loss: 0.0000e+00 - 7s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6723 - auc: 0.5769 - regularization_loss: 0.0000e+00 - loss_batch: 0.6648 - val_loss: 0.7166 - val_auc: 0.4989 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.7082 - 8s/epoch - 11ms/step\u001b[0m\n", "\u001b[34mEpoch 9/10\u001b[0m\n", - "\u001b[34m684/684 - 8s - loss: 0.6687 - auc: 0.5792 - regularization_loss: 0.0000e+00 - val_loss: 0.7404 - val_auc: 0.4994 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step\u001b[0m\n", + "\u001b[34m684/684 - 8s - loss: 0.6702 - auc: 0.5789 - regularization_loss: 0.0000e+00 - loss_batch: 0.6652 - val_loss: 0.7227 - val_auc: 0.4987 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.7124 - 8s/epoch - 12ms/step\u001b[0m\n", "\u001b[34mEpoch 10/10\u001b[0m\n", - "\u001b[34m684/684 - 8s - loss: 0.6678 - auc: 0.5801 - regularization_loss: 0.0000e+00 - val_loss: 0.7393 - val_auc: 0.4988 - val_regularization_loss: 0.0000e+00 - 8s/epoch - 11ms/step\u001b[0m\n", - "\u001b[34m/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\u001b[0m\n", - "\u001b[34m/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\u001b[0m\n", - "\u001b[34m/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\u001b[0m\n", - "\u001b[34mWARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.\u001b[0m\n", - "\u001b[34mINFO:__main__:Model saved to /tmp/tmp5fpdavsc/dlrm.\u001b[0m\n", - "\u001b[34mModel saved to /tmp/tmp5fpdavsc/dlrm.\u001b[0m\n", - "\u001b[34mWARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.\u001b[0m\n", - "\u001b[34m/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\u001b[0m\n", - "\u001b[34m/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\u001b[0m\n", + "\u001b[34m684/684 - 9s - loss: 0.6688 - auc: 0.5802 - regularization_loss: 0.0000e+00 - loss_batch: 0.6631 - val_loss: 0.7403 - val_auc: 0.4993 - val_regularization_loss: 0.0000e+00 - val_loss_batch: 0.7173 - 9s/epoch - 13ms/step\u001b[0m\n", + "\u001b[34mWARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, prepare_list_features_layer_call_fn, prepare_list_features_layer_call_and_return_conditional_losses, output_layer_layer_call_fn while saving (showing 5 of 98). These functions will not be directly callable after loading.\u001b[0m\n", + "\u001b[34mModel saved to /tmp/tmpo0dgd1_j/dlrm.\u001b[0m\n", + "\u001b[34mINFO:__main__:Model saved to /tmp/tmpo0dgd1_j/dlrm.\u001b[0m\n", + "\u001b[34mWARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, prepare_list_features_layer_call_fn, prepare_list_features_layer_call_and_return_conditional_losses, output_layer_layer_call_fn while saving (showing 5 of 98). These functions will not be directly callable after loading.\u001b[0m\n", + "\u001b[34mWARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, prepare_list_features_layer_call_fn, prepare_list_features_layer_call_and_return_conditional_losses, output_layer_layer_call_fn while saving (showing 5 of 98). These functions will not be directly callable after loading.\u001b[0m\n", "\u001b[34mWARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\u001b[0m\n", "\u001b[34mWARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\u001b[0m\n", - "\u001b[34mWARNING:absl:Found untraced functions such as train_compute_metrics, model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, output_layer_layer_call_fn, output_layer_layer_call_and_return_conditional_losses while saving (showing 5 of 97). These functions will not be directly callable after loading.\u001b[0m\n", "\u001b[34mEnsemble graph saved to /opt/ml/model.\u001b[0m\n", "\u001b[34mINFO:__main__:Ensemble graph saved to /opt/ml/model.\u001b[0m\n", - "\u001b[34m2022-11-09 10:29:21,498 sagemaker-training-toolkit INFO Reporting training SUCCESS\u001b[0m\n", + "\u001b[34m2023-10-26 00:31:48,854 sagemaker-training-toolkit INFO Reporting training SUCCESS\u001b[0m\n", "\n", - "2022-11-09 10:29:41 Uploading - Uploading generated training model\n", - "2022-11-09 10:29:41 Completed - Training job completed\n", - "Training seconds: 589\n", - "Billable seconds: 589\n" + "2023-10-26 00:32:05 Uploading - Uploading generated training model\n", + "2023-10-26 00:32:05 Completed - Training job completed\n", + "Training seconds: 423\n", + "Billable seconds: 423\n" ] } ], @@ -920,7 +961,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2022-11-09-10-18-29-376/output/model.tar.gz\n" + "s3://sagemaker-us-east-1-843263297212/sagemaker-merlin-tensorflow-2023-10-26-00-23-35-295/output/model.tar.gz\n" ] } ], @@ -933,7 +974,26 @@ "execution_count": 13, "id": "c5d6d979-1976-46dd-bf88-669bbad39ede", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n" + ] + }, + { + "data": { + "text/plain": [ + "['/tmp/ensemble/model.tar.gz']" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from sagemaker.s3 import S3Downloader as s3down\n", "\n", @@ -950,43 +1010,96 @@ "name": "stdout", "output_type": "stream", "text": [ - "1_predicttensorflow/\n", - "1_predicttensorflow/config.pbtxt\n", - "1_predicttensorflow/1/\n", - "1_predicttensorflow/1/model.savedmodel/\n", - "1_predicttensorflow/1/model.savedmodel/assets/\n", - "1_predicttensorflow/1/model.savedmodel/variables/\n", - "1_predicttensorflow/1/model.savedmodel/variables/variables.index\n", - "1_predicttensorflow/1/model.savedmodel/variables/variables.data-00000-of-00001\n", - "1_predicttensorflow/1/model.savedmodel/saved_model.pb\n", - "1_predicttensorflow/1/model.savedmodel/keras_metadata.pb\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/config.pbtxt\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/variables/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/variables/variables.data-00000-of-00001\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/variables/variables.index\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/saved_model.pb\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/keras_metadata.pb\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/assets/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/.merlin/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/.merlin/input_schema.json\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/.merlin/output_schema.json\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "1_predicttensorflowtriton/1/model.savedmodel/fingerprint.pb\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/config.pbtxt\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.item_shop.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_consumption_2.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_intentions.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_profile.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.item_id.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_shops.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.item_brand.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_id.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_group.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.item_category.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_age.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_gender.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_geography.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_is_occupied.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_brands.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/categories/unique.user_categories.parquet\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/workflow.pkl\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/workflow/metadata.json\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "0_transformworkflowtriton/1/model.py\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", "executor_model/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", "executor_model/config.pbtxt\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", "executor_model/1/\n", - "0_transformworkflow/\n", - "0_transformworkflow/config.pbtxt\n", - "0_transformworkflow/1/\n", - "0_transformworkflow/1/model.py\n", - "0_transformworkflow/1/workflow/\n", - "0_transformworkflow/1/workflow/categories/\n", - "0_transformworkflow/1/workflow/categories/unique.user_profile.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_age.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_group.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_intentions.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.item_brand.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_geography.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_is_occupied.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_id.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_gender.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_shops.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.item_category.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_brands.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_consumption_2.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.item_id.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.item_shop.parquet\n", - "0_transformworkflow/1/workflow/categories/unique.user_categories.parquet\n", - "0_transformworkflow/1/workflow/workflow.pkl\n", - "0_transformworkflow/1/workflow/metadata.json\n" + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "executor_model/1/model.py\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "executor_model/1/ensemble/\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "executor_model/1/ensemble/ensemble.pkl\n", + "tar: Ignoring unknown extended header keyword `LIBARCHIVE.creationtime'\n", + "executor_model/1/ensemble/metadata.json\n" ] } ], @@ -1003,48 +1116,52 @@ "\n", "Although we use the Sagemaker Python SDK to train our model, here we will use `boto3` to launch our inference endpoint as it offers more low-level control than the Python SDK.\n", "\n", - "The model artificat `model.tar.gz` uploaded to S3 from the Sagemaker training job contained three directories: `0_transformworkflow` for the NVTabular workflow, `1_predicttensorflow` for the Tensorflow model, and `executor_model` for the ensemble graph that we can use in Triton.\n", + "The model artificat `model.tar.gz` uploaded to S3 from the Sagemaker training job contained three directories: `0_transformworkflowtriton` for the NVTabular workflow, `1_predicttensorflowtriton` for the Tensorflow model, and `executor_model` for the ensemble graph that we can use in Triton.\n", "\n", "```shell\n", "/tmp/ensemble/\n", - "├── 0_transformworkflow\n", - "│ ├── 1\n", - "│ │ ├── model.py\n", - "│ │ └── workflow\n", - "│ │ ├── categories\n", - "│ │ │ ├── unique.item_brand.parquet\n", - "│ │ │ ├── unique.item_category.parquet\n", - "│ │ │ ├── unique.item_id.parquet\n", - "│ │ │ ├── unique.item_shop.parquet\n", - "│ │ │ ├── unique.user_age.parquet\n", - "│ │ │ ├── unique.user_brands.parquet\n", - "│ │ │ ├── unique.user_categories.parquet\n", - "│ │ │ ├── unique.user_consumption_2.parquet\n", - "│ │ │ ├── unique.user_gender.parquet\n", - "│ │ │ ├── unique.user_geography.parquet\n", - "│ │ │ ├── unique.user_group.parquet\n", - "│ │ │ ├── unique.user_id.parquet\n", - "│ │ │ ├── unique.user_intentions.parquet\n", - "│ │ │ ├── unique.user_is_occupied.parquet\n", - "│ │ │ ├── unique.user_profile.parquet\n", - "│ │ │ └── unique.user_shops.parquet\n", - "│ │ ├── metadata.json\n", - "│ │ └── workflow.pkl\n", - "│ └── config.pbtxt\n", - "├── 1_predicttensorflow\n", - "│ ├── 1\n", - "│ │ └── model.savedmodel\n", - "│ │ ├── assets\n", - "│ │ ├── keras_metadata.pb\n", - "│ │ ├── saved_model.pb\n", - "│ │ └── variables\n", - "│ │ ├── variables.data-00000-of-00001\n", - "│ │ └── variables.index\n", - "│ └── config.pbtxt\n", - "├── executor_model\n", - "│ ├── 1\n", - "│ └── config.pbtxt\n", - "└── model.tar.gz\n", + "├── 0_transformworkflowtriton\n", + "│   ├── 1\n", + "│   │   ├── model.py\n", + "│   │   └── workflow\n", + "│   │   ├── categories\n", + "│   │   │   ├── unique.item_brand.parquet\n", + "│   │   │   ├── unique.item_category.parquet\n", + "│   │   │   ├── unique.item_id.parquet\n", + "│   │   │   ├── unique.item_shop.parquet\n", + "│   │   │   ├── unique.user_age.parquet\n", + "│   │   │   ├── unique.user_brands.parquet\n", + "│   │   │   ├── unique.user_categories.parquet\n", + "│   │   │   ├── unique.user_consumption_2.parquet\n", + "│   │   │   ├── unique.user_gender.parquet\n", + "│   │   │   ├── unique.user_geography.parquet\n", + "│   │   │   ├── unique.user_group.parquet\n", + "│   │   │   ├── unique.user_id.parquet\n", + "│   │   │   ├── unique.user_intentions.parquet\n", + "│   │   │   ├── unique.user_is_occupied.parquet\n", + "│   │   │   ├── unique.user_profile.parquet\n", + "│   │   │   └── unique.user_shops.parquet\n", + "│   │   ├── metadata.json\n", + "│   │   └── workflow.pkl\n", + "│   └── config.pbtxt\n", + "├── 1_predicttensorflowtriton\n", + "│   ├── 1\n", + "│   │   └── model.savedmodel\n", + "│   │   ├── assets\n", + "│   │   ├── fingerprint.pb\n", + "│   │   ├── keras_metadata.pb\n", + "│   │   ├── saved_model.pb\n", + "│   │   └── variables\n", + "│   │   ├── variables.data-00000-of-00001\n", + "│   │   └── variables.index\n", + "│   └── config.pbtxt\n", + "└── executor_model\n", + " ├── 1\n", + " │   ├── ensemble\n", + " │   │   ├── ensemble.pkl\n", + " │   │   └── metadata.json\n", + " │   └── model.py\n", + " └── config.pbtxt\n", "```\n", "\n", "We specify that we only want to use `executor_model` in Triton by passing the environment variable `SAGEMAKER_TRITON_DEFAULT_MODEL_NAME`." @@ -1060,7 +1177,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Model Arn: arn:aws:sagemaker:us-east-1:843263297212:model/model-triton-merlin-ensemble-2022-11-09-10-29-57\n" + "Model Arn: arn:aws:sagemaker:us-east-1:843263297212:model/model-triton-merlin-ensemble-2023-10-26-00-32-25\n" ] } ], @@ -1111,12 +1228,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Endpoint Config Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint-config/endpoint-config-triton-merlin-ensemble-2022-11-09-10-29-58\n" + "Endpoint Config Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint-config/endpoint-config-triton-merlin-ensemble-2023-10-26-00-32-26\n" ] } ], "source": [ - "endpoint_instance_type = \"ml.g4dn.xlarge\"\n", + "endpoint_instance_type = \"ml.g4dn.2xlarge\"\n", "\n", "endpoint_config_name = \"endpoint-config-triton-merlin-ensemble-\" + time.strftime(\n", " \"%Y-%m-%d-%H-%M-%S\", time.gmtime()\n", @@ -1150,7 +1267,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2022-11-09-10-29-58\n" + "Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2023-10-26-00-32-27\n" ] } ], @@ -1185,9 +1302,8 @@ "Endpoint Creation Status: Creating\n", "Endpoint Creation Status: Creating\n", "Endpoint Creation Status: Creating\n", - "Endpoint Creation Status: Creating\n", "Endpoint Creation Status: InService\n", - "Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2022-11-09-10-29-58\n", + "Endpoint Arn: arn:aws:sagemaker:us-east-1:843263297212:endpoint/endpoint-triton-merlin-ensemble-2023-10-26-00-32-27\n", "Endpoint Status: InService\n" ] } @@ -1229,81 +1345,91 @@ "name": "stdout", "output_type": "stream", "text": [ - " user_id item_id item_category item_shop item_brand \\\n", - "__null_dask_index__ \n", - "700000 12 2 3 194 67 \n", - "700001 12 30 80 5621 1936 \n", - "700002 18 5 12 776 267 \n", - "700003 35 6 14 970 334 \n", - "700004 51 11 28 1939 668 \n", - "700005 22 83 226 15893 5474 \n", - "700006 13 38 102 7172 2470 \n", - "700007 10 7 17 1163 401 \n", - "700008 4 4 9 582 201 \n", - "700009 4 24 64 4458 1536 \n", + " item_category item_shop item_brand user_shops \\\n", + "__null_dask_index__ \n", + "700000 147 10342 3562 235 \n", + "700001 84 5876 2024 141 \n", + "700002 147 10342 3562 1361 \n", + "700003 244 17158 5909 1033 \n", + "700004 134 9402 3238 71 \n", + "700005 64 4466 1538 681 \n", + "700006 24 1646 567 165 \n", + "700007 14 941 324 188 \n", + "700008 14 941 324 1596 \n", + "700009 64 4466 1538 962 \n", "\n", - " user_shops user_profile user_group user_gender \\\n", - "__null_dask_index__ \n", - "700000 636 1 1 1 \n", - "700001 636 1 1 1 \n", - "700002 983 1 1 1 \n", - "700003 1965 2 1 1 \n", - "700004 2890 3 1 1 \n", - "700005 1214 2 1 1 \n", - "700006 694 1 1 1 \n", - "700007 521 1 1 1 \n", - "700008 174 1 1 1 \n", - "700009 174 1 1 1 \n", + " user_profile user_group user_gender user_age \\\n", + "__null_dask_index__ \n", + "700000 1 1 1 1 \n", + "700001 1 1 1 1 \n", + "700002 2 1 1 1 \n", + "700003 1 1 1 1 \n", + "700004 1 1 1 1 \n", + "700005 1 1 1 1 \n", + "700006 1 1 1 1 \n", + "700007 1 1 1 1 \n", + "700008 2 1 1 1 \n", + "700009 1 1 1 1 \n", "\n", - " user_age user_consumption_2 user_is_occupied \\\n", - "__null_dask_index__ \n", - "700000 1 1 1 \n", - "700001 1 1 1 \n", - "700002 1 1 1 \n", - "700003 1 1 1 \n", - "700004 1 1 1 \n", - "700005 1 1 1 \n", - "700006 1 1 1 \n", - "700007 1 1 1 \n", - "700008 1 1 1 \n", - "700009 1 1 1 \n", + " user_consumption_2 user_is_occupied user_geography \\\n", + "__null_dask_index__ \n", + "700000 1 1 1 \n", + "700001 1 1 1 \n", + "700002 1 1 1 \n", + "700003 1 1 1 \n", + "700004 1 1 1 \n", + "700005 1 1 1 \n", + "700006 1 1 1 \n", + "700007 1 1 1 \n", + "700008 1 1 1 \n", + "700009 1 1 1 \n", "\n", - " user_geography user_intentions user_brands \\\n", - "__null_dask_index__ \n", - "700000 1 184 316 \n", - "700001 1 184 316 \n", - "700002 1 285 489 \n", - "700003 1 569 977 \n", - "700004 1 837 1436 \n", - "700005 1 352 604 \n", - "700006 1 201 345 \n", - "700007 1 151 259 \n", - "700008 1 51 87 \n", - "700009 1 51 87 \n", + " user_intentions user_brands user_categories user_id \\\n", + "__null_dask_index__ \n", + "700000 68 117 13 11 \n", + "700001 41 70 8 7 \n", + "700002 394 677 71 59 \n", + "700003 299 513 54 45 \n", + "700004 21 35 4 4 \n", + "700005 197 339 36 30 \n", + "700006 48 82 9 8 \n", + "700007 55 94 10 9 \n", + "700008 462 793 84 69 \n", + "700009 279 479 51 42 \n", "\n", - " user_categories \n", - "__null_dask_index__ \n", - "700000 34 \n", - "700001 34 \n", - "700002 52 \n", - "700003 103 \n", - "700004 151 \n", - "700005 64 \n", - "700006 37 \n", - "700007 28 \n", - "700008 10 \n", - "700009 10 \n" + " item_id \n", + "__null_dask_index__ \n", + "700000 45 \n", + "700001 26 \n", + "700002 45 \n", + "700003 74 \n", + "700004 41 \n", + "700005 20 \n", + "700006 8 \n", + "700007 5 \n", + "700008 5 \n", + "700009 20 \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/nvtabular/workflow/workflow.py:445: UserWarning: Loading workflow generated on GPU\n", + " warnings.warn(f\"Loading workflow generated on {expected}\")\n" ] } ], "source": [ "from merlin.schema.tags import Tags\n", "from merlin.core.dispatch import get_lib\n", + "from merlin.systems.dag.ensemble import Ensemble\n", "from nvtabular.workflow import Workflow\n", "\n", "df_lib = get_lib()\n", "\n", - "workflow = Workflow.load(\"/tmp/ensemble/0_transformworkflow/1/workflow/\")\n", + "workflow = Workflow.load(\"/tmp/ensemble/0_transformworkflowtriton/1/workflow/\")\n", + "ensemble = Ensemble.load(\"/tmp/ensemble/executor_model/1/ensemble\")\n", "\n", "label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names\n", "workflow.remove_inputs(label_columns)\n", @@ -1312,7 +1438,7 @@ "batch = df_lib.read_parquet(\n", " os.path.join(DATA_FOLDER, \"valid\", \"part.0.parquet\"),\n", " columns=workflow.input_schema.column_names,\n", - ")[:10]\n", + ").head(10)\n", "print(batch)" ] }, @@ -1334,7 +1460,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "b'{\"inputs\":[{\"name\":\"user_id\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_id\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_category\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_shop\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_brand\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_shops\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_profile\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_group\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_gender\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_age\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_consumption_2\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_is_occupied\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_geography\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_intentions\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_brands\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_categories\",\"shape\":[10,1],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}}],\"parameters\":{\"binary_data_output\":true}}\\x0c\\x00\\x00\\x00\\x0c\\x00\\x00\\x00\\x12\\x00\\x00\\x00#\\x00\\x00\\x003\\x00\\x00\\x00\\x16\\x00\\x00\\x00\\r\\x00\\x00\\x00\\n\\x00\\x00\\x00\\x04\\x00\\x00\\x00\\x04\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x1e\\x00\\x00\\x00\\x05\\x00\\x00\\x00\\x06\\x00\\x00\\x00\\x0b\\x00\\x00\\x00S\\x00\\x00\\x00&\\x00\\x00\\x00\\x07\\x00\\x00\\x00\\x04\\x00\\x00\\x00\\x18\\x00\\x00\\x00\\x03\\x00\\x00\\x00P\\x00\\x00\\x00\\x0c\\x00\\x00\\x00\\x0e\\x00\\x00\\x00\\x1c\\x00\\x00\\x00\\xe2\\x00\\x00\\x00f\\x00\\x00\\x00\\x11\\x00\\x00\\x00\\t\\x00\\x00\\x00@\\x00\\x00\\x00\\xc2\\x00\\x00\\x00\\xf5\\x15\\x00\\x00\\x08\\x03\\x00\\x00\\xca\\x03\\x00\\x00\\x93\\x07\\x00\\x00\\x15>\\x00\\x00\\x04\\x1c\\x00\\x00\\x8b\\x04\\x00\\x00F\\x02\\x00\\x00j\\x11\\x00\\x00C\\x00\\x00\\x00\\x90\\x07\\x00\\x00\\x0b\\x01\\x00\\x00N\\x01\\x00\\x00\\x9c\\x02\\x00\\x00b\\x15\\x00\\x00\\xa6\\t\\x00\\x00\\x91\\x01\\x00\\x00\\xc9\\x00\\x00\\x00\\x00\\x06\\x00\\x00|\\x02\\x00\\x00|\\x02\\x00\\x00\\xd7\\x03\\x00\\x00\\xad\\x07\\x00\\x00J\\x0b\\x00\\x00\\xbe\\x04\\x00\\x00\\xb6\\x02\\x00\\x00\\t\\x02\\x00\\x00\\xae\\x00\\x00\\x00\\xae\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x03\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\xb8\\x00\\x00\\x00\\xb8\\x00\\x00\\x00\\x1d\\x01\\x00\\x009\\x02\\x00\\x00E\\x03\\x00\\x00`\\x01\\x00\\x00\\xc9\\x00\\x00\\x00\\x97\\x00\\x00\\x003\\x00\\x00\\x003\\x00\\x00\\x00<\\x01\\x00\\x00<\\x01\\x00\\x00\\xe9\\x01\\x00\\x00\\xd1\\x03\\x00\\x00\\x9c\\x05\\x00\\x00\\\\\\x02\\x00\\x00Y\\x01\\x00\\x00\\x03\\x01\\x00\\x00W\\x00\\x00\\x00W\\x00\\x00\\x00\"\\x00\\x00\\x00\"\\x00\\x00\\x004\\x00\\x00\\x00g\\x00\\x00\\x00\\x97\\x00\\x00\\x00@\\x00\\x00\\x00%\\x00\\x00\\x00\\x1c\\x00\\x00\\x00\\n\\x00\\x00\\x00\\n\\x00\\x00\\x00'\n" + "b'{\"inputs\":[{\"name\":\"item_category\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_shop\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_brand\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_shops\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_profile\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_group\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_gender\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_age\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_consumption_2\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_is_occupied\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_geography\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_intentions\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_brands\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_categories\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"user_id\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}},{\"name\":\"item_id\",\"shape\":[10],\"datatype\":\"INT32\",\"parameters\":{\"binary_data_size\":40}}],\"parameters\":{\"binary_data_output\":true}}\\x93\\x00\\x00\\x00T\\x00\\x00\\x00\\x93\\x00\\x00\\x00\\xf4\\x00\\x00\\x00\\x86\\x00\\x00\\x00@\\x00\\x00\\x00\\x18\\x00\\x00\\x00\\x0e\\x00\\x00\\x00\\x0e\\x00\\x00\\x00@\\x00\\x00\\x00f(\\x00\\x00\\xf4\\x16\\x00\\x00f(\\x00\\x00\\x06C\\x00\\x00\\xba$\\x00\\x00r\\x11\\x00\\x00n\\x06\\x00\\x00\\xad\\x03\\x00\\x00\\xad\\x03\\x00\\x00r\\x11\\x00\\x00\\xea\\r\\x00\\x00\\xe8\\x07\\x00\\x00\\xea\\r\\x00\\x00\\x15\\x17\\x00\\x00\\xa6\\x0c\\x00\\x00\\x02\\x06\\x00\\x007\\x02\\x00\\x00D\\x01\\x00\\x00D\\x01\\x00\\x00\\x02\\x06\\x00\\x00\\xeb\\x00\\x00\\x00\\x8d\\x00\\x00\\x00Q\\x05\\x00\\x00\\t\\x04\\x00\\x00G\\x00\\x00\\x00\\xa9\\x02\\x00\\x00\\xa5\\x00\\x00\\x00\\xbc\\x00\\x00\\x00<\\x06\\x00\\x00\\xc2\\x03\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00D\\x00\\x00\\x00)\\x00\\x00\\x00\\x8a\\x01\\x00\\x00+\\x01\\x00\\x00\\x15\\x00\\x00\\x00\\xc5\\x00\\x00\\x000\\x00\\x00\\x007\\x00\\x00\\x00\\xce\\x01\\x00\\x00\\x17\\x01\\x00\\x00u\\x00\\x00\\x00F\\x00\\x00\\x00\\xa5\\x02\\x00\\x00\\x01\\x02\\x00\\x00#\\x00\\x00\\x00S\\x01\\x00\\x00R\\x00\\x00\\x00^\\x00\\x00\\x00\\x19\\x03\\x00\\x00\\xdf\\x01\\x00\\x00\\r\\x00\\x00\\x00\\x08\\x00\\x00\\x00G\\x00\\x00\\x006\\x00\\x00\\x00\\x04\\x00\\x00\\x00$\\x00\\x00\\x00\\t\\x00\\x00\\x00\\n\\x00\\x00\\x00T\\x00\\x00\\x003\\x00\\x00\\x00\\x0b\\x00\\x00\\x00\\x07\\x00\\x00\\x00;\\x00\\x00\\x00-\\x00\\x00\\x00\\x04\\x00\\x00\\x00\\x1e\\x00\\x00\\x00\\x08\\x00\\x00\\x00\\t\\x00\\x00\\x00E\\x00\\x00\\x00*\\x00\\x00\\x00-\\x00\\x00\\x00\\x1a\\x00\\x00\\x00-\\x00\\x00\\x00J\\x00\\x00\\x00)\\x00\\x00\\x00\\x14\\x00\\x00\\x00\\x08\\x00\\x00\\x00\\x05\\x00\\x00\\x00\\x05\\x00\\x00\\x00\\x14\\x00\\x00\\x00'\n", + "1535\n" ] } ], @@ -1344,11 +1471,17 @@ "\n", "inputs = convert_df_to_triton_input(workflow.input_schema, batch, httpclient.InferInput)\n", "\n", - "request_body, header_length = httpclient.InferenceServerClient.generate_request_body(\n", - " inputs\n", - ")\n", + "output_cols = ensemble.graph.output_schema.column_names\n", "\n", - "print(request_body)" + "outputs = [\n", + " httpclient.InferRequestedOutput(col, binary_data=False)\n", + " for col in output_cols\n", + "]\n", + "\n", + "request_body, header_length = httpclient.InferenceServerClient.generate_request_body(inputs) #, outputs=outputs)\n", + "\n", + "print(request_body)\n", + "print(header_length)" ] }, { @@ -1373,16 +1506,16 @@ "output_type": "stream", "text": [ "predicted sigmoid result:\n", - " [[0.48595208]\n", - " [0.4647554 ]\n", - " [0.50048226]\n", - " [0.53553176]\n", - " [0.5209902 ]\n", - " [0.54944164]\n", - " [0.5032344 ]\n", - " [0.475241 ]\n", - " [0.5077254 ]\n", - " [0.5009623 ]]\n" + " [[0.48622972]\n", + " [0.52633965]\n", + " [0.46211788]\n", + " [0.6060424 ]\n", + " [0.46786073]\n", + " [0.47899282]\n", + " [0.48058835]\n", + " [0.5024603 ]\n", + " [0.565173 ]\n", + " [0.46900135]]\n" ] } ], @@ -1403,6 +1536,7 @@ "result = httpclient.InferenceServerClient.parse_response_body(\n", " response[\"Body\"].read(), header_length=int(header_length_str)\n", ")\n", + "\n", "output_data = result.as_numpy(\"click/binary_classification_task\")\n", "print(\"predicted sigmoid result:\\n\", output_data)" ] @@ -1426,12 +1560,12 @@ { "data": { "text/plain": [ - "{'ResponseMetadata': {'RequestId': '6ad24616-5c7c-4525-a63c-62d1b06ee8ad',\n", + "{'ResponseMetadata': {'RequestId': 'c12a4eca-a162-48cd-9787-cf5f3525bc7e',\n", " 'HTTPStatusCode': 200,\n", - " 'HTTPHeaders': {'x-amzn-requestid': '6ad24616-5c7c-4525-a63c-62d1b06ee8ad',\n", + " 'HTTPHeaders': {'x-amzn-requestid': 'c12a4eca-a162-48cd-9787-cf5f3525bc7e',\n", " 'content-type': 'application/x-amz-json-1.1',\n", " 'content-length': '0',\n", - " 'date': 'Wed, 09 Nov 2022 10:38:12 GMT'},\n", + " 'date': 'Thu, 26 Oct 2023 00:39:31 GMT'},\n", " 'RetryAttempts': 0}}" ] }, @@ -1449,9 +1583,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "conda_python3", "language": "python", - "name": "python3" + "name": "conda_python3" }, "language_info": { "codemirror_mode": { @@ -1463,7 +1597,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/sagemaker-tensorflow/train.py b/examples/sagemaker-tensorflow/train.py index a8e2dabb6..d414a1644 100644 --- a/examples/sagemaker-tensorflow/train.py +++ b/examples/sagemaker-tensorflow/train.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2023 NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -68,14 +68,15 @@ def parse_args(): def create_nvtabular_workflow(train_path, valid_path): - user_id = ["user_id"] >> Categorify() >> TagAsUserID() - item_id = ["item_id"] >> Categorify() >> TagAsItemID() - targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"]) + + user_id_raw = ["user_id"] >> Rename(postfix='_raw') >> LambdaOp(lambda col: col.astype("int32")) >> TagAsUserFeatures() + item_id_raw = ["item_id"] >> Rename(postfix='_raw') >> LambdaOp(lambda col: col.astype("int32")) >> TagAsItemFeatures() + + user_id = ["user_id"] >> Categorify(dtype="int32") >> TagAsUserID() + item_id = ["item_id"] >> Categorify(dtype="int32") >> TagAsItemID() item_features = ( - ["item_category", "item_shop", "item_brand"] - >> Categorify() - >> TagAsItemFeatures() + ["item_category", "item_shop", "item_brand"] >> Categorify(dtype="int32") >> TagAsItemFeatures() ) user_features = ( @@ -91,12 +92,15 @@ def create_nvtabular_workflow(train_path, valid_path): "user_intentions", "user_brands", "user_categories", - ] - >> Categorify() - >> TagAsUserFeatures() + ] >> Categorify(dtype="int32") >> TagAsUserFeatures() ) - outputs = user_id + item_id + item_features + user_features + targets + targets = ["click"] >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"]) + + outputs = user_id + item_id + item_features + user_features + user_id_raw + item_id_raw + targets + + # add dropna op to filter rows with nulls + outputs = outputs >> Dropna() workflow = nvt.Workflow(outputs)