{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "aCl-IzLoDr2H" }, "outputs": [], "source": [ "!pip install -U transformers mamba-ssm" ] }, { "cell_type": "markdown", "source": [ "# Load Models" ], "metadata": { "id": "SpRo_KJIRsxv" } }, { "cell_type": "code", "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "# Load tokenizer and model\n", "tokenizer = AutoTokenizer.from_pretrained(\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\")\n", "model = AutoModelForCausalLM.from_pretrained(\n", " \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n", " torch_dtype=torch.bfloat16,\n", " trust_remote_code=True,\n", " device_map=\"auto\"\n", ")\n" ], "metadata": { "id": "waveliieEI1n" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Define Input with Tools" ], "metadata": { "id": "xjVkqaSdRx0_" } }, { "cell_type": "code", "source": [ "from transformers.utils import get_json_schema\n", "\n", "def multiply(a: float, b: float):\n", " \"\"\"\n", " A function that multiplies two numbers\n", "\n", " Args:\n", " a: The first number to multiply\n", " b: The second number to multiply\n", " \"\"\"\n", " return a * b\n", "\n", "messages = [\n", " {\"role\": \"user\", \"content\": \"what is 2.0909090923 x 0.897987987\"},\n", "]\n", "\n", "tokenized_chat = tokenizer.apply_chat_template(\n", " messages,\n", " tools=[\n", " multiply\n", " ],\n", " tokenize=True,\n", " add_generation_prompt=True,\n", " return_tensors=\"pt\"\n", ").to(model.device)\n" ], "metadata": { "id": "zxZZ7iMZETsw" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Inference" ], "metadata": { "id": "SVBAG3dLRw4v" } }, { "cell_type": "code", "source": [ "outputs = model.generate(\n", " tokenized_chat,\n", " max_new_tokens=1024,\n", " temperature=1.0,\n", " top_p=1.0,\n", " eos_token_id=tokenizer.eos_token_id\n", ")\n", "print(tokenizer.decode(outputs[0]))" ], "metadata": { "id": "BKYqPT5ORDx3" }, "execution_count": null, "outputs": [] } ] }