CANN/ops-transformer局部旋转位置编码梯度算子 InplacePartialRotaryMulGrad【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品×Atlas A2 训练系列产品/Atlas A2 推理系列产品×Atlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明算子功能执行局部旋转位置编码InplacePartialRotaryMul的反向计算。该算子对输入dy的D维度上切片[start, end)区域执行旋转位置编码梯度计算dx的结果inplace写回dy的[start, end)区间。dy的[start, end)之外的数据保持不变。计算公式取旋转位置编码的正向计算中broadcast的轴列表为dims在D维度上的切片范围为[start, end)令参与计算的切片数据为 $$ dy dy[..., start:end] $$ $$ cos cos[..., start:end] $$ $$ sin sin[..., start:end] $$ 则梯度计算公式可表达如下1half模式rotary_mode等于0 $$ dy1, dy2 chunk(dy, chunks2, dim-1) $$$$ cos1, cos2 chunk(cos, chunks2, dim-1) $$$$ sin1, sin2 chunk(sin, chunks2, dim-1) $$$$ dx cat((cos1 * dy1 sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim-1) $$2interleave模式rotary_mode等于1 $$ dy1, dy2 dy[..., :: 2], dy[..., 1 :: 2] $$$$ cos1, cos2 cos[..., :: 2], cos[..., 1 :: 2] $$$$ sin1, sin2 sin[..., :: 2], sin[..., 1 :: 2] $$$$ dx stack((cos1 * dy1 sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim-1).reshape(dy.shape) $$3quarter模式rotary_mode等于2 $$ dy1, dy2, dy3, dy4 chunk(dy, chunks4, dim-1) $$$$ cos1, cos2, cos3, cos4 chunk(cos, chunks4, dim-1) $$$$ sin1, sin2, sin3, sin4 chunk(sin, chunks4, dim-1) $$$$ dx cat((cos1 * dy1 sin2 * dy2, cos2 * dy2 - sin1 * dy1, cos3 * dy3 sin4 * dy4, cos4 * dy4 - sin3 * dy3), dim-1) $$4interleave-half模式rotary_mode等于3 $$ dy1, dy2 chunk(dy, chunks2, dim-1) $$$$ cos1, cos2 chunk(cos, chunks2, dim-1) $$$$ sin1, sin2 chunk(sin, chunks2, dim-1) $$$$ dx stack((cos1 * dy1 sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim-1).reshape(dy.shape) $$参数说明参数名输入/输出/属性描述数据类型数据格式dy输入公式中的dy表示正向计算输出y的导数inplace更新为正向输入x的导数。Inplace模式dy同时作为输出写入结果。BFLOAT16、FLOAT16、FLOAT32NDcos输入公式中的cos正向计算输入需与sin数据类型一致。BFLOAT16、FLOAT16、FLOAT32NDsin输入公式中的sin正向计算输入需与cos数据类型一致。BFLOAT16、FLOAT16、FLOAT32NDrotary_mode属性旋转模式0half1interleave2quarter3interleave-half。当前仅支持interleave模式rotary_mode1。INT64-partial_slice属性D维度上的切片范围[start, end)默认{0, 0}表示不做有效计算。start须在[0, D]内end须在[start, D]内。IntArray-约束说明该算子仅支持Ascend 950 AI Processor。该算子仅支持连续Tensor不支持非连续Tensor。该算子当前版本仅支持 interleave 模式rotary_mode1。其他模式暂不支持。Inplace执行输入dy和输出共享同一个Tensor计算结果直接写回输入dy。输入dy当前只支持BSND排布输入cos/sin的shape必须与dy满足B/S/N维度的广播关系如BSND、111D、1SND、B1ND、BS1D、11ND、B11D、1S1D等。各参数的shape约束可以描述如下输入张量dy的最后一维大小D必须小于等于1024。输入张量cos、sin的最后一维大小必须等于partial_slice的切片长度即partial_slice[1] - partial_slice[0]。输入张量cos和sin的shape必须完全相同cos和sin的B、S、N维度需要与dy满足broadcast关系且广播后的B、S、N必须等于dy的B、S、N。half、interleave和interleave-half模式下partial_slice切片长度即partial_slice[1] - partial_slice[0]必须能被2整除。quarter模式下partial_slice切片长度即partial_slice[1] - partial_slice[0]必须能被4整除。当start等于end时算子不执行有效计算直接返回。输入张量cos和sin的数据类型必须相同。调用说明调用方式调用样例说明aclnn调用test_aclnn_inplace_partial_rotary_mul_grad通过aclnnInplacePartialRotaryMulGrad接口方式调用InplacePartialRotaryMulGrad算子。图模式调用test_geir_inplace_partial_rotary_mul_grad通过算子IR构图方式调用InplacePartialRotaryMulGrad算子。【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考