from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
import torch.nn as nn
import torch
class MSOTConfig(PretrainedConfig):
model_type = "msot"
def __init__(self, vocab_size=128, hidden_size=16, **kwargs):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
super().__init__(**kwargs)
class MSOTModel(PreTrainedModel):
config_class = MSOTConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
self.l3 = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, input_ids, return_dict = None, **kwargs):
hidden = self.emb(input_ids)
a = self.l1(hidden)
b = self.l2(hidden).transpose(-2, -1)
c = self.l3(hidden)
res = a @ b @ c
# print("input:", input_ids)
# print("output:", res)
if not return_dict:
return (res,)
else:
return BaseModelOutput(res)
class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin):
config_class = MSOTConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.model = MSOTModel(config, **kwargs)
self.l = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids, return_dict = None, **kwargs):
hidden = self.model(input_ids)[0]
res = self.l(hidden)
if not return_dict:
return (res,)
else:
return CausalLMOutput(logits=res)
def can_generate(self):
return True
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def gen128(model, input):
tokens = torch.tensor([list(bytes(input,"ascii"))])
res = list(model.generate(tokens, max_new_tokens=50)[0])
return bytes(res).decode("utf-8")
def gen65536(model, input):
tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]])
res = list(model.generate(tokens, max_new_tokens=50)[0])
return "".join([chr(o) for o in res])
if __name__ == "__main__":
MSOTConfig.register_for_auto_class()
MSOTModel.register_for_auto_class("AutoModel")
MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM")