ScaledDotProductAttention

public struct ScaledDotProductAttention<Element, Device> : LayerType, Codable where Element : NumericType, Device : DeviceType

Computes Scaled Multi-Head Dot Product Attention as introduced by Attention Is All You Need.

  • Undocumented

    Declaration

    Swift

    public var temperature: Element
  • Undocumented

    Declaration

    Swift

    public init(temperature: Element)
  • Declaration

    Swift

    public var parameters: [Tensor<Element, Device>] { get }
  • Declaration

    Swift

    public var parameterPaths: [WritableKeyPath<`Self`, Tensor<Element, Device>>] { get }
  • Performs scaled dot product attention.

    Declaration

    Swift

    public func callAsFunction(_ inputs: (q: Tensor<Element, Device>, k: Tensor<Element, Device>, v: Tensor<Element, Device>, mask: Tensor<Element, Device>?)) -> Tensor<Element, Device>

    Parameters

    inputs

    Tuple containing queries of shape [batchSize, heads, queryCount, keyDim], keys of shape [batchSize, heads, keyCount, keyDim] and values of shape [batchSize, heads, valueCount, valueDim] as well as an optional mask that may be used to prevent attention to certain elements outside of the batch or in future timesteps. Mask must be broadcastable to shape [batchSize, heads, queryCount, keyCount]

    Return Value

    Attended values tensor of shape [batchSize, heads, queryCount, valueDim]