Copyright 2025 Google LLC | Apache License 2.0

Fine-tuning do TxGemma com Hugging Face

Colab Executar no Google Colab GitHub Ver no GitHub Hugging Face Ver no Hugging Face

Este notebook demonstra como fazer o ajuste fino (fine-tuning) dos modelos TxGemma para generalizar novas tarefas de desenvolvimento terapêutico usando as bibliotecas do Hugging Face.

A demonstração utiliza a biblioteca Transformer Reinforcement Learning (TRL) do Hugging Face para treinar o modelo com Supervised Fine-Tuning (SFT), utilizando Parameter-Efficient Fine-Tuning (PEFT) com Low-Rank Adaptation (LoRA) para reduzir os custos computacionais. Os dados de treinamento incluem um subconjunto do dataset TrialBench para ajustar o TxGemma na previsão de eventos adversos em ensaios clínicos.

Configuração

Para completar este tutorial, você precisará de um ambiente Colab com recursos suficientes para realizar o ajuste fino e executar o modelo TxGemma. Neste caso, você pode usar uma GPU T4:

  1. No canto superior direito da janela do Colab, selecione ▾ (Opções de conexão adicionais).
  2. Selecione Alterar tipo de tempo de execução.
  3. Em Acelerador de hardware, selecione T4 GPU.

Obter acesso ao TxGemma

Antes de começar, certifique-se de ter acesso aos modelos TxGemma no Hugging Face:

  1. Se você ainda não tem uma conta no Hugging Face, pode criar uma gratuitamente clicando aqui.
  2. Vá até a página do modelo TxGemma e aceite as condições de uso.

Configurar seu token HF

Gere um token de acesso de leitura (read) do Hugging Face clicando aqui e adicione seu token ao gerenciador de Segredos (Secrets) do Colab para armazená-lo com segurança.

  1. Abra seu notebook no Google Colab e clique na aba 🔑 Segredos no painel esquerdo.
  2. Crie um novo segredo com o nome HF_TOKEN.
  3. Copie/cole sua chave de token na caixa de entrada Valor de HF_TOKEN.
  4. Ative o botão à esquerda para permitir o acesso do notebook ao segredo.
import os
from google.colab import userdata
# Nota: `userdata.get` é uma API do Colab. Se você não estiver usando o Colab, 
# defina as variáveis de ambiente conforme apropriado para seu sistema.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

Instalar dependências

! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl

Carregar modelo do Hugging Face Hub

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/txgemma-2b-predict"

# Use quantização de 4 bits para reduzir o uso de memória
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

Carregar dataset

Este notebook utiliza dados de previsão de eventos adversos do TrialBench para ajustar o TxGemma. O conjunto de dados foi pré-processado em um formato de ajuste de instrução (instruction-tuning) e está disponível no Cloud Storage.

Carregue o conjunto de dados usando a biblioteca datasets do Hugging Face.

from datasets import load_dataset

! wget -nc https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl
data = load_dataset(
    "json",
    data_files="/content/trialbench_adverse-event-rate-prediction_train.jsonl",
    split="train",
)

# Exibir detalhes do dataset
data

Cada ponto de dado inclui:

Defina uma função que formata corretamente cada exemplo no conjunto de dados. Em uma seção posterior, ela será passada para o SFTTrainer, que aplica a função de formatação ao conjunto de dados antes da tokenização.

def formatting_func(example):
    text = f"{example['input_text']} {example['output_text']}<eos>"
    return text

# Exibir exemplo de dados de treinamento formatados
print(formatting_func(data[0]))

Experimentar o modelo pré-treinado

Solicite ao modelo pré-treinado para ver como ele se sai em uma tarefa de amostra de previsão de evento adverso. Antes do ajuste fino, o modelo não entende a instrução e fornece uma resposta inadequada.

prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Ajuste fino (Fine-tune) do modelo com LoRA

O ajuste fino tradicional de grandes modelos de linguagem (LLMs) consome muitos recursos porque requer o ajuste de bilhões de parâmetros. O PEFT (Parameter-Efficient Fine-Tuning) aborda isso treinando um número menor de parâmetros, usando técnicas como Low-Rank Adaptation (LoRA). O LoRA adapta eficientemente grandes modelos de linguagem treinando pequenas matrizes de baixa classificação que são adicionadas ao modelo original, em vez de atualizar as matrizes de peso total.

Primeiro, defina a LoraConfig, incluindo a classificação das matrizes de adaptação e as camadas do modelo onde adicionar os adaptadores LoRA.

from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

Prepare o modelo para treinamento.

from peft import prepare_model_for_kbit_training, get_peft_model

# Pré-processar modelo quantizado para treinamento
model = prepare_model_for_kbit_training(model)

# Criar PeftModel a partir do modelo quantizado e configuração
model = get_peft_model(model, lora_config)

Este exemplo usa o método Supervised Fine-Tuning (SFT) para treinar o modelo TxGemma.

Aqui, construímos o SFTTrainer que lida com o loop de treinamento completo. Especifique a configuração LoRA e a função de formatação do conjunto de dados definidas anteriormente e o SFTConfig com parâmetros de treinamento.

import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=data,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

# Iniciar processo de fine-tuning
trainer.train()

Testar o modelo ajustado

Solicite ao modelo ajustado para ver como ele se sai em uma tarefa de amostra de previsão de evento adverso. Após o ajuste fino, o modelo aprendeu a responder com uma resposta apropriada ao prompt.

prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Aviso de Licença: Conteúdo original Copyright 2025 Google LLC. Licenciado sob Apache License, Version 2.0. Tradução não oficial para fins educacionais.