wuyukai0403's blog

By wuyukai0403, history, 13 days ago, In English
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
import torch.nn as nn
import torch

class MSOTConfig(PretrainedConfig):
    model_type = "msot"
    def __init__(self, vocab_size=128, hidden_size=16, **kwargs):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        super().__init__(**kwargs)

class MSOTModel(PreTrainedModel):
    config_class = MSOTConfig
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.config = config
        self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.l3 = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, input_ids, return_dict = None, **kwargs):
        hidden = self.emb(input_ids)
        a = self.l1(hidden)
        b = self.l2(hidden).transpose(-2, -1)
        c = self.l3(hidden)
        res = a @ b @ c
#        print("input:", input_ids)
#        print("output:", res)
        if not return_dict:
            return (res,)
        else:
            return BaseModelOutput(res)

class MSOTModelForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = MSOTConfig
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.model = MSOTModel(config, **kwargs)
        self.l = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, return_dict = None, **kwargs):
        hidden = self.model(input_ids)[0]
        res = self.l(hidden)
        if not return_dict:
            return (res,)
        else:
            return CausalLMOutput(logits=res)

    def can_generate(self):
        return True

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

def gen128(model, input):
    tokens = torch.tensor([list(bytes(input,"ascii"))])
    res = list(model.generate(tokens, max_new_tokens=50)[0])
    return bytes(res).decode("utf-8")

def gen65536(model, input):
    tokens = torch.tensor([[ord(c) for c in input if ord(c) < 65536]])
    res = list(model.generate(tokens, max_new_tokens=50)[0])
    return "".join([chr(o) for o in res])

if __name__ == "__main__":
    MSOTConfig.register_for_auto_class()
    MSOTModel.register_for_auto_class("AutoModel")
    MSOTModelForCausalLM.register_for_auto_class("AutoModelForCausalLM")
  • Vote: I like it
  • -12
  • Vote: I do not like it