~/Projects/ChatGLM3
git clone https://code.lsong.org/ChatGLM3
Commit
- Commit
- 25728bdd6b8bd7e6ebe3cccae66f613a4ca99866
- Author
- zR <[email protected]>
- Date
- 2023-11-23 21:39:07 +0800 +0800
- Diffstat
cookbook/accurate_prompt.ipynb | 334 ++++++++++++++++++++++++++++++++++++
Create accurate_prompt.ipynb
diff --git a/cookbook/accurate_prompt.ipynb b/cookbook/accurate_prompt.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d7a59c5cfa1309e6fe80c0de51629fa15692d16a --- /dev/null +++ b/cookbook/accurate_prompt.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 使用精确的提示词完成符合格式输出的文本分类任务\n", + "\n", + "在本章节中,我们将带领开发者体验在 `文本分类` 任务中,在不进行任何模型微调的情况下,使用具有微小差异的提示词,并使用 `ChatGLM3-6B` 模型来完成符合格式输出的文本分类任务。\n", + "我们使用 [新闻标题分类](https://github.com/fateleak/toutiao-multilevel-text-classfication-dataset) 任务来体验模型的表现。这是一个经典的文本分类任务,我们将使用 `新闻信息` 作为输入,模型需要预测出这个标题属于哪个类别。\n", + "由于 `ChatGLM3-6B` 强大的能力,我们可以直接使用 `新闻标题` 作为输入,而不需要额外的信息,也不需要进行任何模型微调,就能完成这个任务。\n", + "我们的目标是,让模型能够成功输出原始数据集中的15种类别中的一种作为分类结果,而且不能输入任何冗余的对话。\n", + "在本章节中,用户将直观的对比两种不同细粒度的提示词下对模型分类造成的影响。\n", + "## 硬件要求\n", + "本实践手册需要使用 FP16 精度的模型进行推理,因此,我们推荐使用至少 16GB 显存的 英伟达 GPU 来完成本实践手册。" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "引入必要的库" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [], + "source": [ + "from transformers import AutoModel, AutoTokenizer\n", + "import torch" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:03.002351Z", + "end_time": "2023-11-23T19:23:03.016701Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "设置好对应的参数,以保证模型推理的公平性。" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [], + "source": [ + "max_new_tokens = 1024\n", + "temperature = 0.1\n", + "top_p = 0.9\n", + "device = \"cuda\"\n", + "model_path_chat = \"/Models/chatglm3-6b\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:03.004850Z", + "end_time": "2023-11-23T19:23:03.047154Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [ + { + "data": { + "text/plain": "Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "d6468f7889994e638ac754fde95b6e58" + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tokenizer_path_chat = model_path_chat\n", + "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_chat, trust_remote_code=True, encode_special_tokens=True)\n", + "model = AutoModel.from_pretrained(model_path_chat, load_in_8bit=False, trust_remote_code=True).to(device)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:03.047154Z", + "end_time": "2023-11-23T19:23:13.086148Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [], + "source": [ + "def answer(prompt):\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", + " response = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=max_new_tokens, history=None,\n", + " temperature=temperature,\n", + " top_p=top_p, do_sample=True)\n", + " response = response[0, inputs[\"input_ids\"].shape[-1]:]\n", + " answer = tokenizer.decode(response, skip_special_tokens=True)\n", + " return answer" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:13.086148Z", + "end_time": "2023-11-23T19:23:13.086148Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "在样例中的标准答案应该为:\n", + "news_sports" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [], + "source": [ + "PROMPT1 = \"\"\"\n", + "<|system|>\n", + "你是一个专业的新闻专家,请根据我提供的新闻信息,包括新闻标题,新闻关键词等信息,你需要在这些类中选择其中一个,它们分别是:\n", + "news_story\n", + "news_culture\n", + "news_sports\n", + "news_finance\n", + "news_house\n", + "news_car\n", + "news_edu\n", + "news_tech\n", + "news_military\n", + "news_travel\n", + "news_world\n", + "stock\n", + "news_agriculture\n", + "news_game\n", + "这是我的信息:\n", + "<|user|>\n", + "新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n", + "新闻关键词: 无\n", + "<|assistant|>\n", + "\"\"\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:13.086148Z", + "end_time": "2023-11-23T19:23:13.095864Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "data": { + "text/plain": "' 新闻类型:体育新闻\\n 新闻关键词:女乒,排兵布阵,丁宁,硬仗,第一单打'" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer(PROMPT1)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:13.095864Z", + "end_time": "2023-11-23T19:23:14.034154Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [], + "source": [ + "PROMPT2 = \"\"\"\n", + "<|system|>\n", + "请根据我提供的新闻信息,格式为\n", + "```\n", + "新闻标题: xxx\n", + "新闻关键词: xxx\n", + "```\n", + "你要对每一行新闻类别进行分类并告诉我结果,不要返回其他信息和多于的文字,这些类别是:\n", + "news_story\n", + "news_culture\n", + "news_sports\n", + "news_finance\n", + "news_house\n", + "news_car\n", + "news_edu\n", + "news_tech\n", + "news_military\n", + "news_travel\n", + "news_world\n", + "stock\n", + "news_agriculture\n", + "news_game\n", + "我将为你提供一些新闻标题和关键词,你需要根据这些信息,对这些新闻进行分类,每条一行,格式为\n", + "```\n", + "类别中的其中一个\n", + "```\n", + "不要返回其他内容,例如新闻标题,新闻关键词等等,只需要返回分类结果即可\n", + "<|user|>\n", + "新闻标题: 女乒今天排兵布阵不合理,丁宁昨天刚打硬仗,今天应该打第一单打,你认为呢?\n", + "新闻关键词: 无\n", + "<|assistant|>\n", + "\"\"\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:14.044595Z", + "end_time": "2023-11-23T19:23:14.044595Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 26, + "outputs": [ + { + "data": { + "text/plain": "'news_sports'" + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer(PROMPT2)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:14.044595Z", + "end_time": "2023-11-23T19:23:14.217801Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "删除模型" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [], + "source": [ + "del model\n", + "torch.cuda.empty_cache()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-23T19:23:14.217801Z", + "end_time": "2023-11-23T19:23:14.373137Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 总结\n", + "在本实践手册中,我们让开发者体验了使用不同提示词下,`ChatGLM3-6B` 模型在 `新闻标题分类` 任务中的表现。\n", + "在使用第二个提示词时,模型的输出符合格式要求,而且没有冗余的对话。\n", + "因此,如果对于有格式要求的任务,可以使用更具有格式化的提示词来完成任务。同时,使用更低的 `temperature` 和更高的 `top_p` 可以提高模型的输出质量。\n", + "对于格式化的数据,由于模型在训练中大量使用MarkDown,因此,用 ``` 符号来规范提示词将能提升模型输出格式化数据的准确率\n", + "希望各位开发者在大模型的世界里玩的开心!" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}