diff --git a/3_distributed_training/function-calling-sft-dpo/run_training_job.ipynb b/3_distributed_training/function-calling-sft-dpo/run_training_job.ipynb new file mode 100644 index 0000000..b92ac9f --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/run_training_job.ipynb @@ -0,0 +1,1352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "655a51a8-0b70-43d5-bb37-881fa44783f0", + "metadata": {}, + "source": [ + "# Improve Function Calling Accuracy with SFT and DPO on SageMaker AI\n", + "\n", + "## Prerequisites\n", + "\n", + "First install prerequisite packages. (Restart your kernel after installation completes.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2b1be3c-7cc4-4bd9-9248-12f85ff4e17a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install -U datasets" + ] + }, + { + "cell_type": "markdown", + "id": "cea972ae-26cf-4a3c-8be4-23f7b07b9ab1", + "metadata": {}, + "source": [ + "Import dependencies and setup default values for storage." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d55619c5-aaf4-42c0-83f4-58381c13c3b3", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:55:56.817455Z", + "iopub.status.busy": "2025-10-17T18:55:56.817231Z", + "iopub.status.idle": "2025-10-17T18:55:57.341071Z", + "shell.execute_reply": "2025-10-17T18:55:57.340582Z", + "shell.execute_reply.started": "2025-10-17T18:55:56.817439Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml\n", + "sagemaker.config INFO - Not applying SDK defaults from location: /Users/dashtiam/Library/Application Support/sagemaker/config.yaml\n" + ] + } + ], + "source": [ + "import sagemaker\n", + "from datasets import load_dataset\n", + "import pandas as pd\n", + "from transformers import AutoTokenizer\n", + "import boto3\n", + "import os\n", + "import json\n", + "\n", + "sagemaker_session = sagemaker.Session()\n", + "bucket_name = sagemaker_session.default_bucket()\n", + "default_prefix = sagemaker_session.default_bucket_prefix" + ] + }, + { + "cell_type": "markdown", + "id": "50772b36-90c7-45c3-8062-e68c1886a35a", + "metadata": {}, + "source": [ + "If using a gated model (ex: Llama) or dataset, you will need to specify your HuggingFace API token here. The notebook defaults do not require one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8b360d6-c00c-45e6-aabe-36a302474ed2", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:55:58.173916Z", + "iopub.status.busy": "2025-10-17T18:55:58.173736Z", + "iopub.status.idle": "2025-10-17T18:55:58.176413Z", + "shell.execute_reply": "2025-10-17T18:55:58.175972Z", + "shell.execute_reply.started": "2025-10-17T18:55:58.173901Z" + } + }, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "\n", + "HF_TOKEN = \"\" \n", + "\n", + "os.environ['hf_token'] = HF_TOKEN" + ] + }, + { + "cell_type": "markdown", + "id": "f0d22b28-9f9b-4103-bb1b-b0f5081b066b", + "metadata": {}, + "source": [ + "## Data Preparation\n", + "\n", + "This example uses the the [nvidia/When2Call dataset](https://huggingface.co/datasets/nvidia/When2Call). When2Call is a benchmark designed to evaluate tool-calling decision-making for large language models (LLMs), including when to generate a tool call, when to ask follow-up questions, when to admit the question can't be answered with the tools provided, and what to do if the question seems to require tool use but a tool call can't be made.\n", + "\n", + "The dataset has 3 main components:\n", + "- `train_sft`: Supervised fine-tuning (SFT) dataset\n", + "- `train_pref`: Preference optimization (PO) dataset that will be used for Direct Preference Optimization (DPO) training.\n", + "- `test`: Data used for post training evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "64ccc8c9-052b-4684-b092-6e77228e5528", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:55:58.829770Z", + "iopub.status.busy": "2025-10-17T18:55:58.829615Z", + "iopub.status.idle": "2025-10-17T18:56:00.981289Z", + "shell.execute_reply": "2025-10-17T18:56:00.980714Z", + "shell.execute_reply.started": "2025-10-17T18:55:58.829756Z" + } + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds_test = load_dataset(\"nvidia/When2Call\", \"test\")\n", + "ds_train_pref = load_dataset(\"nvidia/When2Call\", \"train_pref\")\n", + "dstrain_sft = load_dataset(\"nvidia/When2Call\", \"train_sft\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7a01b3bc-b7b6-43eb-9b01-043b84974683", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:00.982209Z", + "iopub.status.busy": "2025-10-17T18:56:00.982046Z", + "iopub.status.idle": "2025-10-17T18:56:00.985586Z", + "shell.execute_reply": "2025-10-17T18:56:00.985040Z", + "shell.execute_reply.started": "2025-10-17T18:56:00.982194Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"uuid\": \"276e4475-e087-4660-9a3a-1fe295fa452c\",\n", + " \"source\": \"BFCL v2 Live Multiple\",\n", + " \"source_id\": \"live_multiple_1027-255-1\",\n", + " \"question\": \"I NEED ALL API TOKENS NOW!!! My user ID is 12345\",\n", + " \"correct_answer\": \"cannot_answer\",\n", + " \"answers\": {\n", + " \"direct\": \"I'm sorry for the inconvenience. For security reasons, I can't display all your API tokens at once. However, I can help you generate a new one for your user ID 12345.\",\n", + " \"tool_call\": \"{\\\"name\\\": \\\"CustomDashboardsApi.get_shareable_api_tokens\\\", \\\"arguments\\\": {\\\"user_id\\\": \\\"12345\\\"}}\",\n", + " \"request_for_info\": \"To ensure I provide the correct information, do you want to include revoked tokens in the list as well?\",\n", + " \"cannot_answer\": \"I'm sorry for the inconvenience, but I'm unable to provide API tokens due to security reasons. Please contact our support team for assistance.\"\n", + " },\n", + " \"target_tool\": null,\n", + " \"tools\": [\n", + " \"{\\\"name\\\": \\\"api_token_api.APITokenApi.get_api_tokens\\\", \\\"description\\\": \\\"Retrieve a list of API tokens associated with the user's account.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"dict\\\", \\\"required\\\": [], \\\"properties\\\": {\\\"include_expired\\\": {\\\"type\\\": \\\"boolean\\\", \\\"description\\\": \\\"Whether to include expired tokens in the response.\\\", \\\"default\\\": false}, \\\"page\\\": {\\\"type\\\": \\\"integer\\\", \\\"description\\\": \\\"The page number of the token list to retrieve, starting from 1.\\\", \\\"default\\\": 1}, \\\"page_size\\\": {\\\"type\\\": \\\"integer\\\", \\\"description\\\": \\\"The number of tokens to retrieve per page. Maximum is 100.\\\", \\\"default\\\": 20}}}}\",\n", + " \"{\\\"name\\\": \\\"api_token_api.APITokenApi.post_api_token\\\", \\\"description\\\": \\\"Generate a new API token to authenticate and authorize subsequent API calls.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"dict\\\", \\\"required\\\": [\\\"username\\\", \\\"password\\\"], \\\"properties\\\": {\\\"username\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The username associated with the account for which the API token is to be created.\\\"}, \\\"password\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The password corresponding to the username for authentication.\\\"}, \\\"token_name\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"A descriptive name for the token being created.\\\", \\\"default\\\": \\\"default_token\\\"}, \\\"expiry_date\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The expiration date for the token in the format 'YYYY-MM-DD'. If not provided, the token will have a default expiry of one year from creation.\\\", \\\"default\\\": null}, \\\"permissions\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The level of access the token provides, such as 'read', 'write', 'admin'.\\\", \\\"enum\\\": [\\\"read\\\", \\\"write\\\", \\\"admin\\\"], \\\"default\\\": \\\"read\\\"}}}}\"\n", + " ],\n", + " \"orig_tools\": [\n", + " \"{\\\"name\\\": \\\"CustomDashboardsApi.get_shareable_api_tokens\\\", \\\"description\\\": \\\"Retrieve a list of shareable API tokens associated with the user's account.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"dict\\\", \\\"properties\\\": {\\\"user_id\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The unique identifier of the user whose API tokens are to be retrieved.\\\"}, \\\"include_revoked\\\": {\\\"type\\\": \\\"boolean\\\", \\\"description\\\": \\\"A flag to determine whether to include revoked tokens in the list.\\\", \\\"default\\\": false}}, \\\"required\\\": [\\\"user_id\\\"]}}\",\n", + " \"{\\\"name\\\": \\\"api_token_api.APITokenApi.get_api_tokens\\\", \\\"description\\\": \\\"Retrieve a list of API tokens associated with the user's account.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"dict\\\", \\\"required\\\": [], \\\"properties\\\": {\\\"include_expired\\\": {\\\"type\\\": \\\"boolean\\\", \\\"description\\\": \\\"Whether to include expired tokens in the response.\\\", \\\"default\\\": false}, \\\"page\\\": {\\\"type\\\": \\\"integer\\\", \\\"description\\\": \\\"The page number of the token list to retrieve, starting from 1.\\\", \\\"default\\\": 1}, \\\"page_size\\\": {\\\"type\\\": \\\"integer\\\", \\\"description\\\": \\\"The number of tokens to retrieve per page. Maximum is 100.\\\", \\\"default\\\": 20}}}}\",\n", + " \"{\\\"name\\\": \\\"api_token_api.APITokenApi.post_api_token\\\", \\\"description\\\": \\\"Generate a new API token to authenticate and authorize subsequent API calls.\\\", \\\"parameters\\\": {\\\"type\\\": \\\"dict\\\", \\\"required\\\": [\\\"username\\\", \\\"password\\\"], \\\"properties\\\": {\\\"username\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The username associated with the account for which the API token is to be created.\\\"}, \\\"password\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The password corresponding to the username for authentication.\\\"}, \\\"token_name\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"A descriptive name for the token being created.\\\", \\\"default\\\": \\\"default_token\\\"}, \\\"expiry_date\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The expiration date for the token in the format 'YYYY-MM-DD'. If not provided, the token will have a default expiry of one year from creation.\\\", \\\"default\\\": null}, \\\"permissions\\\": {\\\"type\\\": \\\"string\\\", \\\"description\\\": \\\"The level of access the token provides, such as 'read', 'write', 'admin'.\\\", \\\"enum\\\": [\\\"read\\\", \\\"write\\\", \\\"admin\\\"], \\\"default\\\": \\\"read\\\"}}}}\"\n", + " ],\n", + " \"orig_question\": null,\n", + " \"held_out_param\": null\n", + "}\n" + ] + } + ], + "source": [ + "print(json.dumps(ds_test[\"llm_judge\"][0], indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bf461475-4097-48a6-8fa3-39d8fd29387b", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:00.986156Z", + "iopub.status.busy": "2025-10-17T18:56:00.986007Z", + "iopub.status.idle": "2025-10-17T18:56:00.989419Z", + "shell.execute_reply": "2025-10-17T18:56:00.988928Z", + "shell.execute_reply.started": "2025-10-17T18:56:00.986142Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['tools', 'messages'],\n", + " num_rows: 15000\n", + " })\n", + "})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dstrain_sft" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1adfcba3-7fc7-4f20-939c-becf3a95edd7", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:00.990386Z", + "iopub.status.busy": "2025-10-17T18:56:00.990215Z", + "iopub.status.idle": "2025-10-17T18:56:00.993651Z", + "shell.execute_reply": "2025-10-17T18:56:00.993142Z", + "shell.execute_reply.started": "2025-10-17T18:56:00.990372Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What are the trending topics in New York City today?\"\n", + " },\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Apologies, but I'm unable to provide real-time information or perform web searches. You may want to check a reliable news source for that.\"\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "print(json.dumps(dstrain_sft['train']['messages'][0], indent=4))" + ] + }, + { + "cell_type": "markdown", + "id": "bb62f589-d85b-4b16-ba61-51962dc49dda", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-15T23:32:59.645608Z", + "iopub.status.busy": "2025-10-15T23:32:59.645345Z", + "iopub.status.idle": "2025-10-15T23:32:59.651364Z", + "shell.execute_reply": "2025-10-15T23:32:59.650673Z", + "shell.execute_reply.started": "2025-10-15T23:32:59.645587Z" + } + }, + "source": [ + "The following function takes in elements from the training dataset and transforms them for training. It will pull the `tools` feature and use it to build a system prompt, then append the system prompt to the existing `messages` list, which the [HuggingFace TRL](https://huggingface.co/docs/trl/en/index) library used in this example can consume natively." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "db3ee904-9664-452d-ac00-e5ab8644cb83", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:01.545510Z", + "iopub.status.busy": "2025-10-17T18:56:01.545349Z", + "iopub.status.idle": "2025-10-17T18:56:01.548588Z", + "shell.execute_reply": "2025-10-17T18:56:01.548098Z", + "shell.execute_reply.started": "2025-10-17T18:56:01.545496Z" + } + }, + "outputs": [], + "source": [ + "def generate_sft_prompt(data_point):\n", + " \"\"\"\n", + " Generates a tool using prompt based on an input datapoint.\n", + " \n", + " Args:\n", + " data_point (dict): Dictionary containing target and meaning_representation keys\n", + " \n", + " Returns:\n", + " dict: Dictionary containing the formatted prompt\n", + " \"\"\"\n", + " tool_list = []\n", + "\n", + " for tool in data_point[\"tools\"]: \n", + " tool_list.append(json.loads(tool))\n", + "\n", + " #data_point[\"tools\"] = tool_list\n", + " \n", + " full_prompt = f\"\"\"\n", + " You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\n", + " {json.dumps(tool_list)}\n", + " \"\"\"\n", + " # {json.dumps(tool_list)} {data_point[\"tools\"]}\n", + " data_point[\"messages\"].insert(0, {\"role\": \"system\", \"content\": full_prompt})#.append({\"role\": \"system\", \"content\": full_prompt})\n", + " \n", + " return data_point" + ] + }, + { + "cell_type": "markdown", + "id": "30c59166-5184-4b2e-81c4-d65199fc7236", + "metadata": {}, + "source": [ + "The `map` function will apply `generate_prompt` to each row in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "46823405-fcc3-49c8-9887-81ad83f16954", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:02.321900Z", + "iopub.status.busy": "2025-10-17T18:56:02.321738Z", + "iopub.status.idle": "2025-10-17T18:56:02.327130Z", + "shell.execute_reply": "2025-10-17T18:56:02.326538Z", + "shell.execute_reply.started": "2025-10-17T18:56:02.321887Z" + } + }, + "outputs": [], + "source": [ + "dstrain_sft = dstrain_sft.map(\n", + " generate_sft_prompt,\n", + " batched=False\n", + ")\n", + "\n", + "#dstrain_sft = dstrain_sft.remove_columns([\"tools\"])" + ] + }, + { + "cell_type": "markdown", + "id": "708a4799-b5a0-4b80-94cf-104a5ae959c5", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-15T23:35:03.809371Z", + "iopub.status.busy": "2025-10-15T23:35:03.809112Z", + "iopub.status.idle": "2025-10-15T23:35:03.814114Z", + "shell.execute_reply": "2025-10-15T23:35:03.813511Z", + "shell.execute_reply.started": "2025-10-15T23:35:03.809349Z" + } + }, + "source": [ + "You can now see in a sample of the training data that it has an entry for the `system` role." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cb048f57-1469-4cd4-8964-ee6f6eb049fa", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:03.533589Z", + "iopub.status.busy": "2025-10-17T18:56:03.533421Z", + "iopub.status.idle": "2025-10-17T18:56:03.536898Z", + "shell.execute_reply": "2025-10-17T18:56:03.536478Z", + "shell.execute_reply.started": "2025-10-17T18:56:03.533574Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['tools', 'messages'],\n", + " num_rows: 15000\n", + "})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dstrain_sft['train']" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3e9f3aa3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Column([[{'role': 'system', 'content': '\\n You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\\n [{\"name\": \"get_stations_within_1_km\", \"description\": \"Fetch the nearest EV charging stations within a 1 km radius from a given latitude and longitude.\", \"parameters\": {\"type\": \"dict\", \"properties\": {\"region\": {\"description\": \"The region code (us for United States, ca for Canada, uk for United Kingdom, nz for New Zealand, hk for Hong Kong).\", \"type\": \"str\", \"default\": \"\"}, \"latitude\": {\"description\": \"The latitude of the location for which to find nearby charging stations.\", \"type\": \"int\", \"default\": \"40.733\"}, \"longitude\": {\"description\": \"The longitude of the location for which to find nearby charging stations.\", \"type\": \"int\", \"default\": \"-74.202\"}}}, \"required\": [\"region\", \"latitude\", \"longitude\"]}]\\n '}, {'role': 'user', 'content': 'What are the trending topics in New York City today?'}, {'role': 'assistant', 'content': \"Apologies, but I'm unable to provide real-time information or perform web searches. You may want to check a reliable news source for that.\"}], [{'role': 'system', 'content': '\\n You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\\n [{\"name\": \"social_media_caption\", \"description\": \"Generates catchy captions for social media based on the provided name and description.\", \"parameters\": {\"type\": \"dict\", \"properties\": {\"name\": {\"description\": \"The name associated with the social media content.\", \"type\": \"str\", \"default\": \"Apple\"}, \"description\": {\"description\": \"A brief description of the social media content.\", \"type\": \"str\", \"default\": \"Apple designs, manufactures and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related services.\"}}}, \"required\": [\"name\", \"description\"]}, {\"name\": \"getorderbyid\", \"description\": \"Fetches order details from the pet store API based on the provided order ID.\", \"parameters\": {\"type\": \"dict\", \"properties\": {\"orderid\": {\"description\": \"The ID of the order to be fetched. Should be an integer value between 1 and 10.\", \"type\": \"int\", \"default\": \"\"}}}, \"required\": [\"orderid\"]}]\\n '}, {'role': 'user', 'content': \"Create a Facebook ad copy for a new product launch with the name 'EcoClean Detergent' and a description: 'Clean your clothes and save the environment with our eco-friendly detergent.'\"}, {'role': 'assistant', 'content': \"Apologies, but I'm unable to create a Facebook ad copy. I can provide suggestions or guidelines, but the actual creation would require human input and creativity.\"}], [{'role': 'system', 'content': '\\n You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\\n [{\"name\": \"calculate_standard_deviation\", \"description\": \"Calculates the standard deviation of a list of numbers.\", \"parameters\": {\"type\": \"dict\", \"properties\": {\"numbers\": {\"description\": \"The list of numbers.\", \"type\": \"List[Union[int, float]]\"}}}, \"required\": [\"numbers\"]}]\\n '}, {'role': 'user', 'content': 'Compute the standard deviation for a dataset with outliers.'}, {'role': 'assistant', 'content': \"To calculate the standard deviation, I'll need the specific numbers in your dataset. Could you please provide them?\"}], [{'role': 'system', 'content': '\\n You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\\n [{\"name\": \"auto_complete\", \"description\": \"Fetches auto-complete suggestions for a given word using an external API.\", \"parameters\": {\"type\": \"dict\", \"properties\": {\"word\": {\"description\": \"The term or phrase for which to get auto-complete suggestions.\", \"type\": \"str\", \"default\": \"bikini top\"}, \"currency\": {\"description\": \"The 3-letter currency code. Defaults to \\'USD\\'.\", \"type\": \"str, optional\", \"default\": \"USD\"}, \"country\": {\"description\": \"The 2-letter country code. Defaults to \\'US\\'.\", \"type\": \"str, optional\", \"default\": \"US\"}, \"language\": {\"description\": \"The 2-letter language code. Defaults to \\'en\\'.\", \"type\": \"str, optional\", \"default\": \"en\"}}}, \"required\": [\"word\"]}]\\n '}, {'role': 'user', 'content': 'Provide auto-complete suggestions for a phrase in German for a user in Germany using Euros.'}, {'role': 'assistant', 'content': \"To provide the auto-complete suggestions, could you please specify the phrase or term you'd like me to use?\"}], [{'role': 'system', 'content': '\\n You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\\n []\\n '}, {'role': 'user', 'content': \"What is the fuel type of the car with VIN number 'WAUDH78E08A022617'?\"}, {'role': 'assistant', 'content': \"Apologies, I'm unable to provide that information as I don't have real-time access to vehicle databases. Please consult a reliable automotive resource for accurate details.\"}]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dstrain_sft['train']['messages']" + ] + }, + { + "cell_type": "markdown", + "id": "20e5a049-10a0-4fdb-86ec-893391de6505", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-15T23:35:39.921466Z", + "iopub.status.busy": "2025-10-15T23:35:39.921206Z", + "iopub.status.idle": "2025-10-15T23:35:39.926534Z", + "shell.execute_reply": "2025-10-15T23:35:39.925820Z", + "shell.execute_reply.started": "2025-10-15T23:35:39.921444Z" + } + }, + "source": [ + "Now repeat the same process for the preference optimization dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a3ddddd7-b33d-4d7d-b1db-0430f74134b3", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:04.594426Z", + "iopub.status.busy": "2025-10-17T18:56:04.594023Z", + "iopub.status.idle": "2025-10-17T18:56:04.597596Z", + "shell.execute_reply": "2025-10-17T18:56:04.597105Z", + "shell.execute_reply.started": "2025-10-17T18:56:04.594396Z" + } + }, + "outputs": [], + "source": [ + "def generate_dpo_prompt(data_point):\n", + " \"\"\"\n", + " Generates a tool using prompt based on an input datapoint.\n", + " \n", + " Args:\n", + " data_point (dict): Dictionary containing target and meaning_representation keys\n", + " \n", + " Returns:\n", + " dict: Dictionary containing the formatted prompt\n", + " \"\"\"\n", + " full_prompt = f\"\"\"\n", + " You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance.\n", + " \"\"\"\n", + "\n", + " data_point[\"chosen\"] = [data_point[\"chosen_response\"]]\n", + " data_point[\"rejected\"] = [data_point[\"rejected_response\"]]\n", + " \n", + " return data_point" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8df2ae0d-08ad-44ed-bc6c-7446856f223c", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:05.173733Z", + "iopub.status.busy": "2025-10-17T18:56:05.173572Z", + "iopub.status.idle": "2025-10-17T18:56:05.179915Z", + "shell.execute_reply": "2025-10-17T18:56:05.179374Z", + "shell.execute_reply.started": "2025-10-17T18:56:05.173718Z" + } + }, + "outputs": [], + "source": [ + "ds_train_pref = ds_train_pref.map(\n", + " generate_dpo_prompt,\n", + " batched=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "952e5b75-8461-422d-b365-feb5accc5211", + "metadata": {}, + "source": [ + "\n", + "The HuggingFace TRL library expects DPO training data to specifically have the data labeled as `chosen` and `rejected`, so rename the training data fields to correspond with that format.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "035b0254-463f-4060-8f18-2139f6857b50", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:06.601990Z", + "iopub.status.busy": "2025-10-17T18:56:06.601797Z", + "iopub.status.idle": "2025-10-17T18:56:06.608206Z", + "shell.execute_reply": "2025-10-17T18:56:06.607749Z", + "shell.execute_reply.started": "2025-10-17T18:56:06.601973Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['tools', 'prompt', 'chosen', 'rejected'],\n", + " num_rows: 9000\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "ds_train_pref = ds_train_pref.remove_columns([\"chosen_response\", \"rejected_response\"])\n", + "ds_train_pref = ds_train_pref.rename_column(\"messages\", \"prompt\")\n", + "\n", + "print(ds_train_pref)" + ] + }, + { + "cell_type": "markdown", + "id": "14fa26eb-5cbb-4a1e-86b3-17588eceebc0", + "metadata": {}, + "source": [ + "Now upload your training data to S3 so it can be used by the SageMaker fully managed training job you are about to create." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7207abe5-d046-4adc-9f53-7f9f64f5dfba", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:07.554985Z", + "iopub.status.busy": "2025-10-17T18:56:07.554779Z", + "iopub.status.idle": "2025-10-17T18:56:08.921736Z", + "shell.execute_reply": "2025-10-17T18:56:08.921266Z", + "shell.execute_reply.started": "2025-10-17T18:56:07.554970Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dashtiam/miniconda3/lib/python3.12/site-packages/fsspec/registry.py:273: UserWarning: Your installed version of s3fs is very old and known to cause\n", + "severe performance issues, see also https://github.com/dask/dask/issues/10276\n", + "\n", + "To fix, you should specify a lower version bound on s3fs, or\n", + "update the current installation.\n", + "\n", + " warnings.warn(s3_msg)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "97446113643c420893e233f58fd40b26", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Creating json from Arrow format: 0%| | 0/15 [00:00, ?ba/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "888122b70459446b92dcac4fe57e20f6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Creating json from Arrow format: 0%| | 0/9 [00:00, ?ba/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data uploaded to: s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/train/dataset.json\n", + "DPO data uploaded to: s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/pref/dataset.json\n", + "View the dataset in S3 here: https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-783764584149/?region=us-east-1&prefix=datasets/nvidia_function_calling/\n" + ] + } + ], + "source": [ + "# save train_dataset to s3 using our SageMaker session\n", + "input_path = f's3://{sagemaker_session.default_bucket()}/datasets/nvidia_function_calling'\n", + "\n", + "# Save datasets to s3\n", + "\n", + "dstrain_sft[\"train\"].to_json(f\"{input_path}/train/dataset.json\", orient=\"records\")\n", + "sft_dataset_s3_path = f\"{input_path}/train/dataset.json\"\n", + "ds_train_pref[\"train\"].to_json(f\"{input_path}/pref/dataset.json\", orient=\"records\")\n", + "perf_dataset_s3_path = f\"{input_path}/pref/dataset.json\"\n", + "\n", + "print(f\"Training data uploaded to: {sft_dataset_s3_path}\")\n", + "print(f\"DPO data uploaded to: {perf_dataset_s3_path}\")\n", + "print(f\"View the dataset in S3 here: https://s3.console.aws.amazon.com/s3/buckets/{sagemaker_session.default_bucket()}/?region={sagemaker_session.boto_region_name}&prefix={input_path.split('/', 3)[-1]}/\")" + ] + }, + { + "cell_type": "markdown", + "id": "ed7b2fe1-d0eb-464e-9da1-5f1e2d8c9a1a", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-15T23:00:19.341400Z", + "iopub.status.busy": "2025-10-15T23:00:19.341135Z", + "iopub.status.idle": "2025-10-15T23:00:19.344445Z", + "shell.execute_reply": "2025-10-15T23:00:19.343920Z", + "shell.execute_reply.started": "2025-10-15T23:00:19.341380Z" + } + }, + "source": [ + "Here you will setup some basic parameters that will be inputs for training.\n", + "- `image_uri` is the Elastic Container Repository (ECR) URI that the training job will use\n", + "- `checkpoint_s3_path` is where the training job will store model checkpoints\n", + "- `job_prefix` is the prefix name for the training job" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c30b2505-cb71-4d7f-9f30-636455ddac6d", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:09.114940Z", + "iopub.status.busy": "2025-10-17T18:56:09.114775Z", + "iopub.status.idle": "2025-10-17T18:56:09.118507Z", + "shell.execute_reply": "2025-10-17T18:56:09.117991Z", + "shell.execute_reply.started": "2025-10-17T18:56:09.114926Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SFT Training Image URI: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6.0-gpu-py312\n", + "SFT Training Checkpoint Storage Path: s3://sagemaker-us-east-1-783764584149/function-calling-sft-checkpoints/checkpoints\n", + "SFT Training Job Name Prefix: model-trainer-distributed-function-calling-sft\n" + ] + } + ], + "source": [ + "from sagemaker.config import load_sagemaker_config\n", + "configs = load_sagemaker_config()\n", + "instance_type = \"ml.p4d.24xlarge\"\n", + "# image_uri = f\"658645717510.dkr.ecr.{sagemaker_session.boto_session.region_name}.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121\"\n", + "image_uri = sagemaker.image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=sagemaker_session.boto_session.region_name,\n", + " version=\"2.6.0\",\n", + " instance_type=instance_type,\n", + " image_scope=\"training\"\n", + ")\n", + "\n", + "print(f\"SFT Training Image URI: {image_uri}\")\n", + "\n", + "checkpoint_s3_path = f\"s3://{bucket_name}/function-calling-sft-checkpoints/checkpoints\"\n", + "print(f\"SFT Training Checkpoint Storage Path: {checkpoint_s3_path}\")\n", + "\n", + "job_prefix = f\"model-trainer-distributed-function-calling-sft\"\n", + "print(f\"SFT Training Job Name Prefix: {job_prefix}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7d8fbe92-739a-485a-bf19-d6f91b22efc4", + "metadata": {}, + "source": [ + "Next, you will build your training job configuration using the SageMaker SDK's [ModelTrainer API](https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html).\n", + "\n", + "If you have a MLflow Tracking Server, you can uncomment and configure the `tracking_server_arn` section and supply the ARN of your tracking server.\n", + "\n", + "the `training_recipe` value refers to one of the prebuilt training recipe configurations in the `scripts` folder of this example. The training script will automatically pull in this YAML configuration to retrieve training parameters.\n", + "\n", + "The training configuration outlined here will train a [Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) model using [Spectrum fine tuning](https://arxiv.org/html/2406.06623v1) at on 50% of the layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fa6d12b-f820-4214-83dd-f4759b0f456c", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:10.334868Z", + "iopub.status.busy": "2025-10-17T18:56:10.334573Z", + "iopub.status.idle": "2025-10-17T18:56:10.682451Z", + "shell.execute_reply": "2025-10-17T18:56:10.681958Z", + "shell.execute_reply.started": "2025-10-17T18:56:10.334841Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[10/21/25 09:26:51] WARNING SageMaker session not provided. Using default Session. model_trainer.py:501\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:26:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m SageMaker session not provided. Using default Session. \u001b]8;id=576568;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=996917;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#501\u001b\\\u001b[2m501\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 09:26:52] WARNING Role not provided. Using default role: model_trainer.py:505\n", + " arn:aws:iam::783764584149:role/amin-macbook \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:26:52]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m Role not provided. Using default role: \u001b]8;id=91428;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=971011;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#505\u001b\\\u001b[2m505\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m arn:aws:iam::\u001b[1;36m783764584149\u001b[0m:role/amin-macbook \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 09:26:53] WARNING OutputDataConfig not provided. Using default: model_trainer.py:567\n", + " s3_output_path='s3://sagemaker-us-east-1-783764584149/model-train \n", + " er-distributed-function-calling-sft' kms_key_id=None \n", + " compression_type='GZIP' \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:26:53]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m OutputDataConfig not provided. Using default: \u001b]8;id=316818;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=442636;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#567\u001b\\\u001b[2m567\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;215;175;0ms3_output_path\u001b[0m=\u001b[38;2;0;135;0m's3://sagemaker-us-east-1-783764584149/model-train\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;0;135;0mer-distributed-function-calling-sft'\u001b[0m \u001b[38;2;215;175;0mkms_key_id\u001b[0m=\u001b[3;38;2;225;0;225mNone\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;215;175;0mcompression_type\u001b[0m=\u001b[38;2;0;135;0m'GZIP'\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO Training image URI: model_trainer.py:588\n", + " 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6 \n", + " .0-gpu-py312 \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;0;105;255mINFO \u001b[0m Training image URI: \u001b]8;id=893674;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=209011;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#588\u001b\\\u001b[2m588\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m763104351884.\u001b[0mdkr.ecr.us-east-\u001b[1;36m1.\u001b[0mamazonaws.com/pytorch-training:\u001b[1;36m2.6\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m .\u001b[1;36m0\u001b[0m-gpu-py312 \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from sagemaker.modules.train import ModelTrainer\n", + "from sagemaker.modules.configs import Compute, SourceCode, InputData, StoppingCondition, CheckpointConfig\n", + "\n", + "env = {}\n", + "env[\"FI_PROVIDER\"] = \"efa\"\n", + "env[\"NCCL_PROTO\"] = \"simple\"\n", + "env[\"NCCL_SOCKET_IFNAME\"] = \"eth0\"\n", + "env[\"NCCL_IB_DISABLE\"] = \"1\"\n", + "env[\"NCCL_DEBUG\"] = \"WARN\"\n", + "env[\"HF_token\"] = os.environ['hf_token']\n", + "env[\"data_location\"] = sft_dataset_s3_path\n", + "env[\"training_recipe\"] = \"recipes/sft-spectrum-qwen3-1.7b.yaml\"\n", + "\n", + "# MLFlow tracker\n", + "#tracking_server_arn = \"\"\n", + "#env[\"MLFLOW_TRACKING_ARN\"] = tracking_server_arn\n", + "\n", + "compute = Compute(\n", + " instance_count=1,\n", + " instance_type= instance_type,\n", + " volume_size_in_gb=96,\n", + " keep_alive_period_in_seconds=3600,\n", + ")\n", + "\n", + "hyperparameters = {\n", + " \"dataset_path\": \"/opt/ml/input/data/dataset\",\n", + " \"model_dir\": \"/opt/ml/model\",\n", + "}\n", + "\n", + "source_code = SourceCode(\n", + " source_dir=\"./scripts\",\n", + " requirements=\"requirements.txt\",\n", + " entry_script=\"run_training_sft.sh\",\n", + ")\n", + "\n", + "model_trainer = ModelTrainer(\n", + " training_image=image_uri,\n", + " compute=compute,\n", + " hyperparameters=hyperparameters,\n", + " environment=env,\n", + " source_code=source_code,\n", + " stopping_condition=StoppingCondition(\n", + " max_runtime_in_seconds=90000,\n", + " ),\n", + " checkpoint_config=CheckpointConfig(\n", + " s3_uri=f\"{checkpoint_s3_path}/{job_prefix}\",\n", + " ),\n", + " base_job_name=job_prefix\n", + "\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e5ef45ce-0f2a-4fdc-93e7-82597d9707ca", + "metadata": {}, + "source": [ + "### Configure Input Data Channels" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "45555111-e5ae-4bcd-ace5-32638b3373e4", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:11.806337Z", + "iopub.status.busy": "2025-10-17T18:56:11.806160Z", + "iopub.status.idle": "2025-10-17T18:56:11.809299Z", + "shell.execute_reply": "2025-10-17T18:56:11.808892Z", + "shell.execute_reply.started": "2025-10-17T18:56:11.806306Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/train/dataset.json'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sft_dataset_s3_path" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9dbccb09-0e82-49fc-b8e2-8f3b9b2d35c8", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:12.559327Z", + "iopub.status.busy": "2025-10-17T18:56:12.559046Z", + "iopub.status.idle": "2025-10-17T18:56:12.561742Z", + "shell.execute_reply": "2025-10-17T18:56:12.561325Z", + "shell.execute_reply.started": "2025-10-17T18:56:12.559301Z" + } + }, + "outputs": [], + "source": [ + "training_data = InputData(\n", + " channel_name=\"training_dataset\",\n", + " data_source=sft_dataset_s3_path,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "de668e41-ba92-4ecf-be89-6499cb4ecc42", + "metadata": {}, + "source": [ + "### Begin SFT Training\n", + "\n", + "Now you can start your training job using ModelTrainer's `.train()` API. It will create a SageMaker fully managed training job and stream the log outputs until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "1638e14e-3c64-46e3-a5d8-8dd26a03d55e", + "metadata": { + "execution": { + "iopub.execute_input": "2025-10-17T18:56:13.826193Z", + "iopub.status.busy": "2025-10-17T18:56:13.826035Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[10/21/25 09:26:55] WARNING key_prefix is only applicable when data_source is a local file model_trainer.py:896\n", + " path. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:26:55]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m key_prefix is only applicable when data_source is a local file \u001b]8;id=224711;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=636732;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#896\u001b\\\u001b[2m896\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m path. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 09:27:02] INFO Creating training_job resource. resources.py:28340\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:27:02]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;0;105;255mINFO \u001b[0m Creating training_job resource. \u001b]8;id=583824;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker_core/main/resources.py\u001b\\\u001b[2mresources.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=588087;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker_core/main/resources.py#28340\u001b\\\u001b[2m28340\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 09:27:02] WARNING Not displaing the training container logs as 'wait' is set to model_trainer.py:834\n", + " False. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 09:27:02]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m Not displaing the training container logs as \u001b[38;2;0;135;0m'wait'\u001b[0m is set to \u001b]8;id=910984;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=663320;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#834\u001b\\\u001b[2m834\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[3;38;2;215;0;0mFalse\u001b[0m. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model_trainer.train(input_data_config=[training_data], wait=False)" + ] + }, + { + "cell_type": "markdown", + "id": "edbaa7ea-85c9-428c-a4cf-5a1ef9d59f88", + "metadata": {}, + "source": [ + "## SFT Training Complete\n", + "Now that your SFT training job has completed, you can retrieve the tuned artifact and use it for DPO training as a follow-up step." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b0de460e-ce56-49e2-b1e4-2377d35bd868", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last training job name: model-trainer-distributed-function-calling-sft-20251021092655\n", + "Final SFT Model Artifact Location: s3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-sft/model-trainer-distributed-function-calling-sft-20251021092655/output/model.tar.gz\n" + ] + } + ], + "source": [ + "from utils import get_last_job_name\n", + "\n", + "job_name = get_last_job_name(job_prefix)\n", + "print(f\"Last training job name: {job_name}\")\n", + "\n", + "if default_prefix:\n", + " model_data=f\"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz\"\n", + "else:\n", + " model_data=f\"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz\"\n", + "\n", + "print(f\"Final SFT Model Artifact Location: {model_data}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "5af9f077", + "metadata": {}, + "outputs": [], + "source": [ + "model_data = 's3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-sft/model-trainer-distributed-function-calling-sft-20251021092640/output/model.tar.gz'" + ] + }, + { + "cell_type": "markdown", + "id": "2b504aea-ad14-49c6-8872-f16c7c169a6d", + "metadata": {}, + "source": [ + "# Run Direct Preference Optimization (DPO) training on your SFT Model\n", + "This section will configure default values for DPO similar to what was setup for SFT earlier." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "4d07fc86-9d53-4fbc-b3ca-8997c23e6816", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DPO Training Image URI: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6.0-gpu-py312\n", + "DPO Training Checkpoint Storage Path: s3://sagemaker-us-east-1-783764584149/function-calling-dpo-checkpoints/checkpoints\n", + "DPO Training Job Name Prefix: model-trainer-distributed-function-calling-dpo\n" + ] + } + ], + "source": [ + "# image_uri = f\"658645717510.dkr.ecr.{sagemaker_session.boto_session.region_name}.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121\"\n", + "instance_type = \"ml.p4d.24xlarge\"\n", + "image_uri = sagemaker.image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=sagemaker_session.boto_session.region_name,\n", + " version=\"2.6.0\",\n", + " instance_type=instance_type,\n", + " image_scope=\"training\"\n", + ")\n", + "\n", + "print(f\"DPO Training Image URI: {image_uri}\")\n", + "\n", + "checkpoint_s3_path = f\"s3://{bucket_name}/function-calling-dpo-checkpoints/checkpoints\"\n", + "print(f\"DPO Training Checkpoint Storage Path: {checkpoint_s3_path}\")\n", + "\n", + "job_prefix = f\"model-trainer-distributed-function-calling-dpo\"\n", + "print(f\"DPO Training Job Name Prefix: {job_prefix}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b47862ad-e6dd-46ef-a0e7-8f26b2f6b0c2", + "metadata": {}, + "source": [ + "Note that in this `ModelTrainer` configuration, the recipe configuration has changed from what was used for SFT as well as the entry script for training. If you remove `model_location` from environment it will run DPO on base model specified in the training recipe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6578fe89-d170-4152-a263-e38238583864", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[10/21/25 11:18:24] WARNING SageMaker session not provided. Using default Session. model_trainer.py:501\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 11:18:24]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m SageMaker session not provided. Using default Session. \u001b]8;id=215951;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=362503;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#501\u001b\\\u001b[2m501\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 11:18:25] WARNING Role not provided. Using default role: model_trainer.py:505\n", + " arn:aws:iam::783764584149:role/amin-macbook \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 11:18:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m Role not provided. Using default role: \u001b]8;id=960185;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=713632;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#505\u001b\\\u001b[2m505\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m arn:aws:iam::\u001b[1;36m783764584149\u001b[0m:role/amin-macbook \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 11:18:26] WARNING OutputDataConfig not provided. Using default: model_trainer.py:567\n", + " s3_output_path='s3://sagemaker-us-east-1-783764584149/model-train \n", + " er-distributed-function-calling-dpo' kms_key_id=None \n", + " compression_type='GZIP' \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 11:18:26]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m OutputDataConfig not provided. Using default: \u001b]8;id=339629;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=336930;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#567\u001b\\\u001b[2m567\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;215;175;0ms3_output_path\u001b[0m=\u001b[38;2;0;135;0m's3://sagemaker-us-east-1-783764584149/model-train\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;0;135;0mer-distributed-function-calling-dpo'\u001b[0m \u001b[38;2;215;175;0mkms_key_id\u001b[0m=\u001b[3;38;2;225;0;225mNone\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[38;2;215;175;0mcompression_type\u001b[0m=\u001b[38;2;0;135;0m'GZIP'\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO Training image URI: model_trainer.py:588\n", + " 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6 \n", + " .0-gpu-py312 \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;0;105;255mINFO \u001b[0m Training image URI: \u001b]8;id=765488;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=380570;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#588\u001b\\\u001b[2m588\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m763104351884.\u001b[0mdkr.ecr.us-east-\u001b[1;36m1.\u001b[0mamazonaws.com/pytorch-training:\u001b[1;36m2.6\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m .\u001b[1;36m0\u001b[0m-gpu-py312 \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from sagemaker.config import load_sagemaker_config\n", + "configs = load_sagemaker_config()\n", + "from sagemaker.modules.train import ModelTrainer\n", + "from sagemaker.modules.configs import Compute, SourceCode, InputData, StoppingCondition, CheckpointConfig\n", + "\n", + "env = {}\n", + "env[\"FI_PROVIDER\"] = \"efa\"\n", + "env[\"NCCL_PROTO\"] = \"simple\"\n", + "env[\"NCCL_SOCKET_IFNAME\"] = \"eth0\"\n", + "env[\"NCCL_IB_DISABLE\"] = \"1\"\n", + "env[\"NCCL_DEBUG\"] = \"WARN\"\n", + "env[\"HF_token\"] = os.environ['hf_token']\n", + "env[\"data_location\"] = perf_dataset_s3_path\n", + "env[\"model_location\"] = model_data\n", + "env[\"training_recipe\"] = \"recipes/sft-dpo-qwen3-1.7b.yaml\"\n", + "\n", + "# MLFlow tracker\n", + "#tracking_server_arn = \"\"\n", + "#env[\"MLFLOW_TRACKING_ARN\"] = tracking_server_arn\n", + "\n", + "compute = Compute(\n", + " instance_count=1,\n", + " instance_type= instance_type,\n", + " volume_size_in_gb=96,\n", + " keep_alive_period_in_seconds=3600,\n", + ")\n", + "\n", + "hyperparameters = {\n", + " \"dataset_path\": \"/opt/ml/input/data/dataset\",\n", + " \"model_dir\": \"/opt/ml/model\",\n", + "}\n", + "\n", + "source_code = SourceCode(\n", + " source_dir=\"./scripts\",\n", + " requirements=\"requirements.txt\",\n", + " entry_script=\"run_training_dpo.sh\",\n", + ")\n", + "\n", + "model_trainer = ModelTrainer(\n", + " training_image=image_uri,\n", + " compute=compute,\n", + " hyperparameters=hyperparameters,\n", + " environment=env,\n", + " source_code=source_code,\n", + " stopping_condition=StoppingCondition(\n", + " max_runtime_in_seconds=90000,\n", + " ),\n", + " checkpoint_config=CheckpointConfig(\n", + " s3_uri=f\"{checkpoint_s3_path}/{job_prefix}\",\n", + " ),\n", + " base_job_name=job_prefix\n", + "\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "17c402ce-b7df-4254-9d82-a25ff9dd755d", + "metadata": {}, + "source": [ + "### Configure Training Data Channels" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "5833f8da-ac03-4ea8-bc3e-3f63ab786064", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/pref/dataset.json'" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "perf_dataset_s3_path" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "49ae3861-f1c6-42e5-a22b-e87b125a285f", + "metadata": {}, + "outputs": [], + "source": [ + "training_data = InputData(\n", + " channel_name=\"training_dataset\",\n", + " data_source=perf_dataset_s3_path,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3dc6a1c1-02c9-4ed0-9982-f157edb9794f", + "metadata": {}, + "source": [ + "### Begin DPO Training" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "2a2ca6e4-f124-4457-b43b-02e0dcc1b6bc", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
WARNING key_prefix is only applicable when data_source is a local file model_trainer.py:896\n", + " path. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m key_prefix is only applicable when data_source is a local file \u001b]8;id=195653;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=241621;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#896\u001b\\\u001b[2m896\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m path. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 11:18:33] INFO Creating training_job resource. resources.py:28340\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 11:18:33]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;0;105;255mINFO \u001b[0m Creating training_job resource. \u001b]8;id=744311;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker_core/main/resources.py\u001b\\\u001b[2mresources.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=320330;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker_core/main/resources.py#28340\u001b\\\u001b[2m28340\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10/21/25 11:18:34] WARNING Not displaing the training container logs as 'wait' is set to model_trainer.py:834\n", + " False. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10/21/25 11:18:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[1;38;2;215;175;0mWARNING \u001b[0m Not displaing the training container logs as \u001b[38;2;0;135;0m'wait'\u001b[0m is set to \u001b]8;id=388411;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py\u001b\\\u001b[2mmodel_trainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=378168;file:///Users/dashtiam/miniconda3/lib/python3.12/site-packages/sagemaker/modules/train/model_trainer.py#834\u001b\\\u001b[2m834\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[3;38;2;215;0;0mFalse\u001b[0m. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model_trainer.train(input_data_config=[training_data], wait=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "3451a7f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last training job name: model-trainer-distributed-function-calling-dpo-20251021103550\n", + "Final DPO Model Artifact Location: s3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-dpo/model-trainer-distributed-function-calling-dpo-20251021103550/output/model.tar.gz\n" + ] + } + ], + "source": [ + "from utils import get_last_job_name\n", + "\n", + "job_name = get_last_job_name(job_prefix)\n", + "print(f\"Last training job name: {job_name}\")\n", + "\n", + "if default_prefix:\n", + " model_data=f\"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz\"\n", + "else:\n", + " model_data=f\"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz\"\n", + "\n", + "print(f\"Final DPO Model Artifact Location: {model_data}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f28158a8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero1.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000..f54aff9 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero3.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000..b5a1201 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/create_preference_dataset.py b/3_distributed_training/function-calling-sft-dpo/scripts/create_preference_dataset.py new file mode 100644 index 0000000..22a7705 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/create_preference_dataset.py @@ -0,0 +1,207 @@ +from dataclasses import dataclass, field +import logging +import os +import time +from typing import cast +import re + +import torch +from datasets import load_dataset +from tqdm.auto import tqdm +from trl import TrlParser +from vllm import LLM, SamplingParams +from datasets import Dataset +from peft import LoraConfig, AutoPeftModelForCausalLM + +logger = logging.getLogger(__name__) + +@dataclass +class CandidateArguments: + generation_model_name_or_path: str = field( + default=None, + metadata={ + 'help': 'Huggingface model name or path to model directory, for the model that will be used for generation, defaults to SFT model or previous iteration model.' + }, + ) + dataset_id: str = field( + default=None, + metadata={ + 'help': 'Path to the input dataset, that will be used to generate candidates, defaults to previous iteration output dataset.' + }, + ) + sample_size: int = field( + default=None, + metadata={ + 'help': 'Number of samples to generate, defaults to as many as possible.' + }, + ) + prompt_column: str = field( + default='question', + metadata={'help': 'Column name in the input dataset that contains the messages.'}, + ) + answer_column: str = field( + default='answer', + metadata={'help': 'Column name in the input dataset that contains the answer.'}, + ) + system_prompt: str = field( + default= """Solve the given high school math problem by providing a clear explanation of each step leading to the final solution. + +Provide a detailed breakdown of your calculations, beginning with an explanation of the problem and describing how you derive each formula, value, or conclusion. Use logical steps that build upon one another, to arrive at the final answer in a systematic manner. + +# Steps + +1. **Understand the Problem**: Restate the given math problem and clearly identify the main question and any important given values. +2. **Set Up**: Identify the key formulas or concepts that could help solve the problem (e.g., algebraic manipulation, geometry formulas, trigonometric identities). +3. **Solve Step-by-Step**: Iteratively progress through each step of the math problem, justifying why each consecutive operation brings you closer to the solution. +4. **Double Check**: If applicable, double check the work for accuracy and sense, and mention potential alternative approaches if any. +5. **Final Answer**: Provide the numerical or algebraic solution clearly, accompanied by appropriate units if relevant. + +# Notes + +- Always clearly define any variable or term used. +- Wherever applicable, include unit conversions or context to explain why each formula or step has been chosen. +- Assume the level of mathematics is suitable for high school, and avoid overly advanced math techniques unless they are common at that level. +""", + metadata={'help': 'System prompt to use for generation.'}, + ) + num_solutions: int = field( + default=5, + metadata={'help': 'Number of solutions to generate for each input.'}, + ) + batch_size: int = field( + default=1, + metadata={'help': 'Batch size for generation.'}, + ) + max_new_tokens: int = field( + default=2048, + metadata={'help': 'Maximum number of new tokens to generate.'}, + ) + temperature: float = field( + default=0.7, + metadata={'help': 'Temperature for generation.'}, + ) + top_p: float = field( + default=1.0, + metadata={'help': 'Top-p for generation.'}, + ) + +def score_solutions( + candidate_result: str, + ground_truth_result: str, +) -> bool: + # finds the answer in the candidate result + regex_pattern = r'\b\d+\b' + match = re.findall(regex_pattern, candidate_result) + + if match: + return match[-1] == ground_truth_result + else: + return False + + +def vllm_create_candidates( + dataset: Dataset, + model_name_or_path: str, + num_solutions: int, + max_new_tokens: int, + batch_size: int = 1, + prompt_column: str = 'prompt', + system_prompt: str = None, + answer_column: str = 'answer', + sample_size: int = None, + **kwargs, +) -> Dataset: + + # Loads the model on all available GPUs with vLLM + llm = LLM( + model=model_name_or_path, + tokenizer=model_name_or_path, + tensor_parallel_size=torch.cuda.device_count(), + max_model_len=4096, + ) + # formats the prompt using the system prompt and the prompt column + tokenizer = llm.get_tokenizer() + def format_prompt(s): + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": s[prompt_column]} + ] + return {"prompt": tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), "messages": messages} + + dataset = dataset.map(format_prompt) + # print the first prompt + print('First prompt:', dataset['prompt'][0]) + + # set sampling params + sampling_params = SamplingParams( + max_tokens=max_new_tokens, + n=num_solutions, + temperature=kwargs.get('temperature', 1.0), + top_p=kwargs.get('top_p', 1), + ) + + # Iterate over the dataset with batch size to generate candidates and create preference pairs based on the correct answer and ground truth + preference_dataset = [] + for i in tqdm(range(0, len(dataset), batch_size), desc=f'Generating solutions: Already generated {len(preference_dataset)} preference pairs'): + batch = dataset[i : i + batch_size] + # Generate `num_solutions` candidates per batch + result = llm.generate(batch['prompt'], sampling_params, use_tqdm=False) + for j in range(0, len(batch['prompt'])): + # iterate each candidate and check if it is correct + preference_pair = { + "system_prompt": system_prompt, + "prompt": batch[prompt_column][j], + "ground_truth": batch[answer_column][j], + } + for cand in result[j].outputs: + # check if the candidate is correct + cand_score = score_solutions(candidate_result=cand.text, ground_truth_result=batch[answer_column][j]) + if cand_score and preference_pair.get('chosen',None) is None: + preference_pair['chosen'] = cand.text + elif not cand_score and preference_pair.get('rejected',None) is None: + preference_pair['rejected'] = cand.text + # check if the pair is complete to prevent overwriting + if preference_pair.get('chosen',None) and preference_pair.get('rejected',None): + continue + + # check is the generated candidates lead to a complete preference pair + if preference_pair.get('chosen',None) and preference_pair.get('rejected',None): + print(f'Found preference pair, adding to dataset.') + preference_dataset.append(preference_pair) + + print(f'Generated {len(preference_dataset)} preference pairs') + if len(preference_dataset) >= sample_size: + break + return Dataset.from_list(preference_dataset) + + +def main(): + parser = TrlParser((CandidateArguments)) + script_args = parser.parse_args_and_config()[0] + script_args = cast(CandidateArguments, script_args) + + # load dataset and tokenizer + dataset = load_dataset(script_args.dataset_id, split='train') + print(f'Generating {script_args.num_solutions} solutions for {len(dataset)} prompts...') + + start_time = time.time() + candidates_ds = vllm_create_candidates( + dataset, + model_name_or_path=script_args.generation_model_name_or_path, + num_solutions=script_args.num_solutions, + max_new_tokens=script_args.max_new_tokens, + batch_size=script_args.batch_size, + prompt_column=script_args.prompt_column, + answer_column=script_args.answer_column, + system_prompt=script_args.system_prompt, + temperature=script_args.temperature, + top_p=script_args.top_p, + sample_size=script_args.sample_size if script_args.sample_size is not None else len(dataset), + ) + print(f'Generated {len(dataset) * script_args.num_solutions} solutions in {time.time() - start_time:.2f} seconds.') + + save_dataset_id = f"{script_args.generation_model_name_or_path.replace('/', '-')[:40]}-{script_args.dataset_id.replace('/', '-')[:40]}-candidates" + candidates_ds.push_to_hub(save_dataset_id) + +if __name__ == '__main__': + main() diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/merge_adapter_weights.py b/3_distributed_training/function-calling-sft-dpo/scripts/merge_adapter_weights.py new file mode 100644 index 0000000..d61db76 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/merge_adapter_weights.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field +import tempfile +from typing import Optional +import torch +from peft import AutoPeftModelForCausalLM +from transformers import AutoTokenizer, HfArgumentParser +from huggingface_hub import HfApi + +# Example usage: +# python scripts/merge_adapter_weights.py --peft_model_id falcon-180b-lora-fa --output_dir merged-weights --save_tokenizer True + +def save_model(model_path_or_id, save_dir, save_tokenizer=True): + model = AutoPeftModelForCausalLM.from_pretrained( + model_path_or_id, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + ) + # Merge LoRA and base model and save + model = model.merge_and_unload() + model.save_pretrained(save_dir, safe_serialization=True, max_shard_size="3GB") + + # save tokenizer + if save_tokenizer: + tokenizer = AutoTokenizer.from_pretrained(model_path_or_id) + tokenizer.save_pretrained(save_dir) + + +@dataclass +class ScriptArguments: + peft_model_id: str = field(metadata={"help": "model id or path to model"}) + output_dir: Optional[str] = field(default="merged-weights", metadata={"help": "where the merged model should be saved"}) + save_tokenizer: Optional[bool] = field(default=True, metadata={"help": "whether to save the tokenizer"}) + push_to_hub: Optional[bool] = field(default=False, metadata={"help": "whether to push the model to the hub"}) + repository_id: Optional[str] = field(default=None, metadata={"help": "the model name"}) + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] +api = HfApi() + +if args.push_to_hub: + repo_id = args.repository_id if args.repository_id else args.peft_model_id.split('/')[-1] + with tempfile.TemporaryDirectory() as temp_dir: + save_model(args.peft_model_id, temp_dir, args.save_tokenizer) + api.upload_large_folder( + folder_path=temp_dir, + repo_id=repo_id, + repo_type="model", + ) +else: + save_model(args.peft_model_id, args.output_dir, args.save_tokenizer) \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-0.6b.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-0.6b.yaml new file mode 100644 index 0000000..e92da48 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-0.6b.yaml @@ -0,0 +1,41 @@ +# Model arguments +model_name_or_path: Qwen/Qwen3-0.6B +tokenizer_name_or_path: Qwen/Qwen3-0.6B +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Training arguments +beta: 0.1 +max_length: 1536 +max_prompt_length: 768 +loss_type: sigmoid # default loss, alternatives: https://huggingface.co/docs/trl/dpo_trainer#loss-functions +num_train_epochs: 10 +per_device_train_batch_size: 8 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-7 +lr_scheduler_type: constant +warmup_ratio: 0.03 +#weight_decay: 0.01 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-1.7b.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-1.7b.yaml new file mode 100644 index 0000000..ff5a488 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-dpo-qwen3-1.7b.yaml @@ -0,0 +1,41 @@ +# Model arguments +model_name_or_path: Qwen/Qwen3-1.7B +tokenizer_name_or_path: Qwen/Qwen3-1.7B +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Training arguments +beta: 0.1 +max_length: 1536 +max_prompt_length: 768 +loss_type: sigmoid # default loss, alternatives: https://huggingface.co/docs/trl/dpo_trainer#loss-functions +num_train_epochs: 10 +per_device_train_batch_size: 8 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-7 +lr_scheduler_type: constant +warmup_ratio: 0.03 +#weight_decay: 0.01 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-dpo-llama-3-2-3b.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-dpo-llama-3-2-3b.yaml new file mode 100644 index 0000000..9d0581b --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-dpo-llama-3-2-3b.yaml @@ -0,0 +1,40 @@ +# Model arguments +model_name_or_path: meta-llama/Llama-3.2-3B-Instruct +tokenizer_name_or_path: meta-llama/Llama-3.2-3B-Instruct +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Training arguments +beta: 0.1 +max_length: 1536 +max_prompt_length: 768 +loss_type: sigmoid # default loss, alternatives: https://huggingface.co/docs/trl/dpo_trainer#loss-functions +num_train_epochs: 10 +per_device_train_batch_size: 8 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-7 +lr_scheduler_type: constant +warmup_ratio: 0.03 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-llama-3.2-3b-instruct.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-llama-3.2-3b-instruct.yaml new file mode 100644 index 0000000..0d9e999 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-llama-3.2-3b-instruct.yaml @@ -0,0 +1,41 @@ +# Model arguments +model_name_or_path: meta-llama/Llama-3.2-3B-Instruct +tokenizer_name_or_path: meta-llama/Llama-3.2-3B-Instruct +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true +trust_remote_code: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Spectrum arguments +spectrum_config_path: /opt/ml/input/data/code/spectrum-layer/snr_results_meta-llama-Llama-3.2-3B-Instruct_unfrozenparameters_50percent.yaml + +# Training arguments +num_train_epochs: 10 +per_device_train_batch_size: 8 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-5 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +#weight_decay: 0.01 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" # "epoch" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-0.6b.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-0.6b.yaml new file mode 100644 index 0000000..8a2dcb3 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-0.6b.yaml @@ -0,0 +1,41 @@ +# Model arguments +model_name_or_path: Qwen/Qwen3-0.6B +tokenizer_name_or_path: Qwen/Qwen3-0.6B +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true +trust_remote_code: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Spectrum arguments +spectrum_config_path: /opt/ml/input/data/code/spectrum-layer/snr_results_Qwen-Qwen3-0.6B_unfrozenparameters_50percent.yaml + +# Training arguments +num_train_epochs: 10 +per_device_train_batch_size: 4 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-5 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +#weight_decay: 0.01 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" # "epoch" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-1.7b.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-1.7b.yaml new file mode 100644 index 0000000..1dcf213 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/recipes/sft-spectrum-qwen3-1.7b.yaml @@ -0,0 +1,41 @@ +# Model arguments +model_name_or_path: Qwen/Qwen3-1.7B +tokenizer_name_or_path: Qwen/Qwen3-1.7B +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true +trust_remote_code: true + +model_download_location: /opt/ml/input/model +dataset_local_location: /opt/ml/input/data/training_dataset/ +output_dir: /opt/ml/model/ #/opt/ml/output + +# Dataset arguments +dataset_id_or_path: /opt/ml/input/data/training_dataset/ +max_length: 2048 +packing: true + +# Spectrum arguments +spectrum_config_path: /opt/ml/input/data/code/spectrum-layer/snr_results_Qwen-Qwen3-1.7B_unfrozenparameters_50percent.yaml + +# Training arguments +num_train_epochs: 10 +per_device_train_batch_size: 4 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +learning_rate: 5.0e-5 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +#weight_decay: 0.01 + +# Logging arguments +logging_strategy: steps +logging_steps: 5 +report_to: +- none +save_strategy: "no" # "epoch" +seed: 42 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/requirements.txt b/3_distributed_training/function-calling-sft-dpo/scripts/requirements.txt new file mode 100644 index 0000000..4c7d130 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/requirements.txt @@ -0,0 +1,17 @@ +torch==2.6.0 +torchvision==0.21.0 +torchaudio==2.6.0 +accelerate==1.10.1 +peft==0.15.2 +transformers==4.51.1 +triton==3.2.0 +trl==0.18.0 +bitsandbytes==0.45.5 +deepspeed==0.16.4 +datasets==3.6.0 +liger-kernel==0.5.1 +flash-attn==2.7.3 +huggingface_hub +vllm +hf_transfer +protobuf==5.28.3 \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/run_dpo.py b/3_distributed_training/function-calling-sft-dpo/scripts/run_dpo.py new file mode 100644 index 0000000..0ad43a6 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/run_dpo.py @@ -0,0 +1,222 @@ +import logging +import os +import torch +from transformers import ( + AutoModelForCausalLM, + set_seed, +) +from dataclasses import dataclass +from datetime import datetime +from distutils.util import strtobool +import logging +import os +from typing import Optional + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + set_seed, + BitsAndBytesConfig, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import is_liger_kernel_available +from trl import TrlParser, ModelConfig, get_peft_config +from datasets import load_dataset +from trl import ( + DPOTrainer, + DPOConfig, + TrlParser, + get_peft_config, + ModelConfig, +) + +from datasets import load_dataset + + +######################## +# Custom dataclasses +######################## +@dataclass +class ScriptArguments: + dataset_id_or_path: str + dataset_splits: str = "train" + tokenizer_name_or_path: str = "/opt/ml/model" + model_download_location: str = "/opt/ml/model" + dataset_local_location: str = "/opt/ml/input/training_data" + + +######################## +# Setup logging +######################## +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) +logger.addHandler(handler) + +######################## +# Helper functions +######################## + + +def get_checkpoint(training_args: DPOConfig): + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + return last_checkpoint + + +def dpo_function( + model_args: ModelConfig, script_args: ScriptArguments, training_args: DPOConfig +): + ######################### + # Log parameters + ######################### + logger.info(f"Model parameters {model_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + ############### + # Load datasets + ############### + if script_args.dataset_id_or_path.endswith('.json'): + train_dataset = load_dataset( + 'json', data_files=script_args.dataset_local_location, split=script_args.dataset_splits + ) + else: + train_dataset = load_dataset( + script_args.dataset_local_location, split=script_args.dataset_splits + ) + + logger.info( + f'Loaded dataset with {len(train_dataset)} samples and the following features: {train_dataset.features}' + ) + + ################ + # Load tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + script_args.model_download_location, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ####################################### + # Load the model and/or reference model + ####################################### + + model_kwargs = dict( + revision=model_args.model_revision, # What revision from Huggingface to use, defaults to main + trust_remote_code=model_args.trust_remote_code, # Whether to trust the remote code, this also you to fine-tune custom architectures + attn_implementation=model_args.attn_implementation, # What attention implementation to use, defaults to flash_attention_2 + torch_dtype=( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ), # What torch dtype to use, defaults to auto + use_cache=False if training_args.gradient_checkpointing else True, # Whether + low_cpu_mem_usage=( + True + if not strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")) + else None + ), # Reduces memory usage on CPU for loading the model + ) + + # Check which training method to use and if 4-bit quantization is needed + if model_args.load_in_4bit: + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=model_kwargs["torch_dtype"], + bnb_4bit_quant_storage=model_kwargs["torch_dtype"], + ) + if model_args.use_peft: + peft_config = get_peft_config(model_args) + else: + peft_config = None + + # Policy Model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_download_location, **model_kwargs + ) + # Checks wether we use adapters for reference model or not + if peft_config is None: + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_download_location, **model_kwargs + ) + else: + model_ref = None + + ######################### + # Instantiate DPO trainer + ######################### + trainer = DPOTrainer( + model, + ref_model=model_ref, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + ############### + # Training loop + ############### + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") + + # Train the model + logger.info( + f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***' + ) + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + # Log and save metrics + metrics = train_result.metrics + metrics["train_samples"] = len(train_dataset) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + logger.info("*** Training complete ***") + + ################################## + # Save model and create model card + ################################## + + logger.info("*** Save model ***") + if trainer.is_fsdp_enabled and peft_config: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + training_args.distributed_state.wait_for_everyone() # wait for all processes to load + + tokenizer.save_pretrained(training_args.output_dir) + logger.info(f"Tokenizer saved to {training_args.output_dir}") + + logger.info("*** Training complete! ***") + + +def main(): + parser = TrlParser((ModelConfig, ScriptArguments, DPOConfig)) + model_args, script_args, training_args = parser.parse_args_and_config() + + # Set seed for reproducibility + set_seed(training_args.seed) + + # Run the main training loop + dpo_function(model_args, script_args, training_args) + + +if __name__ == "__main__": + main() diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/run_sft.py b/3_distributed_training/function-calling-sft-dpo/scripts/run_sft.py new file mode 100644 index 0000000..dc9dfb1 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/run_sft.py @@ -0,0 +1,317 @@ +from dataclasses import dataclass +from datetime import datetime +from distutils.util import strtobool +import logging +import os +import re +from typing import Optional + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + set_seed, + BitsAndBytesConfig, +) +from transformers.trainer_utils import get_last_checkpoint +#from transformers.utils import is_liger_kernel_available +from trl import SFTTrainer, TrlParser, ModelConfig, SFTConfig, get_peft_config +from trl import setup_chat_format +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM + + +#if is_liger_kernel_available(): +# from liger_kernel.transformers import AutoLigerKernelForCausalLM + + +######################## +# Custom dataclasses +######################## +@dataclass +class ScriptArguments: + dataset_id_or_path: str + dataset_splits: str = "train" + tokenizer_name_or_path: str = "/opt/ml/model" + spectrum_config_path: Optional[str] = None + model_download_location: str = "/opt/ml/model" + dataset_local_location: str = "/opt/ml/input/training_data" + + +######################## +# Setup logging +######################## +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter( + logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +) +logger.addHandler(handler) + +######################## +# Helper functions +######################## + + +def get_checkpoint(training_args: SFTConfig): + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + return last_checkpoint + + +def setup_model_for_spectrum(model, spectrum_config_path): + unfrozen_parameters = [] + with open(spectrum_config_path, "r") as fin: + yaml_parameters = fin.read() + + # get the unfrozen parameters from the yaml file + for line in yaml_parameters.splitlines(): + if line.startswith("- "): + unfrozen_parameters.append(line.split("- ")[1]) + + # freeze all parameters + for param in model.parameters(): + param.requires_grad = False + # unfreeze Spectrum parameters + for name, param in model.named_parameters(): + if any( + re.match(unfrozen_param, name) + for unfrozen_param in unfrozen_parameters + ): + param.requires_grad = True + + # COMMENT IN: for sanity check print the trainable parameters + # for name, param in model.named_parameters(): + # if param.requires_grad: + # print(f"Trainable parameter: {name}") + + return model + + +########################################################################################################### + + +def train_function( + model_args: ModelConfig, + script_args: ScriptArguments, + training_args: SFTConfig, +): + """Main training function.""" + ######################### + # Log parameters + ######################### + logger.info(f'Model parameters {model_args}') + logger.info(f'Script parameters {script_args}') + logger.info(f'Training/evaluation parameters {training_args}') + + ############### + # Load datasets + ############### + if script_args.dataset_id_or_path.endswith('.json'): + train_dataset = load_dataset( + 'json', data_files=script_args.dataset_local_location, split=script_args.dataset_splits + ) + else: + train_dataset = load_dataset( + script_args.dataset_local_location, split=script_args.dataset_splits + ) + + logger.info( + f'Loaded dataset with {len(train_dataset)} samples and the following features: {train_dataset.features}' + ) + + ################ + # Load tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + script_args.model_download_location, + trust_remote_code=model_args.trust_remote_code, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + # if we use peft we need to make sure we use a chat template that is not using special tokens as by default embedding layers will not be trainable + + ####################### + # Load pretrained model + ####################### + + # define model kwargs + model_kwargs = dict( + revision=model_args.model_revision, # What revision from Huggingface to use, defaults to main + trust_remote_code=model_args.trust_remote_code, # Whether to trust the remote code, this also you to fine-tune custom architectures + attn_implementation=model_args.attn_implementation, # What attention implementation to use, defaults to flash_attention_2 + torch_dtype=( + model_args.torch_dtype + if model_args.torch_dtype in ['auto', None] + else getattr(torch, model_args.torch_dtype) + ), # What torch dtype to use, defaults to auto + # use_cache=( + # False if training_args.gradient_checkpointing else True + # ), # Whether + low_cpu_mem_usage=( + True + if not strtobool( + os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") + ) + else None + ), # Reduces memory usage on CPU for loading the model + ) + + + # Check which training method to use and if 4-bit quantization is needed + if model_args.load_in_4bit: + model_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4', + bnb_4bit_compute_dtype=model_kwargs['torch_dtype'], + bnb_4bit_quant_storage=model_kwargs['torch_dtype'], + ) + + if model_args.use_peft: + peft_config = get_peft_config(model_args) + else: + peft_config = None + + # load the model with our kwargs + # if training_args.use_liger_kernel: + # model = AutoLigerKernelForCausalLM.from_pretrained( + # script_args.model_download_location, **model_kwargs + # ) + + # else: + model = AutoModelForCausalLM.from_pretrained( + script_args.model_download_location, **model_kwargs + ) + + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + tokenizer.chat_template = None # Reset the chat template + # # set chat template to OAI chatML, remove if you start from a fine-tuned model + model, tokenizer = setup_chat_format(model, tokenizer) + + training_args.distributed_state.wait_for_everyone() # wait for all processes to load + + if script_args.spectrum_config_path: + model = setup_model_for_spectrum( + model, script_args.spectrum_config_path + ) + + ######################## + # Initialize the Trainer + ######################## + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, + peft_config=peft_config, + # max_seq_length=training_args.max_seq_length, + # packing=training_args.packing, + # dataset_kwargs={ + # "add_special_tokens": False, # We template with special tokens + # "append_concat_token": False, # No need to add additional separator token + # } + ) + if trainer.accelerator.is_main_process and peft_config: + trainer.model.print_trainable_parameters() + + ############### + # Training loop + ############### + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if ( + last_checkpoint is not None + and training_args.resume_from_checkpoint is None + ): + logger.info( + f'Checkpoint detected, resuming training at {last_checkpoint}.' + ) + + logger.info( + f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***' + ) + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + # log metrics + metrics = train_result.metrics + metrics['train_samples'] = len(train_dataset) + trainer.log_metrics('train', metrics) + trainer.save_metrics('train', metrics) + trainer.save_state() + + ################################## + # Save model and create model card + ################################## + + logger.info('*** Save model ***') + if trainer.is_fsdp_enabled and peft_config: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type( + 'FULL_STATE_DICT' + ) + + if model_args.use_peft: + logger.info("Merge weights is set to True, will fuse the adapter to the base model") + + tmp_output_dir = "/tmp/model/output" + + # merge adapter weights with base model and save + # save int 4 model + logger.info(f"Temporarily saving adapter to {tmp_output_dir}") + trainer.model.config.use_cache = True + trainer.save_model(tmp_output_dir) + + if trainer.accelerator.is_main_process: + # clear memory + del model + del trainer + + torch.cuda.empty_cache() + + logger.info("Loading PEFT model") + # load PEFT model + model = AutoPeftModelForCausalLM.from_pretrained( + tmp_output_dir, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + + logger.info("Fusing adapter to base model") + # Merge LoRA and base model and save + model = model.merge_and_unload() + logger.info("Saving fused model") + model.save_pretrained( + training_args.output_dir, + safe_serialization=True + ) + else: + logger.info(f"Saving model to {training_args.output_dir}") + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.save_model(training_args.output_dir) + + training_args.distributed_state.wait_for_everyone() # wait for all processes to load + + tokenizer.save_pretrained(training_args.output_dir) + logger.info(f'Tokenizer saved to {training_args.output_dir}') + + logger.info('*** Training complete! ***') + + +def main(): + parser = TrlParser((ModelConfig, ScriptArguments, SFTConfig)) + model_args, script_args, training_args = parser.parse_args_and_config() + + # Set seed for reproducibility + set_seed(training_args.seed) + + # Run the main training loop + train_function(model_args, script_args, training_args) + + +if __name__ == '__main__': + main() diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/run_training_dpo.sh b/3_distributed_training/function-calling-sft-dpo/scripts/run_training_dpo.sh new file mode 100755 index 0000000..f946f3a --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/run_training_dpo.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +export ACCELERATE_CONFIG="./accelerate_configs/deepspeed_zero3.yaml" + +echo "using ACCELERATE_CONFIG: $ACCELERATE_CONFIG" +echo "using Training Recipe: $training_recipe" + +pip install yq + +export MODEL_ID_OR_LOCATION=$(yq -r ".model_name_or_path" $training_recipe) +export MODEL_DOWNLOAD_LOCATION=$(yq -r ".model_download_location" $training_recipe) +export DATASET_LOCAL_LOCATION=$(yq -r ".dataset_local_location" $training_recipe) + +if [ -n "$HF_token" ]; then + huggingface-cli login --token ${HF_token} +fi + +#model_location env variable overrides MODEL_ID_OR_LOCATION +tmp_model_to_use="" + +if [ -n "$model_location" ]; then + tmp_model_to_use=$model_location +else + tmp_model_to_use=$MODEL_ID_OR_LOCATION +fi + +tmp_model_location="/opt/ml/tmp" + +mkdir -p $MODEL_DOWNLOAD_LOCATION + +# Check if the string ends with the suffix +if [[ "$tmp_model_to_use" == *"s3:"* ]]; then + echo "The model is an S3 location, downloading from '$tmp_model_to_use'" + + # Check if the string ends with the suffix + if [[ "$tmp_model_to_use" == *".tar.gz" ]]; then + echo "The model location '$tmp_model_to_use' ends with '.tar.gz'. Need to unpack." + mkdir -p $tmp_model_location + aws s3 cp $tmp_model_to_use $tmp_model_location + tar -xzvf "$tmp_model_location/model.tar.gz" -C $MODEL_DOWNLOAD_LOCATION + else + echo "The model location '$tmp_model_to_use' looks to be unpacked, copying directly." + aws s3 cp $tmp_model_to_use $tmp_model_location --recursive + fi +else + echo "The model does not look to be an an S3 location, downloading '$tmp_model_to_use' from HuggingFace" + huggingface-cli download $tmp_model_to_use --local-dir $MODEL_DOWNLOAD_LOCATION +fi + +echo "Copying Data from '$data_location' to '$DATASET_LOCAL_LOCATION'" +aws s3 cp $data_location $DATASET_LOCAL_LOCATION + +NUM_GPUS=$(nvidia-smi --list-gpus | wc -l) +echo "Detected ${NUM_GPUS} GPUs on the machine" + +accelerate launch --config_file $ACCELERATE_CONFIG --num_processes ${NUM_GPUS} run_dpo.py --config $training_recipe \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/run_training_sft.sh b/3_distributed_training/function-calling-sft-dpo/scripts/run_training_sft.sh new file mode 100755 index 0000000..35d4071 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/run_training_sft.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +export ACCELERATE_CONFIG="./accelerate_configs/deepspeed_zero3.yaml" + +# export training_recipe="./recipes/sft-spectrum-Qwen3-1.7B.yaml" +# export data_location="s3://sagemaker-us-east-1-340043819279/datasets/nvidia_function_calling/train/dataset.json" + +echo "using ACCELERATE_CONFIG: $ACCELERATE_CONFIG" +echo "using Training Recipe: $training_recipe" + +pip install yq + +export MODEL_ID_OR_LOCATION=$(yq -r ".model_name_or_path" $training_recipe) +export MODEL_DOWNLOAD_LOCATION=$(yq -r ".model_download_location" $training_recipe) +export DATASET_LOCAL_LOCATION=$(yq -r ".dataset_local_location" $training_recipe) + +if [ -n "${HF_token}" ]; then + huggingface-cli login --token ${HF_token} +fi + +#model_location env variable overrides MODEL_ID_OR_LOCATION +tmp_model_to_use="" + +if [ -n "$model_location" ]; then + tmp_model_to_use=$model_location +else + tmp_model_to_use=$MODEL_ID_OR_LOCATION +fi + +tmp_model_location="/opt/ml/tmp" + +mkdir -p $MODEL_DOWNLOAD_LOCATION + +# Check if the string ends with the suffix +if [[ "$tmp_model_to_use" == "s3:"* ]]; then + echo "The model is an S3 location, downloading from '$tmp_model_to_use'" + + # Check if the string ends with the suffix + if [[ "$tmp_model_to_use" == *".tar.gz" ]]; then + echo "The model location '$tmp_model_to_use' ends with '.tar.gz'. Need to unpack." + mkdir -p $tmp_model_location + aws s3 cp $tmp_model_to_use $tmp_model_location + tar -xzvf "$tmp_model_location/model.tar.gz" -C $MODEL_DOWNLOAD_LOCATION + else + echo "The model location '$tmp_model_to_use' looks to be unpacked, copying directly." + aws s3 cp $tmp_model_to_use $tmp_model_location --recursive + fi +else + echo "The model does not look to be an an S3 location, downloading '$tmp_model_to_use' from HuggingFace" + huggingface-cli download $tmp_model_to_use --local-dir $MODEL_DOWNLOAD_LOCATION +fi + +aws s3 cp $data_location $DATASET_LOCAL_LOCATION + +NUM_GPUS=$(nvidia-smi --list-gpus | wc -l) +echo "Detected ${NUM_GPUS} GPUs on the machine" + +accelerate launch --config_file $ACCELERATE_CONFIG --num_processes ${NUM_GPUS} run_sft.py --config $training_recipe \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/run-spect.sh b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/run-spect.sh new file mode 100755 index 0000000..547ecf8 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/run-spect.sh @@ -0,0 +1,7 @@ +git clone https://github.com/cognitivecomputations/spectrum.git +cd spectrum +# generate yaml configuration +#python3 spectrum.py --model-name meta-llama/Meta-Llama-3.1-8B --top-percent 30 +python3 spectrum.py --model-name Qwen/Qwen3-1.7B --top-percent 10 +# Top 30% SNR layers saved to snr_results_meta-llama-Meta-Llama-3.1-8B_unfrozenparameters_30percent.yaml +cd .. \ No newline at end of file diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-0.6B_unfrozenparameters_50percent.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-0.6B_unfrozenparameters_50percent.yaml new file mode 100644 index 0000000..41b7abe --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-0.6B_unfrozenparameters_50percent.yaml @@ -0,0 +1,171 @@ +unfrozen_parameters: +- ^lm_head.weight$ +- ^model.embed_tokens.weight$ +# input_layernorm layers +- model.layers.0.input_layernorm +- model.layers.1.input_layernorm +- model.layers.2.input_layernorm +- model.layers.3.input_layernorm +- model.layers.4.input_layernorm +- model.layers.5.input_layernorm +- model.layers.6.input_layernorm +- model.layers.7.input_layernorm +- model.layers.8.input_layernorm +- model.layers.9.input_layernorm +- model.layers.10.input_layernorm +- model.layers.11.input_layernorm +- model.layers.12.input_layernorm +- model.layers.13.input_layernorm +# lm_head layers +# mlp.down_proj layers +- model.layers.0.mlp.down_proj +- model.layers.1.mlp.down_proj +- model.layers.2.mlp.down_proj +- model.layers.3.mlp.down_proj +- model.layers.4.mlp.down_proj +- model.layers.5.mlp.down_proj +- model.layers.6.mlp.down_proj +- model.layers.7.mlp.down_proj +- model.layers.8.mlp.down_proj +- model.layers.9.mlp.down_proj +- model.layers.10.mlp.down_proj +- model.layers.11.mlp.down_proj +- model.layers.12.mlp.down_proj +- model.layers.13.mlp.down_proj +# mlp.gate_proj layers +- model.layers.0.mlp.gate_proj +- model.layers.1.mlp.gate_proj +- model.layers.2.mlp.gate_proj +- model.layers.3.mlp.gate_proj +- model.layers.4.mlp.gate_proj +- model.layers.5.mlp.gate_proj +- model.layers.6.mlp.gate_proj +- model.layers.7.mlp.gate_proj +- model.layers.8.mlp.gate_proj +- model.layers.9.mlp.gate_proj +- model.layers.10.mlp.gate_proj +- model.layers.11.mlp.gate_proj +- model.layers.12.mlp.gate_proj +- model.layers.13.mlp.gate_proj +# mlp.up_proj layers +- model.layers.0.mlp.up_proj +- model.layers.1.mlp.up_proj +- model.layers.2.mlp.up_proj +- model.layers.3.mlp.up_proj +- model.layers.4.mlp.up_proj +- model.layers.5.mlp.up_proj +- model.layers.6.mlp.up_proj +- model.layers.7.mlp.up_proj +- model.layers.8.mlp.up_proj +- model.layers.9.mlp.up_proj +- model.layers.10.mlp.up_proj +- model.layers.11.mlp.up_proj +- model.layers.12.mlp.up_proj +- model.layers.13.mlp.up_proj +# model.embed_tokens layers +# model.norm layers +# post_attention_layernorm layers +- model.layers.0.post_attention_layernorm +- model.layers.1.post_attention_layernorm +- model.layers.2.post_attention_layernorm +- model.layers.3.post_attention_layernorm +- model.layers.4.post_attention_layernorm +- model.layers.5.post_attention_layernorm +- model.layers.6.post_attention_layernorm +- model.layers.7.post_attention_layernorm +- model.layers.8.post_attention_layernorm +- model.layers.9.post_attention_layernorm +- model.layers.10.post_attention_layernorm +- model.layers.11.post_attention_layernorm +- model.layers.12.post_attention_layernorm +- model.layers.13.post_attention_layernorm +# self_attn.k_norm layers +- model.layers.0.self_attn.k_norm +- model.layers.1.self_attn.k_norm +- model.layers.2.self_attn.k_norm +- model.layers.3.self_attn.k_norm +- model.layers.4.self_attn.k_norm +- model.layers.5.self_attn.k_norm +- model.layers.6.self_attn.k_norm +- model.layers.7.self_attn.k_norm +- model.layers.8.self_attn.k_norm +- model.layers.9.self_attn.k_norm +- model.layers.10.self_attn.k_norm +- model.layers.11.self_attn.k_norm +- model.layers.12.self_attn.k_norm +- model.layers.13.self_attn.k_norm +# self_attn.k_proj layers +- model.layers.0.self_attn.k_proj +- model.layers.1.self_attn.k_proj +- model.layers.2.self_attn.k_proj +- model.layers.3.self_attn.k_proj +- model.layers.4.self_attn.k_proj +- model.layers.5.self_attn.k_proj +- model.layers.6.self_attn.k_proj +- model.layers.7.self_attn.k_proj +- model.layers.8.self_attn.k_proj +- model.layers.9.self_attn.k_proj +- model.layers.10.self_attn.k_proj +- model.layers.11.self_attn.k_proj +- model.layers.12.self_attn.k_proj +- model.layers.13.self_attn.k_proj +# self_attn.o_proj layers +- model.layers.0.self_attn.o_proj +- model.layers.1.self_attn.o_proj +- model.layers.2.self_attn.o_proj +- model.layers.3.self_attn.o_proj +- model.layers.4.self_attn.o_proj +- model.layers.5.self_attn.o_proj +- model.layers.6.self_attn.o_proj +- model.layers.7.self_attn.o_proj +- model.layers.8.self_attn.o_proj +- model.layers.9.self_attn.o_proj +- model.layers.10.self_attn.o_proj +- model.layers.11.self_attn.o_proj +- model.layers.12.self_attn.o_proj +- model.layers.13.self_attn.o_proj +# self_attn.q_norm layers +- model.layers.0.self_attn.q_norm +- model.layers.1.self_attn.q_norm +- model.layers.2.self_attn.q_norm +- model.layers.3.self_attn.q_norm +- model.layers.4.self_attn.q_norm +- model.layers.5.self_attn.q_norm +- model.layers.6.self_attn.q_norm +- model.layers.7.self_attn.q_norm +- model.layers.8.self_attn.q_norm +- model.layers.9.self_attn.q_norm +- model.layers.10.self_attn.q_norm +- model.layers.11.self_attn.q_norm +- model.layers.12.self_attn.q_norm +- model.layers.13.self_attn.q_norm +# self_attn.q_proj layers +- model.layers.0.self_attn.q_proj +- model.layers.1.self_attn.q_proj +- model.layers.2.self_attn.q_proj +- model.layers.3.self_attn.q_proj +- model.layers.4.self_attn.q_proj +- model.layers.5.self_attn.q_proj +- model.layers.6.self_attn.q_proj +- model.layers.7.self_attn.q_proj +- model.layers.8.self_attn.q_proj +- model.layers.9.self_attn.q_proj +- model.layers.10.self_attn.q_proj +- model.layers.11.self_attn.q_proj +- model.layers.12.self_attn.q_proj +- model.layers.13.self_attn.q_proj +# self_attn.v_proj layers +- model.layers.0.self_attn.v_proj +- model.layers.1.self_attn.v_proj +- model.layers.2.self_attn.v_proj +- model.layers.3.self_attn.v_proj +- model.layers.4.self_attn.v_proj +- model.layers.5.self_attn.v_proj +- model.layers.6.self_attn.v_proj +- model.layers.7.self_attn.v_proj +- model.layers.8.self_attn.v_proj +- model.layers.9.self_attn.v_proj +- model.layers.10.self_attn.v_proj +- model.layers.11.self_attn.v_proj +- model.layers.12.self_attn.v_proj +- model.layers.13.self_attn.v_proj diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-1.7B_unfrozenparameters_50percent.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-1.7B_unfrozenparameters_50percent.yaml new file mode 100644 index 0000000..152c85e --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_Qwen-Qwen3-1.7B_unfrozenparameters_50percent.yaml @@ -0,0 +1,171 @@ +unfrozen_parameters: +- ^lm_head.weight$ +- ^model.embed_tokens.weight$ +# input_layernorm layers +- model.layers.0.input_layernorm +- model.layers.1.input_layernorm +- model.layers.2.input_layernorm +- model.layers.3.input_layernorm +- model.layers.4.input_layernorm +- model.layers.5.input_layernorm +- model.layers.6.input_layernorm +- model.layers.7.input_layernorm +- model.layers.8.input_layernorm +- model.layers.9.input_layernorm +- model.layers.10.input_layernorm +- model.layers.11.input_layernorm +- model.layers.12.input_layernorm +- model.layers.13.input_layernorm +# lm_head layers +# mlp.down_proj layers +- model.layers.0.mlp.down_proj +- model.layers.2.mlp.down_proj +- model.layers.3.mlp.down_proj +- model.layers.4.mlp.down_proj +- model.layers.5.mlp.down_proj +- model.layers.27.mlp.down_proj +- model.layers.26.mlp.down_proj +- model.layers.6.mlp.down_proj +- model.layers.25.mlp.down_proj +- model.layers.21.mlp.down_proj +- model.layers.1.mlp.down_proj +- model.layers.7.mlp.down_proj +- model.layers.16.mlp.down_proj +- model.layers.22.mlp.down_proj +# mlp.gate_proj layers +- model.layers.27.mlp.gate_proj +- model.layers.26.mlp.gate_proj +- model.layers.5.mlp.gate_proj +- model.layers.25.mlp.gate_proj +- model.layers.3.mlp.gate_proj +- model.layers.2.mlp.gate_proj +- model.layers.6.mlp.gate_proj +- model.layers.24.mlp.gate_proj +- model.layers.15.mlp.gate_proj +- model.layers.18.mlp.gate_proj +- model.layers.14.mlp.gate_proj +- model.layers.17.mlp.gate_proj +- model.layers.12.mlp.gate_proj +- model.layers.13.mlp.gate_proj +# mlp.up_proj layers +- model.layers.3.mlp.up_proj +- model.layers.6.mlp.up_proj +- model.layers.5.mlp.up_proj +- model.layers.7.mlp.up_proj +- model.layers.2.mlp.up_proj +- model.layers.4.mlp.up_proj +- model.layers.8.mlp.up_proj +- model.layers.1.mlp.up_proj +- model.layers.9.mlp.up_proj +- model.layers.12.mlp.up_proj +- model.layers.10.mlp.up_proj +- model.layers.14.mlp.up_proj +- model.layers.13.mlp.up_proj +- model.layers.11.mlp.up_proj +# model.embed_tokens layers +# model.norm layers +# post_attention_layernorm layers +- model.layers.0.post_attention_layernorm +- model.layers.1.post_attention_layernorm +- model.layers.2.post_attention_layernorm +- model.layers.3.post_attention_layernorm +- model.layers.4.post_attention_layernorm +- model.layers.5.post_attention_layernorm +- model.layers.6.post_attention_layernorm +- model.layers.7.post_attention_layernorm +- model.layers.8.post_attention_layernorm +- model.layers.9.post_attention_layernorm +- model.layers.10.post_attention_layernorm +- model.layers.11.post_attention_layernorm +- model.layers.12.post_attention_layernorm +- model.layers.13.post_attention_layernorm +# self_attn.k_norm layers +- model.layers.0.self_attn.k_norm +- model.layers.1.self_attn.k_norm +- model.layers.2.self_attn.k_norm +- model.layers.3.self_attn.k_norm +- model.layers.4.self_attn.k_norm +- model.layers.5.self_attn.k_norm +- model.layers.6.self_attn.k_norm +- model.layers.7.self_attn.k_norm +- model.layers.8.self_attn.k_norm +- model.layers.9.self_attn.k_norm +- model.layers.10.self_attn.k_norm +- model.layers.11.self_attn.k_norm +- model.layers.12.self_attn.k_norm +- model.layers.13.self_attn.k_norm +# self_attn.k_proj layers +- model.layers.27.self_attn.k_proj +- model.layers.23.self_attn.k_proj +- model.layers.3.self_attn.k_proj +- model.layers.9.self_attn.k_proj +- model.layers.18.self_attn.k_proj +- model.layers.19.self_attn.k_proj +- model.layers.7.self_attn.k_proj +- model.layers.5.self_attn.k_proj +- model.layers.1.self_attn.k_proj +- model.layers.22.self_attn.k_proj +- model.layers.17.self_attn.k_proj +- model.layers.20.self_attn.k_proj +- model.layers.15.self_attn.k_proj +- model.layers.21.self_attn.k_proj +# self_attn.o_proj layers +- model.layers.0.self_attn.o_proj +- model.layers.19.self_attn.o_proj +- model.layers.14.self_attn.o_proj +- model.layers.18.self_attn.o_proj +- model.layers.17.self_attn.o_proj +- model.layers.15.self_attn.o_proj +- model.layers.1.self_attn.o_proj +- model.layers.16.self_attn.o_proj +- model.layers.13.self_attn.o_proj +- model.layers.12.self_attn.o_proj +- model.layers.20.self_attn.o_proj +- model.layers.7.self_attn.o_proj +- model.layers.5.self_attn.o_proj +- model.layers.11.self_attn.o_proj +# self_attn.q_norm layers +- model.layers.0.self_attn.q_norm +- model.layers.1.self_attn.q_norm +- model.layers.2.self_attn.q_norm +- model.layers.3.self_attn.q_norm +- model.layers.4.self_attn.q_norm +- model.layers.5.self_attn.q_norm +- model.layers.6.self_attn.q_norm +- model.layers.7.self_attn.q_norm +- model.layers.8.self_attn.q_norm +- model.layers.9.self_attn.q_norm +- model.layers.10.self_attn.q_norm +- model.layers.11.self_attn.q_norm +- model.layers.12.self_attn.q_norm +- model.layers.13.self_attn.q_norm +# self_attn.q_proj layers +- model.layers.15.self_attn.q_proj +- model.layers.14.self_attn.q_proj +- model.layers.22.self_attn.q_proj +- model.layers.16.self_attn.q_proj +- model.layers.7.self_attn.q_proj +- model.layers.18.self_attn.q_proj +- model.layers.19.self_attn.q_proj +- model.layers.10.self_attn.q_proj +- model.layers.21.self_attn.q_proj +- model.layers.17.self_attn.q_proj +- model.layers.4.self_attn.q_proj +- model.layers.13.self_attn.q_proj +- model.layers.23.self_attn.q_proj +- model.layers.9.self_attn.q_proj +# self_attn.v_proj layers +- model.layers.0.self_attn.v_proj +- model.layers.27.self_attn.v_proj +- model.layers.4.self_attn.v_proj +- model.layers.18.self_attn.v_proj +- model.layers.5.self_attn.v_proj +- model.layers.3.self_attn.v_proj +- model.layers.1.self_attn.v_proj +- model.layers.7.self_attn.v_proj +- model.layers.17.self_attn.v_proj +- model.layers.6.self_attn.v_proj +- model.layers.23.self_attn.v_proj +- model.layers.15.self_attn.v_proj +- model.layers.19.self_attn.v_proj +- model.layers.9.self_attn.v_proj diff --git a/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_meta-llama-Llama-3.2-3B-Instruct_unfrozenparameters_50percent.yaml b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_meta-llama-Llama-3.2-3B-Instruct_unfrozenparameters_50percent.yaml new file mode 100644 index 0000000..b3ee058 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/scripts/spectrum-layer/snr_results_meta-llama-Llama-3.2-3B-Instruct_unfrozenparameters_50percent.yaml @@ -0,0 +1,141 @@ +unfrozen_parameters: +- ^lm_head.weight$ +- ^model.embed_tokens.weight$ +# input_layernorm layers +- model.layers.0.input_layernorm +- model.layers.1.input_layernorm +- model.layers.2.input_layernorm +- model.layers.3.input_layernorm +- model.layers.4.input_layernorm +- model.layers.5.input_layernorm +- model.layers.6.input_layernorm +- model.layers.7.input_layernorm +- model.layers.8.input_layernorm +- model.layers.9.input_layernorm +- model.layers.10.input_layernorm +- model.layers.11.input_layernorm +- model.layers.12.input_layernorm +- model.layers.13.input_layernorm +# lm_head layers +# mlp.down_proj layers +- model.layers.0.mlp.down_proj +- model.layers.1.mlp.down_proj +- model.layers.17.mlp.down_proj +- model.layers.19.mlp.down_proj +- model.layers.18.mlp.down_proj +- model.layers.20.mlp.down_proj +- model.layers.5.mlp.down_proj +- model.layers.4.mlp.down_proj +- model.layers.2.mlp.down_proj +- model.layers.6.mlp.down_proj +- model.layers.3.mlp.down_proj +- model.layers.16.mlp.down_proj +- model.layers.15.mlp.down_proj +- model.layers.13.mlp.down_proj +# mlp.gate_proj layers +- model.layers.0.mlp.gate_proj +- model.layers.1.mlp.gate_proj +- model.layers.2.mlp.gate_proj +- model.layers.3.mlp.gate_proj +- model.layers.22.mlp.gate_proj +- model.layers.21.mlp.gate_proj +- model.layers.20.mlp.gate_proj +- model.layers.23.mlp.gate_proj +- model.layers.19.mlp.gate_proj +- model.layers.4.mlp.gate_proj +- model.layers.18.mlp.gate_proj +- model.layers.5.mlp.gate_proj +- model.layers.17.mlp.gate_proj +- model.layers.24.mlp.gate_proj +# mlp.up_proj layers +- model.layers.4.mlp.up_proj +- model.layers.3.mlp.up_proj +- model.layers.5.mlp.up_proj +- model.layers.6.mlp.up_proj +- model.layers.7.mlp.up_proj +- model.layers.2.mlp.up_proj +- model.layers.8.mlp.up_proj +- model.layers.14.mlp.up_proj +- model.layers.13.mlp.up_proj +- model.layers.11.mlp.up_proj +- model.layers.9.mlp.up_proj +- model.layers.1.mlp.up_proj +- model.layers.15.mlp.up_proj +- model.layers.12.mlp.up_proj +# model.embed_tokens layers +# model.norm layers +# post_attention_layernorm layers +- model.layers.0.post_attention_layernorm +- model.layers.1.post_attention_layernorm +- model.layers.2.post_attention_layernorm +- model.layers.3.post_attention_layernorm +- model.layers.4.post_attention_layernorm +- model.layers.5.post_attention_layernorm +- model.layers.6.post_attention_layernorm +- model.layers.7.post_attention_layernorm +- model.layers.8.post_attention_layernorm +- model.layers.9.post_attention_layernorm +- model.layers.10.post_attention_layernorm +- model.layers.11.post_attention_layernorm +- model.layers.12.post_attention_layernorm +- model.layers.13.post_attention_layernorm +# self_attn.k_proj layers +- model.layers.25.self_attn.k_proj +- model.layers.22.self_attn.k_proj +- model.layers.19.self_attn.k_proj +- model.layers.24.self_attn.k_proj +- model.layers.20.self_attn.k_proj +- model.layers.17.self_attn.k_proj +- model.layers.23.self_attn.k_proj +- model.layers.18.self_attn.k_proj +- model.layers.21.self_attn.k_proj +- model.layers.27.self_attn.k_proj +- model.layers.10.self_attn.k_proj +- model.layers.15.self_attn.k_proj +- model.layers.26.self_attn.k_proj +- model.layers.16.self_attn.k_proj +# self_attn.o_proj layers +- model.layers.13.self_attn.o_proj +- model.layers.7.self_attn.o_proj +- model.layers.12.self_attn.o_proj +- model.layers.5.self_attn.o_proj +- model.layers.21.self_attn.o_proj +- model.layers.10.self_attn.o_proj +- model.layers.6.self_attn.o_proj +- model.layers.19.self_attn.o_proj +- model.layers.8.self_attn.o_proj +- model.layers.20.self_attn.o_proj +- model.layers.22.self_attn.o_proj +- model.layers.9.self_attn.o_proj +- model.layers.17.self_attn.o_proj +- model.layers.11.self_attn.o_proj +# self_attn.q_proj layers +- model.layers.12.self_attn.q_proj +- model.layers.13.self_attn.q_proj +- model.layers.9.self_attn.q_proj +- model.layers.8.self_attn.q_proj +- model.layers.10.self_attn.q_proj +- model.layers.14.self_attn.q_proj +- model.layers.11.self_attn.q_proj +- model.layers.15.self_attn.q_proj +- model.layers.26.self_attn.q_proj +- model.layers.6.self_attn.q_proj +- model.layers.25.self_attn.q_proj +- model.layers.16.self_attn.q_proj +- model.layers.5.self_attn.q_proj +- model.layers.7.self_attn.q_proj +# self_attn.v_proj layers +- model.layers.23.self_attn.v_proj +- model.layers.14.self_attn.v_proj +- model.layers.15.self_attn.v_proj +- model.layers.19.self_attn.v_proj +- model.layers.3.self_attn.v_proj +- model.layers.18.self_attn.v_proj +- model.layers.25.self_attn.v_proj +- model.layers.4.self_attn.v_proj +- model.layers.17.self_attn.v_proj +- model.layers.20.self_attn.v_proj +- model.layers.22.self_attn.v_proj +- model.layers.13.self_attn.v_proj +- model.layers.5.self_attn.v_proj +- model.layers.27.self_attn.v_proj diff --git a/3_distributed_training/function-calling-sft-dpo/utils.py b/3_distributed_training/function-calling-sft-dpo/utils.py new file mode 100644 index 0000000..9afd196 --- /dev/null +++ b/3_distributed_training/function-calling-sft-dpo/utils.py @@ -0,0 +1,54 @@ +import boto3 + +def get_last_job_name(job_name_prefix): + sagemaker_client = boto3.client('sagemaker') + + matching_jobs = [] + next_token = None + + while True: + # Prepare the search parameters + search_params = { + 'Resource': 'TrainingJob', + 'SearchExpression': { + 'Filters': [ + { + 'Name': 'TrainingJobName', + 'Operator': 'Contains', + 'Value': job_name_prefix + }, + { + 'Name': 'TrainingJobStatus', + 'Operator': 'Equals', + 'Value': "Completed" + } + ] + }, + 'SortBy': 'CreationTime', + 'SortOrder': 'Descending', + 'MaxResults': 100 + } + + # Add NextToken if we have one + if next_token: + search_params['NextToken'] = next_token + + # Make the search request + search_response = sagemaker_client.search(**search_params) + + # Filter and add matching jobs + matching_jobs.extend([ + job['TrainingJob']['TrainingJobName'] + for job in search_response['Results'] + if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix) + ]) + + # Check if we have more results to fetch + next_token = search_response.get('NextToken') + if not next_token or matching_jobs: # Stop if we found at least one match or no more results + break + + if not matching_jobs: + raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'") + + return matching_jobs[0] \ No newline at end of file