Skip to main content

TextRL - use reinforcement learning to adjust text generation results.

Project description

TextRL

Text generation with reinforcement learning using huggingface's transformer.

Introduction

This project is trying to use reinforcement learning to adjust text generation results. It is based on any text-generation model on huggingaface's transformer with PFRL and OpenAI GYM.

Installation

pip install

pip install textrl

Build from source

git clone and cd into this project.

pip install -e .

Usage

init agent and environment

from textrl import TextRLEnv,TextRLActor

from transformers import AutoTokenizer, AutoModelWithLMHead  
tokenizer = AutoTokenizer.from_pretrained("any models")  
model = AutoModelWithLMHead.from_pretrained("any models")
model.eval()

setup reward function for environment

  • predicted(list[str]): will be the list of predicted token
  • finish(bool): it met the end of sentence or not
class MyRLEnv(TextRLEnv):
    def get_reward(self, input_text, predicted_list, finish): # predicted will be the list of predicted token
        if "[UNK]" in predicted_list:
            reward = -1
        else:
            reward = 1
        return reward

prepare for training

  • observation_input should be a list of all possible input string for model training
env = MyRLEnv(model, tokenizer, observation_input=observaton_list)
actor = TextRLActor(env,model,tokenizer)
agent = actor.agent_ppo(update_interval=10, minibatch_size=2000, epochs=20)

Train

n_episodes = 1000
max_episode_len = 200 # max sentence length

for i in range(1, n_episodes + 1):
    obs = env.reset()
    R = 0 
    t = 0 
    while True:
        action = agent.act(obs)
        obs, reward, done, pred = env.step(action)
        R += reward
        t += 1
        reset = t == max_episode_len
        agent.observe(obs, reward, done, reset)
        if done or reset:
            break
    if i % 10 == 0:
        print('episode:', i, 'R:', R)
    if i % 50 == 0:
        print('statistics:', agent.get_statistics())
print('Finished.')

another way to train

import logging
import sys
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=1000,
    eval_n_steps=None,
    eval_n_episodes=1500,       
    train_max_episode_len=50,  
    eval_interval=10000,
    outdir='somewhere', 
)

prediction

actor.predict("input text")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

textrl-0.0.3.tar.gz (6.2 kB view hashes)

Uploaded Source

Built Distributions

textrl-0.0.3-py3.7.egg (7.9 kB view hashes)

Uploaded Source

textrl-0.0.3-py3-none-any.whl (4.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page