February 24, 2024
This code base is inspired by the design of Sklearn, for whichever machine learning in Sklearn, there are always two basic functions: one for training and the other for prediction.
In this code base, for each LLM module, we also provide two basic functions: train
and predict
, as shown in the following simple example
import fire
from lora import Llama_Lora
def main(
task: str = "eval",
base_model: str = "meta-llama/Llama-2-7b-hf",
):
m = Llama_Lora(
base_model=base_model,
)
if task == "train":
m.train(
train_file = "data/sst2/train.json",
val_file = "data/sst2/val.json",
output_dir = "./ckp_sst_llama2_lora",
train_batch_size = 32,
num_epochs = 1,
)
elif task == "eval":
m.predict(
input_file = "data/sst2/val.json",
max_new_tokens = 32,
verbose = True,
)
else:
raise ValueError(f"Unrecognized task: {task}")
if __name__ == "__main__":
fire.Fire(main)
With the classical SST2 dataset, this simple example demonstrates
Llama_Lora
from lora import Llama_Lora
...
m = Llama_Lora(
base_model=base_model,
)
train_file
and val_file
m.train(
train_file = "data/sst2/train.json",
val_file = "data/sst2/val.json",
output_dir = "./ckp_sst_llama2_lora",
train_batch_size = 32,
num_epochs = 1,
)
input_file
m.predict(
input_file = "data/sst2/val.json",
max_new_tokens = 32,
verbose = True,
)
More details on the train
and predict
will be explained in the next section, and information related to each LLM module will be explained in the sections after the next one.
By February 24, 2024, the code base supports the following LLM modules
Llama_Lora
Falcon_Lora