class SHAInet::MultiHeadAttention

Defined in:

shainet/transformer/multi_head_attention.cr

Constructors

Instance Method Summary

Constructor Detail

def self.new(d_model : Int32, num_heads : Int32) #

[View source]

Instance Method Detail

def apply_gradients(lr : Float64) #

[View source]
def backward(d_out : SimpleMatrix) #

[View source]
def d_model : Int32 #

[View source]
def forward(x : SimpleMatrix) #

[View source]
def grads_w_k : SimpleMatrix #

[View source]
def grads_w_k=(grads_w_k : SimpleMatrix) #

[View source]
def grads_w_o : SimpleMatrix #

[View source]
def grads_w_o=(grads_w_o : SimpleMatrix) #

[View source]
def grads_w_q : SimpleMatrix #

[View source]
def grads_w_q=(grads_w_q : SimpleMatrix) #

[View source]
def grads_w_v : SimpleMatrix #

[View source]
def grads_w_v=(grads_w_v : SimpleMatrix) #

[View source]
def head_dim : Int32 #

[View source]
def num_heads : Int32 #

[View source]

[View source]

[View source]

[View source]

[View source]
def zero_gradients #

[View source]