/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.pooling;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.util.Preconditions;
import java.util.Objects;

public final class Pool {
    private Pool() {
    }

    public static NDArray maxPool1d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for maxPool1d");
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 1 && stride.dimension() == 1 && padding.dimension() == 1, "kernelShape, Stride and Padding dimensions for maxPool1d layer should be 1");
        return input.getNDArrayInternal().maxPool(kernelShape, stride, padding, ceilMode);
    }

    public static NDArray maxPool2d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for maxPool2d");
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 2 && stride.dimension() == 2 && padding.dimension() == 2, "kernelShape, Stride and Padding dimensions for maxPool2d should be 2");
        return input.getNDArrayInternal().maxPool(kernelShape, stride, padding, ceilMode);
    }

    public static NDArray maxPool3d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for maxPool3d");
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 3 && stride.dimension() == 3 && padding.dimension() == 3, "kernelShape, Stride and Pad dimensions for maxPool3d should be 3");
        return input.getNDArrayInternal().maxPool(kernelShape, stride, padding, ceilMode);
    }

    public static NDArray globalMaxPool1d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalMaxPool();
    }

    public static NDArray globalMaxPool2d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalMaxPool();
    }

    public static NDArray globalMaxPool3d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalMaxPool();
    }

    public static NDArray avgPool1d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for avgPool1d");
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 1 && stride.dimension() == 1 && padding.dimension() == 1, "kernelShape, Stride and Padding dimensions for avgPool1d should be 1");
        return input.getNDArrayInternal().avgPool(kernelShape, stride, padding, ceilMode, countIncludePad);
    }

    public static NDArray avgPool2d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for avgPool2d");
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 2 && stride.dimension() == 2 && padding.dimension() == 2, "kernelShape, Stride and Padding dimensions for avgPool2d should be 2");
        return input.getNDArrayInternal().avgPool(kernelShape, stride, padding, ceilMode, countIncludePad);
    }

    public static NDArray avgPool3d(NDArray input, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for avgPool3d");
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 3 && stride.dimension() == 3 && padding.dimension() == 3, "kernelShape, Stride and Padding dimensions for avgPool2d should be 3");
        return input.getNDArrayInternal().avgPool(kernelShape, stride, padding, ceilMode, countIncludePad);
    }

    public static NDArray globalAvgPool1d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalAvgPool();
    }

    public static NDArray globalAvgPool2d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalAvgPool();
    }

    public static NDArray globalAvgPool3d(NDArray input) {
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalAvgPool();
    }

    public static NDArray lpPool1d(NDArray input, float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for lpPool1d");
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 1 && stride.dimension() == 1 && padding.dimension() == 1, "kernelShape, Stride and Padding dimensions for lpPool1d should be 1");
        return input.getNDArrayInternal().lpPool(normType, kernelShape, stride, padding, ceilMode);
    }

    public static NDArray lpPool2d(NDArray input, float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for lpPool2d");
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 2 && stride.dimension() == 2, "kernelShape, Stride and Padding dimensions for lpPool2d should be 2");
        return input.getNDArrayInternal().lpPool(normType, kernelShape, stride, padding, ceilMode);
    }

    public static NDArray lpPool3d(NDArray input, float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        Objects.requireNonNull(kernelShape, "kernelShape cannot be null for lpPool3d");
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        Preconditions.checkArgument(kernelShape.dimension() == 3 && stride.dimension() == 3 && padding.dimension() == 3, "kernelShape, Stride and Padding dimensions for lpPool3d should be 1");
        return input.getNDArrayInternal().lpPool(normType, kernelShape, stride, padding, ceilMode);
    }

    public static NDArray globalLpPool1d(NDArray input, float normType) {
        Preconditions.checkArgument(input.getShape().dimension() == 3, "Expect input dimension is 3 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalLpPool(normType);
    }

    public static NDArray globalLpPool2d(NDArray input, float normType) {
        Preconditions.checkArgument(input.getShape().dimension() == 4, "Expect input dimension is 4 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalLpPool(normType);
    }

    public static NDArray globalLpPool3d(NDArray input, float normType) {
        Preconditions.checkArgument(input.getShape().dimension() == 5, "Expect input dimension is 5 but got " + input.getShape().dimension());
        return input.getNDArrayInternal().globalLpPool(normType);
    }

    public static Block maxPool1dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.maxPool1d(array, kernelShape, stride, padding, ceilMode), "maxPool1d");
    }

    public static Block maxPool1dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.maxPool1dBlock(kernelShape, stride, padding, false);
    }

    public static Block maxPool1dBlock(Shape kernelShape, Shape stride) {
        return Pool.maxPool1dBlock(kernelShape, stride, new Shape(0L), false);
    }

    public static Block maxPool1dBlock(Shape kernelShape) {
        return Pool.maxPool1dBlock(kernelShape, kernelShape, new Shape(0L), false);
    }

    public static Block maxPool2dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.maxPool2d(array, kernelShape, stride, padding, ceilMode), "maxPool2d");
    }

    public static Block maxPool2dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.maxPool2dBlock(kernelShape, stride, padding, false);
    }

    public static Block maxPool2dBlock(Shape kernelShape, Shape stride) {
        return Pool.maxPool2dBlock(kernelShape, stride, new Shape(0L, 0L), false);
    }

    public static Block maxPool2dBlock(Shape kernelShape) {
        return Pool.maxPool2dBlock(kernelShape, kernelShape, new Shape(0L, 0L), false);
    }

    public static Block maxPool3dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.maxPool3d(array, kernelShape, stride, padding, ceilMode), "maxPool3d");
    }

    public static Block maxPool3dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.maxPool3dBlock(kernelShape, stride, padding, false);
    }

    public static Block maxPool3dBlock(Shape kernelShape, Shape stride) {
        return Pool.maxPool3dBlock(kernelShape, stride, new Shape(0L, 0L, 0L), false);
    }

    public static Block maxPool3dBlock(Shape kernelShape) {
        return Pool.maxPool3dBlock(kernelShape, new Shape(1L, 1L, 1L), new Shape(0L, 0L, 0L), false);
    }

    public static Block globalMaxPool1dBlock() {
        return LambdaBlock.singleton(Pool::globalMaxPool1d, "globalMaxPool1d");
    }

    public static Block globalMaxPool2dBlock() {
        return LambdaBlock.singleton(Pool::globalMaxPool2d, "globalMaxPool2d");
    }

    public static Block globalMaxPool3dBlock() {
        return LambdaBlock.singleton(Pool::globalMaxPool3d, "globalMaxPool3d");
    }

    public static Block avgPool1dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        return LambdaBlock.singleton(array -> Pool.avgPool1d(array, kernelShape, stride, padding, ceilMode, countIncludePad), "avgPool1d");
    }

    public static Block avgPool1dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return Pool.avgPool1dBlock(kernelShape, stride, padding, ceilMode, true);
    }

    public static Block avgPool1dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.avgPool1dBlock(kernelShape, stride, padding, false, true);
    }

    public static Block avgPool1dBlock(Shape kernelShape, Shape stride) {
        return Pool.avgPool1dBlock(kernelShape, stride, new Shape(0L), false, true);
    }

    public static Block avgPool1dBlock(Shape kernelShape) {
        return Pool.avgPool1dBlock(kernelShape, kernelShape, new Shape(0L), false, true);
    }

    public static Block avgPool2dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        return LambdaBlock.singleton(array -> Pool.avgPool2d(array, kernelShape, stride, padding, ceilMode, countIncludePad), "avgPool2d");
    }

    public static Block avgPool2dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return Pool.avgPool2dBlock(kernelShape, stride, padding, ceilMode, true);
    }

    public static Block avgPool2dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.avgPool2dBlock(kernelShape, stride, padding, false, true);
    }

    public static Block avgPool2dBlock(Shape kernelShape, Shape stride) {
        return Pool.avgPool2dBlock(kernelShape, stride, new Shape(0L, 0L), false, true);
    }

    public static Block avgPool2dBlock(Shape kernelShape) {
        return Pool.avgPool2dBlock(kernelShape, kernelShape, new Shape(0L, 0L), false, true);
    }

    public static Block avgPool3dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) {
        return LambdaBlock.singleton(array -> Pool.avgPool3d(array, kernelShape, stride, padding, ceilMode, countIncludePad), "avgPool3d");
    }

    public static Block avgPool3dBlock(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return Pool.avgPool3dBlock(kernelShape, stride, padding, ceilMode, true);
    }

    public static Block avgPool3dBlock(Shape kernelShape, Shape stride, Shape padding) {
        return Pool.avgPool3dBlock(kernelShape, stride, padding, false, true);
    }

    public static Block avgPool3dBlock(Shape kernelShape, Shape stride) {
        return Pool.avgPool3dBlock(kernelShape, stride, new Shape(0L, 0L, 0L), false, true);
    }

    public static Block avgPool3dBlock(Shape kernelShape) {
        return Pool.avgPool3dBlock(kernelShape, kernelShape, new Shape(0L, 0L, 0L), false, true);
    }

    public static Block globalAvgPool1dBlock() {
        return LambdaBlock.singleton(Pool::globalAvgPool1d, "globalAvgPool1d");
    }

    public static Block globalAvgPool2dBlock() {
        return LambdaBlock.singleton(Pool::globalAvgPool2d, "globalAvgPool2d");
    }

    public static Block globalAvgPool3dBlock() {
        return LambdaBlock.singleton(Pool::globalAvgPool3d, "globalAvgPool3d");
    }

    public static Block lpPool1dBlock(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.lpPool1d(array, normType, kernelShape, stride, padding, ceilMode), "lpPool1d");
    }

    public static Block lpPool1dBlock(float normType, Shape kernelShape, Shape stride, Shape padding) {
        return Pool.lpPool1dBlock(normType, kernelShape, stride, padding, false);
    }

    public static Block lpPool1dBlock(float normType, Shape kernelShape) {
        return Pool.lpPool1dBlock(normType, kernelShape, new Shape(1L), new Shape(0L), false);
    }

    public static Block lpPool2dBlock(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.lpPool2d(array, normType, kernelShape, stride, padding, ceilMode), "lpPool2d");
    }

    public static Block lpPool2dBlock(float normType, Shape kernelShape, Shape stride, Shape padding) {
        return Pool.lpPool2dBlock(normType, kernelShape, stride, padding, false);
    }

    public static Block lpPool2dBlock(float normType, Shape kernelShape, Shape stride) {
        return Pool.lpPool2dBlock(normType, kernelShape, stride, new Shape(0L, 0L), false);
    }

    public static Block lpPool2dBlock(float normType, Shape kernelShape) {
        return Pool.lpPool2dBlock(normType, kernelShape, new Shape(1L, 1L), new Shape(0L, 0L));
    }

    public static Block lpPool3dBlock(float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return LambdaBlock.singleton(array -> Pool.lpPool3d(array, normType, kernelShape, stride, padding, ceilMode), "lpPool3d");
    }

    public static Block lpPool3dBlock(float normType, Shape kernelShape, Shape stride, Shape padding) {
        return Pool.lpPool3dBlock(normType, kernelShape, stride, padding, false);
    }

    public static Block lpPool3dBlock(float normType, Shape kernelShape, Shape stride) {
        return Pool.lpPool3dBlock(normType, kernelShape, stride, new Shape(0L, 0L, 0L), false);
    }

    public static Block lpPool3dBlock(float normType, Shape kernelShape) {
        return Pool.lpPool3dBlock(normType, kernelShape, kernelShape, new Shape(0L, 0L, 0L), false);
    }

    public static Block globalLpPool1dBlock(float normType) {
        return LambdaBlock.singleton(array -> Pool.globalLpPool1d(array, normType), "globalLpPool1d");
    }

    public static Block globalLpPool2dBlock(float normType) {
        return LambdaBlock.singleton(array -> Pool.globalLpPool2d(array, normType), "globalLpPool2d");
    }

    public static Block globalLpPool3dBlock(float normType) {
        return LambdaBlock.singleton(array -> Pool.globalLpPool3d(array, normType), "globalLpPool3d");
    }
}

