February 24, 2024


This code base is inspired by the design of Sklearn, for whichever machine learning in Sklearn, there are always two basic functions: one for training and the other for prediction.

1. Example

In this code base, for each LLM module, we also provide two basic functions: train and predict , as shown in the following simple example

import fire

from lora import Llama_Lora

def main(
        task: str = "eval",
				base_model: str = "meta-llama/Llama-2-7b-hf",
):
    m = Llama_Lora(
        base_model=base_model,
    )
    if task == "train":
        m.train(
            train_file = "data/sst2/train.json",
            val_file = "data/sst2/val.json",
            output_dir = "./ckp_sst_llama2_lora",
            train_batch_size = 32,
            num_epochs = 1,
        )
    elif task == "eval":
        m.predict(
            input_file = "data/sst2/val.json",
            max_new_tokens = 32,
            verbose = True,
        )
    else:
        raise ValueError(f"Unrecognized task: {task}")

if __name__ == "__main__":
    fire.Fire(main)

With the classical SST2 dataset, this simple example demonstrates

from lora import Llama_Lora
...

    m = Llama_Lora(
        base_model=base_model,
    )
        m.train(
            train_file = "data/sst2/train.json",
            val_file = "data/sst2/val.json",
            output_dir = "./ckp_sst_llama2_lora",
            train_batch_size = 32,
            num_epochs = 1,
        )
        m.predict(
            input_file = "data/sst2/val.json",
            max_new_tokens = 32,
            verbose = True,
        )

More details on the train and predict will be explained in the next section, and information related to each LLM module will be explained in the sections after the next one.


2. Supported LLM Modules

By February 24, 2024, the code base supports the following LLM modules