diff --git a/TensorStack.TextGeneration/Pipelines/Qwen/QwenConfig.cs b/TensorStack.TextGeneration/Pipelines/Qwen/QwenConfig.cs new file mode 100644 index 0000000..7f65582 --- /dev/null +++ b/TensorStack.TextGeneration/Pipelines/Qwen/QwenConfig.cs @@ -0,0 +1,8 @@ +using TensorStack.TextGeneration.Common; + +namespace TensorStack.TextGeneration.Pipelines.Qwen +{ + public record QwenConfig : TransformerConfig + { + } +} diff --git a/TensorStack.TextGeneration/Pipelines/Qwen/QwenPipeline.cs b/TensorStack.TextGeneration/Pipelines/Qwen/QwenPipeline.cs new file mode 100644 index 0000000..a2b5dca --- /dev/null +++ b/TensorStack.TextGeneration/Pipelines/Qwen/QwenPipeline.cs @@ -0,0 +1,237 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. + +using System; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Pipeline; +using TensorStack.Common.Tensor; +using TensorStack.TextGeneration.Cache; +using TensorStack.TextGeneration.Common; +using TensorStack.TextGeneration.Processing; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.TextGeneration.Pipelines.Qwen +{ + public class QwenPipeline : DecoderPipeline, + IPipeline, + IPipeline + { + /// + /// Initializes a new instance of the class. + /// + /// The tokenizer configuration. + /// The decoder configuration. + public QwenPipeline(QwenConfig configuration) + : base(configuration.Tokenizer, configuration.DecoderConfig) + { + Configuration = configuration; + } + + public QwenConfig Configuration { get; } + + + /// + /// Runs the GreedySearch inference + /// + /// The options. + /// The cancellation token. + /// + public virtual async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + await TokenizePromptAsync(options); + var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken); + using (sequence) + { + return new GenerateResult + { + Score = sequence.Score, + Result = Tokenizer.Decode(sequence.Tokens), + Tokens = sequence.Tokens, + LastHiddenState = sequence.LastHiddenState + }; + } + } + + + /// + /// Runs the BeamSearch inference + /// + /// The options. + /// The progress callback. + /// The cancellation token that can be used by other objects or threads to receive notice of cancellation. + public async Task RunAsync(SearchOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + await TokenizePromptAsync(options); + + var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken); + var results = new GenerateResult[sequences.Length]; + for (int beam = 0; beam < sequences.Length; beam++) + { + var sequence = sequences[beam]; + using (sequence) + { + results[beam] = new GenerateResult + { + Beam = beam, + Score = sequence.Score, + PenaltyScore = sequence.PenaltyScore, + Result = Tokenizer.Decode(sequence.Tokens), + Tokens = sequence.Tokens, + LastHiddenState = sequence.LastHiddenState + }; + } + } + return results; + } + + + /// + /// Tokenize the prompt + /// + /// The options. + /// A Task representing the asynchronous operation. + protected override async Task TokenizePromptAsync(GenerateOptions options) + { + var tokenizerResult = await Tokenizer.EncodeAsync(options.Prompt); + var inputIds = tokenizerResult.InputIds.Span.Pad(Tokenizer.EOS, options.MinLength); + var mask = tokenizerResult.Mask.Span.Pad(0, options.MinLength); + TokenizerOutput = new TokenizerResult(inputIds, mask); + } + + + /// + /// Gets the token processors. + /// + /// The options. + /// ITokenProcessor[]. + protected override ITokenProcessor[] GetTokenProcessors(GenerateOptions options) + { + return + [ + new EOSTokenProcessor(options.MinLength, Tokenizer.EOS), + new MaxLengthTokenProcessor(options.MaxLength) + ]; + } + + + /// + /// Initialize the Decoder cache + /// + /// The options. + /// A Task<Sequence> representing the asynchronous operation. + protected override async Task InitializeAsync(GenerateOptions options) + { + var modelMetadata = await Decoder.LoadAsync(); + var kvCache = new KVCacheDecoder(modelMetadata, DecoderConfig.NumHeads, DecoderConfig.NumLayers, DecoderConfig.HiddenSize, DecoderConfig.NumKVHeads, options.MaxLength); + var sequence = new Sequence(kvCache, Tokenizer.BOS); + sequence.Initialize(0); + + var position = TokenizerOutput.Length; + var inputIds = TokenizerOutput.InputIds; + var positionIds = GetPositionIds(modelMetadata, 0, position); + var attentionMask = new Tensor([1, position], 1); + RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, false); + return sequence; + } + + + /// + /// Run decoder model + /// + /// The sequence. + /// A Task<Tensor`1> representing the asynchronous operation. + protected override async Task> RunDecoderAsync(Sequence sequence) + { + var modelMetadata = await Decoder.LoadAsync(); + var position = TokenizerOutput.Length + sequence.Tokens.Count; + var inputIds = new Tensor([1, 1], sequence.Tokens[^1]); + var positionIds = GetPositionIds(modelMetadata, position); + var attentionMask = new Tensor([1, position], 1); + return RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, true); + } + + + /// + /// Runs the decoder + /// + /// The model metadata. + /// The sequence. + /// The input ids. + /// The position ids. + /// The attention mask. + /// if set to true [use branch cache]. + private Tensor RunDecoderInternal(ModelMetadata modelMetadata, Sequence sequence, Tensor inputIds, Tensor positionIds, Tensor attentionMask, bool useBranchCache) + { + using (var parameters = new ModelParameters(modelMetadata)) + { + // Inputs + parameters.AddInput(inputIds); + parameters.AddInput(attentionMask); + if (positionIds != null) + parameters.AddInput(positionIds); + + foreach (var pastKeyValue in sequence.Cache) + parameters.AddInput(pastKeyValue, false); + + // Outputs + foreach (var output in modelMetadata.Outputs) + parameters.AddOutput(); + + // Result + var modelResult = Decoder.RunInference(parameters); + using (var logitsResult = modelResult[0]) + { + var dimension = logitsResult.GetDimensions(); + var logits = logitsResult.ToTensor(dimension[1..]); + var presentKeyValues = modelResult.ToArray()[1..]; + sequence.UpdateCache(presentKeyValues, useBranchCache); + return logits; + } + } + } + + + /// + /// Creates the QwenPipeline + /// + /// The provider. + /// The model path. + /// The decoder model. + /// QwenPipeline. + public static QwenPipeline Create(ExecutionProvider provider, string modelPath, string model = "model.onnx") + { + // Qwen-2.5 - https://huggingface.co/onnx-community/Qwen2.5-0.5B + var numHeads = 14; + var numLayers = 24; + var hiddenSize = 896; + var numKVHeads = 2; + var vocabSize = 151936; + var config = new QwenConfig + { + Tokenizer = new BPETokenizer(new TokenizerConfig + { + BOS = 151643, + EOS = 151643, + Path = modelPath + }), + DecoderConfig = new DecoderConfig + { + Path = Path.Combine(modelPath, model), + VocabSize = vocabSize, + NumHeads = numHeads, + NumLayers = numLayers, + HiddenSize = hiddenSize, + NumKVHeads = numKVHeads + } + }; + + config.DecoderConfig.SetProvider(provider); + return new QwenPipeline(config); + } + + } +} \ No newline at end of file