diff --git a/cmake/FindJulia.cmake b/cmake/FindJulia.cmake new file mode 100644 index 00000000000..c9d6fd25385 --- /dev/null +++ b/cmake/FindJulia.cmake @@ -0,0 +1,67 @@ +# Copyright 2010-2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#[=======================================================================[.rst: +FindJulia +--------- + +This module determines the Julia interpreter of the system. + +IMPORTED Targets +^^^^^^^^^^^^^^^^ + +This module defines :prop_tgt:`IMPORTED` target ``Julia::Interpreter``, if +Julia has been found. + +Result Variables +^^^^^^^^^^^^^^^^ + +This module defines the following variables: + +:: + +Julia_FOUND - True if Julia found. +Julia_BIN - The path to the Julia executable. +Julia_VERSION - The version of the Julia executable if found. + +Hints +^^^^^ + +A user may set ``JULIA_BINDIR`` to a folder containing the Julia binary +to tell this module where to look. +#]=======================================================================] + +include(FindPackageHandleStandardArgs) + +if(DEFINED ENV{JULIA_BINDIR}) + message(STATUS "JULIA_BINDIR: $ENV{JULIA_BINDIR}") +endif() + +set(Julia_FOUND FALSE) + +if(DEFINED ENV{JULIA_BINDIR}) + find_program(Julia_BIN julia PATHS $ENV{JULIA_BINDIR} DOC "Julia executable") +else() + find_program(Julia_BIN julia DOC "Julia executable") +endif() +message(STATUS "Julia_BIN: ${Julia_BIN}") + +if(Julia_BIN) + execute_process( + COMMAND "${Julia_BIN}" --startup-file=no --version + OUTPUT_VARIABLE Julia_VERSION + ) + message(STATUS "Julia_VERSION: ${Julia_VERSION}") + set(Julia_FOUND TRUE) +endif() + diff --git a/cmake/Makefile b/cmake/Makefile index 848e9bd1c9f..a2e09550c23 100644 --- a/cmake/Makefile +++ b/cmake/Makefile @@ -143,7 +143,9 @@ help: @echo -e "\t${BOLD}clean_vms${RESET}: Remove ALL vagrant box." @echo @echo -e "\tWith ${BOLD}${RESET}:" - @echo -e "\t\t${BOLD}freebsd${RESET} (FreeBSD)" + @echo -e "\t\t${BOLD}freebsd${RESET} (FreeBSD 14)" + @echo -e "\t\t${BOLD}netbsd${RESET} (NetBSD 9)" + @echo -e "\t\t${BOLD}openbsd${RESET} (OpenBSD 7)" @echo -e "\te.g. 'make freebsd_cpp'" @echo @echo -e "\t${BOLD}glop_${RESET}: Build Glop using an Ubuntu:rolling docker image." @@ -717,30 +719,43 @@ clean_web: $(addprefix clean_web_, $(WEB_STAGES)) ############# ## VAGRANT ## ############# -VMS := freebsd - -freebsd_targets = $(addprefix freebsd_, $(LANGUAGES)) -.PHONY: freebsd $(freebsd_targets) -freebsd: $(freebsd_targets) -$(freebsd_targets): freebsd_%: vagrant/freebsd/%/Vagrantfile - @cd vagrant/freebsd/$* && vagrant destroy -f - cd vagrant/freebsd/$* && vagrant box update - cd vagrant/freebsd/$* && vagrant up - -# SSH to a freebsd_ vagrant machine (debug). -sh_freebsd_targets = $(addprefix sh_freebsd_, $(LANGUAGES)) -.PHONY: $(sh_freebsd_targets) -$(sh_freebsd_targets): sh_freebsd_%: - cd vagrant/freebsd/$* && vagrant up - cd vagrant/freebsd/$* && vagrant ssh - -# Clean FreeBSD vagrant machine -clean_freebsd_targets = $(addprefix clean_freebsd_, $(LANGUAGES)) -.PHONY: clean_freebsd $(clean_freebsd_targets) -clean_freebsd: $(clean_freebsd_targets) -$(clean_freebsd_targets): clean_freebsd_%: - cd vagrant/freebsd/$* && vagrant destroy -f - -rm -rf vagrant/freebsd/$*/.vagrant +VAGRANT_VMS := \ + freebsd \ + netbsd \ + openbsd + +define make-vagrant-target = +#$$(info VMS: $1) +#$$(info Create target: $1_.) +$1_targets = $(addprefix $1_, $(LANGUAGES)) +.PHONY: $1 $$($1_targets) +$1: $$($1_targets) +$$($1_targets): $1_%: vagrant/$1/%/Vagrantfile + @cd vagrant/$1/$$* && vagrant destroy -f + cd vagrant/$1/$$* && vagrant box update + cd vagrant/$1/$$* && vagrant up + +#$$(info Create targets: sh_$1_ vagrant machine (debug).) +sh_$1_targets = $(addprefix sh_$1_, $(LANGUAGES)) +.PHONY: $$(sh_$1_targets) +$$(sh_$1_targets): sh_$1_%: + cd vagrant/$1/$$* && vagrant up + cd vagrant/$1/$$* && vagrant ssh + +#$$(info Create targets: clean_$1) +clean_$1_targets = $(addprefix clean_$1_, $(LANGUAGES)) +.PHONY: clean_$1 $(clean_$1_targets) +clean_$1: $$(clean_$1_targets) +$$(clean_$1_targets): clean_$1_%: + cd vagrant/$1/$$* && vagrant destroy -f + -rm -rf vagrant/$1/$$*/.vagrant +endef + +$(foreach vms,$(VAGRANT_VMS),$(eval $(call make-vagrant-target,$(vms)))) + +## MERGE ## +.PHONY: clean_vagrant +clean_vagrant: $(addprefix clean_, $(VAGRANT_VMS)) ########## ## GLOP ## @@ -779,7 +794,7 @@ clean_glop: $(addprefix clean_glop_, $(STAGES)) ## CLEAN ## ########### .PHONY: clean -clean: clean_all clean_platforms clean_toolchains clean_web clean_freebsd clean_glop +clean: clean_all clean_platforms clean_toolchains clean_web clean_vagrant clean_glop docker container prune -f docker image prune -f -rmdir cache diff --git a/cmake/vagrant/netbsd/cpp/Vagrantfile b/cmake/vagrant/netbsd/cpp/Vagrantfile new file mode 100644 index 00000000000..0378b3d5a80 --- /dev/null +++ b/cmake/vagrant/netbsd/cpp/Vagrantfile @@ -0,0 +1,115 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :netbsd + config.vm.box = "generic/netbsd9" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_netbsd_cpp" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + cmake -S. -Bbuild -DBUILD_DEPS=ON + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/netbsd/dotnet/Vagrantfile b/cmake/vagrant/netbsd/dotnet/Vagrantfile new file mode 100644 index 00000000000..bceb231d8bb --- /dev/null +++ b/cmake/vagrant/netbsd/dotnet/Vagrantfile @@ -0,0 +1,118 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :netbsd + config.vm.box = "generic/netbsd9" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_netbsd_dotnet" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + kldload linux64 + pkg install -y swig linux-dotnet-sdk + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../tools/doc/orLogo.png", destination: "$HOME/project/tools/doc/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + cmake -S. -Bbuild -DBUILD_DOTNET=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/netbsd/java/Vagrantfile b/cmake/vagrant/netbsd/java/Vagrantfile new file mode 100644 index 00000000000..050e73496e8 --- /dev/null +++ b/cmake/vagrant/netbsd/java/Vagrantfile @@ -0,0 +1,119 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :netbsd + config.vm.box = "generic/netbsd9" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_netbsd_java" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + pkg install -y swig openjdk11 maven + mount -t fdescfs fdesc /dev/fd + mount -t procfs proc /proc + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + export JAVA_HOME=/usr/local/openjdk11 + cmake -S. -Bbuild -DBUILD_JAVA=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/netbsd/python/Vagrantfile b/cmake/vagrant/netbsd/python/Vagrantfile new file mode 100644 index 00000000000..6cfb5783f6d --- /dev/null +++ b/cmake/vagrant/netbsd/python/Vagrantfile @@ -0,0 +1,122 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :netbsd + config.vm.box = "generic/netbsd9" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_netbsd_python" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + pkg install -y swig python39 py39-wheel py39-pip py39-pytest-virtualenv + pkg install -y py39-numpy py39-pandas py39-matplotlib + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../LICENSE", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake -S. -Bbuild -DBUILD_PYTHON=ON -DVENV_USE_SYSTEM_SITE_PACKAGES=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/openbsd/cpp/Vagrantfile b/cmake/vagrant/openbsd/cpp/Vagrantfile new file mode 100644 index 00000000000..85a99d367b4 --- /dev/null +++ b/cmake/vagrant/openbsd/cpp/Vagrantfile @@ -0,0 +1,115 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :openbsd + config.vm.box = "generic/openbsd7" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_openbsd_cpp" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + cmake -S. -Bbuild -DBUILD_DEPS=ON + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/openbsd/dotnet/Vagrantfile b/cmake/vagrant/openbsd/dotnet/Vagrantfile new file mode 100644 index 00000000000..b129bca007f --- /dev/null +++ b/cmake/vagrant/openbsd/dotnet/Vagrantfile @@ -0,0 +1,118 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :openbsd + config.vm.box = "generic/openbsd7" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_openbsd_dotnet" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + kldload linux64 + pkg install -y swig linux-dotnet-sdk + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../tools/doc/orLogo.png", destination: "$HOME/project/tools/doc/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + cmake -S. -Bbuild -DBUILD_DOTNET=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/openbsd/java/Vagrantfile b/cmake/vagrant/openbsd/java/Vagrantfile new file mode 100644 index 00000000000..c0298674717 --- /dev/null +++ b/cmake/vagrant/openbsd/java/Vagrantfile @@ -0,0 +1,119 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :openbsd + config.vm.box = "generic/openbsd7" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_openbsd_java" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + pkg install -y swig openjdk11 maven + mount -t fdescfs fdesc /dev/fd + mount -t procfs proc /proc + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + cd project + export JAVA_HOME=/usr/local/openjdk11 + cmake -S. -Bbuild -DBUILD_JAVA=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + cd project + cmake --build build --target test -v + SHELL +end diff --git a/cmake/vagrant/openbsd/python/Vagrantfile b/cmake/vagrant/openbsd/python/Vagrantfile new file mode 100644 index 00000000000..d2083458d33 --- /dev/null +++ b/cmake/vagrant/openbsd/python/Vagrantfile @@ -0,0 +1,122 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +# All Vagrant configuration is done below. The "2" in Vagrant.configure +# configures the configuration version (we support older styles for +# backwards compatibility). Please don't change it unless you know what +# you're doing. +Vagrant.configure("2") do |config| + # The most common configuration options are documented and commented below. + # For a complete reference, please see the online documentation at + # https://docs.vagrantup.com. + + # Every Vagrant development environment requires a box. You can search for + # boxes at https://vagrantcloud.com/search. + config.vm.guest = :openbsd + config.vm.box = "generic/openbsd7" + config.vm.provider "virtualbox" do |v| + v.name = "ortools_openbsd_python" + end + config.ssh.shell = "sh" + + # Disable automatic box update checking. If you disable this, then + # boxes will only be checked for updates when the user runs + # `vagrant box outdated`. This is not recommended. + # config.vm.box_check_update = false + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine. In the example below, + # accessing "localhost:8080" will access port 80 on the guest machine. + # NOTE: This will enable public access to the opened port + # config.vm.network "forwarded_port", guest: 80, host: 8080 + + # Create a forwarded port mapping which allows access to a specific port + # within the machine from a port on the host machine and only allow access + # via 127.0.0.1 to disable public access + # config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" + + # Create a private network, which allows host-only access to the machine + # using a specific IP. + # config.vm.network "private_network", ip: "192.168.33.10" + + # Create a public network, which generally matched to bridged network. + # Bridged networks make the machine appear as another physical device on + # your network. + # config.vm.network "public_network" + + # Share an additional folder to the guest VM. The first argument is + # the path on the host to the actual folder. The second argument is + # the path on the guest to mount the folder. And the optional third + # argument is a set of non-required options. + #config.vm.synced_folder "../../..", "/home/vagrant/project" + config.vm.synced_folder ".", "/vagrant", id: "vagrant-root", disabled: true + + + # Provider-specific configuration so you can fine-tune various + # backing providers for Vagrant. These expose provider-specific options. + # Example for VirtualBox: + # + # config.vm.provider "virtualbox" do |vb| + # # Display the VirtualBox GUI when booting the machine + # vb.gui = true + # + # # Customize the amount of memory on the VM: + # vb.memory = "1024" + # end + # + # View the documentation for the provider you are using for more + # information on available options. + + # Enable provisioning with a shell script. Additional provisioners such as + # Ansible, Chef, Docker, Puppet and Salt are also available. Please see the + # documentation for more information about their specific syntax and use. + # note: clang installed by default + config.vm.provision "env", type: "shell", inline:<<-SHELL + set -x + pkg update -f + pkg install -y git cmake + pkg install -y swig python39 py39-wheel py39-pip py39-pytest-virtualenv + pkg install -y py39-numpy py39-pandas py39-matplotlib + SHELL + + config.vm.provision "file", source: "../../../../CMakeLists.txt", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../cmake", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../ortools", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../examples/contrib", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/cpp", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/dotnet", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/java", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/python", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../examples/tests", destination: "$HOME/project/examples/" + config.vm.provision "file", source: "../../../../patches", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../LICENSE", destination: "$HOME/project/" + config.vm.provision "file", source: "../../../../Version.txt", destination: "$HOME/project/" + + config.vm.provision "devel", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + ls + SHELL + + config.vm.provision "configure", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake -S. -Bbuild -DBUILD_PYTHON=ON -DVENV_USE_SYSTEM_SITE_PACKAGES=ON -DBUILD_CXX_SAMPLES=OFF -DBUILD_CXX_EXAMPLES=OFF + SHELL + + config.vm.provision "build", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake --build build -v + SHELL + + config.vm.provision "test", type: "shell", inline:<<-SHELL + set -x + export PATH=${HOME}/.local/bin:"$PATH" + cd project + cmake --build build --target test -v + SHELL +end diff --git a/ortools/algorithms/BUILD.bazel b/ortools/algorithms/BUILD.bazel index be5f372620c..3d7f2284f3f 100644 --- a/ortools/algorithms/BUILD.bazel +++ b/ortools/algorithms/BUILD.bazel @@ -275,6 +275,7 @@ cc_library( ":set_cover_model", "//ortools/base:threadpool", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", ], ) @@ -286,9 +287,10 @@ cc_library( ":set_cover_cc_proto", "//ortools/base:intops", "//ortools/base:strong_vector", - "//ortools/util:aligned_memory", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/strings", ], ) diff --git a/ortools/algorithms/python/set_cover.cc b/ortools/algorithms/python/set_cover.cc index 2d107fdfc05..a4263078c15 100644 --- a/ortools/algorithms/python/set_cover.cc +++ b/ortools/algorithms/python/set_cover.cc @@ -13,91 +13,235 @@ // A pybind11 wrapper for set_cover_*. +#include +#include +#include #include +#include #include "absl/base/nullability.h" #include "ortools/algorithms/set_cover_heuristics.h" #include "ortools/algorithms/set_cover_invariant.h" #include "ortools/algorithms/set_cover_model.h" #include "ortools/algorithms/set_cover_reader.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" +using ::operations_research::BaseInt; +using ::operations_research::ClearRandomSubsets; using ::operations_research::ElementDegreeSolutionGenerator; +using ::operations_research::ElementIndex; using ::operations_research::GreedySolutionGenerator; using ::operations_research::GuidedLocalSearch; +using ::operations_research::GuidedTabuSearch; using ::operations_research::Preprocessor; using ::operations_research::RandomSolutionGenerator; using ::operations_research::ReadBeasleySetCoverProblem; using ::operations_research::ReadRailSetCoverProblem; +using ::operations_research::SetCoverDecision; using ::operations_research::SetCoverInvariant; using ::operations_research::SetCoverModel; +using ::operations_research::SparseColumn; +using ::operations_research::SparseRow; using ::operations_research::SteepestSearch; +using ::operations_research::SubsetBoolVector; +using ::operations_research::SubsetCostVector; using ::operations_research::SubsetIndex; +using ::operations_research::TabuList; using ::operations_research::TrivialSolutionGenerator; namespace py = pybind11; using ::py::arg; +using ::py::make_iterator; -// General note about TODOs: the corresponding functions/classes/methods are -// more complex to wrap, as they use nonstandard types, and are less important, -// as they are not as useful to most users (mostly useful to write some custom -// Python heuristics). +std::vector VectorIntToVectorSubsetIndex( + const std::vector& ints) { + std::vector subs; + std::transform(ints.begin(), ints.end(), subs.begin(), + [](int subset) -> SubsetIndex { return SubsetIndex(subset); }); + return subs; +} + +SubsetCostVector VectorDoubleToSubsetCostVector( + const std::vector& doubles) { + SubsetCostVector costs(doubles.begin(), doubles.end()); + return costs; +} + +class IntIterator { + public: + using value_type = int; + using difference_type = std::ptrdiff_t; + using pointer = int*; + using reference = int&; + using iterator_category = std::input_iterator_tag; + + explicit IntIterator(int max_value) + : max_value_(max_value), current_value_(0) {} + + int operator*() const { return current_value_; } + IntIterator& operator++() { + ++current_value_; + return *this; + } + + static IntIterator begin(int max_value) { return IntIterator{max_value}; } + static IntIterator end(int max_value) { return {max_value, max_value}; } + + friend bool operator==(const IntIterator& lhs, const IntIterator& rhs) { + return lhs.max_value_ == rhs.max_value_ && + lhs.current_value_ == rhs.current_value_; + } + + private: + IntIterator(int max_value, int current_value) + : max_value_(max_value), current_value_(current_value) {} + + const int max_value_; + int current_value_; +}; PYBIND11_MODULE(set_cover, m) { pybind11_protobuf::ImportNativeProtoCasters(); // set_cover_model.h + py::class_(m, "SetCoverModelStats") + .def_readwrite("min", &SetCoverModel::Stats::min) + .def_readwrite("max", &SetCoverModel::Stats::max) + .def_readwrite("median", &SetCoverModel::Stats::median) + .def_readwrite("mean", &SetCoverModel::Stats::mean) + .def_readwrite("stddev", &SetCoverModel::Stats::stddev) + .def("debug_string", &SetCoverModel::Stats::DebugString); + py::class_(m, "SetCoverModel") .def(py::init<>()) .def_property_readonly("num_elements", &SetCoverModel::num_elements) .def_property_readonly("num_subsets", &SetCoverModel::num_subsets) .def_property_readonly("num_nonzeros", &SetCoverModel::num_nonzeros) .def_property_readonly("fill_rate", &SetCoverModel::FillRate) + .def_property_readonly( + "subset_costs", + [](SetCoverModel& model) -> const std::vector& { + return model.subset_costs().get(); + }) + .def("columns", + [](SetCoverModel& model) -> std::vector> { + // Due to the inner StrongVector, make a deep copy. Anyway, + // columns() returns a const ref, so this keeps the semantics, not + // the efficiency. + std::vector> columns; + std::transform( + model.columns().begin(), model.columns().end(), + columns.begin(), + [](const SparseColumn& column) -> std::vector { + std::vector col(column.size()); + std::transform(column.begin(), column.end(), col.begin(), + [](ElementIndex element) -> BaseInt { + return element.value(); + }); + return col; + }); + return columns; + }) + .def("rows", + [](SetCoverModel& model) -> std::vector> { + // Due to the inner StrongVector, make a deep copy. Anyway, + // rows() returns a const ref, so this keeps the semantics, not + // the efficiency. + std::vector> rows; + std::transform( + model.rows().begin(), model.rows().end(), rows.begin(), + [](const SparseRow& row) -> std::vector { + std::vector r(row.size()); + std::transform(row.begin(), row.end(), r.begin(), + [](SubsetIndex element) -> BaseInt { + return element.value(); + }); + return r; + }); + return rows; + }) + .def("row_view_is_valid", &SetCoverModel::row_view_is_valid) + .def("SubsetRange", + [](SetCoverModel& model) { + return make_iterator<>(IntIterator::begin(model.num_subsets()), + IntIterator::end(model.num_subsets())); + }) + .def("ElementRange", + [](SetCoverModel& model) { + return make_iterator<>(IntIterator::begin(model.num_elements()), + IntIterator::end(model.num_elements())); + }) + .def_property_readonly("all_subsets", + [](SetCoverModel& model) -> std::vector { + std::vector subsets; + std::transform( + model.all_subsets().begin(), + model.all_subsets().end(), subsets.begin(), + [](const SubsetIndex element) -> BaseInt { + return element.value(); + }); + return subsets; + }) .def("add_empty_subset", &SetCoverModel::AddEmptySubset, arg("cost")) .def( "add_element_to_last_subset", - [](SetCoverModel& model, int element) { + [](SetCoverModel& model, BaseInt element) { model.AddElementToLastSubset(element); }, arg("element")) .def( "set_subset_cost", - [](SetCoverModel& model, int subset, double cost) { + [](SetCoverModel& model, BaseInt subset, double cost) { model.SetSubsetCost(subset, cost); }, arg("subset"), arg("cost")) .def( "add_element_to_subset", - [](SetCoverModel& model, int element, int subset) { + [](SetCoverModel& model, BaseInt element, BaseInt subset) { model.AddElementToSubset(element, subset); }, arg("subset"), arg("cost")) + .def("create_sparse_row_view", &SetCoverModel::CreateSparseRowView) .def("compute_feasibility", &SetCoverModel::ComputeFeasibility) .def( "reserve_num_subsets", - [](SetCoverModel& model, int num_subsets) { + [](SetCoverModel& model, BaseInt num_subsets) { model.ReserveNumSubsets(num_subsets); }, arg("num_subsets")) .def( "reserve_num_elements_in_subset", - [](SetCoverModel& model, int num_elements, int subset) { + [](SetCoverModel& model, BaseInt num_elements, BaseInt subset) { model.ReserveNumElementsInSubset(num_elements, subset); }, arg("num_elements"), arg("subset")) .def("export_model_as_proto", &SetCoverModel::ExportModelAsProto) - .def("import_model_from_proto", &SetCoverModel::ImportModelFromProto); - // TODO(user): add support for subset_costs, columns, rows, - // row_view_is_valid, SubsetRange, ElementRange, all_subsets, - // CreateSparseRowView, ComputeCostStats, ComputeRowStats, - // ComputeColumnStats, ComputeRowDeciles, ComputeColumnDeciles. + .def("import_model_from_proto", &SetCoverModel::ImportModelFromProto) + .def("compute_cost_stats", &SetCoverModel::ComputeCostStats) + .def("compute_row_stats", &SetCoverModel::ComputeRowStats) + .def("compute_column_stats", &SetCoverModel::ComputeColumnStats) + .def("compute_row_deciles", &SetCoverModel::ComputeRowDeciles) + .def("compute_column_deciles", &SetCoverModel::ComputeRowDeciles); // TODO(user): wrap IntersectingSubsetsIterator. // set_cover_invariant.h + py::class_(m, "SetCoverDecision") + .def(py::init<>()) + .def(py::init([](BaseInt subset, bool value) -> SetCoverDecision* { + return new SetCoverDecision(SubsetIndex(subset), value); + }), + arg("subset"), arg("value")) + .def("subset", + [](const SetCoverDecision& decision) -> BaseInt { + return decision.subset().value(); + }) + .def("decision", &SetCoverDecision::decision); + py::class_(m, "SetCoverInvariant") .def(py::init()) .def("initialize", &SetCoverInvariant::Initialize) @@ -118,44 +262,90 @@ PYBIND11_MODULE(set_cover, m) { }) .def("cost", &SetCoverInvariant::cost) .def("num_uncovered_elements", &SetCoverInvariant::num_uncovered_elements) + .def("is_selected", + [](SetCoverInvariant& invariant) -> std::vector { + return invariant.is_selected().get(); + }) + .def("num_free_elements", + [](SetCoverInvariant& invariant) -> std::vector { + return invariant.num_free_elements().get(); + }) + .def("num_coverage_le_1_elements", + [](SetCoverInvariant& invariant) -> std::vector { + return invariant.num_coverage_le_1_elements().get(); + }) + .def("coverage", + [](SetCoverInvariant& invariant) -> std::vector { + return invariant.coverage().get(); + }) + .def( + "compute_coverage_in_focus", + [](SetCoverInvariant& invariant, + const std::vector& focus) -> std::vector { + return invariant + .ComputeCoverageInFocus(VectorIntToVectorSubsetIndex(focus)) + .get(); + }, + arg("focus")) + .def("is_redundant", + [](SetCoverInvariant& invariant) -> std::vector { + return invariant.is_redundant().get(); + }) + .def("trace", &SetCoverInvariant::trace) .def("clear_trace", &SetCoverInvariant::ClearTrace) .def("clear_removability_information", &SetCoverInvariant::ClearRemovabilityInformation) + .def("new_removable_subsets", &SetCoverInvariant::new_removable_subsets) + .def("new_non_removable_subsets", + &SetCoverInvariant::new_non_removable_subsets) .def("compress_trace", &SetCoverInvariant::CompressTrace) + .def("load_solution", + [](SetCoverInvariant& invariant, + const std::vector& solution) -> void { + SubsetBoolVector sol(solution.begin(), solution.end()); + return invariant.LoadSolution(sol); + }) .def("check_consistency", &SetCoverInvariant::CheckConsistency) + .def( + "compute_is_redundant", + [](SetCoverInvariant& invariant, BaseInt subset) -> bool { + return invariant.ComputeIsRedundant(SubsetIndex(subset)); + }, + arg("subset")) + .def("make_fully_updated", &SetCoverInvariant::MakeFullyUpdated) .def( "flip", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.Flip(SubsetIndex(subset)); }, arg("subset")) .def( "flip_and_fully_update", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.FlipAndFullyUpdate(SubsetIndex(subset)); }, arg("subset")) .def( "select", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.Select(SubsetIndex(subset)); }, arg("subset")) .def( "select_and_fully_update", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.SelectAndFullyUpdate(SubsetIndex(subset)); }, arg("subset")) .def( "deselect", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.Deselect(SubsetIndex(subset)); }, arg("subset")) .def( "deselect_and_fully_update", - [](SetCoverInvariant& invariant, int subset) { + [](SetCoverInvariant& invariant, BaseInt subset) { invariant.DeselectAndFullyUpdate(SubsetIndex(subset)); }, arg("subset")) @@ -163,10 +353,6 @@ PYBIND11_MODULE(set_cover, m) { &SetCoverInvariant::ExportSolutionAsProto) .def("import_solution_from_proto", &SetCoverInvariant::ImportSolutionFromProto); - // TODO(user): add support for is_selected, num_free_elements, - // num_coverage_le_1_elements, coverage, ComputeCoverageInFocus, - // is_redundant, trace, new_removable_subsets, new_non_removable_subsets, - // LoadSolution, ComputeIsRedundant. // set_cover_heuristics.h py::class_(m, "Preprocessor") @@ -175,30 +361,57 @@ PYBIND11_MODULE(set_cover, m) { [](Preprocessor& heuristic) -> bool { return heuristic.NextSolution(); }) + .def("next_solution", + [](Preprocessor& heuristic, + const std::vector& focus) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus)); + }) .def("num_columns_fixed_by_singleton_row", &Preprocessor::num_columns_fixed_by_singleton_row); - // TODO(user): add support for focus argument. py::class_(m, "TrivialSolutionGenerator") .def(py::init()) - .def("next_solution", [](TrivialSolutionGenerator& heuristic) -> bool { - return heuristic.NextSolution(); - }); - // TODO(user): add support for focus argument. + .def("next_solution", + [](TrivialSolutionGenerator& heuristic) -> bool { + return heuristic.NextSolution(); + }) + .def("next_solution", + [](TrivialSolutionGenerator& heuristic, + const std::vector& focus) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus)); + }); py::class_(m, "RandomSolutionGenerator") .def(py::init()) - .def("next_solution", [](RandomSolutionGenerator& heuristic) -> bool { - return heuristic.NextSolution(); - }); - // TODO(user): add support for focus argument. + .def("next_solution", + [](RandomSolutionGenerator& heuristic) -> bool { + return heuristic.NextSolution(); + }) + .def("next_solution", + [](RandomSolutionGenerator& heuristic, + const std::vector& focus) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus)); + }); py::class_(m, "GreedySolutionGenerator") .def(py::init()) - .def("next_solution", [](GreedySolutionGenerator& heuristic) -> bool { - return heuristic.NextSolution(); - }); - // TODO(user): add support for focus and cost arguments. + .def("next_solution", + [](GreedySolutionGenerator& heuristic) -> bool { + return heuristic.NextSolution(); + }) + .def("next_solution", + [](GreedySolutionGenerator& heuristic, + const std::vector& focus) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus)); + }) + .def("next_solution", + [](GreedySolutionGenerator& heuristic, + const std::vector& focus, + const std::vector& costs) -> bool { + return heuristic.NextSolution( + VectorIntToVectorSubsetIndex(focus), + VectorDoubleToSubsetCostVector(costs)); + }); py::class_(m, "ElementDegreeSolutionGenerator") @@ -206,16 +419,40 @@ PYBIND11_MODULE(set_cover, m) { .def("next_solution", [](ElementDegreeSolutionGenerator& heuristic) -> bool { return heuristic.NextSolution(); + }) + .def("next_solution", + [](ElementDegreeSolutionGenerator& heuristic, + const std::vector& focus) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus)); + }) + .def("next_solution", + [](ElementDegreeSolutionGenerator& heuristic, + const std::vector& focus, + const std::vector& costs) -> bool { + return heuristic.NextSolution( + VectorIntToVectorSubsetIndex(focus), + VectorDoubleToSubsetCostVector(costs)); }); - // TODO(user): add support for focus and cost arguments. py::class_(m, "SteepestSearch") .def(py::init()) .def("next_solution", [](SteepestSearch& heuristic, int num_iterations) -> bool { return heuristic.NextSolution(num_iterations); + }) + .def("next_solution", + [](SteepestSearch& heuristic, const std::vector& focus, + int num_iterations) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus), + num_iterations); + }) + .def("next_solution", + [](SteepestSearch& heuristic, const std::vector& focus, + const std::vector& costs, int num_iterations) -> bool { + return heuristic.NextSolution( + VectorIntToVectorSubsetIndex(focus), + VectorDoubleToSubsetCostVector(costs), num_iterations); }); - // TODO(user): add support for focus and cost arguments. py::class_(m, "GuidedLocalSearch") .def(py::init()) @@ -223,12 +460,92 @@ PYBIND11_MODULE(set_cover, m) { .def("next_solution", [](GuidedLocalSearch& heuristic, int num_iterations) -> bool { return heuristic.NextSolution(num_iterations); + }) + .def("next_solution", + [](GuidedLocalSearch& heuristic, const std::vector& focus, + int num_iterations) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus), + num_iterations); }); - // TODO(user): add support for focus and cost arguments. - // TODO(user): add support for ClearRandomSubsets, ClearRandomSubsets, - // ClearMostCoveredElements, ClearMostCoveredElements, TabuList, - // GuidedTabuSearch. + // Specialization for T = SubsetIndex ~= BaseInt (aka int for Python, whatever + // the size of BaseInt). + // A base type doesn't work, because TabuList uses `T::value` in the + // constructor. + py::class_>(m, "TabuList") + .def(py::init([](int size) -> TabuList* { + return new TabuList(SubsetIndex(size)); + }), + arg("size")) + .def("size", &TabuList::size) + .def("init", &TabuList::Init, arg("size")) + .def( + "add", + [](TabuList& list, BaseInt t) -> void { + return list.Add(SubsetIndex(t)); + }, + arg("t")) + .def( + "contains", + [](TabuList& list, BaseInt t) -> bool { + return list.Contains(SubsetIndex(t)); + }, + arg("t")); + + py::class_(m, "GuidedTabuSearch") + .def(py::init()) + .def("initialize", &GuidedTabuSearch::Initialize) + .def("next_solution", + [](GuidedTabuSearch& heuristic, int num_iterations) -> bool { + return heuristic.NextSolution(num_iterations); + }) + .def("next_solution", + [](GuidedTabuSearch& heuristic, const std::vector& focus, + int num_iterations) -> bool { + return heuristic.NextSolution(VectorIntToVectorSubsetIndex(focus), + num_iterations); + }) + .def("get_lagrangian_factor", &GuidedTabuSearch::SetLagrangianFactor, + arg("factor")) + .def("set_lagrangian_factor", &GuidedTabuSearch::GetLagrangianFactor) + .def("set_epsilon", &GuidedTabuSearch::SetEpsilon, arg("r")) + .def("get_epsilon", &GuidedTabuSearch::GetEpsilon) + .def("set_penalty_factor", &GuidedTabuSearch::SetPenaltyFactor, + arg("factor")) + .def("get_penalty_factor", &GuidedTabuSearch::GetPenaltyFactor) + .def("set_tabu_list_size", &GuidedTabuSearch::SetTabuListSize, + arg("size")) + .def("get_tabu_list_size", &GuidedTabuSearch::GetTabuListSize); + + m.def( + "clear_random_subsets", + [](BaseInt num_subsets, SetCoverInvariant* inv) -> std::vector { + const std::vector cleared = + ClearRandomSubsets(num_subsets, inv); + return {cleared.begin(), cleared.end()}; + }); + m.def("clear_random_subsets", + [](const std::vector& focus, BaseInt num_subsets, + SetCoverInvariant* inv) -> std::vector { + const std::vector cleared = ClearRandomSubsets( + VectorIntToVectorSubsetIndex(focus), num_subsets, inv); + return {cleared.begin(), cleared.end()}; + }); + + m.def( + "clear_most_covered_elements", + [](BaseInt num_subsets, SetCoverInvariant* inv) -> std::vector { + const std::vector cleared = + ClearMostCoveredElements(num_subsets, inv); + return {cleared.begin(), cleared.end()}; + }); + m.def("clear_most_covered_elements", + [](const std::vector& focus, BaseInt num_subsets, + SetCoverInvariant* inv) -> std::vector { + const std::vector cleared = ClearMostCoveredElements( + VectorIntToVectorSubsetIndex(focus), num_subsets, inv); + return {cleared.begin(), cleared.end()}; + }); // set_cover_reader.h m.def("read_beasly_set_cover_problem", &ReadBeasleySetCoverProblem); diff --git a/ortools/algorithms/samples/code_samples.bzl b/ortools/algorithms/samples/code_samples.bzl index 5f8039a325b..8960d993a1b 100644 --- a/ortools/algorithms/samples/code_samples.bzl +++ b/ortools/algorithms/samples/code_samples.bzl @@ -14,10 +14,12 @@ """Helper macro to compile and test code samples.""" load("@pip_deps//:requirements.bzl", "requirement") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_test") +load("@rules_java//java:defs.bzl", "java_test") load("@rules_python//python:defs.bzl", "py_binary", "py_test") def code_sample_cc(name): - native.cc_binary( + cc_binary( name = name + "_cc", srcs = [name + ".cc"], deps = [ @@ -25,7 +27,7 @@ def code_sample_cc(name): ], ) - native.cc_test( + cc_test( name = name + "_cc_test", size = "small", srcs = [name + ".cc"], @@ -68,7 +70,7 @@ def code_sample_cc_py(name): code_sample_py(name = name) def code_sample_java(name): - native.java_test( + java_test( name = name + "_java_test", size = "small", srcs = [name + ".java"], diff --git a/ortools/algorithms/samples/set_cover.cc b/ortools/algorithms/samples/set_cover.cc new file mode 100644 index 00000000000..c9feae012d1 --- /dev/null +++ b/ortools/algorithms/samples/set_cover.cc @@ -0,0 +1,65 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// [START program] +// [START import] +#include + +#include "ortools/algorithms/set_cover_heuristics.h" +#include "ortools/algorithms/set_cover_invariant.h" +#include "ortools/algorithms/set_cover_model.h" +#include "ortools/base/logging.h" +// [END import] + +namespace operations_research { + +void SimpleSetCoverProgram() { + // [START data] + SetCoverModel model; + model.AddEmptySubset(2.0); + model.AddElementToLastSubset(0); + model.AddEmptySubset(2.0); + model.AddElementToLastSubset(1); + model.AddEmptySubset(1.0); + model.AddElementToLastSubset(0); + model.AddElementToLastSubset(1); + // [END data] + + // [START solve] + SetCoverInvariant inv(&model); + GreedySolutionGenerator greedy(&inv); + bool found_solution = greedy.NextSolution(); + if (!found_solution) { + LOG(INFO) << "No solution found by the greedy heuristic."; + return; + } + SetCoverSolutionResponse solution = inv.ExportSolutionAsProto(); + // [END solve] + + // [START print_solution] + LOG(INFO) << "Total cost: " << solution.cost(); // == inv.cost() + LOG(INFO) << "Total number of selected subsets: " << solution.num_subsets(); + LOG(INFO) << "Chosen subsets:"; + for (int i = 0; i < solution.subset_size(); ++i) { + LOG(INFO) << " " << solution.subset(i); + } + // [END print_solution] +} + +} // namespace operations_research + +int main() { + operations_research::SimpleSetCoverProgram(); + return EXIT_SUCCESS; +} +// [END program] diff --git a/ortools/algorithms/samples/set_cover.py b/ortools/algorithms/samples/set_cover.py new file mode 100755 index 00000000000..aab98c6b0fa --- /dev/null +++ b/ortools/algorithms/samples/set_cover.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2010-2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple set-covering problem.""" + +# [START program] +# [START import] +from ortools.algorithms.python import set_cover +# [END import] + + +def main(): + # [START data] + model = set_cover.SetCoverModel() + model.add_empty_subset(2.0) + model.add_element_to_last_subset(0) + model.add_empty_subset(2.0) + model.add_element_to_last_subset(1) + model.add_empty_subset(1.0) + model.add_element_to_last_subset(0) + model.add_element_to_last_subset(1) + # [END data] + + # [START solve] + inv = set_cover.SetCoverInvariant(model) + greedy = set_cover.GreedySolutionGenerator(inv) + has_found = greedy.next_solution() + if not has_found: + print("No solution found by the greedy heuristic.") + return + solution = inv.export_solution_as_proto() + # [END solve] + + # [START print_solution] + print(f"Total cost: {solution.cost}") # == inv.cost() + print(f"Total number of selected subsets: {solution.num_subsets}") + print("Chosen subsets:") + for subset in solution.subset: + print(f" {subset}") + # [END print_solution] + + +if __name__ == "__main__": + main() +# [END program] diff --git a/ortools/algorithms/samples/simple_knapsack_program.cc b/ortools/algorithms/samples/simple_knapsack_program.cc index 64a80d3a46d..d9fe9e70912 100644 --- a/ortools/algorithms/samples/simple_knapsack_program.cc +++ b/ortools/algorithms/samples/simple_knapsack_program.cc @@ -15,12 +15,14 @@ // [START import] #include #include +#include #include #include #include #include #include "ortools/algorithms/knapsack_solver.h" +#include "ortools/base/logging.h" // [END import] namespace operations_research { diff --git a/ortools/algorithms/set_cover_heuristics.cc b/ortools/algorithms/set_cover_heuristics.cc index 63c7fe30665..4208c90e17a 100644 --- a/ortools/algorithms/set_cover_heuristics.cc +++ b/ortools/algorithms/set_cover_heuristics.cc @@ -544,23 +544,23 @@ bool GuidedLocalSearch::NextSolution(absl::Span focus, } namespace { -void SampleSubsets(std::vector* list, std::size_t num_subsets) { - num_subsets = std::min(num_subsets, list->size()); +void SampleSubsets(std::vector* list, BaseInt num_subsets) { + num_subsets = std::min(num_subsets, static_cast(list->size())); CHECK_GE(num_subsets, 0); std::shuffle(list->begin(), list->end(), absl::BitGen()); list->resize(num_subsets); } } // namespace -std::vector ClearRandomSubsets(std::size_t num_subsets, +std::vector ClearRandomSubsets(BaseInt num_subsets, SetCoverInvariant* inv) { return ClearRandomSubsets(inv->model()->all_subsets(), num_subsets, inv); } std::vector ClearRandomSubsets(absl::Span focus, - std::size_t num_subsets, + BaseInt num_subsets, SetCoverInvariant* inv) { - num_subsets = std::min(num_subsets, focus.size()); + num_subsets = std::min(num_subsets, static_cast(focus.size())); CHECK_GE(num_subsets, 0); std::vector chosen_indices; for (const SubsetIndex subset : focus) { @@ -569,7 +569,7 @@ std::vector ClearRandomSubsets(absl::Span focus, } } SampleSubsets(&chosen_indices, num_subsets); - std::size_t num_deselected = 0; + BaseInt num_deselected = 0; for (const SubsetIndex subset : chosen_indices) { inv->Deselect(subset); ++num_deselected; @@ -585,14 +585,14 @@ std::vector ClearRandomSubsets(absl::Span focus, return chosen_indices; } -std::vector ClearMostCoveredElements(std::size_t max_num_subsets, +std::vector ClearMostCoveredElements(BaseInt max_num_subsets, SetCoverInvariant* inv) { return ClearMostCoveredElements(inv->model()->all_subsets(), max_num_subsets, inv); } std::vector ClearMostCoveredElements( - absl::Span focus, std::size_t max_num_subsets, + absl::Span focus, BaseInt max_num_subsets, SetCoverInvariant* inv) { // This is the vector we will return. std::vector sampled_subsets; @@ -625,7 +625,8 @@ std::vector ClearMostCoveredElements( // Actually *sample* sampled_subset. // TODO(user): find another algorithm if necessary. std::shuffle(sampled_subsets.begin(), sampled_subsets.end(), absl::BitGen()); - sampled_subsets.resize(std::min(sampled_subsets.size(), max_num_subsets)); + sampled_subsets.resize( + std::min(static_cast(sampled_subsets.size()), max_num_subsets)); // Testing has shown that sorting sampled_subsets is not necessary. // Now, un-select the subset in sampled_subsets. diff --git a/ortools/algorithms/set_cover_heuristics.h b/ortools/algorithms/set_cover_heuristics.h index 04d3b31b9f8..154724cf616 100644 --- a/ortools/algorithms/set_cover_heuristics.h +++ b/ortools/algorithms/set_cover_heuristics.h @@ -477,12 +477,12 @@ class GuidedLocalSearch { // solution. There can be more than num_subsets variables cleared because the // intersecting subsets are also removed from the solution. Returns a list of // subset indices that can be reused as a focus. -std::vector ClearRandomSubsets(std::size_t num_subsets, +std::vector ClearRandomSubsets(BaseInt num_subsets, SetCoverInvariant* inv); // Same as above, but clears the subset indices in focus. std::vector ClearRandomSubsets(absl::Span focus, - std::size_t num_subsets, + BaseInt num_subsets, SetCoverInvariant* inv); // Clears the variables (subsets) that cover the most covered elements. This is @@ -490,12 +490,12 @@ std::vector ClearRandomSubsets(absl::Span focus, // randomly. // Returns the list of the chosen subset indices. // This indices can then be used ax a focus. -std::vector ClearMostCoveredElements(std::size_t num_subsets, +std::vector ClearMostCoveredElements(BaseInt num_subsets, SetCoverInvariant* inv); // Same as above, but clears the subset indices in focus. std::vector ClearMostCoveredElements( - absl::Span focus, std::size_t num_subsets, + absl::Span focus, BaseInt num_subsets, SetCoverInvariant* inv); } // namespace operations_research diff --git a/ortools/algorithms/set_cover_lagrangian.cc b/ortools/algorithms/set_cover_lagrangian.cc index 18dea932dfd..f20a98a0516 100644 --- a/ortools/algorithms/set_cover_lagrangian.cc +++ b/ortools/algorithms/set_cover_lagrangian.cc @@ -20,6 +20,7 @@ #include #include "absl/log/check.h" +#include "absl/synchronization/blocking_counter.h" #include "ortools/algorithms/adjustable_k_ary_heap.h" #include "ortools/algorithms/set_cover_invariant.h" #include "ortools/algorithms/set_cover_model.h" @@ -74,7 +75,7 @@ namespace { // TODO(user): Investigate. Cost ScalarProduct(const SparseColumn& column, const ElementCostVector& dual) { Cost result = 0.0; - for (ColumnEntryIndex pos(0); pos.value() < column.size(); ++pos) { + for (const ColumnEntryIndex pos : column.index_range()) { result += dual[column[pos]]; } return result; @@ -82,52 +83,49 @@ Cost ScalarProduct(const SparseColumn& column, const ElementCostVector& dual) { // Computes the reduced costs for a subset of subsets. // This is a helper function for ParallelComputeReducedCosts(). -// It is called on a slice of subsets, defined by start and end. +// It is called on a slice of subsets, defined by slice_start and slice_end. // The reduced costs are computed using the multipliers vector. // The columns of the subsets are given by the columns view. // The result is stored in reduced_costs. -void FillReducedCostsSlice(SubsetIndex start, SubsetIndex end, +void FillReducedCostsSlice(SubsetIndex slice_start, SubsetIndex slice_end, const SubsetCostVector& costs, const ElementCostVector& multipliers, const SparseColumnView& columns, SubsetCostVector* reduced_costs) { - for (SubsetIndex subset = start; subset < end; ++subset) { + for (SubsetIndex subset = slice_start; subset < slice_end; ++subset) { (*reduced_costs)[subset] = costs[subset] - ScalarProduct(columns[subset], multipliers); } } + +BaseInt BlockSize(BaseInt size, int num_threads) { + return 1 + (size - 1) / num_threads; +} } // namespace // Computes the reduced costs for all subsets in parallel using ThreadPool. SubsetCostVector SetCoverLagrangian::ParallelComputeReducedCosts( const SubsetCostVector& costs, const ElementCostVector& multipliers) const { const SubsetIndex num_subsets(model_.num_subsets()); - // TODO(user): compute a close-to-optimal k-subset partitioning. - const SubsetIndex block_size = - SubsetIndex(1) + num_subsets / num_threads_; // [***] Arbitrary choice. const SparseColumnView& columns = model_.columns(); SubsetCostVector reduced_costs(num_subsets); - ThreadPool thread_pool("ParallelComputeReducedCosts", num_threads_); - thread_pool.StartWorkers(); - { - // TODO(user): check how costly it is to create a new ThreadPool. - // TODO(user): use a queue of subsets to process? instead of a fixed range. - - // This parallelization is not very efficient, because all the threads - // use the same costs vector. Maybe it should be local to the thread. - // It's unclear whether sharing columns and costs is better than having - // each thread use its own partial copy. - // Finally, it might be better to use a queue of subsets to process, instead - // of a fixed range. - for (SubsetIndex start(0); start < num_subsets; start += block_size) { - thread_pool.Schedule([start, block_size, num_subsets, &costs, - &multipliers, &columns, &reduced_costs]() { - const SubsetIndex end = std::min(start + block_size, num_subsets); - FillReducedCostsSlice(start, end, costs, multipliers, columns, - &reduced_costs); - }); - } - } // Synchronize all the threads. This is equivalent to a wait. + // TODO(user): compute a close-to-optimal k-subset partitioning of the columns + // based on their sizes. [***] + const SubsetIndex block_size(BlockSize(num_subsets.value(), num_threads_)); + absl::BlockingCounter num_threads_running(num_threads_); + SubsetIndex slice_start(0); + for (int thread_index = 0; thread_index < num_threads_; ++thread_index) { + const SubsetIndex slice_end = + std::min(slice_start + block_size, num_subsets); + thread_pool_->Schedule([&num_threads_running, slice_start, slice_end, + &costs, &multipliers, &columns, &reduced_costs]() { + FillReducedCostsSlice(slice_start, slice_end, costs, multipliers, columns, + &reduced_costs); + num_threads_running.DecrementCount(); + }); + slice_start = slice_end; + } + num_threads_running.Wait(); return reduced_costs; } @@ -147,14 +145,14 @@ SubsetCostVector SetCoverLagrangian::ComputeReducedCosts( namespace { // Helper function to compute the subgradient. -// It fills a slice of the subgradient vector from indices start to end. -// This is a helper function for ParallelComputeSubgradient(). -// The subgradient is computed using the reduced costs vector. -void FillSubgradientSlice(SubsetIndex start, SubsetIndex end, +// It fills a slice of the subgradient vector from indices slice_start to +// slice_end. This is a helper function for ParallelComputeSubgradient(). The +// subgradient is computed using the reduced costs vector. +void FillSubgradientSlice(SubsetIndex slice_start, SubsetIndex slice_end, const SparseColumnView& columns, const SubsetCostVector& reduced_costs, ElementCostVector* subgradient) { - for (SubsetIndex subset(start); subset < end; ++subset) { + for (SubsetIndex subset(slice_start); subset < slice_end; ++subset) { if (reduced_costs[subset] < 0.0) { for (const ElementIndex element : columns[subset]) { (*subgradient)[element] -= 1.0; @@ -181,8 +179,6 @@ ElementCostVector SetCoverLagrangian::ComputeSubgradient( ElementCostVector SetCoverLagrangian::ParallelComputeSubgradient( const SubsetCostVector& reduced_costs) const { const SubsetIndex num_subsets(model_.num_subsets()); - const SubsetIndex block_size = - SubsetIndex(1) + num_subsets / num_threads_; // [***] const SparseColumnView& columns = model_.columns(); ElementCostVector subgradient(model_.num_elements(), 1.0); // The subgradient has one component per element, each thread processes @@ -191,20 +187,22 @@ ElementCostVector SetCoverLagrangian::ParallelComputeSubgradient( // although this might be less well-balanced. std::vector subgradients( num_threads_, ElementCostVector(model_.num_elements())); - ThreadPool thread_pool("ParallelComputeSubgradient", num_threads_); - thread_pool.StartWorkers(); - { - int thread_index = 0; - for (SubsetIndex start(0); start < num_subsets; - start += block_size, ++thread_index) { - thread_pool.Schedule([start, block_size, num_subsets, &reduced_costs, - &columns, &subgradients, thread_index]() { - const SubsetIndex end = std::min(start + block_size, num_subsets); - FillSubgradientSlice(start, end, columns, reduced_costs, - &subgradients[thread_index]); - }); - } - } // Synchronize all the threads. + absl::BlockingCounter num_threads_running(num_threads_); + const SubsetIndex block_size(BlockSize(num_subsets.value(), num_threads_)); + SubsetIndex slice_start(0); + for (int thread_index = 0; thread_index < num_threads_; ++thread_index) { + const SubsetIndex slice_end = + std::min(slice_start + block_size, num_subsets); + thread_pool_->Schedule([&num_threads_running, slice_start, slice_end, + &reduced_costs, &columns, &subgradients, + thread_index]() { + FillSubgradientSlice(slice_start, slice_end, columns, reduced_costs, + &subgradients[thread_index]); + num_threads_running.DecrementCount(); + }); + slice_start = slice_end; + } + num_threads_running.Wait(); for (int thread_index = 0; thread_index < num_threads_; ++thread_index) { for (const ElementIndex element : model_.ElementRange()) { subgradient[element] += subgradients[thread_index][element]; @@ -216,17 +214,17 @@ ElementCostVector SetCoverLagrangian::ParallelComputeSubgradient( namespace { // Helper function to compute the value of the Lagrangian. // This is a helper function for ParallelComputeLagrangianValue(). -// It is called on a slice of elements, defined by start and end. +// It is called on a slice of elements, defined by slice_start and slice_end. // The value of the Lagrangian is computed using the reduced costs vector and // the multipliers vector. // The result is stored in lagrangian_value. -void FillLagrangianValueSlice(SubsetIndex start, SubsetIndex end, +void FillLagrangianValueSlice(SubsetIndex slice_start, SubsetIndex slice_end, const SubsetCostVector& reduced_costs, Cost* lagrangian_value) { - // This is min \sum_{j \in N} c_j(u) x_j. This captures the remark above (**), - // taking into account the possible values for x_j, and using them to minimize - // the terms. - for (SubsetIndex subset(start); subset < end; ++subset) { + // This is min \sum_{j \in N} c_j(u) x_j. This captures the remark above + // (**), taking into account the possible values for x_j, and using them to + // minimize the terms. + for (SubsetIndex subset(slice_start); subset < slice_end; ++subset) { if (reduced_costs[subset] < 0.0) { *lagrangian_value += reduced_costs[subset]; } @@ -258,30 +256,31 @@ Cost SetCoverLagrangian::ComputeLagrangianValue( Cost SetCoverLagrangian::ParallelComputeLagrangianValue( const SubsetCostVector& reduced_costs, const ElementCostVector& multipliers) const { - const SubsetIndex num_subsets(model_.num_subsets()); - const SubsetIndex block_size = - SubsetIndex(1) + num_subsets / num_threads_; // [***] Arbitrary. Cost lagrangian_value = 0.0; // This is \sum{i \in M} u_i. - for (const Cost u : multipliers) { lagrangian_value += u; } std::vector lagrangian_values(num_threads_, 0.0); - ThreadPool thread_pool("ParallelComputeLagrangianValue", num_threads_); - thread_pool.StartWorkers(); - { - int thread_index = 0; - for (SubsetIndex start(0); start < num_subsets; start += block_size) { - thread_pool.Schedule([start, block_size, num_subsets, thread_index, - &reduced_costs, &lagrangian_values]() { - const SubsetIndex end = std::min(start + block_size, num_subsets); - FillLagrangianValueSlice(start, end, reduced_costs, - &lagrangian_values[thread_index]); - }); - ++thread_index; - } - } // Synchronize all the threads. + absl::BlockingCounter num_threads_running(num_threads_); + const SubsetIndex block_size(BlockSize(model_.num_subsets(), num_threads_)); + const SubsetIndex num_subsets(model_.num_subsets()); + SubsetIndex slice_start(0); + for (int thread_index = 0; thread_index < num_threads_; ++thread_index) { + const SubsetIndex slice_end = + std::min(slice_start + block_size, num_subsets); + thread_pool_->Schedule([&num_threads_running, slice_start, block_size, + num_subsets, thread_index, &reduced_costs, + &lagrangian_values]() { + const SubsetIndex slice_end = + std::min(slice_start + block_size, num_subsets); + FillLagrangianValueSlice(slice_start, slice_end, reduced_costs, + &lagrangian_values[thread_index]); + num_threads_running.DecrementCount(); + }); + slice_start = slice_end; + } + num_threads_running.Wait(); for (const Cost l : lagrangian_values) { lagrangian_value += l; } @@ -290,8 +289,8 @@ Cost SetCoverLagrangian::ParallelComputeLagrangianValue( // Perform a subgradient step. // In the general case, for an Integer Program A.x <=b, the Lagragian -// multipliers vector at step k+1 is defined as: u^{k+1} = u^k + t_k (A x^k - b) -// with term t_k = lambda_k * (UB - L(u^k)) / |A x^k - b|^2. +// multipliers vector at step k+1 is defined as: u^{k+1} = u^k + t_k (A x^k - +// b) with term t_k = lambda_k * (UB - L(u^k)) / |A x^k - b|^2. // |.| is the 2-norm (i.e. Euclidean) // In our case, the problem A x <= b is in the form A x >= 1. We need to // replace A x - b by s_i(u) = 1 - sum_{j \in J_i} x_j(u). @@ -343,9 +342,9 @@ void SetCoverLagrangian::ParallelUpdateMultipliers( step_size * (upper_bound - lagrangian_value) / subgradient_square_norm; for (const ElementIndex element : model_.ElementRange()) { // Avoid multipliers to go negative and to go through the roof. 1e6 chosen - // arbitrarily. [***] + const Cost kRoof = 1e6; // Arbitrary value, from [1]. (*multipliers)[element] = std::clamp( - (*multipliers)[element] + factor * subgradient[element], 0.0, 1e6); + (*multipliers)[element] + factor * subgradient[element], 0.0, kRoof); } } @@ -503,9 +502,9 @@ SetCoverLagrangian::ComputeLowerBound(const SubsetCostVector& costs, for (int iter = 0; iter < 1000; ++iter) { reduced_costs = ParallelComputeReducedCosts(costs, multipliers); const Cost lagrangian_value = - ComputeLagrangianValue(reduced_costs, multipliers); - UpdateMultipliers(step_size, lagrangian_value, upper_bound, reduced_costs, - &multipliers); + ParallelComputeLagrangianValue(reduced_costs, multipliers); + ParallelUpdateMultipliers(step_size, lagrangian_value, upper_bound, + reduced_costs, &multipliers); lower_bound = std::max(lower_bound, lagrangian_value); // step_size should be updated like this. For the time besing, we keep the // step size, because the implementation of the rest is not adequate yet diff --git a/ortools/algorithms/set_cover_lagrangian.h b/ortools/algorithms/set_cover_lagrangian.h index aa63627ad15..9e946b173a8 100644 --- a/ortools/algorithms/set_cover_lagrangian.h +++ b/ortools/algorithms/set_cover_lagrangian.h @@ -15,6 +15,7 @@ #define OR_TOOLS_ALGORITHMS_SET_COVER_LAGRANGIAN_H_ #include +#include #include #include @@ -44,7 +45,12 @@ namespace operations_research { class SetCoverLagrangian { public: explicit SetCoverLagrangian(SetCoverInvariant* inv, int num_threads = 1) - : inv_(inv), model_(*inv->model()), num_threads_(num_threads) {} + : inv_(inv), + model_(*inv->model()), + num_threads_(num_threads), + thread_pool_(new ThreadPool(num_threads)) { + thread_pool_->StartWorkers(); + } // Returns true if a solution was found. // TODO(user): Add time-outs and exit with a partial solution. This seems @@ -137,6 +143,9 @@ class SetCoverLagrangian { // The number of threads to use for parallelization. int num_threads_; + // The thread pool used for parallelization. + std::unique_ptr thread_pool_; + // Total (scalar) Lagrangian cost. Cost lagrangian_; diff --git a/ortools/algorithms/set_cover_model.cc b/ortools/algorithms/set_cover_model.cc index f7b5faf24c4..e67718b197d 100644 --- a/ortools/algorithms/set_cover_model.cc +++ b/ortools/algorithms/set_cover_model.cc @@ -17,17 +17,161 @@ #include #include #include +#include #include #include +#include #include #include #include "absl/log/check.h" +#include "absl/random/discrete_distribution.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" #include "ortools/algorithms/set_cover.pb.h" #include "ortools/base/logging.h" namespace operations_research { +namespace { + +// Returns a value in [min, min + scaling_factor * (raw_value - min + +// random_term)], where raw_value is drawn from a discrete distribution, and +// random_term is a double drawn uniformly in [0, 1]. +BaseInt DiscreteAffine(absl::BitGen& bitgen, + absl::discrete_distribution& dist, BaseInt min, + double scaling_factor) { + const BaseInt raw_value = dist(bitgen); + const double random_term = absl::Uniform(bitgen, 0, 1.0); + const BaseInt affine_value = + static_cast( + std::floor((raw_value - min + random_term) * scaling_factor)) + + min; + return affine_value; +} + +// For a given view (SparseColumnView or SparseRowView), returns the +// distribution of the sizes of the vectors in the view, which can be used in +// an absl::discrete_distribution. +template +std::tuple> ComputeSizeHistogram( + const View& view) { + BaseInt max_size = 0; + BaseInt min_size = std::numeric_limits::max(); + for (const auto& vec : view) { + const BaseInt size = vec.size(); + min_size = std::min(min_size, size); + max_size = std::max(max_size, size); + } + std::vector weights(max_size + 1, 0); + for (const auto& vec : view) { + const BaseInt size = vec.size(); + ++weights[size]; + } + return {min_size, weights}; +} + +template +std::tuple> +ComputeSizeDistribution(const View& view) { + const auto [min_size, weights] = ComputeSizeHistogram(view); + absl::discrete_distribution dist(weights.begin(), weights.end()); + return {min_size, dist}; +} +} // namespace + +SetCoverModel SetCoverModel::GenerateRandomModelFrom( + const SetCoverModel& seed_model, BaseInt num_elements, BaseInt num_subsets, + double row_scale, double column_scale, double cost_scale) { + SetCoverModel model; + DCHECK_GT(row_scale, 0.0); + DCHECK_GT(column_scale, 0.0); + DCHECK_GT(cost_scale, 0.0); + model.num_elements_ = num_elements; + model.num_nonzeros_ = 0; + model.ReserveNumSubsets(num_subsets); + model.UpdateAllSubsetsList(); + absl::BitGen bitgen; + + // Create the distribution of the cardinalities of the subsets based on the + // histogram of column sizes in the seed model. + auto [min_column_size, column_dist] = + ComputeSizeDistribution(seed_model.columns()); + + // Create the distribution of the degrees of the elements based on the + // histogram of row sizes in the seed model. + auto [min_row_size, row_dist] = ComputeSizeDistribution(seed_model.rows()); + + // Prepare the degrees of the elements in the generated model, and use them + // in a distribution to generate the columns. This ponderates the columns + // towards the elements with higher degrees. ??? + ElementToIntVector degrees(num_elements); + for (ElementIndex element(0); element.value() < num_elements; ++element) { + degrees[element] = + DiscreteAffine(bitgen, row_dist, min_row_size, row_scale); + } + absl::discrete_distribution degree_dist(degrees.begin(), + degrees.end()); + + // Vector indicating whether the generated model covers an element. + ElementBoolVector contains_element(num_elements, false); + + // Number of elements in the generated model, using the above vector. + BaseInt num_elements_covered(0); + + // Loop-local vector indicating whether the currently generated subset + // contains an element. + ElementBoolVector subset_contains_element(num_elements, false); + + for (SubsetIndex subset(0); subset.value() < num_subsets; ++subset) { + const BaseInt cardinality = + DiscreteAffine(bitgen, column_dist, min_column_size, column_scale); + model.columns_[subset].reserve(cardinality); + for (BaseInt iter = 0; iter < cardinality; ++iter) { + int num_tries = 0; + ElementIndex element; + // Choose an element that is not yet in the subset at random with a + // distribution that is proportional to the degree of the element. + do { + element = ElementIndex(degree_dist(bitgen)); + CHECK_LT(element.value(), num_elements); + ++num_tries; + if (num_tries > 10) { + return SetCoverModel(); + } + } while (subset_contains_element[element]); + ++model.num_nonzeros_; + model.columns_[subset].push_back(element); + subset_contains_element[element] = true; + if (!contains_element[element]) { + contains_element[element] = true; + ++num_elements_covered; + } + } + for (const ElementIndex element : model.columns_[subset]) { + subset_contains_element[element] = false; + } + } + CHECK_EQ(num_elements_covered, num_elements); + + // TODO(user): if necessary, use a better distribution for the costs. + // The generation of the costs is done in two steps. First, compute the + // minimum and maximum costs. + Cost min_cost = std::numeric_limits::infinity(); + Cost max_cost = -min_cost; + for (const Cost cost : seed_model.subset_costs()) { + min_cost = std::min(min_cost, cost); + max_cost = std::max(max_cost, cost); + } + // Then, generate random numbers in [min_cost, min_cost + cost_range], where + // cost_range is defined as: + const Cost cost_range = cost_scale * (max_cost - min_cost); + for (Cost& cost : model.subset_costs_) { + cost = min_cost + absl::Uniform(bitgen, 0, cost_range); + } + return model; +} + void SetCoverModel::UpdateAllSubsetsList() { const BaseInt old_size = all_subsets_.size(); DCHECK_LE(old_size, num_subsets()); @@ -92,8 +236,8 @@ void SetCoverModel::AddElementToSubset(ElementIndex element, } // Reserves num_subsets columns in the model. -void SetCoverModel::ReserveNumSubsets(BaseInt number_of_subsets) { - num_subsets_ = std::max(num_subsets_, number_of_subsets); +void SetCoverModel::ReserveNumSubsets(BaseInt num_subsets) { + num_subsets_ = std::max(num_subsets_, num_subsets); columns_.resize(num_subsets_, SparseColumn()); subset_costs_.resize(num_subsets_, 0.0); } @@ -121,8 +265,8 @@ void SetCoverModel::CreateSparseRowView() { rows_.resize(num_elements_, SparseRow()); ElementToIntVector row_sizes(num_elements_, 0); for (const SubsetIndex subset : SubsetRange()) { - // Sort the columns. It's not super-critical to improve performance here as - // this needs to be done only once. + // Sort the columns. It's not super-critical to improve performance here + // as this needs to be done only once. std::sort(columns_[subset].begin(), columns_[subset].end()); for (const ElementIndex element : columns_[subset]) { ++row_sizes[element]; @@ -256,7 +400,7 @@ SetCoverModel::Stats SetCoverModel::ComputeCostStats() { } SetCoverModel::Stats SetCoverModel::ComputeRowStats() { - std::vector row_sizes(num_elements(), 0); + std::vector row_sizes(num_elements(), 0); for (const SparseColumn& column : columns_) { for (const ElementIndex element : column) { ++row_sizes[element.value()]; @@ -266,15 +410,15 @@ SetCoverModel::Stats SetCoverModel::ComputeRowStats() { } SetCoverModel::Stats SetCoverModel::ComputeColumnStats() { - std::vector column_sizes(columns_.size()); + std::vector column_sizes(columns_.size()); for (const SubsetIndex subset : SubsetRange()) { column_sizes[subset.value()] = columns_[subset].size(); } return ComputeStats(std::move(column_sizes)); } -std::vector SetCoverModel::ComputeRowDeciles() const { - std::vector row_sizes(num_elements(), 0); +std::vector SetCoverModel::ComputeRowDeciles() const { + std::vector row_sizes(num_elements(), 0); for (const SparseColumn& column : columns_) { for (const ElementIndex element : column) { ++row_sizes[element.value()]; @@ -283,8 +427,8 @@ std::vector SetCoverModel::ComputeRowDeciles() const { return ComputeDeciles(std::move(row_sizes)); } -std::vector SetCoverModel::ComputeColumnDeciles() const { - std::vector column_sizes(columns_.size()); +std::vector SetCoverModel::ComputeColumnDeciles() const { + std::vector column_sizes(columns_.size()); for (const SubsetIndex subset : SubsetRange()) { column_sizes[subset.value()] = columns_[subset].size(); } diff --git a/ortools/algorithms/set_cover_model.h b/ortools/algorithms/set_cover_model.h index fa3f55430b8..ea21f82444d 100644 --- a/ortools/algorithms/set_cover_model.h +++ b/ortools/algorithms/set_cover_model.h @@ -14,13 +14,7 @@ #ifndef OR_TOOLS_ALGORITHMS_SET_COVER_MODEL_H_ #define OR_TOOLS_ALGORITHMS_SET_COVER_MODEL_H_ -#if defined(_MSC_VER) -#include -typedef SSIZE_T ssize_t; -#else -#include -#endif // defined(_MSC_VER) - +#include #include #include @@ -29,7 +23,6 @@ typedef SSIZE_T ssize_t; #include "ortools/algorithms/set_cover.pb.h" #include "ortools/base/strong_int.h" #include "ortools/base/strong_vector.h" -#include "ortools/util/aligned_memory.h" // Representation class for the weighted set-covering problem. // @@ -65,7 +58,7 @@ using Cost = double; // (2e9) elements and subsets. If need arises one day, BaseInt can be split // into SubsetBaseInt and ElementBaseInt. // Quick testing has shown a slowdown of about 20-25% when using int64_t. -using BaseInt = int; +using BaseInt = int32_t; // We make heavy use of strong typing to avoid obvious mistakes. // Subset index. @@ -84,32 +77,14 @@ using SubsetRange = util_intops::StrongIntRange; using ElementRange = util_intops::StrongIntRange; using ColumnEntryRange = util_intops::StrongIntRange; -// SIMD operations require vectors to be aligned at 64-bytes on x86-64 -// processors as of 2024-05-03. -// TODO(user): improve the code to make it possible to use unaligned memory. -constexpr int kSetCoverAlignmentInBytes = 64; - -using CostAllocator = AlignedAllocator; -using ElementAllocator = - AlignedAllocator; -using SubsetAllocator = - AlignedAllocator; - -using SubsetCostVector = - util_intops::StrongVector; -using ElementCostVector = - util_intops::StrongVector; - -using SparseColumn = - util_intops::StrongVector; -using SparseRow = - util_intops::StrongVector; - -using IntAllocator = AlignedAllocator; -using ElementToIntVector = - util_intops::StrongVector; -using SubsetToIntVector = - util_intops::StrongVector; +using SubsetCostVector = util_intops::StrongVector; +using ElementCostVector = util_intops::StrongVector; + +using SparseColumn = util_intops::StrongVector; +using SparseRow = util_intops::StrongVector; + +using ElementToIntVector = util_intops::StrongVector; +using SubsetToIntVector = util_intops::StrongVector; // Views of the sparse vectors. These need not be aligned as it's their contents // that need to be aligned. @@ -117,6 +92,13 @@ using SparseColumnView = util_intops::StrongVector; using SparseRowView = util_intops::StrongVector; using SubsetBoolVector = util_intops::StrongVector; +using ElementBoolVector = util_intops::StrongVector; + +// Useful for representing permutations, +using ElementToElementVector = + util_intops::StrongVector; +using SubsetToSubsetVector = + util_intops::StrongVector; // Main class for describing a weighted set-covering problem. class SetCoverModel { @@ -132,6 +114,34 @@ class SetCoverModel { rows_(), all_subsets_() {} + // Constructs a weighted set-covering problem from a seed model, with + // num_elements elements and num_subsets subsets. + // - The distributions of the degrees of the elements and the cardinalities of + // the subsets are based on those of the seed model. They are scaled + // affininely by row_scale and column_scale respectively. + // - By affine scaling, we mean that the minimum value of the distribution is + // not scaled, but the variation above this minimum value is. + // - For a given subset with a given cardinality in the generated model, its + // elements are sampled from the distribution of the degrees as computed + // above. + // - The costs of the subsets in the new model are sampled from the + // distribution of the costs of the subsets in the seed model, scaled by + // cost_scale. + // IMPORTANT NOTICE: The algorithm may not succeed in generating a model + // where all the elements can be covered. In that case, the model will be + // empty. + + static SetCoverModel GenerateRandomModelFrom(const SetCoverModel& seed_model, + BaseInt num_elements, + BaseInt num_subsets, + double row_scale, + double column_scale, + double cost_scale); + + // Returns true if the model is empty, i.e. has no elements, no subsets, and + // no nonzeros. + bool IsEmpty() const { return rows_.empty() || columns_.empty(); } + // Current number of elements to be covered in the model, i.e. the number of // elements in S. In matrix terms, this is the number of rows. BaseInt num_elements() const { return num_elements_; } @@ -141,7 +151,7 @@ class SetCoverModel { BaseInt num_subsets() const { return num_subsets_; } // Current number of nonzeros in the matrix. - ssize_t num_nonzeros() const { return num_nonzeros_; } + int64_t num_nonzeros() const { return num_nonzeros_; } double FillRate() const { return 1.0 * num_nonzeros() / (1.0 * num_elements() * num_subsets()); @@ -243,10 +253,10 @@ class SetCoverModel { Stats ComputeColumnStats(); // Computes deciles on rows and returns a vector of deciles. - std::vector ComputeRowDeciles() const; + std::vector ComputeRowDeciles() const; // Computes deciles on columns and returns a vector of deciles. - std::vector ComputeColumnDeciles() const; + std::vector ComputeColumnDeciles() const; private: // Updates the all_subsets_ vector so that it always contains 0 to @@ -259,8 +269,9 @@ class SetCoverModel { // Number of subsets. Maintained for ease of access. BaseInt num_subsets_; - // Number of nonzeros in the matrix. - ssize_t num_nonzeros_; + // Number of nonzeros in the matrix. The value is an int64_t because there can + // be more than 1 << 31 nonzeros even with BaseInt = int32_t. + int64_t num_nonzeros_; // True when the SparseRowView is up-to-date. bool row_view_is_valid_; diff --git a/ortools/algorithms/set_cover_orlib_test.cc b/ortools/algorithms/set_cover_orlib_test.cc index 8dd153e44bb..849e2b76add 100644 --- a/ortools/algorithms/set_cover_orlib_test.cc +++ b/ortools/algorithms/set_cover_orlib_test.cc @@ -149,7 +149,7 @@ void ComputeLagrangianLowerBound(std::string name, SetCoverInvariant* inv) { const SetCoverModel* model = inv->model(); WallTimer timer; timer.Start(); - SetCoverLagrangian lagrangian(inv, /*num_threads=*/4); + SetCoverLagrangian lagrangian(inv, /*num_threads=*/8); const auto [lower_bound, reduced_costs, multipliers] = lagrangian.ComputeLowerBound(model->subset_costs(), inv->cost()); LogCostAndTiming(name, "LagrangianLowerBound", lower_bound, @@ -197,11 +197,11 @@ double RunSolver(std::string name, SetCoverModel* model) { global_timer.Start(); RunChvatalAndSteepest(name, model); // SetCoverInvariant inv = ComputeLPLowerBound(name, model); - RunMip(name, model); + // RunMip(name, model); RunChvatalAndGLS(name, model); SetCoverInvariant inv = RunElementDegreeGreedyAndSteepest(name, model); ComputeLagrangianLowerBound(name, &inv); - // IterateClearAndMip(name, inv); + // IterateClearAndMip(name, inv); IterateClearElementDegreeAndSteepest(name, &inv); return inv.cost(); } @@ -407,4 +407,22 @@ RAIL_TEST("rail4872.txt", 1527, 1861, MANYSECONDS); // [2] #undef SCP_TEST #undef RAIL_TEST +TEST(SetCoverHugeTest, GenerateProblem) { + SetCoverModel seed_model = + ReadRailSetCoverProblem(file::JoinPathRespectAbsolute( + ::testing::SrcDir(), data_dir, "rail4284.txt")); + seed_model.CreateSparseRowView(); + const BaseInt num_wanted_subsets(100'000'000); + const BaseInt num_wanted_elements(40'000); + const double row_scale = 1.1; + const double column_scale = 1.1; + const double cost_scale = 10.0; + SetCoverModel model = SetCoverModel::GenerateRandomModelFrom( + seed_model, num_wanted_elements, num_wanted_subsets, row_scale, + column_scale, cost_scale); + SetCoverInvariant inv = + RunElementDegreeGreedyAndSteepest("rail4284_huge.txt", &model); + LOG(INFO) << "Cost: " << inv.cost(); +} + } // namespace operations_research diff --git a/ortools/algorithms/sparse_permutation.cc b/ortools/algorithms/sparse_permutation.cc index ab6801fef14..9d4b6a466d6 100644 --- a/ortools/algorithms/sparse_permutation.cc +++ b/ortools/algorithms/sparse_permutation.cc @@ -81,4 +81,30 @@ std::string SparsePermutation::DebugString() const { return out; } +int SparsePermutation::Image(int element) const { + for (int c = 0; c < NumCycles(); ++c) { + int cur_element = LastElementInCycle(c); + for (int image : Cycle(c)) { + if (cur_element == element) { + return image; + } + cur_element = image; + } + } + return element; +} + +int SparsePermutation::InverseImage(int element) const { + for (int c = 0; c < NumCycles(); ++c) { + int cur_element = LastElementInCycle(c); + for (int image : Cycle(c)) { + if (image == element) { + return cur_element; + } + cur_element = image; + } + } + return element; +} + } // namespace operations_research diff --git a/ortools/algorithms/sparse_permutation.h b/ortools/algorithms/sparse_permutation.h index 0cbee4f9ecf..ee9b70db5c2 100644 --- a/ortools/algorithms/sparse_permutation.h +++ b/ortools/algorithms/sparse_permutation.h @@ -59,6 +59,11 @@ class SparsePermutation { // information with the loop above. Not sure it is needed though. int LastElementInCycle(int i) const; + // Returns the image of the given element or `element` itself if it is stable + // under the permutation. + int Image(int element) const; + int InverseImage(int element) const; + // To add a cycle to the permutation, repeatedly call AddToCurrentCycle() // with the cycle's orbit, then call CloseCurrentCycle(); // This shouldn't be called on trivial cycles (of length 1). @@ -76,6 +81,9 @@ class SparsePermutation { // Example: "(1 4 3) (5 9) (6 8 7)". std::string DebugString() const; + template + void ApplyToDenseCollection(Collection& span) const; + private: const int size_; std::vector cycles_; @@ -129,6 +137,24 @@ inline int SparsePermutation::LastElementInCycle(int i) const { return cycles_[cycle_ends_[i] - 1]; } +template +void SparsePermutation::ApplyToDenseCollection(Collection& span) const { + using T = typename Collection::value_type; + for (int c = 0; c < NumCycles(); ++c) { + const int last_element_idx = LastElementInCycle(c); + int element = last_element_idx; + T last_element = span[element]; + for (int image : Cycle(c)) { + if (image == last_element_idx) { + span[element] = last_element; + } else { + span[element] = span[image]; + } + element = image; + } + } +} + } // namespace operations_research #endif // OR_TOOLS_ALGORITHMS_SPARSE_PERMUTATION_H_ diff --git a/ortools/algorithms/sparse_permutation_test.cc b/ortools/algorithms/sparse_permutation_test.cc index a31927b2b83..44aead1cf8c 100644 --- a/ortools/algorithms/sparse_permutation_test.cc +++ b/ortools/algorithms/sparse_permutation_test.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include "absl/container/flat_hash_set.h" @@ -73,6 +74,20 @@ TEST(SparsePermutationTest, Identity) { EXPECT_EQ(0, permutation.NumCycles()); } +TEST(SparsePermutationTest, ApplyToVector) { + std::vector v = {"0", "1", "2", "3", "4", "5", "6", "7", "8"}; + SparsePermutation permutation(v.size()); + permutation.AddToCurrentCycle(4); + permutation.AddToCurrentCycle(2); + permutation.AddToCurrentCycle(7); + permutation.CloseCurrentCycle(); + permutation.AddToCurrentCycle(6); + permutation.AddToCurrentCycle(1); + permutation.CloseCurrentCycle(); + permutation.ApplyToDenseCollection(v); + EXPECT_THAT(v, ElementsAre("0", "6", "7", "3", "2", "5", "1", "4", "8")); +} + // Generate a bunch of permutation on a 'huge' space, but that have very few // displacements. This would OOM if the implementation was O(N); we verify // that it doesn't. diff --git a/ortools/base/BUILD.bazel b/ortools/base/BUILD.bazel index c57c0d22894..9e3a855c25e 100644 --- a/ortools/base/BUILD.bazel +++ b/ortools/base/BUILD.bazel @@ -367,6 +367,14 @@ cc_library( deps = [], ) +cc_library( + name = "memutil", + hdrs = ["memutil.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "murmur", hdrs = ["murmur.h"], diff --git a/ortools/base/memutil.h b/ortools/base/memutil.h new file mode 100644 index 00000000000..d2bdbab628f --- /dev/null +++ b/ortools/base/memutil.h @@ -0,0 +1,31 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_BASE_MEMUTIL_H_ +#define OR_TOOLS_BASE_MEMUTIL_H_ + +#include +#include + +#include "absl/strings/internal/memutil.h" + +namespace strings { +char* memdup(const char* s, size_t slen) { + void* copy; + if ((copy = malloc(slen)) == nullptr) return nullptr; + memcpy(copy, s, slen); + return reinterpret_cast(copy); +} +} // namespace strings + +#endif // OR_TOOLS_BASE_MEMUTIL_H_ diff --git a/ortools/base/sysinfo.cc b/ortools/base/sysinfo.cc index bb019f1a4cb..bff963d5193 100644 --- a/ortools/base/sysinfo.cc +++ b/ortools/base/sysinfo.cc @@ -17,7 +17,8 @@ #if defined(__APPLE__) && defined(__GNUC__) // MacOS #include #include -#elif (defined(__FreeBSD__) || defined(__OpenBSD__)) // FreeBSD or OpenBSD +#elif (defined(__FreeBSD__) || defined(__NetBSD__) || \ + defined(__OpenBSD__)) // [Free,Net,Open]BSD #include #include // Windows @@ -48,8 +49,9 @@ int64_t GetProcessMemoryUsage() { int64_t resident_memory = t_info.resident_size; return resident_memory; } -#elif defined(__GNUC__) && !defined(__FreeBSD__) && !defined(__OpenBSD__) && \ - !defined(__EMSCRIPTEN__) && !defined(_WIN32) // Linux +#elif defined(__GNUC__) && !defined(__FreeBSD__) && !defined(__NetBSD__) && \ + !defined(__OpenBSD__) && !defined(__EMSCRIPTEN__) && \ + !defined(_WIN32) // Linux int64_t GetProcessMemoryUsage() { unsigned size = 0; char buf[30]; @@ -61,7 +63,8 @@ int64_t GetProcessMemoryUsage() { fclose(pf); return int64_t{1024} * size; } -#elif (defined(__FreeBSD__) || defined(__OpenBSD__)) // FreeBSD or OpenBSD +#elif (defined(__FreeBSD__) || defined(__NetBSD__) || \ + defined(__OpenBSD__)) // [Free,Net,Open]BSD int64_t GetProcessMemoryUsage() { int who = RUSAGE_SELF; struct rusage rusage; diff --git a/ortools/base/top_n.h b/ortools/base/top_n.h index 885735d5d27..f5023d44613 100644 --- a/ortools/base/top_n.h +++ b/ortools/base/top_n.h @@ -106,6 +106,9 @@ class TopN { } // Peeks the bottom result without calling Extract() const T& peek_bottom(); + // Destructively extract the elements as a vector, sorted in descending order. + // Leaves TopN in an empty state. + std::vector Take(); // Extract the elements as a vector sorted in descending order. The caller // assumes ownership of the vector and must delete it when done. This is a // destructive operation. The only method that can be called immediately @@ -250,6 +253,19 @@ const T& TopN::peek_bottom() { } return elements_.front(); } +template +std::vector TopN::Take() { + std::vector out = std::move(elements_); + if (state_ != State::HEAP_SORTED) { + std::sort(out.begin(), out.end(), cmp_); + } else { + out.pop_back(); + std::sort_heap(out.begin(), out.end(), cmp_); + } + Reset(); + return out; +} + template std::vector* TopN::Extract() { auto out = new std::vector; diff --git a/ortools/flatzinc/challenge/Makefile b/ortools/flatzinc/challenge/Makefile index f696954955a..e299acdfa89 100644 --- a/ortools/flatzinc/challenge/Makefile +++ b/ortools/flatzinc/challenge/Makefile @@ -18,7 +18,7 @@ DOCKER_BUILD_CMD := docker build endif DOCKER_RUN_CMD := docker run --rm --init -MZN_SUFFIX=2024v4 +MZN_SUFFIX=2024v5 DOCKER_NAME=cp-sat-minizinc-challenge MZN_TAG=${DOCKER_NAME}:${MZN_SUFFIX} MZN_LS_TAG=${DOCKER_NAME}-ls:${MZN_SUFFIX} diff --git a/ortools/flatzinc/presolve.cc b/ortools/flatzinc/presolve.cc index fb985006458..7b5984dee4e 100644 --- a/ortools/flatzinc/presolve.cc +++ b/ortools/flatzinc/presolve.cc @@ -183,7 +183,7 @@ bool IsIncreasingAndContiguous(absl::Span values) { return true; } -bool AreOnesFollowedByMinusOne(const std::vector& coeffs) { +bool AreOnesFollowedByMinusOne(absl::Span coeffs) { CHECK(!coeffs.empty()); for (int i = 0; i < coeffs.size() - 1; ++i) { if (coeffs[i] != 1) { diff --git a/ortools/glop/lu_factorization.cc b/ortools/glop/lu_factorization.cc index de30ea0f226..8deed0b61c3 100644 --- a/ortools/glop/lu_factorization.cc +++ b/ortools/glop/lu_factorization.cc @@ -405,7 +405,8 @@ bool LuFactorization::LeftSolveLWithNonZeros( ClearAndResizeVectorWithNonZeros(x->size(), result_before_permutation); x->swap(result_before_permutation->values); if (nz->empty()) { - for (RowIndex row(0); row < inverse_row_perm_.size(); ++row) { + const RowIndex num_rows = inverse_row_perm_.size(); + for (RowIndex row(0); row < num_rows; ++row) { const Fractional value = (*result_before_permutation)[row]; if (value != 0.0) { const RowIndex permuted_row = inverse_row_perm_[row]; diff --git a/ortools/glop/revised_simplex.cc b/ortools/glop/revised_simplex.cc index 28ef73d4c97..88549afa300 100644 --- a/ortools/glop/revised_simplex.cc +++ b/ortools/glop/revised_simplex.cc @@ -451,7 +451,7 @@ Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { "PRIMAL_UNBOUNDED was reported, but the tolerance are good " "and the unbounded ray is not great."); SOLVER_LOG(logger_, - "The difference between unbounded and optimal can depends " + "The difference between unbounded and optimal can depend " "on a slight change of tolerance, trying to see if we are " "at OPTIMAL after postsolve."); problem_status_ = ProblemStatus::OPTIMAL; @@ -1087,7 +1087,9 @@ bool RevisedSimplex::InitializeObjectiveAndTestIfUnchanged( // This function work whether the lp is in equation form (with slack) or // without, since the objective of the slacks are always zero. DCHECK_GE(num_cols_, lp.num_variables()); - for (ColIndex col(lp.num_variables()); col < num_cols_; ++col) { + + const auto obj_coeffs = lp.objective_coefficients().const_view(); + for (ColIndex col(obj_coeffs.size()); col < num_cols_; ++col) { if (objective_[col] != 0.0) { objective_is_unchanged = false; objective_[col] = 0.0; @@ -1096,8 +1098,8 @@ bool RevisedSimplex::InitializeObjectiveAndTestIfUnchanged( if (lp.IsMaximizationProblem()) { // Note that we use the minimization version of the objective internally. - for (ColIndex col(0); col < lp.num_variables(); ++col) { - const Fractional coeff = -lp.objective_coefficients()[col]; + for (ColIndex col(0); col < obj_coeffs.size(); ++col) { + const Fractional coeff = -obj_coeffs[col]; if (objective_[col] != coeff) { objective_is_unchanged = false; objective_[col] = coeff; @@ -1106,8 +1108,8 @@ bool RevisedSimplex::InitializeObjectiveAndTestIfUnchanged( objective_offset_ = -lp.objective_offset(); objective_scaling_factor_ = -lp.objective_scaling_factor(); } else { - for (ColIndex col(0); col < lp.num_variables(); ++col) { - const Fractional coeff = lp.objective_coefficients()[col]; + for (ColIndex col(0); col < obj_coeffs.size(); ++col) { + const Fractional coeff = obj_coeffs[col]; if (objective_[col] != coeff) { objective_is_unchanged = false; objective_[col] = coeff; @@ -1120,7 +1122,7 @@ bool RevisedSimplex::InitializeObjectiveAndTestIfUnchanged( return objective_is_unchanged; } -void RevisedSimplex::InitializeObjectiveLimit(const LinearProgram& lp) { +void RevisedSimplex::InitializeObjectiveLimit() { objective_limit_reached_ = false; DCHECK(std::isfinite(objective_offset_)); DCHECK(std::isfinite(objective_scaling_factor_)); @@ -1418,7 +1420,7 @@ Status RevisedSimplex::Initialize(const LinearProgram& lp) { } } - InitializeObjectiveLimit(lp); + InitializeObjectiveLimit(); // Computes the variable name as soon as possible for logging. // TODO(user): do we really need to store them? we could just compute them diff --git a/ortools/glop/revised_simplex.h b/ortools/glop/revised_simplex.h index 3f8d9c8b4ea..a4fa4b51adb 100644 --- a/ortools/glop/revised_simplex.h +++ b/ortools/glop/revised_simplex.h @@ -414,7 +414,7 @@ class RevisedSimplex { bool InitializeObjectiveAndTestIfUnchanged(const LinearProgram& lp); // Computes the stopping criterion on the problem objective value. - void InitializeObjectiveLimit(const LinearProgram& lp); + void InitializeObjectiveLimit(); // Initializes the starting basis. In most cases it starts by the all slack // basis and tries to apply some heuristics to replace fixed variables. diff --git a/ortools/glop/samples/code_samples.bzl b/ortools/glop/samples/code_samples.bzl index f0d55d3b0d0..f2163edba6f 100644 --- a/ortools/glop/samples/code_samples.bzl +++ b/ortools/glop/samples/code_samples.bzl @@ -13,8 +13,10 @@ """Helper macro to compile and test code samples.""" +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_test") + def code_sample_cc(name): - native.cc_binary( + cc_binary( name = name + "_cc", srcs = [name + ".cc"], deps = [ @@ -24,7 +26,7 @@ def code_sample_cc(name): ], ) - native.cc_test( + cc_test( name = name + "_cc_test", size = "small", srcs = [name + ".cc"], diff --git a/ortools/glop/variables_info.cc b/ortools/glop/variables_info.cc index 3830178e80e..d100f15eea1 100644 --- a/ortools/glop/variables_info.cc +++ b/ortools/glop/variables_info.cc @@ -120,13 +120,13 @@ void VariablesInfo::InitializeFromBasisState(ColIndex first_slack_col, // Compute the status for all the columns (note that the slack variables are // already added at the end of the matrix at this stage). + const int state_size = state.statuses.size().value(); for (ColIndex col(0); col < num_cols; ++col) { // Start with the given "warm" status from the BasisState if it exists. VariableStatus status; - if (col < first_new_col && col < state.statuses.size()) { + if (col < first_new_col && col < state_size) { status = state.statuses[col]; - } else if (col >= first_slack_col && - col - num_new_cols < state.statuses.size()) { + } else if (col >= first_slack_col && col - num_new_cols < state_size) { status = state.statuses[col - num_new_cols]; } else { UpdateToNonBasicStatus(col, DefaultVariableStatus(col)); diff --git a/ortools/graph/BUILD.bazel b/ortools/graph/BUILD.bazel index fe0f5883b5d..824d6cb89bc 100644 --- a/ortools/graph/BUILD.bazel +++ b/ortools/graph/BUILD.bazel @@ -68,6 +68,7 @@ cc_library( ":graph", "//ortools/base:iterator_adaptors", "//ortools/base:threadpool", + "//ortools/base:top_n", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", @@ -85,6 +86,25 @@ cc_library( ], ) +cc_test( + name = "multi_dijkstra_test", + size = "small", + srcs = ["multi_dijkstra_test.cc"], + deps = [ + ":connected_components", + ":graph", + ":multi_dijkstra", + ":random_graph", + ":util", + "//ortools/base:gmock_main", + "//ortools/base:map_util", + "//ortools/base:types", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/random:distributions", + ], +) + cc_library( name = "bidirectional_dijkstra", hdrs = ["bidirectional_dijkstra.h"], @@ -99,6 +119,23 @@ cc_library( ], ) +cc_test( + name = "bidirectional_dijkstra_test", + size = "small", + srcs = ["bidirectional_dijkstra_test.cc"], + deps = [ + ":bidirectional_dijkstra", + ":bounded_dijkstra", + ":graph", + "//ortools/base:gmock_main", + "//ortools/base:iterator_adaptors", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cliques", srcs = ["cliques.cc"], @@ -107,8 +144,10 @@ cc_library( "//ortools/base", "//ortools/base:int_type", "//ortools/base:strong_vector", + "//ortools/util:bitset", "//ortools/util:time_limit", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], ) @@ -126,6 +165,21 @@ cc_library( ], ) +cc_test( + name = "hamiltonian_path_test", + size = "medium", + timeout = "long", + srcs = ["hamiltonian_path_test.cc"], + deps = [ + ":hamiltonian_path", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "christofides", hdrs = ["christofides.h"], @@ -144,6 +198,20 @@ cc_library( ], ) +cc_test( + name = "christofides_test", + srcs = ["christofides_test.cc"], + deps = [ + ":christofides", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "eulerian_path", hdrs = ["eulerian_path.h"], @@ -152,6 +220,18 @@ cc_library( ], ) +cc_test( + name = "eulerian_path_test", + srcs = ["eulerian_path_test.cc"], + deps = [ + ":eulerian_path", + ":graph", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "minimum_spanning_tree", hdrs = ["minimum_spanning_tree.h"], @@ -164,6 +244,21 @@ cc_library( ], ) +cc_test( + name = "minimum_spanning_tree_test", + srcs = ["minimum_spanning_tree_test.cc"], + deps = [ + ":graph", + ":minimum_spanning_tree", + "//ortools/base:gmock_main", + "//ortools/base:types", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "one_tree_lower_bound", hdrs = ["one_tree_lower_bound.h"], @@ -185,6 +280,23 @@ cc_library( "//ortools/util:permutation", "//ortools/util:zvector", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_prod", + ], +) + +cc_test( + name = "ebert_graph_test", + size = "small", + srcs = ["ebert_graph_test.cc"], + deps = [ + ":ebert_graph", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/util:permutation", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_benchmark//:benchmark", ], ) @@ -209,6 +321,23 @@ cc_library( ], ) +cc_test( + name = "shortest_paths_test", + size = "medium", + srcs = ["shortest_paths_test.cc"], + tags = ["noasan"], # Times out occasionally in ASAN mode. + deps = [ + ":ebert_graph", + ":shortest_paths", + ":strongly_connected_components", + "//ortools/base:gmock_main", + "//ortools/util:zvector", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + ], +) + cc_library( name = "k_shortest_paths", hdrs = ["k_shortest_paths.h"], @@ -248,9 +377,9 @@ cc_library( ":graph", ":graphs", "//ortools/base", - "//ortools/base:types", "//ortools/util:stats", "//ortools/util:zvector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -271,6 +400,7 @@ cc_test( "//ortools/base:path", "//ortools/linear_solver", "//ortools/util:file_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/random", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -343,6 +473,16 @@ cc_library( ], ) +cc_test( + name = "assignment_test", + size = "small", + srcs = ["assignment_test.cc"], + deps = [ + ":assignment", + "//ortools/base:gmock_main", + ], +) + # Linear Assignment with full-featured interface and efficient # implementation. cc_library( @@ -357,6 +497,23 @@ cc_library( "//ortools/util:zvector", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_prod", + ], +) + +cc_test( + name = "linear_assignment_test", + size = "small", + srcs = ["linear_assignment_test.cc"], + deps = [ + ":ebert_graph", + ":graph", + ":linear_assignment", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", ], ) @@ -428,6 +585,76 @@ cc_library( ], ) +cc_test( + name = "rooted_tree_test", + srcs = ["rooted_tree_test.cc"], + deps = [ + ":graph", + ":rooted_tree", + "//ortools/base:gmock_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/status", + "@com_google_benchmark//:benchmark", + ], +) + +cc_test( + name = "perfect_matching_test", + size = "small", + srcs = ["perfect_matching_test.cc"], + deps = [ + ":perfect_matching", + "//ortools/base:gmock_main", + "//ortools/linear_solver:linear_solver_cc_proto", + "//ortools/linear_solver:solve_mp_model", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "dag_shortest_path_test", + size = "small", + srcs = ["dag_shortest_path_test.cc"], + deps = [ + ":dag_shortest_path", + ":graph", + ":io", + "//ortools/base:dump_vars", + "//ortools/base:gmock_main", + "//ortools/util:flat_matrix", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + +cc_test( + name = "dag_constrained_shortest_path_test", + srcs = ["dag_constrained_shortest_path_test.cc"], + deps = [ + ":dag_constrained_shortest_path", + ":dag_shortest_path", + ":graph", + ":io", + "//ortools/base:dump_vars", + "//ortools/base:gmock_main", + "//ortools/math_opt/cpp:math_opt", + "//ortools/math_opt/solvers:cp_sat_solver", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + # From util/graph cc_library( name = "connected_components", diff --git a/ortools/graph/CMakeLists.txt b/ortools/graph/CMakeLists.txt index 109cf30b019..f76659bd0ce 100644 --- a/ortools/graph/CMakeLists.txt +++ b/ortools/graph/CMakeLists.txt @@ -12,28 +12,9 @@ # limitations under the License. file(GLOB _SRCS "*.h" "*.cc") +list(FILTER _SRCS EXCLUDE REGEX ".*/.*_test.cc") list(REMOVE_ITEM _SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/assignment_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/bidirectional_dijkstra_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/bounded_dijkstra_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/christofides_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cliques_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dag_constrained_shortest_path_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dag_shortest_path_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ebert_graph_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/eulerian_path_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/hamiltonian_path_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/k_shortest_paths_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/linear_assignment_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/max_flow_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/min_cost_flow_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/minimum_spanning_tree_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/multi_dijkstra_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/one_tree_lower_bound_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/perfect_matching_test.cc - ${CMAKE_CURRENT_SOURCE_DIR}/rooted_tree_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/shortest_paths_benchmarks.cc - ${CMAKE_CURRENT_SOURCE_DIR}/shortest_paths_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/solve_flow_model.cc ) diff --git a/ortools/graph/bidirectional_dijkstra_test.cc b/ortools/graph/bidirectional_dijkstra_test.cc index 5de9f168ab8..f71bbc350f8 100644 --- a/ortools/graph/bidirectional_dijkstra_test.cc +++ b/ortools/graph/bidirectional_dijkstra_test.cc @@ -27,6 +27,7 @@ #include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/base/gmock.h" +#include "ortools/base/iterator_adaptors.h" #include "ortools/graph/bounded_dijkstra.h" #include "ortools/graph/graph.h" @@ -202,7 +203,7 @@ TEST(BidirectionalDijkstraTest, RandomizedCorrectnessTest) { ref_dijkstra.ArcPathToNode(ref_dests[0]); const auto path = tested_dijkstra.SetToSetShortestPath(srcs, dsts); std::vector arc_path = path.forward_arc_path; - for (const int arc : gtl::reversed_view(path.backward_arc_path)) { + for (const int arc : ::gtl::reversed_view(path.backward_arc_path)) { arc_path.push_back(forward_arc_of_backward_arc[arc]); } ASSERT_THAT(arc_path, ElementsAreArray(ref_arc_path)) diff --git a/ortools/graph/christofides.h b/ortools/graph/christofides.h index 38f0c907bf2..92bd1bfcce5 100644 --- a/ortools/graph/christofides.h +++ b/ortools/graph/christofides.h @@ -28,7 +28,7 @@ #include #include -#include +#include #include #include @@ -84,7 +84,17 @@ class ChristofidesPathSolver { bool Solve(); private: - int64_t SafeAdd(int64_t a, int64_t b) { return CapAdd(a, b); } + // Safe addition operator to avoid overflows when possible. + template + T SafeAdd(T a, T b) { + // TODO(user): use std::remove_cvref_t once C++20 is available. + if constexpr (std::is_same_v>, + int64_t> == true) { + return CapAdd(a, b); + } else { + return a + b; + } + } // Matching algorithm to use. MatchingAlgorithm matching_; diff --git a/ortools/graph/cliques.cc b/ortools/graph/cliques.cc index ed3ccc1b937..e9fcff0dd16 100644 --- a/ortools/graph/cliques.cc +++ b/ortools/graph/cliques.cc @@ -20,6 +20,8 @@ #include #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "ortools/util/bitset.h" namespace operations_research { namespace { @@ -262,4 +264,129 @@ void CoverArcsByCliques(std::function graph, int node_count, initial_candidates.get(), 0, node_count, &actual, &stop); } +void WeightedBronKerboschBitsetAlgorithm::Initialize(int num_nodes) { + work_ = 0; + weights_.assign(num_nodes, 0.0); + + // We need +1 in case the graph is complete and form a clique. + clique_.resize(num_nodes + 1); + clique_weight_.resize(num_nodes + 1); + left_to_process_.resize(num_nodes + 1); + x_.resize(num_nodes + 1); + + // Initialize to empty graph. + graph_.resize(num_nodes); + for (Bitset64& bitset : graph_) { + bitset.ClearAndResize(num_nodes); + } +} + +void WeightedBronKerboschBitsetAlgorithm:: + TakeTransitiveClosureOfImplicationGraph() { + // We use Floyd-Warshall algorithm. + const int num_nodes = weights_.size(); + for (int k = 0; k < num_nodes; ++k) { + // Loop over all the i => k, we can do that by looking at the not(k) => + // not(i). + for (const int i : graph_[k ^ 1]) { + // Now i also implies all the literals implied by k. + graph_[i].Union(graph_[k]); + } + } +} + +std::vector> WeightedBronKerboschBitsetAlgorithm::Run() { + clique_index_and_weight_.clear(); + std::vector> cliques; + + const int num_nodes = weights_.size(); + in_clique_.ClearAndResize(num_nodes); + + queue_.clear(); + + int depth = 0; + left_to_process_[0].ClearAndResize(num_nodes); + x_[0].ClearAndResize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + left_to_process_[0].Set(i); + queue_.push_back(i); + } + + // We run an iterative DFS where we push all possible next node to + // queue_. We just abort brutally if we hit the work limit. + while (!queue_.empty() && work_ <= work_limit_) { + const int node = queue_.back(); + if (!in_clique_[node]) { + // We add this node to the clique. + in_clique_.Set(node); + clique_[depth] = node; + left_to_process_[depth].Clear(node); + x_[depth].Set(node); + + // Note that it might seems we don't need to keep both set since we + // only process nodes in order, but because of the pivot optim, while + // both set are sorted, they can be "interleaved". + ++depth; + work_ += num_nodes; + const double current_weight = weights_[node] + clique_weight_[depth - 1]; + clique_weight_[depth] = current_weight; + left_to_process_[depth].SetToIntersectionOf(left_to_process_[depth - 1], + graph_[node]); + x_[depth].SetToIntersectionOf(x_[depth - 1], graph_[node]); + + // Choose a pivot. We use the vertex with highest weight according to: + // Samuel Souza Britoa, Haroldo Gambini Santosa, "Preprocessing and + // Cutting Planes with Conflict Graphs", + // https://arxiv.org/pdf/1909.07780 + // but maybe random is more robust? + int pivot = -1; + double pivot_weight = -1.0; + for (const int candidate : x_[depth]) { + const double candidate_weight = weights_[candidate]; + if (candidate_weight > pivot_weight) { + pivot = candidate; + pivot_weight = candidate_weight; + } + } + double total_weight_left = 0.0; + for (const int candidate : left_to_process_[depth]) { + const double candidate_weight = weights_[candidate]; + if (candidate_weight > pivot_weight) { + pivot = candidate; + pivot_weight = candidate_weight; + } + total_weight_left += candidate_weight; + } + + // Heuristic: We can abort early if there is no way to reach the + // threshold here. + if (current_weight + total_weight_left < weight_threshold_) { + continue; + } + + if (pivot == -1 && current_weight >= weight_threshold_) { + // This clique is maximal. + clique_index_and_weight_.push_back({cliques.size(), current_weight}); + cliques.emplace_back(clique_.begin(), clique_.begin() + depth); + continue; + } + + for (const int next : left_to_process_[depth]) { + if (graph_[pivot][next]) continue; // skip. + queue_.push_back(next); + } + } else { + // We finished exploring node. + // backtrack. + --depth; + DCHECK_GE(depth, 0); + DCHECK_EQ(clique_[depth], node); + in_clique_.Clear(node); + queue_.pop_back(); + } + } + + return cliques; +} + } // namespace operations_research diff --git a/ortools/graph/cliques.h b/ortools/graph/cliques.h index 7901566bf28..702a0a09096 100644 --- a/ortools/graph/cliques.h +++ b/ortools/graph/cliques.h @@ -36,6 +36,7 @@ #include "ortools/base/int_type.h" #include "ortools/base/logging.h" #include "ortools/base/strong_vector.h" +#include "ortools/util/bitset.h" #include "ortools/util/time_limit.h" namespace operations_research { @@ -358,6 +359,87 @@ class BronKerboschAlgorithm { TimeLimit* time_limit_; }; +// More specialized version used to separate clique-cuts in MIP solver. +// This finds all maximal clique with a weight greater than a given threshold. +// It also has computation limit. +// +// This implementation assumes small graph since we use a dense bitmask +// representation to encode the graph adjacency. So it shouldn't really be used +// with more than a few thousands nodes. +class WeightedBronKerboschBitsetAlgorithm { + public: + // Resets the class to an empty graph will all weights of zero. + // This also reset the work done. + void Initialize(int num_nodes); + + // Set the weight of a given node, must be in [0, num_nodes). + // Weights are assumed to be non-negative. + void SetWeight(int i, double weight) { weights_[i] = weight; } + + // Add an edge in the graph. + void AddEdge(int a, int b) { + graph_[a].Set(b); + graph_[b].Set(a); + } + + // We count the number of basic operations, and stop when we reach this limit. + void SetWorkLimit(int64_t limit) { work_limit_ = limit; } + + // Set the minimum weight of the maximal cliques we are looking for. + void SetMinimumWeight(double min_weight) { weight_threshold_ = min_weight; } + + // This function is quite specific. It interprets node i as the negated + // literal of node i ^ 1. And all j in graph[i] as literal that are in at most + // two relation. So i implies all not(j) for all j in graph[i]. + // + // The transitive close runs in O(num_nodes ^ 3) in the worst case, but since + // we process 64 bits at the time, it is okay to run it for graph up to 1k + // nodes. + void TakeTransitiveClosureOfImplicationGraph(); + + // Runs the algo and returns all maximal clique with a weight above the + // configured thrheshold via SetMinimumWeight(). It is possible we reach the + // work limit before that. + std::vector> Run(); + + // Specific API where the index refer in the last result of Run(). + // This allows to select cliques when they are many. + std::vector>& GetMutableIndexAndWeight() { + return clique_index_and_weight_; + } + + int64_t WorkDone() const { return work_; } + + bool HasEdge(int i, int j) const { return graph_[i][j]; } + + private: + int64_t work_ = 0; + int64_t work_limit_ = std::numeric_limits::max(); + double weight_threshold_ = 0.0; + + std::vector weights_; + std::vector> graph_; + + // Iterative DFS queue. + std::vector queue_; + + // Current clique we are constructing. + // Note this is always of size num_nodes, the clique is in [0, depth) + Bitset64 in_clique_; + std::vector clique_; + + // We maintain the weight of the clique. We use a stack to avoid floating + // point issue with +/- weights many times. So clique_weight_[i] is the sum of + // weight from [0, i) of element of the cliques. + std::vector clique_weight_; + + // Correspond to P and X in BronKerbosch description. + std::vector> left_to_process_; + std::vector> x_; + + std::vector> clique_index_and_weight_; +}; + template void BronKerboschAlgorithm::InitializeState(State* state) { DCHECK(state != nullptr); diff --git a/ortools/graph/cliques_test.cc b/ortools/graph/cliques_test.cc index d846be1d651..623d104ba3c 100644 --- a/ortools/graph/cliques_test.cc +++ b/ortools/graph/cliques_test.cc @@ -14,6 +14,7 @@ #include "ortools/graph/cliques.h" #include +#include #include #include #include @@ -24,6 +25,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/functional/bind_front.h" +#include "absl/log/check.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" @@ -412,6 +414,44 @@ TEST(BronKerbosch, CompleteGraphCover) { EXPECT_EQ(10, all_cliques[0].size()); } +TEST(WeightedBronKerboschBitsetAlgorithmTest, CompleteGraph) { + const int num_nodes = 1000; + WeightedBronKerboschBitsetAlgorithm algo; + algo.Initialize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + for (int j = i + 1; j < num_nodes; ++j) { + algo.AddEdge(i, j); + } + } + std::vector> cliques = algo.Run(); + EXPECT_EQ(cliques.size(), 1); + for (const std::vector& clique : cliques) { + EXPECT_EQ(num_nodes, clique.size()); + } +} + +TEST(WeightedBronKerboschBitsetAlgorithmTest, ImplicationGraphClosure) { + const int num_nodes = 10; + WeightedBronKerboschBitsetAlgorithm algo; + algo.Initialize(num_nodes); + for (int i = 0; i + 2 < num_nodes; i += 2) { + const int j = i + 2; + algo.AddEdge(i, j ^ 1); // i => j + } + algo.TakeTransitiveClosureOfImplicationGraph(); + for (int i = 0; i < num_nodes; ++i) { + for (int j = 0; j < num_nodes; ++j) { + if (i % 2 == 0 && j % 2 == 0) { + if (j > i) { + EXPECT_TRUE(algo.HasEdge(i, j ^ 1)); + } else { + EXPECT_FALSE(algo.HasEdge(i, j ^ 1)); + } + } + } + } +} + TEST(BronKerbosch, EmptyGraphCover) { auto graph = EmptyGraph; CliqueReporter reporter; @@ -477,6 +517,50 @@ TEST(BronKerboschAlgorithmTest, FullKPartiteGraph) { } } +TEST(WeightedBronKerboschBitsetAlgorithmTest, FullKPartiteGraph) { + const int kNumPartitions[] = {2, 3, 4, 5, 6, 7}; + for (const int num_partitions : kNumPartitions) { + SCOPED_TRACE(absl::StrCat("num_partitions = ", num_partitions)); + WeightedBronKerboschBitsetAlgorithm algo; + + const int num_nodes = num_partitions * num_partitions; + algo.Initialize(num_nodes); + + for (int i = 0; i < num_nodes; ++i) { + for (int j = i + 1; j < num_nodes; ++j) { + if (FullKPartiteGraph(num_partitions, i, j)) algo.AddEdge(i, j); + } + } + + std::vector> cliques = algo.Run(); + EXPECT_EQ(cliques.size(), pow(num_partitions, num_partitions)); + for (const std::vector& clique : cliques) { + EXPECT_EQ(num_partitions, clique.size()); + } + } +} + +TEST(WeightedBronKerboschBitsetAlgorithmTest, ModuloGraph) { + int num_partitions = 50; + int partition_size = 100; + WeightedBronKerboschBitsetAlgorithm algo; + + const int num_nodes = num_partitions * partition_size; + algo.Initialize(num_nodes); + + for (int i = 0; i < num_nodes; ++i) { + for (int j = i + 1; j < num_nodes; ++j) { + if (ModuloGraph(num_partitions, i, j)) algo.AddEdge(i, j); + } + } + + std::vector> cliques = algo.Run(); + EXPECT_EQ(cliques.size(), num_partitions); + for (const std::vector& clique : cliques) { + EXPECT_EQ(partition_size, clique.size()); + } +} + // The following two tests run the Bron-Kerbosch algorithm with wall time // limit and deterministic time limit. They use a full 15-partite graph with // a one second time limit. @@ -590,6 +674,37 @@ BENCHMARK(BM_FindCliquesInModuloGraphWithBronKerboschAlgorithm) ->ArgPair(500, 10) ->ArgPair(1000, 5); +void BM_FindCliquesInModuloGraphWithBitsetBK(benchmark::State& state) { + int num_partitions = state.range(0); + int partition_size = state.range(1); + const int kExpectedNumCliques = num_partitions; + const int kExpectedCliqueSize = partition_size; + const int num_nodes = num_partitions * partition_size; + for (auto _ : state) { + WeightedBronKerboschBitsetAlgorithm algo; + algo.Initialize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + for (int j = i + 1; j < num_nodes; ++j) { + if (ModuloGraph(num_partitions, i, j)) algo.AddEdge(i, j); + } + } + + std::vector> cliques = algo.Run(); + EXPECT_EQ(cliques.size(), kExpectedNumCliques); + for (const std::vector& clique : cliques) { + EXPECT_EQ(kExpectedCliqueSize, clique.size()); + } + } +} + +BENCHMARK(BM_FindCliquesInModuloGraphWithBitsetBK) + ->ArgPair(5, 1000) + ->ArgPair(10, 500) + ->ArgPair(50, 100) + ->ArgPair(100, 50) + ->ArgPair(500, 10) + ->ArgPair(1000, 5); + // A benchmark that finds all maximal cliques in a 7-partite graph (a graph // where the nodes are divided into 7 groups of size 7; each node is connected // to all nodes in other groups but to no node in the same group). This graph diff --git a/ortools/graph/ebert_graph.h b/ortools/graph/ebert_graph.h index e3541a995f2..812b876718d 100644 --- a/ortools/graph/ebert_graph.h +++ b/ortools/graph/ebert_graph.h @@ -1109,6 +1109,16 @@ class EbertGraphBase }; #endif // SWIG + // Using the SetHead() method implies that the BuildRepresentation() + // method must be called to restore consistency before the graph is + // used. + // + // Visible for testing. + void SetHead(const ArcIndexType arc, const NodeIndexType head) { + representation_clean_ = false; + head_.Set(arc, head); + } + protected: EbertGraphBase() : next_adjacent_arc_(), representation_clean_(true) {} @@ -1175,14 +1185,6 @@ class EbertGraphBase } bool RepresentationClean() const { return representation_clean_; } - - // Using the SetHead() method implies that the BuildRepresentation() - // method must be called to restore consistency before the graph is - // used. - void SetHead(const ArcIndexType arc, const NodeIndexType head) { - representation_clean_ = false; - head_.Set(arc, head); - } }; // Most users should only use StarGraph, which is EbertGraph, diff --git a/ortools/graph/ebert_graph_test.cc b/ortools/graph/ebert_graph_test.cc index 5dfa4073ffc..7409781f7ff 100644 --- a/ortools/graph/ebert_graph_test.cc +++ b/ortools/graph/ebert_graph_test.cc @@ -20,13 +20,11 @@ #include "absl/base/macros.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "benchmark/benchmark.h" #include "gtest/gtest.h" #include "ortools/base/macros.h" #include "ortools/util/permutation.h" -#include "testing/base/public/test_utils.h" namespace operations_research { @@ -1033,10 +1031,6 @@ TYPED_TEST(TinyEbertGraphTest, CheckDeathOnBadBounds) { int num_nodes = SmallStarGraph::kMaxNumNodes; int num_arcs = SmallStarGraph::kMaxNumArcs; SmallStarGraph(num_nodes, num_arcs); // Construct an unused graph. All fine. - EXPECT_DFATAL(SmallStarGraph(num_nodes + 1, num_arcs), - "Could not reserve memory for -128 nodes and 127 arcs."); - EXPECT_DFATAL(SmallStarGraph(num_nodes, num_arcs + 1), - "Could not reserve memory for 127 nodes and -128 arcs."); } // An empty fixture to collect the types of small graphs for which we want to do diff --git a/ortools/graph/k_shortest_paths.h b/ortools/graph/k_shortest_paths.h index 89011bc9891..108c393de5c 100644 --- a/ortools/graph/k_shortest_paths.h +++ b/ortools/graph/k_shortest_paths.h @@ -165,14 +165,14 @@ std::tuple, PathDistance> ComputeShortestPath( // This case only happens when some arcs have an infinite length (i.e. // larger than `kMaxDistance`): `BoundedDijkstraWrapper::NodePathTo` fails // to return a path, even empty. - return {{}, kDisconnectedDistance}; + return {std::vector{}, kDisconnectedDistance}; } if (std::vector path = std::move(dijkstra.NodePathTo(destination)); !path.empty()) { return {std::move(path), path_length}; } else { - return {{}, kDisconnectedDistance}; + return {std::vector{}, kDisconnectedDistance}; } } diff --git a/ortools/graph/linear_assignment.h b/ortools/graph/linear_assignment.h index dba3e28b732..635fd275c9b 100644 --- a/ortools/graph/linear_assignment.h +++ b/ortools/graph/linear_assignment.h @@ -378,6 +378,12 @@ class LinearSumAssignment { typename GraphType::NodeIndex node_iterator_; }; + // Returns true if and only if the current pseudoflow is + // epsilon-optimal. To be used in a DCHECK. + // + // Visible for testing. + bool EpsilonOptimal() const; + private: struct Stats { Stats() : pushes_(0), double_pushes_(0), relabelings_(0), refinements_(0) {} @@ -462,10 +468,6 @@ class LinearSumAssignment { // right-side nodes during DoublePush operations. typedef std::pair ImplicitPriceSummary; - // Returns true if and only if the current pseudoflow is - // epsilon-optimal. To be used in a DCHECK. - bool EpsilonOptimal() const; - // Checks that all nodes are matched. // To be used in a DCHECK. bool AllMatched() const; @@ -515,7 +517,7 @@ class LinearSumAssignment { // definition of admissibility, this action is different from // saturating all admissible arcs (which we never do). All negative // arcs are admissible, but not all admissible arcs are negative. It - // is alwsys enough to saturate only the negative ones. + // is always enough to saturate only the negative ones. void SaturateNegativeArcs(); // Performs an optimized sequence of pushing a unit of excess out of diff --git a/ortools/graph/max_flow.cc b/ortools/graph/max_flow.cc index c9022ddc424..cbe1a725b42 100644 --- a/ortools/graph/max_flow.cc +++ b/ortools/graph/max_flow.cc @@ -19,9 +19,11 @@ #include #include +#include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" #include "ortools/graph/graphs.h" @@ -1025,4 +1027,83 @@ template class GenericMaxFlow<::util::ReverseArcListGraph<>>; template class GenericMaxFlow<::util::ReverseArcStaticGraph<>>; template class GenericMaxFlow<::util::ReverseArcMixedGraph<>>; +std::vector BipartiteMinimumVertexCover( + const std::vector>& left_to_right_arcs, int num_right) { + // This algorithm first uses the maximum flow to find a maximum matching. Then + // it uses the same method outlined in the proof of Konig's theorem to + // transform the maximum matching into a minimum vertex cover. + // + // More concretely, it uses a DFS starting with unmatched nodes and + // alternating matched/unmatched edges to find a minimum vertex cover. + SimpleMaxFlow max_flow; + const int num_left = left_to_right_arcs.size(); + std::vector arcs; + for (int i = 0; i < num_left; ++i) { + for (const int right_node : left_to_right_arcs[i]) { + DCHECK_GE(right_node, num_left); + DCHECK_LT(right_node, num_right + num_left); + arcs.push_back(max_flow.AddArcWithCapacity(i, right_node, 1)); + } + } + std::vector> adj_list = left_to_right_arcs; + adj_list.resize(num_left + num_right); + for (int i = 0; i < num_left; ++i) { + for (const int right_node : left_to_right_arcs[i]) { + adj_list[right_node].push_back(i); + } + } + const int sink = num_left + num_right; + const int source = num_left + num_right + 1; + for (int i = 0; i < num_left; ++i) { + max_flow.AddArcWithCapacity(source, i, 1); + } + for (int i = 0; i < num_right; ++i) { + max_flow.AddArcWithCapacity(i + num_left, sink, 1); + } + CHECK(max_flow.Solve(source, sink) == SimpleMaxFlow::OPTIMAL); + std::vector maximum_matching(num_left + num_right, -1); + for (const ArcIndex arc : arcs) { + if (max_flow.Flow(arc) > 0) { + maximum_matching[max_flow.Tail(arc)] = max_flow.Head(arc); + maximum_matching[max_flow.Head(arc)] = max_flow.Tail(arc); + } + } + // We do a DFS starting with unmatched nodes and alternating matched/unmatched + // edges. + std::vector in_alternating_path(num_left + num_right, false); + std::vector to_visit; + for (int i = 0; i < num_left; ++i) { + if (maximum_matching[i] == -1) { + to_visit.push_back(i); + } + } + while (!to_visit.empty()) { + const int current = to_visit.back(); + to_visit.pop_back(); + if (in_alternating_path[current]) { + continue; + } + in_alternating_path[current] = true; + for (const int j : adj_list[current]) { + if (current < num_left && maximum_matching[current] != j) { + to_visit.push_back(j); + } else if (current >= num_left && maximum_matching[current] == j) { + to_visit.push_back(j); + } + } + } + std::vector minimum_vertex_cover(num_left + num_right, false); + for (int i = 0; i < num_left; ++i) { + if (!in_alternating_path[i]) { + minimum_vertex_cover[i] = true; + } + } + for (int i = num_left; i < num_left + num_right; ++i) { + if (in_alternating_path[i]) { + minimum_vertex_cover[i] = true; + } + } + return minimum_vertex_cover; +} + } // namespace operations_research diff --git a/ortools/graph/max_flow.h b/ortools/graph/max_flow.h index a5d961e5396..b9cfa38d692 100644 --- a/ortools/graph/max_flow.h +++ b/ortools/graph/max_flow.h @@ -679,6 +679,18 @@ extern template class GenericMaxFlow<::util::ReverseArcListGraph<>>; extern template class GenericMaxFlow<::util::ReverseArcStaticGraph<>>; extern template class GenericMaxFlow<::util::ReverseArcMixedGraph<>>; +// This method computes a minimum vertex cover for the bipartite graph. +// +// If we define num_left=left_to_right_arcs.size(), the "left" nodes are +// integers in [0, num_left), and the "right" nodes are integers in [num_left, +// num_left + num_right). +// +// Returns a vector of size num_left+num_right, such that element #l is true if +// it is part of the minimum vertex cover and false if it is part of the maximum +// independent set (one is the complement of the other). +std::vector BipartiteMinimumVertexCover( + const std::vector>& left_to_right_arcs, int num_right); + // Default instance MaxFlow that uses StarGraph. Note that we cannot just use a // typedef because of dependent code expecting MaxFlow to be a real class. // TODO(user): Modify this code and remove it. diff --git a/ortools/graph/max_flow_test.cc b/ortools/graph/max_flow_test.cc index 05784665bbb..5b9e0fbcceb 100644 --- a/ortools/graph/max_flow_test.cc +++ b/ortools/graph/max_flow_test.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/random/random.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -264,7 +265,7 @@ class GenericMaxFlowTest : public ::testing::Test {}; typedef ::testing::Types, util::ReverseArcStaticGraph<>, - util::ReverseArcMixedGraph<> > + util::ReverseArcMixedGraph<>> GraphTypes; TYPED_TEST_SUITE(GenericMaxFlowTest, GraphTypes); @@ -559,7 +560,7 @@ void FullRandomAssignment(typename MaxFlowSolver::Solver f, GenerateCompleteGraph(num_tails, num_heads, &graph); Graphs::Build(&graph); std::vector arc_capacity(graph.num_arcs(), 1); - std::unique_ptr > max_flow(new GenericMaxFlow( + std::unique_ptr> max_flow(new GenericMaxFlow( &graph, graph.num_nodes() - 2, graph.num_nodes() - 1)); SetUpNetworkData(arc_capacity, max_flow.get()); FlowQuantity flow = f(max_flow.get()); @@ -578,7 +579,7 @@ void PartialRandomAssignment(typename MaxFlowSolver::Solver f, Graphs::Build(&graph); CHECK_EQ(graph.num_arcs(), num_tails * kDegree + num_tails + num_heads); std::vector arc_capacity(graph.num_arcs(), 1); - std::unique_ptr > max_flow(new GenericMaxFlow( + std::unique_ptr> max_flow(new GenericMaxFlow( &graph, graph.num_nodes() - 2, graph.num_nodes() - 1)); SetUpNetworkData(arc_capacity, max_flow.get()); FlowQuantity flow = f(max_flow.get()); @@ -613,7 +614,7 @@ void PartialRandomFlow(typename MaxFlowSolver::Solver f, Graphs::Build(&graph, &permutation); util::Permute(permutation, &arc_capacity); - std::unique_ptr > max_flow(new GenericMaxFlow( + std::unique_ptr> max_flow(new GenericMaxFlow( &graph, graph.num_nodes() - 2, graph.num_nodes() - 1)); SetUpNetworkData(arc_capacity, max_flow.get()); FlowQuantity flow = f(max_flow.get()); @@ -642,7 +643,7 @@ void FullRandomFlow(typename MaxFlowSolver::Solver f, Graphs::Build(&graph, &permutation); util::Permute(permutation, &arc_capacity); - std::unique_ptr > max_flow(new GenericMaxFlow( + std::unique_ptr> max_flow(new GenericMaxFlow( &graph, graph.num_nodes() - 2, graph.num_nodes() - 1)); SetUpNetworkData(arc_capacity, max_flow.get()); FlowQuantity flow = f(max_flow.get()); @@ -672,10 +673,10 @@ void FullRandomFlow(typename MaxFlowSolver::Solver f, expected_flow2); \ } -#define FLOW_ONLY_TEST_SG(test_name, size, expected_flow1, expected_flow2) \ - TEST(MaxFlowTestStaticGraph, test_name##size) { \ - test_name >(SolveMaxFlow, size, size, \ - expected_flow1, expected_flow2); \ +#define FLOW_ONLY_TEST_SG(test_name, size, expected_flow1, expected_flow2) \ + TEST(MaxFlowTestStaticGraph, test_name##size) { \ + test_name>(SolveMaxFlow, size, size, \ + expected_flow1, expected_flow2); \ } LP_AND_FLOW_TEST(FullRandomAssignment, 300, 300, 300); @@ -838,6 +839,26 @@ TEST(PriorityQueueWithRestrictedPushTest, RandomPushPop) { } } +TEST(BipartiteMinimumVertexCoverTest, BasicBehavior) { + const int num_right = 4; + const std::vector> left_to_right = { + {5}, {4, 5, 6}, {5}, {5, 6, 7}}; + EXPECT_EQ(absl::c_count(BipartiteMinimumVertexCover(left_to_right, num_right), + true), + 3); + EXPECT_EQ(absl::c_count(BipartiteMinimumVertexCover(left_to_right, num_right), + false), + 5); +} + +TEST(BipartiteMinimumVertexCoverTest, Empty) { + const int num_right = 4; + const std::vector> left_to_right = {{}, {}}; + EXPECT_EQ(absl::c_count(BipartiteMinimumVertexCover(left_to_right, num_right), + false), + 6); +} + TEST(PriorityQueueWithRestrictedPushDeathTest, DCHECK) { // Don't run this test in opt mode. if (!DEBUG_MODE) GTEST_SKIP(); diff --git a/ortools/graph/rooted_tree_test.cc b/ortools/graph/rooted_tree_test.cc index d2160c24578..d8033bcb61f 100644 --- a/ortools/graph/rooted_tree_test.cc +++ b/ortools/graph/rooted_tree_test.cc @@ -219,12 +219,12 @@ TYPED_TEST_P(RootedTreeTest, AllDistancesToRoot) { // 0 3 // | // 2 - const int root = 1; + const Node root = 1; std::vector parents = {1, this->kNullParent, 3, 1}; const std::vector arc_lengths = {1, 0, 10, 100}; ASSERT_OK_AND_ASSIGN(const auto tree, RootedTree::Create(root, parents)); - EXPECT_THAT(tree.AllDistancesToRoot(arc_lengths), + EXPECT_THAT(tree.template AllDistancesToRoot(arc_lengths), ElementsAre(1.0, 0.0, 110.0, 100.0)); } diff --git a/ortools/graph/samples/code_samples.bzl b/ortools/graph/samples/code_samples.bzl index cda07eb817e..fcd019c54e7 100644 --- a/ortools/graph/samples/code_samples.bzl +++ b/ortools/graph/samples/code_samples.bzl @@ -14,10 +14,12 @@ """Helper macro to compile and test code samples.""" load("@pip_deps//:requirements.bzl", "requirement") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_test") +load("@rules_java//java:defs.bzl", "java_test") load("@rules_python//python:defs.bzl", "py_binary", "py_test") def code_sample_cc(name): - native.cc_binary( + cc_binary( name = name + "_cc", srcs = [name + ".cc"], deps = [ @@ -38,7 +40,7 @@ def code_sample_cc(name): ], ) - native.cc_test( + cc_test( name = name + "_cc_test", size = "small", srcs = [name + ".cc"], @@ -98,7 +100,7 @@ def code_sample_cc_py(name): code_sample_py(name = name) def code_sample_java(name): - native.java_test( + java_test( name = name + "_java_test", size = "small", srcs = [name + ".java"], diff --git a/ortools/gurobi/environment.cc b/ortools/gurobi/environment.cc index f2f73cbac22..1e606fc9b8e 100644 --- a/ortools/gurobi/environment.cc +++ b/ortools/gurobi/environment.cc @@ -346,8 +346,8 @@ void LoadGurobiFunctions(DynamicLibrary* gurobi_dynamic_library) { std::vector GurobiDynamicLibraryPotentialPaths() { std::vector potential_paths; const std::vector kGurobiVersions = { - "1103", "1102", "1101", "1100", "1003", "1002", "1001", "1000", "952", "951", - "950", "911", "910", "903", "902", "811", "801", "752"}; + "1103", "1102", "1101", "1100", "1003", "1002", "1001", "1000", "952", + "951", "950", "911", "910", "903", "902", "811", "801", "752"}; potential_paths.reserve(kGurobiVersions.size() * 3); // Look for libraries pointed by GUROBI_HOME first. @@ -406,8 +406,8 @@ std::vector GurobiDynamicLibraryPotentialPaths() { #if defined(__GNUC__) // path in linux64 gurobi/optimizer docker image. for (const std::string& version : - {"11.0.3", "11.0.2", "11.0.1", "11.0.0", "10.0.3", "10.0.2", "10.0.1", "10.0.0", - "9.5.2", "9.5.1", "9.5.0"}) { + {"11.0.3", "11.0.2", "11.0.1", "11.0.0", "10.0.3", "10.0.2", "10.0.1", + "10.0.0", "9.5.2", "9.5.1", "9.5.0"}) { potential_paths.push_back( absl::StrCat("/opt/gurobi/linux64/lib/libgurobi.so.", version)); } diff --git a/ortools/julia/ORTools.jl/LICENSE b/ortools/julia/ORTools.jl/LICENSE new file mode 100644 index 00000000000..cbe9ed1ae26 --- /dev/null +++ b/ortools/julia/ORTools.jl/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ortools/julia/ORTools.jl/Project.toml b/ortools/julia/ORTools.jl/Project.toml new file mode 100644 index 00000000000..33f069be367 --- /dev/null +++ b/ortools/julia/ORTools.jl/Project.toml @@ -0,0 +1,15 @@ +name = "ORTools" +uuid = "b7d69b34-a827-4671-8cfa-f7e1eec930c7" +version = "1.0.0-DEV" + +[deps] +MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" + +[compat] +julia = "1.6.7" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] diff --git a/ortools/julia/ORTools.jl/README.md b/ortools/julia/ORTools.jl/README.md new file mode 100644 index 00000000000..941908eba9f --- /dev/null +++ b/ortools/julia/ORTools.jl/README.md @@ -0,0 +1,6 @@ +# ORTools + +This is the +[MathOptInterface.jl](https://github.com/jump-dev/MathOptInterface.jl) Julia +wrapper for Google's +[MathOpt](https://developers.google.com/optimization/math_opt). diff --git a/ortools/julia/ORTools.jl/src/ORTools.jl b/ortools/julia/ORTools.jl/src/ORTools.jl new file mode 100644 index 00000000000..dcdf2dca762 --- /dev/null +++ b/ortools/julia/ORTools.jl/src/ORTools.jl @@ -0,0 +1,5 @@ +module ORTools + +# Write your package code here. + +end diff --git a/ortools/julia/ORTools.jl/test/runtests.jl b/ortools/julia/ORTools.jl/test/runtests.jl new file mode 100644 index 00000000000..8b2e109cf6e --- /dev/null +++ b/ortools/julia/ORTools.jl/test/runtests.jl @@ -0,0 +1,6 @@ +using ORTools +using Test + +@testset "ORTools.jl" begin + # Write your tests here. +end diff --git a/ortools/linear_solver/proto_solver/gurobi_proto_solver.cc b/ortools/linear_solver/proto_solver/gurobi_proto_solver.cc index 9f4cbf0734d..a13e20cdd9d 100644 --- a/ortools/linear_solver/proto_solver/gurobi_proto_solver.cc +++ b/ortools/linear_solver/proto_solver/gurobi_proto_solver.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -274,7 +275,7 @@ absl::Status SetSolverSpecificParameters(absl::string_view parameters, absl::StatusOr GurobiSolveProto( LazyMutableCopy request, GRBenv* gurobi_env) { MPSolutionResponse response; - const absl::optional> optional_model = + const std::optional> optional_model = GetMPModelOrPopulateResponse(request, &response); if (!optional_model) return response; const MPModelProto& model = **optional_model; diff --git a/ortools/linear_solver/proto_solver/highs_proto_solver.cc b/ortools/linear_solver/proto_solver/highs_proto_solver.cc index d015b5df31e..9e8d79c8c17 100644 --- a/ortools/linear_solver/proto_solver/highs_proto_solver.cc +++ b/ortools/linear_solver/proto_solver/highs_proto_solver.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -45,7 +46,7 @@ absl::Status SetSolverSpecificParameters(const std::string& parameters, absl::StatusOr HighsSolveProto( LazyMutableCopy request) { MPSolutionResponse response; - const absl::optional> optional_model = + const std::optional> optional_model = GetMPModelOrPopulateResponse(request, &response); if (!optional_model) return response; const MPModelProto& model = **optional_model; diff --git a/ortools/linear_solver/proto_solver/scip_proto_solver.cc b/ortools/linear_solver/proto_solver/scip_proto_solver.cc index f07b419a58f..5665adee003 100644 --- a/ortools/linear_solver/proto_solver/scip_proto_solver.cc +++ b/ortools/linear_solver/proto_solver/scip_proto_solver.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -692,7 +693,7 @@ std::string FindErrorInMPModelForScip(const MPModelProto& model, SCIP* scip) { absl::StatusOr ScipSolveProto( LazyMutableCopy request) { MPSolutionResponse response; - const absl::optional> optional_model = + const std::optional> optional_model = GetMPModelOrPopulateResponse(request, &response); if (!optional_model) return response; const MPModelProto& model = **optional_model; diff --git a/ortools/lp_data/mps_reader_template.h b/ortools/lp_data/mps_reader_template.h index b8e8e618993..439487d608b 100644 --- a/ortools/lp_data/mps_reader_template.h +++ b/ortools/lp_data/mps_reader_template.h @@ -501,7 +501,7 @@ class MPSReaderTemplate { // Parses a file in MPS format; if successful, returns the type of MPS // format detected (one of `kFree` or `kFixed`). If `form` is either `kFixed` // or `kFree`, the function will either return `kFixed` (or `kFree` - // respectivelly) if the input data satisfies the format, or an + // respectively) if the input data satisfies the format, or an // `absl::InvalidArgumentError` otherwise. absl::StatusOr ParseFile( absl::string_view file_name, DataWrapper* data, @@ -510,7 +510,7 @@ class MPSReaderTemplate { // Parses a string in MPS format; if successful, returns the type of MPS // format detected (one of `kFree` or `kFixed`). If `form` is either `kFixed` // or `kFree`, the function will either return `kFixed` (or `kFree` - // respectivelly) if the input data satisfies the format, or an + // respectively) if the input data satisfies the format, or an // `absl::InvalidArgumentError` otherwise. absl::StatusOr ParseString( absl::string_view source, DataWrapper* data, @@ -720,6 +720,7 @@ absl::Status MPSReaderTemplate::ProcessLine(absl::string_view line, } else { return line_info.InvalidArgumentError("Unknown section."); } + if (section_ == internal::MPSSectionId::kName) { // NOTE(user): The name may differ between fixed and free forms. In // fixed form, the name has at most 8 characters, and starts at a specific @@ -746,6 +747,21 @@ absl::Status MPSReaderTemplate::ProcessLine(absl::string_view line, data->SetName(fixed_name); } } + + // Supports the case where the direction is on the same line as the + // OBJSENSE keyword. + if (section_ == internal::MPSSectionId::kObjsense && + line_info.GetFieldsSize() == 2 && free_form_) { + if (absl::StrContains(line_info.GetField(1), "MIN")) { + data->SetObjectiveDirection(/*maximize=*/false); + } else if (absl::StrContains(line_info.GetField(1), "MAX")) { + data->SetObjectiveDirection(/*maximize=*/true); + } else { + return line_info.InvalidArgumentError( + "Invalid inline objective direction."); + } + } + return absl::OkStatus(); } switch (section_) { diff --git a/ortools/lp_data/sparse.cc b/ortools/lp_data/sparse.cc index 41ff043ee60..1928aa7d55a 100644 --- a/ortools/lp_data/sparse.cc +++ b/ortools/lp_data/sparse.cc @@ -463,15 +463,16 @@ void CompactSparseMatrix::PopulateFromMatrixView(const MatrixView& input) { void CompactSparseMatrix::PopulateFromSparseMatrixAndAddSlacks( const SparseMatrix& input) { - num_cols_ = input.num_cols() + RowToColIndex(input.num_rows()); + const int input_num_cols = input.num_cols().value(); + num_cols_ = input_num_cols + RowToColIndex(input.num_rows()); num_rows_ = input.num_rows(); const EntryIndex num_entries = input.num_entries() + EntryIndex(num_rows_.value()); starts_.assign(num_cols_ + 1, EntryIndex(0)); - coefficients_.assign(num_entries, 0.0); - rows_.assign(num_entries, RowIndex(0)); + coefficients_.resize(num_entries, 0.0); + rows_.resize(num_entries, RowIndex(0)); EntryIndex index(0); - for (ColIndex col(0); col < input.num_cols(); ++col) { + for (ColIndex col(0); col < input_num_cols; ++col) { starts_[col] = index; for (const SparseColumn::Entry e : input.column(col)) { coefficients_[index] = e.coefficient(); @@ -480,11 +481,12 @@ void CompactSparseMatrix::PopulateFromSparseMatrixAndAddSlacks( } } for (RowIndex row(0); row < num_rows_; ++row) { - starts_[input.num_cols() + RowToColIndex(row)] = index; + starts_[input_num_cols + RowToColIndex(row)] = index; coefficients_[index] = 1.0; rows_[index] = row; ++index; } + DCHECK_EQ(index, num_entries); starts_[num_cols_] = index; } @@ -496,11 +498,12 @@ void CompactSparseMatrix::PopulateFromTranspose( // Fill the starts_ vector by computing the number of entries of each rows and // then doing a cumulative sum. After this step starts_[col + 1] will be the // actual start of the column col when we are done. - starts_.assign(num_cols_ + 2, EntryIndex(0)); + const ColIndex start_size = num_cols_ + 2; + starts_.assign(start_size, EntryIndex(0)); for (const RowIndex row : input.rows_) { ++starts_[RowToColIndex(row) + 2]; } - for (ColIndex col(2); col < starts_.size(); ++col) { + for (ColIndex col(2); col < start_size; ++col) { starts_[col] += starts_[col - 1]; } coefficients_.resize(starts_.back(), 0.0); @@ -662,12 +665,13 @@ void TriangularMatrix::CloseCurrentColumn(Fractional diagonal_value) { // TODO(user): This is currently not used by all matrices. It will be good // to fill it only when needed. DCHECK_LT(num_cols_, pruned_ends_.size()); - pruned_ends_[num_cols_] = coefficients_.size(); + const EntryIndex num_entries = coefficients_.size(); + pruned_ends_[num_cols_] = num_entries; ++num_cols_; DCHECK_LT(num_cols_, starts_.size()); - starts_[num_cols_] = coefficients_.size(); - if (first_non_identity_column_ == num_cols_ - 1 && coefficients_.empty() && - diagonal_value == 1.0) { + starts_[num_cols_] = num_entries; + if (first_non_identity_column_ == num_cols_ - 1 && diagonal_value == 1.0 && + num_entries == 0) { first_non_identity_column_ = num_cols_; } all_diagonal_coefficients_are_one_ = diff --git a/ortools/pdlp/BUILD.bazel b/ortools/pdlp/BUILD.bazel index 5b68856f23b..d1e2b010988 100644 --- a/ortools/pdlp/BUILD.bazel +++ b/ortools/pdlp/BUILD.bazel @@ -19,8 +19,23 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "scheduler", + srcs = ["scheduler.cc"], hdrs = ["scheduler.h"], deps = [ + ":solvers_cc_proto", + "//ortools/base:threadpool", + "@com_google_absl//absl/functional:any_invocable", + "@eigen//:eigen3", + ], +) + +cc_test( + name = "scheduler_test", + srcs = ["scheduler_test.cc"], + deps = [ + ":gtest_main", + ":scheduler", + ":solvers_cc_proto", "@com_google_absl//absl/functional:any_invocable", ], ) @@ -258,9 +273,10 @@ cc_library( hdrs = ["sharded_quadratic_program.h"], deps = [ ":quadratic_program", + ":scheduler", ":sharder", + ":solvers_cc_proto", "//ortools/base", - "//ortools/base:threadpool", "//ortools/util:logging", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -286,9 +302,9 @@ cc_library( srcs = ["sharder.cc"], hdrs = ["sharder.h"], deps = [ + ":scheduler", "//ortools/base", "//ortools/base:mathutil", - "//ortools/base:threadpool", "//ortools/base:timer", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -302,10 +318,10 @@ cc_test( srcs = ["sharder_test.cc"], deps = [ ":gtest_main", + ":scheduler", ":sharder", "//ortools/base", "//ortools/base:mathutil", - "//ortools/base:threadpool", "@com_google_absl//absl/random:distributions", "@eigen//:eigen3", ], diff --git a/ortools/pdlp/primal_dual_hybrid_gradient.cc b/ortools/pdlp/primal_dual_hybrid_gradient.cc index 7b2a8013aa4..28daaf173b8 100644 --- a/ortools/pdlp/primal_dual_hybrid_gradient.cc +++ b/ortools/pdlp/primal_dual_hybrid_gradient.cc @@ -718,7 +718,8 @@ PreprocessSolver::PreprocessSolver(QuadraticProgram qp, : num_threads_( NumThreads(params.num_threads(), params.num_shards(), qp, *logger)), num_shards_(NumShards(num_threads_, params.num_shards())), - sharded_qp_(std::move(qp), num_threads_, num_shards_), + sharded_qp_(std::move(qp), num_threads_, num_shards_, + params.scheduler_type(), nullptr), logger_(*logger) {} SolverResult ErrorSolverResult(const TerminationReason reason, diff --git a/ortools/pdlp/samples/code_samples.bzl b/ortools/pdlp/samples/code_samples.bzl index 12d7480e283..9dc8cab8e82 100644 --- a/ortools/pdlp/samples/code_samples.bzl +++ b/ortools/pdlp/samples/code_samples.bzl @@ -13,8 +13,10 @@ """Helper macro to compile and test code samples.""" +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_test") + def code_sample_cc(name): - native.cc_binary( + cc_binary( name = name + "_cc", srcs = [name + ".cc"], deps = [ @@ -28,7 +30,7 @@ def code_sample_cc(name): ], ) - native.cc_test( + cc_test( name = name + "_cc_test", size = "small", srcs = [name + ".cc"], diff --git a/ortools/pdlp/scheduler.cc b/ortools/pdlp/scheduler.cc new file mode 100644 index 00000000000..c24dad4a0c0 --- /dev/null +++ b/ortools/pdlp/scheduler.cc @@ -0,0 +1,34 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/pdlp/scheduler.h" + +#include + +#include "ortools/pdlp/solvers.pb.h" + +namespace operations_research::pdlp { + +// Convenience factory function. +std::unique_ptr MakeScheduler(SchedulerType type, int num_threads) { + switch (type) { + case SchedulerType::SCHEDULER_TYPE_GOOGLE_THREADPOOL: + return std::make_unique(num_threads); + case SchedulerType::SCHEDULER_TYPE_EIGEN_THREADPOOL: + return std::make_unique(num_threads); + default: + return nullptr; + } +} + +} // namespace operations_research::pdlp diff --git a/ortools/pdlp/scheduler.h b/ortools/pdlp/scheduler.h index 8df1c44bdce..ebc791bbe1d 100644 --- a/ortools/pdlp/scheduler.h +++ b/ortools/pdlp/scheduler.h @@ -14,9 +14,19 @@ #ifndef PDLP_SCHEDULER_H_ #define PDLP_SCHEDULER_H_ +// Eigen defaults to using TensorFlow's scheduler, unless we add this line. +#ifndef EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_CUSTOM_THREAD_POOL +#endif + +#include #include #include "absl/functional/any_invocable.h" +#include "absl/synchronization/blocking_counter.h" +#include "ortools/base/threadpool.h" +#include "ortools/pdlp/solvers.pb.h" +#include "unsupported/Eigen/CXX11/ThreadPool" namespace operations_research::pdlp { @@ -32,6 +42,63 @@ class Scheduler { absl::AnyInvocable do_func) = 0; }; +// Google3 ThreadPool scheduler with barrier synchronization. +class GoogleThreadPoolScheduler : public Scheduler { + public: + GoogleThreadPoolScheduler(int num_threads) + : num_threads_(num_threads), + threadpool_(std::make_unique("pdlp", num_threads)) { + threadpool_->StartWorkers(); + } + int num_threads() const override { return num_threads_; }; + std::string info_string() const override { return "google_threadpool"; }; + + void ParallelFor(int start, int end, + absl::AnyInvocable do_func) override { + absl::BlockingCounter counter(end - start); + for (int i = start; i < end; ++i) { + threadpool_->Schedule([&, i]() { + do_func(i); + counter.DecrementCount(); + }); + } + counter.Wait(); + } + + private: + const int num_threads_; + std::unique_ptr threadpool_ = nullptr; +}; + +// Eigen ThreadPool scheduler with barrier synchronization. +class EigenThreadPoolScheduler : public Scheduler { + public: + EigenThreadPoolScheduler(int num_threads) + : num_threads_(num_threads), + eigen_threadpool_(std::make_unique(num_threads)) {} + int num_threads() const override { return num_threads_; }; + std::string info_string() const override { return "eigen_threadpool"; }; + + void ParallelFor(int start, int end, + absl::AnyInvocable do_func) override { + Eigen::Barrier eigen_barrier(end - start); + for (int i = start; i < end; ++i) { + eigen_threadpool_->Schedule([&, i]() { + do_func(i); + eigen_barrier.Notify(); + }); + } + eigen_barrier.Wait(); + } + + private: + const int num_threads_; + std::unique_ptr eigen_threadpool_ = nullptr; +}; + +// Makes a scheduler of a given type. +std::unique_ptr MakeScheduler(SchedulerType type, int num_threads); + } // namespace operations_research::pdlp #endif // PDLP_SCHEDULER_H_ diff --git a/ortools/pdlp/scheduler_test.cc b/ortools/pdlp/scheduler_test.cc new file mode 100644 index 00000000000..de33007fda4 --- /dev/null +++ b/ortools/pdlp/scheduler_test.cc @@ -0,0 +1,103 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/pdlp/scheduler.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/pdlp/solvers.pb.h" + +namespace operations_research::pdlp { + +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::TestWithParam; + +struct SchedulerTestCase { + std::string test_name; + SchedulerType type; + int num_threads; +}; + +using SchedulerTest = TestWithParam; + +TEST(SchedulerTest, CheckUnspecifiedSchedulerReturnsNullptr) { + std::unique_ptr scheduler = + MakeScheduler(SCHEDULER_TYPE_UNSPECIFIED, 1); + EXPECT_THAT(scheduler, IsNull()); +} + +TEST_P(SchedulerTest, CheckThreadCount) { + const SchedulerTestCase& test_case = GetParam(); + std::unique_ptr scheduler = + MakeScheduler(test_case.type, test_case.num_threads); + ASSERT_THAT(scheduler, NotNull()); + EXPECT_THAT(scheduler->num_threads(), Eq(test_case.num_threads)); +} + +TEST_P(SchedulerTest, CheckInfoString) { + const SchedulerTestCase& test_case = GetParam(); + std::unique_ptr scheduler = + MakeScheduler(test_case.type, test_case.num_threads); + ASSERT_THAT(scheduler, NotNull()); + if (test_case.type == SchedulerType::SCHEDULER_TYPE_GOOGLE_THREADPOOL) { + EXPECT_THAT(scheduler->info_string(), Eq("google_threadpool")); + } else if (test_case.type == SchedulerType::SCHEDULER_TYPE_EIGEN_THREADPOOL) { + EXPECT_THAT(scheduler->info_string(), Eq("eigen_threadpool")); + } else { + FAIL() << "Invalid test_case type: " << test_case.type; + } +} + +TEST_P(SchedulerTest, CheckParallelVectorSum) { + const SchedulerTestCase& test_case = GetParam(); + const int num_shards = 100000; // High enough to catch race conditions. + std::unique_ptr scheduler = + MakeScheduler(test_case.type, test_case.num_threads); + ASSERT_THAT(scheduler, NotNull()); + const std::vector data(num_shards, 1.0); + std::atomic sum = 0.0; + // Adds `x` to `sum` using a CAS loop. + std::function do_fn = [&](int i) { + for (double new_sum = sum; + !sum.compare_exchange_weak(new_sum, new_sum + data[i]);) { + }; + }; + scheduler->ParallelFor(0, num_shards, do_fn); + EXPECT_THAT(sum, Eq(num_shards)); +} + +INSTANTIATE_TEST_SUITE_P( + SchedulerTests, SchedulerTest, + testing::ValuesIn({ + {"GoogleThreadPool2", SCHEDULER_TYPE_GOOGLE_THREADPOOL, 2}, + {"GoogleThreadPool4", SCHEDULER_TYPE_GOOGLE_THREADPOOL, 4}, + {"EigenThreadPool2", SCHEDULER_TYPE_EIGEN_THREADPOOL, 2}, + {"EigenThreadPool4", SCHEDULER_TYPE_EIGEN_THREADPOOL, 4}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace operations_research::pdlp diff --git a/ortools/pdlp/sharded_quadratic_program.cc b/ortools/pdlp/sharded_quadratic_program.cc index d4f72b31060..d00198f735e 100644 --- a/ortools/pdlp/sharded_quadratic_program.cc +++ b/ortools/pdlp/sharded_quadratic_program.cc @@ -14,6 +14,7 @@ #include "ortools/pdlp/sharded_quadratic_program.h" #include +#include #include #include #include @@ -23,9 +24,10 @@ #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "ortools/base/logging.h" -#include "ortools/base/threadpool.h" #include "ortools/pdlp/quadratic_program.h" +#include "ortools/pdlp/scheduler.h" #include "ortools/pdlp/sharder.h" +#include "ortools/pdlp/solvers.pb.h" #include "ortools/util/logging.h" namespace operations_research::pdlp { @@ -76,24 +78,22 @@ void WarnIfMatrixUnbalanced( ShardedQuadraticProgram::ShardedQuadraticProgram( QuadraticProgram qp, const int num_threads, const int num_shards, - operations_research::SolverLogger* logger) + SchedulerType scheduler_type, operations_research::SolverLogger* logger) : qp_(std::move(qp)), transposed_constraint_matrix_(qp_.constraint_matrix.transpose()), - thread_pool_(num_threads == 1 - ? nullptr - : std::make_unique("PDLP", num_threads)), + scheduler_(num_threads == 1 ? nullptr + : MakeScheduler(scheduler_type, num_threads)), constraint_matrix_sharder_(qp_.constraint_matrix, num_shards, - thread_pool_.get()), + scheduler_.get()), transposed_constraint_matrix_sharder_(transposed_constraint_matrix_, - num_shards, thread_pool_.get()), + num_shards, scheduler_.get()), primal_sharder_(qp_.variable_lower_bounds.size(), num_shards, - thread_pool_.get()), + scheduler_.get()), dual_sharder_(qp_.constraint_lower_bounds.size(), num_shards, - thread_pool_.get()) { + scheduler_.get()) { CHECK_GE(num_threads, 1); CHECK_GE(num_shards, num_threads); if (num_threads > 1) { - thread_pool_->StartWorkers(); const int64_t work_per_iteration = qp_.constraint_matrix.nonZeros() + qp_.variable_lower_bounds.size() + qp_.constraint_lower_bounds.size(); diff --git a/ortools/pdlp/sharded_quadratic_program.h b/ortools/pdlp/sharded_quadratic_program.h index 9a80ec5f3a6..fa5c34a91e6 100644 --- a/ortools/pdlp/sharded_quadratic_program.h +++ b/ortools/pdlp/sharded_quadratic_program.h @@ -17,13 +17,13 @@ #include #include #include -#include #include "Eigen/Core" #include "Eigen/SparseCore" -#include "ortools/base/threadpool.h" #include "ortools/pdlp/quadratic_program.h" +#include "ortools/pdlp/scheduler.h" #include "ortools/pdlp/sharder.h" +#include "ortools/pdlp/solvers.pb.h" #include "ortools/util/logging.h" namespace operations_research::pdlp { @@ -31,7 +31,7 @@ namespace operations_research::pdlp { // This class stores: // - A `QuadraticProgram` (QP) // - A transposed version of the QP's constraint matrix -// - A thread pool +// - A thread scheduler // - Various `Sharder` objects for doing sharded matrix and vector // computations. class ShardedQuadraticProgram { @@ -40,8 +40,10 @@ class ShardedQuadraticProgram { // Note that the `qp` is intentionally passed by value. // If `logger` is not nullptr, warns about unbalanced matrices using it; // otherwise warns via Google standard logging. - ShardedQuadraticProgram(QuadraticProgram qp, int num_threads, int num_shards, - operations_research::SolverLogger* logger = nullptr); + ShardedQuadraticProgram( + QuadraticProgram qp, int num_threads, int num_shards, + SchedulerType scheduler_type = SCHEDULER_TYPE_GOOGLE_THREADPOOL, + operations_research::SolverLogger* logger = nullptr); // Movable but not copyable. ShardedQuadraticProgram(const ShardedQuadraticProgram&) = delete; @@ -114,7 +116,7 @@ class ShardedQuadraticProgram { QuadraticProgram qp_; Eigen::SparseMatrix transposed_constraint_matrix_; - std::unique_ptr thread_pool_; + std::unique_ptr scheduler_; Sharder constraint_matrix_sharder_; Sharder transposed_constraint_matrix_sharder_; Sharder primal_sharder_; diff --git a/ortools/pdlp/sharder.cc b/ortools/pdlp/sharder.cc index 308b7e0bf52..2ff72ce80a9 100644 --- a/ortools/pdlp/sharder.cc +++ b/ortools/pdlp/sharder.cc @@ -26,17 +26,17 @@ #include "absl/time/time.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" -#include "ortools/base/threadpool.h" #include "ortools/base/timer.h" +#include "ortools/pdlp/scheduler.h" namespace operations_research::pdlp { using ::Eigen::VectorXd; Sharder::Sharder(const int64_t num_elements, const int num_shards, - ThreadPool* const thread_pool, + Scheduler* const scheduler, const std::function& element_mass) - : thread_pool_(thread_pool) { + : scheduler_(scheduler) { CHECK_GE(num_elements, 0); if (num_elements == 0) { shard_starts_.push_back(0); @@ -70,8 +70,8 @@ Sharder::Sharder(const int64_t num_elements, const int num_shards, } Sharder::Sharder(const int64_t num_elements, const int num_shards, - ThreadPool* const thread_pool) - : thread_pool_(thread_pool) { + Scheduler* const scheduler) + : scheduler_(scheduler) { CHECK_GE(num_elements, 0); if (num_elements == 0) { shard_starts_.push_back(0); @@ -104,34 +104,30 @@ Sharder::Sharder(const Sharder& other_sharder, const int64_t num_elements) // The `std::max()` protects against `other_sharder.NumShards() == 0`, which // will happen if `other_sharder` had `num_elements == 0`. : Sharder(num_elements, std::max(1, other_sharder.NumShards()), - other_sharder.thread_pool_) {} + other_sharder.scheduler_) {} void Sharder::ParallelForEachShard( const std::function& func) const { - if (thread_pool_) { + if (scheduler_) { absl::BlockingCounter counter(NumShards()); VLOG(2) << "Starting ParallelForEachShard()"; - for (int shard_num = 0; shard_num < NumShards(); ++shard_num) { - thread_pool_->Schedule([&, shard_num]() { - WallTimer timer; - if (VLOG_IS_ON(2)) { - timer.Start(); - } - func(Shard(shard_num, this)); - if (VLOG_IS_ON(2)) { - timer.Stop(); - VLOG(2) << "Shard " << shard_num << " with " << ShardSize(shard_num) - << " elements and " << ShardMass(shard_num) - << " mass finished with " - << ShardMass(shard_num) / - std::max(int64_t{1}, absl::ToInt64Microseconds( - timer.GetDuration())) - << " mass/usec."; - } - counter.DecrementCount(); - }); - } - counter.Wait(); + scheduler_->ParallelFor(0, NumShards(), [&](int shard_num) { + WallTimer timer; + if (VLOG_IS_ON(2)) { + timer.Start(); + } + func(Shard(shard_num, this)); + if (VLOG_IS_ON(2)) { + timer.Stop(); + VLOG(2) << "Shard " << shard_num << " with " << ShardSize(shard_num) + << " elements and " << ShardMass(shard_num) + << " mass finished with " + << ShardMass(shard_num) / + std::max(int64_t{1}, + absl::ToInt64Microseconds(timer.GetDuration())) + << " mass/usec."; + } + }); VLOG(2) << "Done ParallelForEachShard()"; } else { for (int shard_num = 0; shard_num < NumShards(); ++shard_num) { diff --git a/ortools/pdlp/sharder.h b/ortools/pdlp/sharder.h index 2187be37950..877d5436205 100644 --- a/ortools/pdlp/sharder.h +++ b/ortools/pdlp/sharder.h @@ -22,7 +22,7 @@ #include "Eigen/Core" #include "Eigen/SparseCore" #include "absl/log/check.h" -#include "ortools/base/threadpool.h" +#include "ortools/pdlp/scheduler.h" namespace operations_research::pdlp { @@ -141,26 +141,26 @@ class Sharder { // Creates a `Sharder` for problems with `num_elements` elements and mass of // each element given by `element_mass`. Each shard will have roughly the same // mass. The number of shards in the resulting `Sharder` will be approximately - // `num_shards` but may differ. The `thread_pool` will be used for parallel - // operations executed by e.g. `ParallelForEachShard()`. The `thread_pool` may + // `num_shards` but may differ. The `scheduler` will be used for parallel + // operations executed by e.g. `ParallelForEachShard()`. The `scheduler` may // be nullptr, which means work will be executed in the same thread. If - // `thread_pool` is not nullptr, the underlying object is not owned and must + // `scheduler` is not nullptr, the underlying object is not owned and must // outlive the `Sharder`. - Sharder(int64_t num_elements, int num_shards, ThreadPool* thread_pool, + Sharder(int64_t num_elements, int num_shards, Scheduler* scheduler, const std::function& element_mass); // Creates a `Sharder` for problems with `num_elements` elements and unit // mass. This constructor exploits having all element mass equal to 1 to take // time proportional to `num_shards` instead of `num_elements`. Also see the // comments above the first constructor. - Sharder(int64_t num_elements, int num_shards, ThreadPool* thread_pool); + Sharder(int64_t num_elements, int num_shards, Scheduler* scheduler); // Creates a `Sharder` for processing `matrix`. The elements correspond to // columns of `matrix` and have mass linear in the number of non-zeros. Also // see the comments above the first constructor. Sharder(const Eigen::SparseMatrix& matrix, - int num_shards, ThreadPool* thread_pool) - : Sharder(matrix.cols(), num_shards, thread_pool, [&matrix](int64_t col) { + int num_shards, Scheduler* scheduler) + : Sharder(matrix.cols(), num_shards, scheduler, [&matrix](int64_t col) { return 1 + 1 * matrix.col(col).nonZeros(); }) {} @@ -227,7 +227,7 @@ class Sharder { // Size: `NumShards()`. The mass of each shard. std::vector shard_masses_; // NOT owned. May be nullptr. - ThreadPool* thread_pool_; + Scheduler* scheduler_; }; // Like `matrix.transpose() * vector` but executed in parallel using `sharder`. diff --git a/ortools/pdlp/sharder_test.cc b/ortools/pdlp/sharder_test.cc index fb37b34881c..e77274e2781 100644 --- a/ortools/pdlp/sharder_test.cc +++ b/ortools/pdlp/sharder_test.cc @@ -27,7 +27,7 @@ #include "ortools/base/gmock.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" -#include "ortools/base/threadpool.h" +#include "ortools/pdlp/scheduler.h" namespace operations_research::pdlp { namespace { @@ -434,9 +434,8 @@ TEST_P(VariousSizesTest, LargeMatVec) { LargeSparseMatrix(size); const int num_threads = 5; const int shards_per_thread = 3; - ThreadPool pool("MatrixVectorProductTest", num_threads); - pool.StartWorkers(); - Sharder sharder(mat, shards_per_thread * num_threads, &pool); + GoogleThreadPoolScheduler scheduler(num_threads); + Sharder sharder(mat, shards_per_thread * num_threads, &scheduler); VectorXd rhs = VectorXd::Random(size); VectorXd direct = mat.transpose() * rhs; VectorXd threaded = TransposedMatrixVectorProduct(mat, rhs, sharder); @@ -446,9 +445,8 @@ TEST_P(VariousSizesTest, LargeMatVec) { TEST_P(VariousSizesTest, LargeVectors) { const int64_t size = GetParam(); const int num_threads = 5; - ThreadPool pool("SquaredNormTest", num_threads); - pool.StartWorkers(); - Sharder sharder(size, num_threads, &pool); + GoogleThreadPoolScheduler scheduler(num_threads); + Sharder sharder(size, num_threads, &scheduler); VectorXd vec = VectorXd::Random(size); const double direct = vec.squaredNorm(); const double threaded = SquaredNorm(vec, sharder); diff --git a/ortools/pdlp/solvers.proto b/ortools/pdlp/solvers.proto index 215dfd594db..fefbdfe4091 100644 --- a/ortools/pdlp/solvers.proto +++ b/ortools/pdlp/solvers.proto @@ -40,6 +40,16 @@ enum OptimalityNorm { OPTIMALITY_NORM_L_INF_COMPONENTWISE = 3; } +// The type of system used to schedule CPU threads to do work in parallel. +enum SchedulerType { + SCHEDULER_TYPE_UNSPECIFIED = 0; + // Google ThreadPool with barrier synchronization. + SCHEDULER_TYPE_GOOGLE_THREADPOOL = 1; + // Eigen non-blocking ThreadPool with barrier synchronization (see + // ). + SCHEDULER_TYPE_EIGEN_THREADPOOL = 3; +} + // A description of solver termination criteria. The criteria are defined in // terms of the quantities recorded in IterationStats in solve_log.proto. @@ -285,6 +295,11 @@ message PrimalDualHybridGradientParams { // Otherwise a default that depends on num_threads will be used. optional int32 num_shards = 27 [default = 0]; + // The type of scheduler used for CPU multi-threading. See the documentation + // of the corresponding enum for more details. + optional SchedulerType scheduler_type = 32 + [default = SCHEDULER_TYPE_GOOGLE_THREADPOOL]; + // If true, the iteration_stats field of the SolveLog output will be populated // at every iteration. Note that we only compute solution statistics at // termination checks. Setting this parameter to true may substantially diff --git a/ortools/port/proto_utils.h b/ortools/port/proto_utils.h index 535d6343126..ba0464b8ca6 100644 --- a/ortools/port/proto_utils.h +++ b/ortools/port/proto_utils.h @@ -58,7 +58,7 @@ std::string ProtoEnumToString(ProtoEnumType enum_value) { "Invalid enum value of: ", enum_value, " for enum type: ", google::protobuf::GetEnumDescriptor()->name()); } - return enum_value_descriptor->name(); + return std::string(enum_value_descriptor->name()); #endif // !defined(__PORTABLE_PLATFORM__) } diff --git a/ortools/python/setup.py.in b/ortools/python/setup.py.in index e0627a3aea5..e70211f9335 100644 --- a/ortools/python/setup.py.in +++ b/ortools/python/setup.py.in @@ -137,6 +137,7 @@ setup( 'Operating System :: Unix', 'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: BSD :: FreeBSD', + 'Operating System :: POSIX :: BSD :: NetBSD', 'Operating System :: POSIX :: BSD :: OpenBSD', 'Operating System :: MacOS', 'Operating System :: MacOS :: MacOS X', diff --git a/ortools/sat/2d_orthogonal_packing_testing.cc b/ortools/sat/2d_orthogonal_packing_testing.cc index 597e718db36..4fc9789b1c6 100644 --- a/ortools/sat/2d_orthogonal_packing_testing.cc +++ b/ortools/sat/2d_orthogonal_packing_testing.cc @@ -158,10 +158,12 @@ std::vector MakeItemsFromRectangles( ranges.reserve(rectangles.size()); const int max_slack_x = slack_factor * size_max_x.value(); const int max_slack_y = slack_factor * size_max_y.value(); + int count = 0; for (const Rectangle& rec : rectangles) { RectangleInRange range; range.x_size = rec.x_max - rec.x_min; range.y_size = rec.y_max - rec.y_min; + range.box_index = count++; range.bounding_area = { .x_min = rec.x_min - IntegerValue(absl::Uniform(random, 0, max_slack_x)), diff --git a/ortools/sat/2d_rectangle_presolve.cc b/ortools/sat/2d_rectangle_presolve.cc index e454c86cd9e..8db1a1e2f84 100644 --- a/ortools/sat/2d_rectangle_presolve.cc +++ b/ortools/sat/2d_rectangle_presolve.cc @@ -14,6 +14,7 @@ #include "ortools/sat/2d_rectangle_presolve.h" #include +#include #include #include #include @@ -22,6 +23,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -29,6 +31,7 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/base/stl_util.h" +#include "ortools/graph/max_flow.h" #include "ortools/graph/strongly_connected_components.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer.h" @@ -147,7 +150,7 @@ bool PresolveFixed2dRectangles( if (!new_box.IsDisjoint(existing_box)) { is_disjoint = false; for (const Rectangle& disjoint_box : - new_box.SetDifference(existing_box)) { + new_box.RegionDifference(existing_box)) { to_add.push_back(disjoint_box); } break; @@ -207,7 +210,30 @@ bool PresolveFixed2dRectangles( optional_boxes.erase(optional_boxes.begin(), optional_boxes.begin() + num_optional_boxes_to_remove); - if (ReduceNumberofBoxes(fixed_boxes, &optional_boxes)) { + // TODO(user): instead of doing the greedy algorithm first with optional + // boxes, and then the one that is exact for mandatory boxes but weak for + // optional ones, refactor the second algorithm. One possible way of doing + // that would be to follow the shape boundary of optional+mandatory boxes and + // look whether we can shave off some turns. For example, if we have a shape + // like below, with the "+" representing area covered by optional boxes, we + // can replace the turns by a straight line. + // + // --> + // ^ ++++ + // . ++++ . + // . ++++ . => + // ++++ \/ + // --> ++++ --> --> + // *********** *********** + // *********** *********** + // + // Since less turns means less edges, this should be a good way to reduce the + // number of boxes. + if (ReduceNumberofBoxesGreedy(fixed_boxes, &optional_boxes)) { + changed = true; + } + const int num_after_first_pass = fixed_boxes->size(); + if (ReduceNumberOfBoxesExactMandatory(fixed_boxes, &optional_boxes)) { changed = true; } if (changed && VLOG_IS_ON(1)) { @@ -217,8 +243,8 @@ bool PresolveFixed2dRectangles( } VLOG_EVERY_N_SEC(1, 1) << "Presolved " << original_num_boxes << " fixed rectangles (area=" << original_area - << ") into " << fixed_boxes->size() - << " (area=" << area << ")"; + << ") into " << num_after_first_pass << " then " + << fixed_boxes->size() << " (area=" << area << ")"; VLOG_EVERY_N_SEC(2, 2) << "Presolved rectangles:\n" << RenderDot(bounding_box, fixed_boxes_copy) @@ -281,18 +307,10 @@ struct Edge { }; } // namespace -bool ReduceNumberofBoxes(std::vector* mandatory_rectangles, - std::vector* optional_rectangles) { +bool ReduceNumberofBoxesGreedy(std::vector* mandatory_rectangles, + std::vector* optional_rectangles) { // The current implementation just greedly merge rectangles that shares an - // edge. This is far from optimal, and it exists a polynomial optimal - // algorithm (see page 3 of [1]) for this problem at least for the case where - // optional_rectangles is empty. - // - // TODO(user): improve - // - // [1] Eppstein, David. "Graph-theoretic solutions to computational geometry - // problems." International Workshop on Graph-Theoretic Concepts in Computer - // Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2009. + // edge. std::vector> rectangle_storage; enum class OptionalEnum { OPTIONAL, MANDATORY }; // bool for is_optional @@ -534,357 +552,889 @@ std::vector> SplitInConnectedComponents( return components; } -struct ContourPoint { - IntegerValue x; - IntegerValue y; - int next_box_index; - EdgePosition next_direction; - - bool operator!=(const ContourPoint& other) const { - return x != other.x || y != other.y || - next_box_index != other.next_box_index || - next_direction != other.next_direction; +namespace { +IntegerValue GetClockwiseStart(EdgePosition edge, const Rectangle& rectangle) { + switch (edge) { + case EdgePosition::LEFT: + return rectangle.y_min; + case EdgePosition::RIGHT: + return rectangle.y_max; + case EdgePosition::BOTTOM: + return rectangle.x_max; + case EdgePosition::TOP: + return rectangle.x_min; } -}; +} -// This function runs in O(log N). -ContourPoint NextByClockwiseOrder(const ContourPoint& point, - absl::Span rectangles, - const Neighbours& neighbours) { - // This algorithm is very verbose, but it is about handling four cases. In the - // schema below, "-->" is the current direction, "X" the next point and - // the dashed arrow the next direction. - // - // Case 1: - // ++++++++ - // ^ ++++++++ - // : ++++++++ - // : ++++++++ - // ++++++++ - // ---> X ++++++++ - // ****************** - // ****************** - // ****************** - // ****************** - // - // Case 2: - // ^ ++++++++ - // : ++++++++ - // : ++++++++ - // ++++++++ - // ---> X ++++++++ - // *************++++++++ - // *************++++++++ - // ************* - // ************* - // - // Case 3: - // ---> X ...> - // *************++++++++ - // *************++++++++ - // *************++++++++ - // *************++++++++ - // - // Case 4: - // ---> X - // ************* : - // ************* : - // ************* : - // ************* \/ - ContourPoint result; - const Rectangle& cur_rectangle = rectangles[point.next_box_index]; - - EdgePosition cur_edge; - bool clockwise; - // Much of the code below need to know two things: in which direction we are - // going and what edge of which rectangle we are touching. For example, in the - // "Case 4" drawing above we are going RIGHT and touching the TOP edge of the - // current rectangle. This switch statement finds this `cur_edge`. - switch (point.next_direction) { +IntegerValue GetClockwiseEnd(EdgePosition edge, const Rectangle& rectangle) { + switch (edge) { + case EdgePosition::LEFT: + return rectangle.y_max; + case EdgePosition::RIGHT: + return rectangle.y_min; + case EdgePosition::BOTTOM: + return rectangle.x_min; case EdgePosition::TOP: - if (cur_rectangle.x_max == point.x) { - cur_edge = EdgePosition::RIGHT; - clockwise = false; - } else { - cur_edge = EdgePosition::LEFT; - clockwise = true; + return rectangle.x_max; + } +} + +// Given a list of rectangles and their neighbours graph, find the list of +// vertical and horizontal segments that touches a single rectangle edge. Or, +// view in another way, the pieces of an edge that is touching the empty space. +// For example, this corresponds to the "0" segments in the example below: +// +// 000000 +// 0****0 000000 +// 0****0 0****0 +// 0****0 0****0 +// 00******00000****00000 +// 0********************0 +// 0********************0 +// 0000000000000000000000 +void GetAllSegmentsTouchingVoid( + absl::Span rectangles, const Neighbours& neighbours, + std::vector>& vertical_edges_on_boundary, + std::vector>& horizontal_edges_on_boundary) { + for (int i = 0; i < rectangles.size(); ++i) { + const Rectangle& rectangle = rectangles[i]; + for (int edge_int = 0; edge_int < 4; ++edge_int) { + const EdgePosition edge = static_cast(edge_int); + const auto box_neighbors = neighbours.GetSortedNeighbors(i, edge); + if (box_neighbors.empty()) { + if (edge == EdgePosition::LEFT || edge == EdgePosition::RIGHT) { + vertical_edges_on_boundary.push_back( + {Edge::GetEdge(rectangle, edge), i}); + } else { + horizontal_edges_on_boundary.push_back( + {Edge::GetEdge(rectangle, edge), i}); + } + continue; } - break; - case EdgePosition::BOTTOM: - if (cur_rectangle.x_min == point.x) { - cur_edge = EdgePosition::LEFT; - clockwise = false; - } else { - cur_edge = EdgePosition::RIGHT; - clockwise = true; + IntegerValue previous_pos = GetClockwiseStart(edge, rectangle); + for (int n = 0; n <= box_neighbors.size(); ++n) { + IntegerValue neighbor_start; + const Rectangle* neighbor; + if (n == box_neighbors.size()) { + // On the last iteration we consider instead of the next neighbor the + // end of the current box. + neighbor_start = GetClockwiseEnd(edge, rectangle); + } else { + const int neighbor_idx = box_neighbors[n]; + neighbor = &rectangles[neighbor_idx]; + neighbor_start = GetClockwiseStart(edge, *neighbor); + } + switch (edge) { + case EdgePosition::LEFT: + if (neighbor_start > previous_pos) { + vertical_edges_on_boundary.push_back( + {Edge{.x_start = rectangle.x_min, + .y_start = previous_pos, + .size = neighbor_start - previous_pos}, + i}); + } + break; + case EdgePosition::RIGHT: + if (neighbor_start < previous_pos) { + vertical_edges_on_boundary.push_back( + {Edge{.x_start = rectangle.x_max, + .y_start = neighbor_start, + .size = previous_pos - neighbor_start}, + i}); + } + break; + case EdgePosition::BOTTOM: + if (neighbor_start < previous_pos) { + horizontal_edges_on_boundary.push_back( + {Edge{.x_start = neighbor_start, + .y_start = rectangle.y_min, + .size = previous_pos - neighbor_start}, + i}); + } + break; + case EdgePosition::TOP: + if (neighbor_start > previous_pos) { + horizontal_edges_on_boundary.push_back( + {Edge{.x_start = previous_pos, + .y_start = rectangle.y_max, + .size = neighbor_start - previous_pos}, + i}); + } + break; + } + if (n != box_neighbors.size()) { + previous_pos = GetClockwiseEnd(edge, *neighbor); + } } + } + } +} + +// Trace a boundary (interior or exterior) that contains the edge described by +// starting_edge_position and starting_step_point. This method removes the edges +// that were added to the boundary from `segments_to_follow`. +ShapePath TraceBoundary( + const EdgePosition& starting_edge_position, + std::pair starting_step_point, + std::array, + std::pair>, + 4>& segments_to_follow) { + // The boundary is composed of edges on the `segments_to_follow` map. So all + // we need is to find and glue them together on the right order. + ShapePath path; + + auto extracted = + segments_to_follow[starting_edge_position].extract(starting_step_point); + CHECK(!extracted.empty()); + const int first_index = extracted.mapped().second; + + std::pair cur = starting_step_point; + int cur_index = first_index; + // Now we navigate from one edge to the next. To avoid going back, we remove + // used edges from the hash map. + while (true) { + path.step_points.push_back(cur); + + bool can_go[4] = {false, false, false, false}; + EdgePosition direction_to_take = EdgePosition::LEFT; + for (int edge_int = 0; edge_int < 4; ++edge_int) { + const EdgePosition edge = static_cast(edge_int); + if (segments_to_follow[edge].contains(cur)) { + can_go[edge] = true; + direction_to_take = edge; + } + } + + if (can_go == absl::Span{false, false, false, false}) { + // Cannot move anywhere, we closed the loop. break; - case EdgePosition::LEFT: - if (cur_rectangle.y_max == point.y) { - cur_edge = EdgePosition::TOP; - clockwise = false; + } + + // Handle one pathological case. + if (can_go[EdgePosition::LEFT] && can_go[EdgePosition::RIGHT]) { + // Corner case (literally): + // ******** + // ******** + // ******** + // ******** + // ^ +++++++++ + // | +++++++++ + // | +++++++++ + // +++++++++ + // + // In this case we keep following the same box. + auto it_x = segments_to_follow[EdgePosition::LEFT].find(cur); + if (cur_index == it_x->second.second) { + direction_to_take = EdgePosition::LEFT; } else { - cur_edge = EdgePosition::BOTTOM; - clockwise = true; + direction_to_take = EdgePosition::RIGHT; } - break; - case EdgePosition::RIGHT: - if (cur_rectangle.y_min == point.y) { - cur_edge = EdgePosition::BOTTOM; - clockwise = false; + } else if (can_go[EdgePosition::TOP] && can_go[EdgePosition::BOTTOM]) { + auto it_y = segments_to_follow[EdgePosition::TOP].find(cur); + if (cur_index == it_y->second.second) { + direction_to_take = EdgePosition::TOP; } else { - cur_edge = EdgePosition::TOP; - clockwise = true; + direction_to_take = EdgePosition::BOTTOM; } - break; + } + + auto extracted = segments_to_follow[direction_to_take].extract(cur); + cur_index = extracted.mapped().second; + switch (direction_to_take) { + case EdgePosition::LEFT: + cur.first -= extracted.mapped().first; + segments_to_follow[EdgePosition::RIGHT].erase( + cur); // Forbid going back + break; + case EdgePosition::RIGHT: + cur.first += extracted.mapped().first; + segments_to_follow[EdgePosition::LEFT].erase(cur); // Forbid going back + break; + case EdgePosition::TOP: + cur.second += extracted.mapped().first; + segments_to_follow[EdgePosition::BOTTOM].erase( + cur); // Forbid going back + break; + case EdgePosition::BOTTOM: + cur.second -= extracted.mapped().first; + segments_to_follow[EdgePosition::TOP].erase(cur); // Forbid going back + break; + } + path.touching_box_index.push_back(cur_index); + } + path.touching_box_index.push_back(cur_index); + + return path; +} +} // namespace + +std::vector BoxesToShapes(absl::Span rectangles, + const Neighbours& neighbours) { + std::vector> vertical_edges_on_boundary; + std::vector> horizontal_edges_on_boundary; + GetAllSegmentsTouchingVoid(rectangles, neighbours, vertical_edges_on_boundary, + horizontal_edges_on_boundary); + + std::array, + std::pair>, + 4> + segments_to_follow; + + for (const auto& [edge, box_index] : vertical_edges_on_boundary) { + segments_to_follow[EdgePosition::TOP][{edge.x_start, edge.y_start}] = { + edge.size, box_index}; + segments_to_follow[EdgePosition::BOTTOM][{ + edge.x_start, edge.y_start + edge.size}] = {edge.size, box_index}; + } + for (const auto& [edge, box_index] : horizontal_edges_on_boundary) { + segments_to_follow[EdgePosition::RIGHT][{edge.x_start, edge.y_start}] = { + edge.size, box_index}; + segments_to_follow[EdgePosition::LEFT][{ + edge.x_start + edge.size, edge.y_start}] = {edge.size, box_index}; + } + + const auto components = SplitInConnectedComponents(neighbours); + std::vector result(components.size()); + std::vector box_to_component(rectangles.size()); + for (int i = 0; i < components.size(); ++i) { + for (const int box_index : components[i]) { + box_to_component[box_index] = i; + } } + while (!segments_to_follow[EdgePosition::LEFT].empty()) { + // Get edge most to the bottom left + const int box_index = + segments_to_follow[EdgePosition::RIGHT].begin()->second.second; + const std::pair starting_step_point = + segments_to_follow[EdgePosition::RIGHT].begin()->first; + const int component_index = box_to_component[box_index]; + + // The left-most vertical edge of the connected component must be of its + // exterior boundary. So we must always see the exterior boundary before + // seeing any holes. + const bool is_hole = !result[component_index].boundary.step_points.empty(); + ShapePath& path = is_hole ? result[component_index].holes.emplace_back() + : result[component_index].boundary; + path = TraceBoundary(EdgePosition::RIGHT, starting_step_point, + segments_to_follow); + if (is_hole) { + // Follow the usual convention that holes are in the inverse orientation + // of the external boundary. + absl::c_reverse(path.step_points); + absl::c_reverse(path.touching_box_index); + } + } + return result; +} + +namespace { +struct PolygonCut { + std::pair start; + std::pair end; + int start_index; + int end_index; + + struct CmpByStartY { + bool operator()(const PolygonCut& a, const PolygonCut& b) const { + return std::tie(a.start.second, a.start.first) < + std::tie(b.start.second, b.start.first); + } + }; + + struct CmpByEndY { + bool operator()(const PolygonCut& a, const PolygonCut& b) const { + return std::tie(a.end.second, a.end.first) < + std::tie(b.end.second, b.end.first); + } + }; + + struct CmpByStartX { + bool operator()(const PolygonCut& a, const PolygonCut& b) const { + return a.start < b.start; + } + }; + + struct CmpByEndX { + bool operator()(const PolygonCut& a, const PolygonCut& b) const { + return a.end < b.end; + } + }; + + template + friend void AbslStringify(Sink& sink, const PolygonCut& diagonal) { + absl::Format(&sink, "(%v,%v)-(%v,%v)", diagonal.start.first, + diagonal.start.second, diagonal.end.first, + diagonal.end.second); + } +}; + +// A different representation of a shape. The two vectors must have the same +// size. The first one contains the points of the shape and the second one +// contains the index of the next point in the shape. +// +// Note that we code in this file is only correct for shapes with points +// connected only by horizontal or vertical lines. +struct FlatShape { + std::vector> points; + std::vector next; +}; + +EdgePosition GetSegmentDirection( + const std::pair& curr_segment, + const std::pair& next_segment) { + if (curr_segment.first == next_segment.first) { + return next_segment.second > curr_segment.second ? EdgePosition::TOP + : EdgePosition::BOTTOM; + } else { + return next_segment.first > curr_segment.first ? EdgePosition::RIGHT + : EdgePosition::LEFT; + } +} + +// Given a polygon, this function returns all line segments that start on a +// concave vertex and follow horizontally or vertically until it reaches the +// border of the polygon. This function returns all such segments grouped on the +// direction the line takes after starting in the concave vertex. Some of those +// segments start and end on a convex vertex, so they will appear twice in the +// output. This function modifies the shape by splitting some of the path +// segments in two. This is needed to make sure that `PolygonCut.start_index` +// and `PolygonCut.end_index` always corresponds to points in the FlatShape, +// even if they are not edges. +std::array, 4> GetPotentialPolygonCuts( + FlatShape& shape) { + std::array, 4> cuts; + + // First, for each concave vertex we create a cut that starts at it and + // crosses the polygon until infinite (in practice, int_max/int_min). + for (int i = 0; i < shape.points.size(); i++) { + const auto& it = &shape.points[shape.next[i]]; + const auto& previous = &shape.points[i]; + const auto& next_segment = &shape.points[shape.next[shape.next[i]]]; + const EdgePosition previous_dir = GetSegmentDirection(*previous, *it); + const EdgePosition next_dir = GetSegmentDirection(*it, *next_segment); + + if ((previous_dir == EdgePosition::TOP && next_dir == EdgePosition::LEFT) || + (previous_dir == EdgePosition::RIGHT && + next_dir == EdgePosition::TOP)) { + cuts[EdgePosition::RIGHT].push_back( + {.start = *it, + .end = {std::numeric_limits::max(), it->second}, + .start_index = shape.next[i]}); + } + if ((previous_dir == EdgePosition::BOTTOM && + next_dir == EdgePosition::RIGHT) || + (previous_dir == EdgePosition::LEFT && + next_dir == EdgePosition::BOTTOM)) { + cuts[EdgePosition::LEFT].push_back( + {.start = {std::numeric_limits::min(), it->second}, + .end = *it, + .end_index = shape.next[i]}); + } + if ((previous_dir == EdgePosition::RIGHT && + next_dir == EdgePosition::TOP) || + (previous_dir == EdgePosition::BOTTOM && + next_dir == EdgePosition::RIGHT)) { + cuts[EdgePosition::BOTTOM].push_back( + {.start = {it->first, std::numeric_limits::min()}, + .end = *it, + .end_index = shape.next[i]}); + } + if ((previous_dir == EdgePosition::TOP && next_dir == EdgePosition::LEFT) || + (previous_dir == EdgePosition::LEFT && + next_dir == EdgePosition::BOTTOM)) { + cuts[EdgePosition::TOP].push_back( + {.start = *it, + .end = {it->first, std::numeric_limits::max()}, + .start_index = shape.next[i]}); + } + } + + // Now that we have one of the points of the segment (the one starting on a + // vertex), we need to find the other point. This is basically finding the + // first path segment that crosses each cut connecting edge->infinity we + // collected above. We do a rather naive implementation of that below and its + // complexity is O(N^2) even if it should be fast in most cases. If it + // turns out to be costly on profiling we can use a more sophisticated + // algorithm for finding the first intersection. - // Test case 1. We need to find the next box after the current point in the - // edge we are following in the current direction. - const auto cur_edge_neighbors = - neighbours.GetSortedNeighbors(point.next_box_index, cur_edge); - - const Rectangle fake_box_for_lower_bound = { - .x_min = point.x, .x_max = point.x, .y_min = point.y, .y_max = point.y}; - const auto clockwise_cmp = Neighbours::CompareClockwise(cur_edge); - auto it = absl::c_lower_bound( - cur_edge_neighbors, -1, - [&fake_box_for_lower_bound, rectangles, clockwise_cmp, clockwise](int a, - int b) { - const Rectangle& rectangle_a = - (a == -1 ? fake_box_for_lower_bound : rectangles[a]); - const Rectangle& rectangle_b = - (b == -1 ? fake_box_for_lower_bound : rectangles[b]); - if (clockwise) { - return clockwise_cmp(rectangle_a, rectangle_b); + // We need to sort the cuts so we can use binary search to quickly find cuts + // that cross a segment. + std::sort(cuts[EdgePosition::RIGHT].begin(), cuts[EdgePosition::RIGHT].end(), + PolygonCut::CmpByStartY()); + std::sort(cuts[EdgePosition::LEFT].begin(), cuts[EdgePosition::LEFT].end(), + PolygonCut::CmpByEndY()); + std::sort(cuts[EdgePosition::BOTTOM].begin(), + cuts[EdgePosition::BOTTOM].end(), PolygonCut::CmpByEndX()); + std::sort(cuts[EdgePosition::TOP].begin(), cuts[EdgePosition::TOP].end(), + PolygonCut::CmpByStartX()); + + // This function cuts a segment in two if it crosses a cut. In any case, it + // returns the index of a point `point_idx` so that `shape.points[point_idx] + // == point_to_cut`. + const auto cut_segment_if_necessary = + [&shape](int segment_idx, + std::pair point_to_cut) { + const auto& cur = shape.points[segment_idx]; + const auto& next = shape.points[shape.next[segment_idx]]; + if (cur.second == next.second) { + DCHECK_EQ(point_to_cut.second, cur.second); + // We have a horizontal segment + const IntegerValue edge_start = std::min(cur.first, next.first); + const IntegerValue edge_end = std::max(cur.first, next.first); + + if (edge_start < point_to_cut.first && + point_to_cut.first < edge_end) { + shape.points.push_back(point_to_cut); + const int next_idx = shape.next[segment_idx]; + shape.next[segment_idx] = shape.points.size() - 1; + shape.next.push_back(next_idx); + return static_cast(shape.points.size() - 1); + } + return (shape.points[segment_idx] == point_to_cut) + ? segment_idx + : shape.next[segment_idx]; } else { - return clockwise_cmp(rectangle_b, rectangle_a); + DCHECK_EQ(cur.first, next.first); + DCHECK_EQ(point_to_cut.first, cur.first); + // We have a vertical segment + const IntegerValue edge_start = std::min(cur.second, next.second); + const IntegerValue edge_end = std::max(cur.second, next.second); + + if (edge_start < point_to_cut.second && + point_to_cut.second < edge_end) { + shape.points.push_back(point_to_cut); + const int next_idx = shape.next[segment_idx]; + shape.next[segment_idx] = shape.points.size() - 1; + shape.next.push_back(next_idx); + return static_cast(shape.points.size() - 1); + } + return (shape.points[segment_idx] == point_to_cut) + ? segment_idx + : shape.next[segment_idx]; } - }); + }; - if (it != cur_edge_neighbors.end()) { - // We found box in the current edge. We are in case 1. - result.next_box_index = *it; - const Rectangle& next_rectangle = rectangles[*it]; - switch (point.next_direction) { - case EdgePosition::TOP: - result.x = point.x; - result.y = next_rectangle.y_min; - if (cur_edge == EdgePosition::LEFT) { - result.next_direction = EdgePosition::LEFT; - } else { - result.next_direction = EdgePosition::RIGHT; + for (int i = 0; i < shape.points.size(); i++) { + const auto* cur_point_ptr = &shape.points[shape.next[i]]; + const auto* previous = &shape.points[i]; + DCHECK(cur_point_ptr->first == previous->first || + cur_point_ptr->second == previous->second) + << "found a segment that is neither horizontal nor vertical"; + const EdgePosition direction = + GetSegmentDirection(*previous, *cur_point_ptr); + + if (direction == EdgePosition::BOTTOM) { + const auto cut_start = absl::c_lower_bound( + cuts[EdgePosition::RIGHT], + PolygonCut{.start = {std::numeric_limits::min(), + cur_point_ptr->second}}, + PolygonCut::CmpByStartY()); + auto cut_end = absl::c_upper_bound( + cuts[EdgePosition::RIGHT], + PolygonCut{.start = {std::numeric_limits::max(), + previous->second}}, + PolygonCut::CmpByStartY()); + + for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { + PolygonCut& diagonal = *cut_it; + const IntegerValue diagonal_start_x = diagonal.start.first; + const IntegerValue diagonal_cur_end_x = diagonal.end.first; + // Our binary search guarantees those two conditions. + DCHECK_LE(cur_point_ptr->second, diagonal.start.second); + DCHECK_LE(diagonal.start.second, previous->second); + + // Let's test if the diagonal crosses the current boundary segment + if (diagonal_start_x <= previous->first && + diagonal_cur_end_x > cur_point_ptr->first) { + DCHECK_LT(diagonal_start_x, cur_point_ptr->first); + DCHECK_LE(previous->first, diagonal_cur_end_x); + + diagonal.end.first = cur_point_ptr->first; + + diagonal.end_index = cut_segment_if_necessary(i, diagonal.end); + DCHECK(shape.points[diagonal.end_index] == diagonal.end); + + // Subtle: cut_segment_if_necessary might add new points to the vector + // of the shape, so the pointers computed from it might become + // invalid. Moreover, the current segment now is shorter, so we need + // to update our upper bound. + cur_point_ptr = &shape.points[shape.next[i]]; + previous = &shape.points[i]; + cut_end = absl::c_upper_bound( + cuts[EdgePosition::RIGHT], + PolygonCut{.start = {std::numeric_limits::max(), + previous->second}}, + PolygonCut::CmpByStartY()); } - break; - case EdgePosition::BOTTOM: - result.x = point.x; - result.y = next_rectangle.y_max; - if (cur_edge == EdgePosition::LEFT) { - result.next_direction = EdgePosition::LEFT; - } else { - result.next_direction = EdgePosition::RIGHT; + } + } + + if (direction == EdgePosition::TOP) { + const auto cut_start = absl::c_lower_bound( + cuts[EdgePosition::LEFT], + PolygonCut{.end = {std::numeric_limits::min(), + previous->second}}, + PolygonCut::CmpByEndY()); + auto cut_end = absl::c_upper_bound( + cuts[EdgePosition::LEFT], + PolygonCut{.end = {std::numeric_limits::max(), + cur_point_ptr->second}}, + PolygonCut::CmpByEndY()); + for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { + PolygonCut& diagonal = *cut_it; + const IntegerValue diagonal_start_x = diagonal.start.first; + const IntegerValue diagonal_cur_end_x = diagonal.end.first; + // Our binary search guarantees those two conditions. + DCHECK_LE(diagonal.end.second, cur_point_ptr->second); + DCHECK_LE(previous->second, diagonal.end.second); + + // Let's test if the diagonal crosses the current boundary segment + if (diagonal_start_x < cur_point_ptr->first && + previous->first <= diagonal_cur_end_x) { + DCHECK_LT(cur_point_ptr->first, diagonal_cur_end_x); + DCHECK_LE(diagonal_start_x, previous->first); + + diagonal.start.first = cur_point_ptr->first; + diagonal.start_index = cut_segment_if_necessary(i, diagonal.start); + DCHECK(shape.points[diagonal.start_index] == diagonal.start); + cur_point_ptr = &shape.points[shape.next[i]]; + previous = &shape.points[i]; + cut_end = absl::c_upper_bound( + cuts[EdgePosition::LEFT], + PolygonCut{.end = {std::numeric_limits::max(), + cur_point_ptr->second}}, + PolygonCut::CmpByEndY()); } - break; - case EdgePosition::LEFT: - result.y = point.y; - result.x = next_rectangle.x_max; - if (cur_edge == EdgePosition::TOP) { - result.next_direction = EdgePosition::TOP; - } else { - result.next_direction = EdgePosition::BOTTOM; + } + } + + if (direction == EdgePosition::LEFT) { + const auto cut_start = absl::c_lower_bound( + cuts[EdgePosition::BOTTOM], + PolygonCut{.end = {cur_point_ptr->first, + std::numeric_limits::min()}}, + PolygonCut::CmpByEndX()); + auto cut_end = absl::c_upper_bound( + cuts[EdgePosition::BOTTOM], + PolygonCut{.end = {previous->first, + std::numeric_limits::max()}}, + PolygonCut::CmpByEndX()); + for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { + PolygonCut& diagonal = *cut_it; + const IntegerValue diagonal_start_y = diagonal.start.second; + const IntegerValue diagonal_cur_end_y = diagonal.end.second; + + // Our binary search guarantees those two conditions. + DCHECK_LE(cur_point_ptr->first, diagonal.end.first); + DCHECK_LE(diagonal.end.first, previous->first); + + // Let's test if the diagonal crosses the current boundary segment + if (diagonal_start_y < cur_point_ptr->second && + cur_point_ptr->second <= diagonal_cur_end_y) { + DCHECK_LE(diagonal_start_y, previous->second); + DCHECK_LT(cur_point_ptr->second, diagonal_cur_end_y); + + diagonal.start.second = cur_point_ptr->second; + diagonal.start_index = cut_segment_if_necessary(i, diagonal.start); + DCHECK(shape.points[diagonal.start_index] == diagonal.start); + cur_point_ptr = &shape.points[shape.next[i]]; + previous = &shape.points[i]; + cut_end = absl::c_upper_bound( + cuts[EdgePosition::BOTTOM], + PolygonCut{.end = {previous->first, + std::numeric_limits::max()}}, + PolygonCut::CmpByEndX()); } - break; - case EdgePosition::RIGHT: - result.y = point.y; - result.x = next_rectangle.x_min; - if (cur_edge == EdgePosition::TOP) { - result.next_direction = EdgePosition::TOP; - } else { - result.next_direction = EdgePosition::BOTTOM; + } + } + + if (direction == EdgePosition::RIGHT) { + const auto cut_start = absl::c_lower_bound( + cuts[EdgePosition::TOP], + PolygonCut{.start = {previous->first, + std::numeric_limits::min()}}, + PolygonCut::CmpByStartX()); + auto cut_end = absl::c_upper_bound( + cuts[EdgePosition::TOP], + PolygonCut{.start = {cur_point_ptr->first, + std::numeric_limits::max()}}, + PolygonCut::CmpByStartX()); + for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { + PolygonCut& diagonal = *cut_it; + const IntegerValue diagonal_start_y = diagonal.start.second; + const IntegerValue diagonal_cur_end_y = diagonal.end.second; + + // Our binary search guarantees those two conditions. + DCHECK_LE(previous->first, diagonal.start.first); + DCHECK_LE(diagonal.start.first, cur_point_ptr->first); + + // Let's test if the diagonal crosses the current boundary segment + if (diagonal_start_y <= cur_point_ptr->second && + cur_point_ptr->second < diagonal_cur_end_y) { + DCHECK_LT(diagonal_start_y, previous->second); + DCHECK_LE(cur_point_ptr->second, diagonal_cur_end_y); + + diagonal.end.second = cur_point_ptr->second; + diagonal.end_index = cut_segment_if_necessary(i, diagonal.end); + DCHECK(shape.points[diagonal.end_index] == diagonal.end); + cur_point_ptr = &shape.points[shape.next[i]]; + cut_end = absl::c_upper_bound( + cuts[EdgePosition::TOP], + PolygonCut{.start = {cur_point_ptr->first, + std::numeric_limits::max()}}, + PolygonCut::CmpByStartX()); + previous = &shape.points[i]; } - break; + } } - return result; } + return cuts; +} - // We now know we are not in Case 1, so know the next (x, y) position: it is - // the corner of the current rectangle in the direction we are going. - switch (point.next_direction) { - case EdgePosition::TOP: - result.x = point.x; - result.y = cur_rectangle.y_max; - break; - case EdgePosition::BOTTOM: - result.x = point.x; - result.y = cur_rectangle.y_min; - break; - case EdgePosition::LEFT: - result.x = cur_rectangle.x_min; - result.y = point.y; - break; - case EdgePosition::RIGHT: - result.x = cur_rectangle.x_max; - result.y = point.y; - break; +void CutShapeWithPolygonCuts(FlatShape& shape, + absl::Span cuts) { + std::vector previous(shape.points.size(), -1); + for (int i = 0; i < shape.points.size(); i++) { + previous[shape.next[i]] = i; } - // Case 2 and 3. - const auto next_edge_neighbors = - neighbours.GetSortedNeighbors(point.next_box_index, point.next_direction); - if (!next_edge_neighbors.empty()) { - // We are looking for the neighbor on the edge of the current box. - const int candidate_index = - clockwise ? next_edge_neighbors.front() : next_edge_neighbors.back(); - const Rectangle& next_rectangle = rectangles[candidate_index]; - switch (point.next_direction) { - case EdgePosition::TOP: - case EdgePosition::BOTTOM: - if (next_rectangle.x_min < point.x && point.x < next_rectangle.x_max) { - // Case 2 - result.next_box_index = candidate_index; - if (cur_edge == EdgePosition::LEFT) { - result.next_direction = EdgePosition::LEFT; - } else { - result.next_direction = EdgePosition::RIGHT; - } - return result; - } else if (next_rectangle.x_min == point.x && - cur_edge == EdgePosition::LEFT) { - // Case 3 - result.next_box_index = candidate_index; - result.next_direction = point.next_direction; - return result; - } else if (next_rectangle.x_max == point.x && - cur_edge == EdgePosition::RIGHT) { - // Case 3 - result.next_box_index = candidate_index; - result.next_direction = point.next_direction; - return result; - } - break; - case EdgePosition::LEFT: - case EdgePosition::RIGHT: - if (next_rectangle.y_min < point.y && point.y < next_rectangle.y_max) { - result.next_box_index = candidate_index; - if (cur_edge == EdgePosition::TOP) { - result.next_direction = EdgePosition::TOP; - } else { - result.next_direction = EdgePosition::BOTTOM; - } - return result; - } else if (next_rectangle.y_max == point.y && - cur_edge == EdgePosition::TOP) { - result.next_box_index = candidate_index; - result.next_direction = point.next_direction; - return result; - } else if (next_rectangle.y_min == point.y && - cur_edge == EdgePosition::BOTTOM) { - result.next_box_index = candidate_index; - result.next_direction = point.next_direction; - return result; + std::vector> cut_previous_index(cuts.size(), {-1, -1}); + for (int i = 0; i < cuts.size(); i++) { + DCHECK(cuts[i].start == shape.points[cuts[i].start_index]); + DCHECK(cuts[i].end == shape.points[cuts[i].end_index]); + + cut_previous_index[i].first = previous[cuts[i].start_index]; + cut_previous_index[i].second = previous[cuts[i].end_index]; + } + + for (const auto& [i, j] : cut_previous_index) { + const int prev_start_next = shape.next[i]; + const int prev_end_next = shape.next[j]; + const std::pair start = + shape.points[prev_start_next]; + const std::pair end = + shape.points[prev_end_next]; + + shape.points.push_back(start); + shape.next[i] = shape.points.size() - 1; + shape.next.push_back(prev_end_next); + + shape.points.push_back(end); + shape.next[j] = shape.points.size() - 1; + shape.next.push_back(prev_start_next); + } +} +} // namespace + +// This function applies the method described in page 3 of [1]. +// +// [1] Eppstein, David. "Graph-theoretic solutions to computational geometry +// problems." International Workshop on Graph-Theoretic Concepts in Computer +// Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2009. +std::vector CutShapeIntoRectangles(SingleShape shape) { + auto is_aligned = [](const std::pair& p1, + const std::pair& p2, + const std::pair& p3) { + return ((p1.first == p2.first) == (p2.first == p3.first)) && + ((p1.second == p2.second) == (p2.second == p3.second)); + }; + const auto add_segment = + [&is_aligned](const std::pair& segment, + const int start_index, + std::vector>& points, + std::vector& next) { + if (points.size() > 1 + start_index && + is_aligned(points[points.size() - 1], points[points.size() - 2], + segment)) { + points.back() = segment; + } else { + points.push_back(segment); + next.push_back(points.size()); } - break; + }; + + // To cut our polygon into rectangles, we first put it into a data structure + // that is easier to manipulate. + FlatShape flat_shape; + for (int i = 0; 1 + i < shape.boundary.step_points.size(); ++i) { + const std::pair& segment = + shape.boundary.step_points[i]; + add_segment(segment, 0, flat_shape.points, flat_shape.next); + } + flat_shape.next.back() = 0; + for (const ShapePath& hole : shape.holes) { + const int start = flat_shape.next.size(); + if (hole.step_points.size() < 2) continue; + for (int i = 0; i + 1 < hole.step_points.size(); ++i) { + const std::pair& segment = + hole.step_points[i]; + add_segment(segment, start, flat_shape.points, flat_shape.next); } + flat_shape.next.back() = start; } - // Now we must be in the case 4. - result.next_box_index = point.next_box_index; - switch (point.next_direction) { - case EdgePosition::TOP: - case EdgePosition::BOTTOM: - if (cur_edge == EdgePosition::LEFT) { - result.next_direction = EdgePosition::RIGHT; - } else { - result.next_direction = EdgePosition::LEFT; - } - break; - case EdgePosition::LEFT: - case EdgePosition::RIGHT: - if (cur_edge == EdgePosition::TOP) { - result.next_direction = EdgePosition::BOTTOM; - } else { - result.next_direction = EdgePosition::TOP; + std::array, 4> all_cuts = + GetPotentialPolygonCuts(flat_shape); + + // Some cuts connect two concave edges and will be duplicated in all_cuts. + // Those are important: since they "fix" two concavities with a single cut, + // they are called "good diagonals" in the literature. Note that in + // computational geometry jargon, a diagonal of a polygon is a line segment + // that connects two non-adjacent vertices of a polygon, even in cases like + // ours that we are only talking of diagonals that are not "diagonal" in the + // usual meaning of the word: ie., horizontal or vertical segments connecting + // two vertices of the polygon). + std::array, 2> good_diagonals; + for (const auto& d : all_cuts[EdgePosition::BOTTOM]) { + if (absl::c_binary_search(all_cuts[EdgePosition::TOP], d, + PolygonCut::CmpByStartX())) { + good_diagonals[0].push_back(d); + } + } + for (const auto& d : all_cuts[EdgePosition::LEFT]) { + if (absl::c_binary_search(all_cuts[EdgePosition::RIGHT], d, + PolygonCut::CmpByStartY())) { + good_diagonals[1].push_back(d); + } + } + + // The "good diagonals" are only more optimal that any cut if they are not + // crossed by other cuts. To maximize their usefulness, we build a graph where + // the good diagonals are the vertices and we add an edge every time a + // vertical and horizontal diagonal cross. The minimum vertex cover of this + // graph is the minimal set of good diagonals that are not crossed by other + // cuts. + std::vector> arcs(good_diagonals[0].size()); + for (int i = 0; i < good_diagonals[0].size(); ++i) { + for (int j = 0; j < good_diagonals[1].size(); ++j) { + const PolygonCut& vertical = good_diagonals[0][i]; + const PolygonCut& horizontal = good_diagonals[1][j]; + const IntegerValue vertical_x = vertical.start.first; + const IntegerValue horizontal_y = horizontal.start.second; + if (horizontal.start.first <= vertical_x && + vertical_x <= horizontal.end.first && + vertical.start.second <= horizontal_y && + horizontal_y <= vertical.end.second) { + arcs[i].push_back(good_diagonals[0].size() + j); } - break; + } + } + + const std::vector minimum_cover = + BipartiteMinimumVertexCover(arcs, good_diagonals[1].size()); + + std::vector minimum_cover_horizontal_diagonals; + for (int i = good_diagonals[0].size(); + i < good_diagonals[0].size() + good_diagonals[1].size(); ++i) { + if (minimum_cover[i]) continue; + minimum_cover_horizontal_diagonals.push_back( + good_diagonals[1][i - good_diagonals[0].size()]); } + + // Since our data structure only allow to cut the shape according to a list + // of vertical or horizontal cuts, but not a list mixing both, we cut first + // on the chosen horizontal good diagonals. + CutShapeWithPolygonCuts(flat_shape, minimum_cover_horizontal_diagonals); + + // We need to recompute the cuts after we applied the good diagonals, since + // the geometry has changed. + all_cuts = GetPotentialPolygonCuts(flat_shape); + + // Now that we did all horizontal good diagonals, we need to cut on all + // vertical good diagonals and then cut arbitrarily to remove all concave + // edges. To make things simple, just apply all vertical cuts, since they + // include all the vertical good diagonals and also fully slice the shape into + // rectangles. + + // Remove duplicates coming from good diagonals first. + std::vector cuts = all_cuts[EdgePosition::TOP]; + for (const auto& cut : all_cuts[EdgePosition::BOTTOM]) { + if (!absl::c_binary_search(all_cuts[EdgePosition::TOP], cut, + PolygonCut::CmpByStartX())) { + cuts.push_back(cut); + } + } + + CutShapeWithPolygonCuts(flat_shape, cuts); + + // Now every connected component of the shape is a rectangle. Build the final + // result. + std::vector result; + std::vector seen(flat_shape.points.size(), false); + for (int i = 0; i < flat_shape.points.size(); ++i) { + if (seen[i]) continue; + Rectangle& rectangle = result.emplace_back(Rectangle{ + .x_min = std::numeric_limits::max(), + .x_max = std::numeric_limits::min(), + .y_min = std::numeric_limits::max(), + .y_max = std::numeric_limits::min(), + }); + int cur = i; + do { + seen[cur] = true; + rectangle.GrowToInclude({.x_min = flat_shape.points[cur].first, + .x_max = flat_shape.points[cur].first, + .y_min = flat_shape.points[cur].second, + .y_max = flat_shape.points[cur].second}); + cur = flat_shape.next[cur]; + DCHECK_LT(cur, flat_shape.next.size()); + } while (cur != i); + } + return result; } -ShapePath TraceBoundary( - const std::pair& starting_step_point, - int starting_box_index, absl::Span rectangles, - const Neighbours& neighbours) { - // First find which direction we need to go to follow the border in the - // clockwise order. - const Rectangle& initial_rec = rectangles[starting_box_index]; - bool touching_edge[4]; - touching_edge[EdgePosition::LEFT] = - initial_rec.x_min == starting_step_point.first; - touching_edge[EdgePosition::RIGHT] = - initial_rec.x_max == starting_step_point.first; - touching_edge[EdgePosition::TOP] = - initial_rec.y_max == starting_step_point.second; - touching_edge[EdgePosition::BOTTOM] = - initial_rec.y_min == starting_step_point.second; - - EdgePosition next_direction; - if (touching_edge[EdgePosition::LEFT]) { - if (touching_edge[EdgePosition::TOP]) { - next_direction = EdgePosition::RIGHT; - } else { - next_direction = EdgePosition::TOP; - } - } else if (touching_edge[EdgePosition::RIGHT]) { - if (touching_edge[EdgePosition::BOTTOM]) { - next_direction = EdgePosition::LEFT; - } else { - next_direction = EdgePosition::BOTTOM; +bool ReduceNumberOfBoxesExactMandatory( + std::vector* mandatory_rectangles, + std::vector* optional_rectangles) { + if (mandatory_rectangles->empty()) return false; + std::vector result = *mandatory_rectangles; + std::vector new_optional_rectangles = *optional_rectangles; + + Rectangle mandatory_bounding_box = (*mandatory_rectangles)[0]; + for (const Rectangle& box : *mandatory_rectangles) { + mandatory_bounding_box.GrowToInclude(box); + } + const std::vector mandatory_empty_holes = + FindEmptySpaces(mandatory_bounding_box, *mandatory_rectangles); + const std::vector> mandatory_holes_components = + SplitInConnectedComponents(BuildNeighboursGraph(mandatory_empty_holes)); + + // Now for every connected component of the holes in the mandatory area, see + // if we can fill them with optional boxes. + std::vector holes_in_component; + for (const std::vector& component : mandatory_holes_components) { + holes_in_component.clear(); + holes_in_component.reserve(component.size()); + for (const int index : component) { + holes_in_component.push_back(mandatory_empty_holes[index]); } - } else if (touching_edge[EdgePosition::TOP]) { - next_direction = EdgePosition::LEFT; - } else if (touching_edge[EdgePosition::BOTTOM]) { - next_direction = EdgePosition::RIGHT; - } else { - LOG(FATAL) - << "TraceBoundary() got a `starting_step_point` that is not in an edge " - "of the rectangle of `starting_box_index`. This is not allowed."; - } - const ContourPoint starting_point = {.x = starting_step_point.first, - .y = starting_step_point.second, - .next_box_index = starting_box_index, - .next_direction = next_direction}; - - ShapePath result; - for (ContourPoint point = starting_point; - result.step_points.empty() || point != starting_point; - point = NextByClockwiseOrder(point, rectangles, neighbours)) { - if (!result.step_points.empty() && - point.x == result.step_points.back().first && - point.y == result.step_points.back().second) { - // There is a special corner-case of the algorithm using the neighbours. - // Consider the following set-up: - // - // ******** | - // ******** | - // ******** +----> - // ########++++++++ - // ########++++++++ - // ########++++++++ - // - // In this case, the only way the algorithm could reach the "+" box is via - // the "#" box, but which is doesn't contribute to the path. The algorithm - // returns a technically correct zero-size interval, which might be useful - // for callers that want to count the "#" box as visited, but this is not - // our case. - result.touching_box_index.back() = point.next_box_index; - } else { - result.touching_box_index.push_back(point.next_box_index); - result.step_points.push_back({point.x, point.y}); + if (RegionIncludesOther(new_optional_rectangles, holes_in_component)) { + // Fill the hole. + result.insert(result.end(), holes_in_component.begin(), + holes_in_component.end()); + // We can modify `optional_rectangles` here since we know that if we + // remove a hole this function will return true. + new_optional_rectangles = PavedRegionDifference( + new_optional_rectangles, std::move(holes_in_component)); } } - result.touching_box_index.push_back(result.touching_box_index.front()); - result.step_points.push_back(result.step_points.front()); - return result; + const Neighbours neighbours = BuildNeighboursGraph(result); + std::vector shapes = BoxesToShapes(result, neighbours); + + result.clear(); + for (SingleShape& shape : shapes) { + // This is the function that applies the algorithm described in [1]. + const std::vector cut_rectangles = + CutShapeIntoRectangles(std::move(shape)); + result.insert(result.end(), cut_rectangles.begin(), cut_rectangles.end()); + } + // It is possible that the algorithm actually increases the number of boxes. + // See the "Problematic2" test. + if (result.size() >= mandatory_rectangles->size()) return false; + mandatory_rectangles->swap(result); + optional_rectangles->swap(new_optional_rectangles); + return true; } } // namespace sat diff --git a/ortools/sat/2d_rectangle_presolve.h b/ortools/sat/2d_rectangle_presolve.h index 28821877e80..a7f1c6f07e6 100644 --- a/ortools/sat/2d_rectangle_presolve.h +++ b/ortools/sat/2d_rectangle_presolve.h @@ -38,17 +38,29 @@ bool PresolveFixed2dRectangles( absl::Span non_fixed_boxes, std::vector* fixed_boxes); -// Given a set of non-overlapping rectangles split in two groups, mandatory and -// optional, try to build a set of as few non-overlapping rectangles as -// possible defining a region R that satisfy: +// Given two vectors of non-overlapping rectangles defining two regions of the +// space: one mandatory region that must be occupied and one optional region +// that can be occupied, try to build a vector of as few non-overlapping +// rectangles as possible defining a region R that satisfy: // - R \subset (mandatory \union optional); // - mandatory \subset R. // -// The function updates the set of `mandatory_rectangles` with `R` and +// The function updates the vector of `mandatory_rectangles` with `R` and // `optional_rectangles` with `optional_rectangles \setdiff R`. It returns // true if the `mandatory_rectangles` was updated. -bool ReduceNumberofBoxes(std::vector* mandatory_rectangles, - std::vector* optional_rectangles); +// +// This function uses a greedy algorithm that merge rectangles that share an +// edge. +bool ReduceNumberofBoxesGreedy(std::vector* mandatory_rectangles, + std::vector* optional_rectangles); + +// Same as above, but this implementation returns the optimal solution in +// minimizing the number of boxes if `optional_rectangles` is empty. On the +// other hand, its handling of optional boxes is rather limited. It simply fills +// the holes in the mandatory boxes with optional boxes, if possible. +bool ReduceNumberOfBoxesExactMandatory( + std::vector* mandatory_rectangles, + std::vector* optional_rectangles); enum EdgePosition { TOP = 0, RIGHT = 1, BOTTOM = 2, LEFT = 3 }; @@ -162,23 +174,17 @@ struct ShapePath { std::vector touching_box_index; }; -// Returns a path delimiting a boundary of the union of a set of rectangles. It -// should work for both the exterior boundary and the boundaries of the holes -// inside the union. The path will start on `starting_point` and follow the -// boundary on clockwise order. -// -// `starting_point` should be a point in the boundary and `starting_box_index` -// the index of a rectangle with one edge containing `starting_point`. -// -// The resulting `path` satisfy: -// - path.step_points.front() == path.step_points.back() == starting_point -// - path.touching_box_index.front() == path.touching_box_index.back() == -// == starting_box_index -// -ShapePath TraceBoundary( - const std::pair& starting_step_point, - int starting_box_index, absl::Span rectangles, - const Neighbours& neighbours); +struct SingleShape { + ShapePath boundary; + std::vector holes; +}; + +// Given a set of rectangles, split it into connected components and transform +// each individual set into a shape described by its boundary and holes paths. +std::vector BoxesToShapes(absl::Span rectangles, + const Neighbours& neighbours); + +std::vector CutShapeIntoRectangles(SingleShape shapes); } // namespace sat } // namespace operations_research diff --git a/ortools/sat/2d_rectangle_presolve_test.cc b/ortools/sat/2d_rectangle_presolve_test.cc new file mode 100644 index 00000000000..899dc9cd376 --- /dev/null +++ b/ortools/sat/2d_rectangle_presolve_test.cc @@ -0,0 +1,985 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/2d_rectangle_presolve.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/sat/2d_orthogonal_packing_testing.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +std::vector BuildFromAsciiArt(std::string_view input) { + std::vector rectangles; + std::vector lines = absl::StrSplit(input, '\n'); + for (int i = 0; i < lines.size(); i++) { + for (int j = 0; j < lines[i].size(); j++) { + if (lines[i][j] != ' ') { + rectangles.push_back({.x_min = j, + .x_max = j + 1, + .y_min = 2 * lines.size() - 2 * i, + .y_max = 2 * lines.size() - 2 * i + 2}); + } + } + } + std::vector empty; + ReduceNumberofBoxesGreedy(&rectangles, &empty); + return rectangles; +} + +TEST(RectanglePresolve, Basic) { + std::vector input = BuildFromAsciiArt(R"( + *********** *********** + *********** *********** + *********** *********** + + + *********** *********** + *********** *********** + *********** *********** + )"); + // Note that a single naive pass over the fixed rectangles' gaps would not + // fill the middle region. + std::vector input_in_range; + // Add a single object that is too large to fit between the fixed boxes. + input_in_range.push_back( + {.box_index = 0, + .bounding_area = {.x_min = 0, .x_max = 80, .y_min = 0, .y_max = 80}, + .x_size = 5, + .y_size = 5}); + + EXPECT_TRUE(PresolveFixed2dRectangles(input_in_range, &input)); + EXPECT_EQ(input.size(), 1); +} + +TEST(RectanglePresolve, Trim) { + std::vector input = { + {.x_min = 0, .x_max = 5, .y_min = 0, .y_max = 5}}; + std::vector input_in_range; + input_in_range.push_back( + {.box_index = 0, + .bounding_area = {.x_min = 1, .x_max = 80, .y_min = 1, .y_max = 80}, + .x_size = 5, + .y_size = 5}); + + EXPECT_TRUE(PresolveFixed2dRectangles(input_in_range, &input)); + EXPECT_THAT(input, ElementsAre(Rectangle{ + .x_min = 1, .x_max = 5, .y_min = 1, .y_max = 5})); +} + +TEST(RectanglePresolve, FillBoundingBoxEdge) { + std::vector input = { + {.x_min = 1, .x_max = 5, .y_min = 1, .y_max = 5}}; + std::vector input_in_range; + input_in_range.push_back( + {.box_index = 0, + .bounding_area = {.x_min = 0, .x_max = 80, .y_min = 0, .y_max = 80}, + .x_size = 5, + .y_size = 5}); + + EXPECT_TRUE(PresolveFixed2dRectangles(input_in_range, &input)); + EXPECT_THAT(input, ElementsAre(Rectangle{ + .x_min = 0, .x_max = 5, .y_min = 0, .y_max = 5})); +} + +TEST(RectanglePresolve, UseAreaNotOccupiable) { + std::vector input = { + {.x_min = 20, .x_max = 25, .y_min = 0, .y_max = 5}}; + std::vector input_in_range; + input_in_range.push_back( + {.box_index = 0, + .bounding_area = {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, + .x_size = 5, + .y_size = 5}); + input_in_range.push_back( + {.box_index = 1, + .bounding_area = {.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 10}, + .x_size = 5, + .y_size = 5}); + input_in_range.push_back( + {.box_index = 1, + .bounding_area = {.x_min = 25, .x_max = 100, .y_min = 0, .y_max = 10}, + .x_size = 5, + .y_size = 5}); + + EXPECT_TRUE(PresolveFixed2dRectangles(input_in_range, &input)); + EXPECT_THAT(input, ElementsAre(Rectangle{ + .x_min = 15, .x_max = 25, .y_min = 0, .y_max = 10})); +} + +TEST(RectanglePresolve, RemoveOutsideBB) { + std::vector input = { + {.x_min = 0, .x_max = 5, .y_min = 0, .y_max = 5}}; + std::vector input_in_range; + input_in_range.push_back( + {.box_index = 0, + .bounding_area = {.x_min = 5, .x_max = 80, .y_min = 5, .y_max = 80}, + .x_size = 5, + .y_size = 5}); + + EXPECT_TRUE(PresolveFixed2dRectangles(input_in_range, &input)); + EXPECT_THAT(input, IsEmpty()); +} + +TEST(RectanglePresolve, RandomTest) { + constexpr int kFixedRectangleSize = 10; + constexpr int kNumRuns = 1000; + absl::BitGen bit_gen; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector input = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 40, bit_gen); + std::shuffle(input.begin(), input.end(), bit_gen); + absl::Span fixed_rectangles = + absl::MakeConstSpan(input).subspan(0, kFixedRectangleSize); + absl::Span other_rectangles = + absl::MakeSpan(input).subspan(kFixedRectangleSize); + std::vector new_fixed_rectangles(fixed_rectangles.begin(), + fixed_rectangles.end()); + const std::vector input_in_range = + MakeItemsFromRectangles(other_rectangles, 0.6, bit_gen); + + // Presolve the fixed items. + PresolveFixed2dRectangles(input_in_range, &new_fixed_rectangles); + if (run == 0) { + LOG(INFO) << "Presolved:\n" + << RenderDot(std::nullopt, fixed_rectangles) << "To:\n" + << RenderDot(std::nullopt, new_fixed_rectangles); + } + + if (new_fixed_rectangles.size() > fixed_rectangles.size()) { + LOG(FATAL) << "Presolved:\n" + << RenderDot(std::nullopt, fixed_rectangles) << "To:\n" + << RenderDot(std::nullopt, new_fixed_rectangles); + } + CHECK_LE(new_fixed_rectangles.size(), fixed_rectangles.size()); + + // Check if the original solution is still a solution. + std::vector all_rectangles(new_fixed_rectangles.begin(), + new_fixed_rectangles.end()); + all_rectangles.insert(all_rectangles.end(), other_rectangles.begin(), + other_rectangles.end()); + for (int i = 0; i < all_rectangles.size(); ++i) { + for (int j = i + 1; j < all_rectangles.size(); ++j) { + CHECK(all_rectangles[i].IsDisjoint(all_rectangles[j])) + << RenderDot(std::nullopt, {all_rectangles[i], all_rectangles[j]}); + } + } + } +} + +Neighbours NaiveBuildNeighboursGraph(const std::vector& rectangles) { + auto interval_intersect = [](IntegerValue begin1, IntegerValue end1, + IntegerValue begin2, IntegerValue end2) { + return std::max(begin1, begin2) < std::min(end1, end2); + }; + std::vector> neighbors; + for (int i = 0; i < rectangles.size(); ++i) { + for (int j = 0; j < rectangles.size(); ++j) { + if (i == j) continue; + const Rectangle& r1 = rectangles[i]; + const Rectangle& r2 = rectangles[j]; + if (r1.x_min == r2.x_max && + interval_intersect(r1.y_min, r1.y_max, r2.y_min, r2.y_max)) { + neighbors.push_back({i, EdgePosition::LEFT, j}); + neighbors.push_back({j, EdgePosition::RIGHT, i}); + } + if (r1.y_min == r2.y_max && + interval_intersect(r1.x_min, r1.x_max, r2.x_min, r2.x_max)) { + neighbors.push_back({i, EdgePosition::BOTTOM, j}); + neighbors.push_back({j, EdgePosition::TOP, i}); + } + } + } + return Neighbours(rectangles, neighbors); +} + +std::string RenderNeighborsGraph(std::optional bb, + absl::Span rectangles, + const Neighbours& neighbours) { + const absl::flat_hash_map edge_colors = { + {EdgePosition::TOP, "red"}, + {EdgePosition::BOTTOM, "green"}, + {EdgePosition::LEFT, "blue"}, + {EdgePosition::RIGHT, "cyan"}}; + std::stringstream ss; + ss << " edge[headclip=false, tailclip=false, penwidth=30];\n"; + for (int box_index = 0; box_index < neighbours.NumRectangles(); ++box_index) { + for (int edge_int = 0; edge_int < 4; ++edge_int) { + const EdgePosition edge = static_cast(edge_int); + const auto edge_neighbors = + neighbours.GetSortedNeighbors(box_index, edge); + for (int neighbor : edge_neighbors) { + ss << " " << box_index << "->" << neighbor << " [color=\"" + << edge_colors.find(edge)->second << "\"];\n"; + } + } + } + return RenderDot(bb, rectangles, ss.str()); +} + +std::string RenderContour(std::optional bb, + absl::Span rectangles, + const ShapePath& path) { + const std::vector colors = {"red", "green", "blue", + "cyan", "yellow", "purple"}; + std::stringstream ss; + ss << " edge[headclip=false, tailclip=false, penwidth=30];\n"; + for (int i = 0; i < path.step_points.size(); ++i) { + std::pair p = path.step_points[i]; + ss << " p" << i << "[pos=\"" << 2 * p.first << "," << 2 * p.second + << "!\" shape=point]\n"; + if (i != path.step_points.size() - 1) { + ss << " p" << i << "->p" << i + 1 << "\n"; + } + } + return RenderDot(bb, rectangles, ss.str()); +} + +TEST(BuildNeighboursGraphTest, Simple) { + std::vector rectangles = { + {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, + {.x_min = 10, .x_max = 20, .y_min = 0, .y_max = 10}, + {.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}}; + const Neighbours neighbours = BuildNeighboursGraph(rectangles); + EXPECT_THAT(neighbours.GetSortedNeighbors(0, EdgePosition::RIGHT), + ElementsAre(1)); + EXPECT_THAT(neighbours.GetSortedNeighbors(0, EdgePosition::TOP), + ElementsAre(2)); + EXPECT_THAT(neighbours.GetSortedNeighbors(1, EdgePosition::LEFT), + ElementsAre(0)); + EXPECT_THAT(neighbours.GetSortedNeighbors(2, EdgePosition::BOTTOM), + ElementsAre(0)); +} + +TEST(BuildNeighboursGraphTest, NeighborsAroundCorner) { + std::vector rectangles = { + {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, + {.x_min = 10, .x_max = 20, .y_min = 10, .y_max = 20}}; + const Neighbours neighbours = BuildNeighboursGraph(rectangles); + for (int i = 0; i < 4; ++i) { + const EdgePosition edge = static_cast(i); + EXPECT_THAT(neighbours.GetSortedNeighbors(0, edge), IsEmpty()); + EXPECT_THAT(neighbours.GetSortedNeighbors(1, edge), IsEmpty()); + } +} + +TEST(BuildNeighboursGraphTest, RandomTest) { + constexpr int kNumRuns = 100; + absl::BitGen bit_gen; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector input = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, bit_gen); + std::shuffle(input.begin(), input.end(), bit_gen); + auto neighbours = BuildNeighboursGraph(input); + auto expected_neighbours = NaiveBuildNeighboursGraph(input); + for (int box_index = 0; box_index < neighbours.NumRectangles(); + ++box_index) { + for (int edge_int = 0; edge_int < 4; ++edge_int) { + const EdgePosition edge = static_cast(edge_int); + if (neighbours.GetSortedNeighbors(box_index, edge) != + expected_neighbours.GetSortedNeighbors(box_index, edge)) { + LOG(FATAL) << "Got:\n" + << RenderNeighborsGraph(std::nullopt, input, neighbours) + << "Expected:\n" + << RenderNeighborsGraph(std::nullopt, input, + expected_neighbours); + } + } + } + } +} + +struct ContourPoint { + IntegerValue x; + IntegerValue y; + int next_box_index; + EdgePosition next_direction; + + bool operator!=(const ContourPoint& other) const { + return x != other.x || y != other.y || + next_box_index != other.next_box_index || + next_direction != other.next_direction; + } +}; + +// This function runs in O(log N). +ContourPoint NextByClockwiseOrder(const ContourPoint& point, + absl::Span rectangles, + const Neighbours& neighbours) { + // This algorithm is very verbose, but it is about handling four cases. In the + // schema below, "-->" is the current direction, "X" the next point and + // the dashed arrow the next direction. + // + // Case 1: + // ++++++++ + // ^ ++++++++ + // : ++++++++ + // : ++++++++ + // ++++++++ + // ---> X ++++++++ + // ****************** + // ****************** + // ****************** + // ****************** + // + // Case 2: + // ^ ++++++++ + // : ++++++++ + // : ++++++++ + // ++++++++ + // ---> X ++++++++ + // *************++++++++ + // *************++++++++ + // ************* + // ************* + // + // Case 3: + // ---> X ...> + // *************++++++++ + // *************++++++++ + // *************++++++++ + // *************++++++++ + // + // Case 4: + // ---> X + // ************* : + // ************* : + // ************* : + // ************* \/ + ContourPoint result; + const Rectangle& cur_rectangle = rectangles[point.next_box_index]; + + EdgePosition cur_edge; + bool clockwise; + // Much of the code below need to know two things: in which direction we are + // going and what edge of which rectangle we are touching. For example, in the + // "Case 4" drawing above we are going RIGHT and touching the TOP edge of the + // current rectangle. This switch statement finds this `cur_edge`. + switch (point.next_direction) { + case EdgePosition::TOP: + if (cur_rectangle.x_max == point.x) { + cur_edge = EdgePosition::RIGHT; + clockwise = false; + } else { + cur_edge = EdgePosition::LEFT; + clockwise = true; + } + break; + case EdgePosition::BOTTOM: + if (cur_rectangle.x_min == point.x) { + cur_edge = EdgePosition::LEFT; + clockwise = false; + } else { + cur_edge = EdgePosition::RIGHT; + clockwise = true; + } + break; + case EdgePosition::LEFT: + if (cur_rectangle.y_max == point.y) { + cur_edge = EdgePosition::TOP; + clockwise = false; + } else { + cur_edge = EdgePosition::BOTTOM; + clockwise = true; + } + break; + case EdgePosition::RIGHT: + if (cur_rectangle.y_min == point.y) { + cur_edge = EdgePosition::BOTTOM; + clockwise = false; + } else { + cur_edge = EdgePosition::TOP; + clockwise = true; + } + break; + } + + // Test case 1. We need to find the next box after the current point in the + // edge we are following in the current direction. + const auto cur_edge_neighbors = + neighbours.GetSortedNeighbors(point.next_box_index, cur_edge); + + const Rectangle fake_box_for_lower_bound = { + .x_min = point.x, .x_max = point.x, .y_min = point.y, .y_max = point.y}; + const auto clockwise_cmp = Neighbours::CompareClockwise(cur_edge); + auto it = absl::c_lower_bound( + cur_edge_neighbors, -1, + [&fake_box_for_lower_bound, rectangles, clockwise_cmp, clockwise](int a, + int b) { + const Rectangle& rectangle_a = + (a == -1 ? fake_box_for_lower_bound : rectangles[a]); + const Rectangle& rectangle_b = + (b == -1 ? fake_box_for_lower_bound : rectangles[b]); + if (clockwise) { + return clockwise_cmp(rectangle_a, rectangle_b); + } else { + return clockwise_cmp(rectangle_b, rectangle_a); + } + }); + + if (it != cur_edge_neighbors.end()) { + // We found box in the current edge. We are in case 1. + result.next_box_index = *it; + const Rectangle& next_rectangle = rectangles[*it]; + switch (point.next_direction) { + case EdgePosition::TOP: + result.x = point.x; + result.y = next_rectangle.y_min; + if (cur_edge == EdgePosition::LEFT) { + result.next_direction = EdgePosition::LEFT; + } else { + result.next_direction = EdgePosition::RIGHT; + } + break; + case EdgePosition::BOTTOM: + result.x = point.x; + result.y = next_rectangle.y_max; + if (cur_edge == EdgePosition::LEFT) { + result.next_direction = EdgePosition::LEFT; + } else { + result.next_direction = EdgePosition::RIGHT; + } + break; + case EdgePosition::LEFT: + result.y = point.y; + result.x = next_rectangle.x_max; + if (cur_edge == EdgePosition::TOP) { + result.next_direction = EdgePosition::TOP; + } else { + result.next_direction = EdgePosition::BOTTOM; + } + break; + case EdgePosition::RIGHT: + result.y = point.y; + result.x = next_rectangle.x_min; + if (cur_edge == EdgePosition::TOP) { + result.next_direction = EdgePosition::TOP; + } else { + result.next_direction = EdgePosition::BOTTOM; + } + break; + } + return result; + } + + // We now know we are not in Case 1, so know the next (x, y) position: it is + // the corner of the current rectangle in the direction we are going. + switch (point.next_direction) { + case EdgePosition::TOP: + result.x = point.x; + result.y = cur_rectangle.y_max; + break; + case EdgePosition::BOTTOM: + result.x = point.x; + result.y = cur_rectangle.y_min; + break; + case EdgePosition::LEFT: + result.x = cur_rectangle.x_min; + result.y = point.y; + break; + case EdgePosition::RIGHT: + result.x = cur_rectangle.x_max; + result.y = point.y; + break; + } + + // Case 2 and 3. + const auto next_edge_neighbors = + neighbours.GetSortedNeighbors(point.next_box_index, point.next_direction); + if (!next_edge_neighbors.empty()) { + // We are looking for the neighbor on the edge of the current box. + const int candidate_index = + clockwise ? next_edge_neighbors.front() : next_edge_neighbors.back(); + const Rectangle& next_rectangle = rectangles[candidate_index]; + switch (point.next_direction) { + case EdgePosition::TOP: + case EdgePosition::BOTTOM: + if (next_rectangle.x_min < point.x && point.x < next_rectangle.x_max) { + // Case 2 + result.next_box_index = candidate_index; + if (cur_edge == EdgePosition::LEFT) { + result.next_direction = EdgePosition::LEFT; + } else { + result.next_direction = EdgePosition::RIGHT; + } + return result; + } else if (next_rectangle.x_min == point.x && + cur_edge == EdgePosition::LEFT) { + // Case 3 + result.next_box_index = candidate_index; + result.next_direction = point.next_direction; + return result; + } else if (next_rectangle.x_max == point.x && + cur_edge == EdgePosition::RIGHT) { + // Case 3 + result.next_box_index = candidate_index; + result.next_direction = point.next_direction; + return result; + } + break; + case EdgePosition::LEFT: + case EdgePosition::RIGHT: + if (next_rectangle.y_min < point.y && point.y < next_rectangle.y_max) { + result.next_box_index = candidate_index; + if (cur_edge == EdgePosition::TOP) { + result.next_direction = EdgePosition::TOP; + } else { + result.next_direction = EdgePosition::BOTTOM; + } + return result; + } else if (next_rectangle.y_max == point.y && + cur_edge == EdgePosition::TOP) { + result.next_box_index = candidate_index; + result.next_direction = point.next_direction; + return result; + } else if (next_rectangle.y_min == point.y && + cur_edge == EdgePosition::BOTTOM) { + result.next_box_index = candidate_index; + result.next_direction = point.next_direction; + return result; + } + break; + } + } + + // Now we must be in the case 4. + result.next_box_index = point.next_box_index; + switch (point.next_direction) { + case EdgePosition::TOP: + case EdgePosition::BOTTOM: + if (cur_edge == EdgePosition::LEFT) { + result.next_direction = EdgePosition::RIGHT; + } else { + result.next_direction = EdgePosition::LEFT; + } + break; + case EdgePosition::LEFT: + case EdgePosition::RIGHT: + if (cur_edge == EdgePosition::TOP) { + result.next_direction = EdgePosition::BOTTOM; + } else { + result.next_direction = EdgePosition::TOP; + } + break; + } + return result; +} + +// Returns a path delimiting a boundary of the union of a set of rectangles. It +// should work for both the exterior boundary and the boundaries of the holes +// inside the union. The path will start on `starting_point` and follow the +// boundary on clockwise order. +// +// `starting_point` should be a point in the boundary and `starting_box_index` +// the index of a rectangle with one edge containing `starting_point`. +// +// The resulting `path` satisfy: +// - path.step_points.front() == path.step_points.back() == starting_point +// - path.touching_box_index.front() == path.touching_box_index.back() == +// == starting_box_index +// +ShapePath TraceBoundary( + const std::pair& starting_step_point, + int starting_box_index, absl::Span rectangles, + const Neighbours& neighbours) { + // First find which direction we need to go to follow the border in the + // clockwise order. + const Rectangle& initial_rec = rectangles[starting_box_index]; + bool touching_edge[4]; + touching_edge[EdgePosition::LEFT] = + initial_rec.x_min == starting_step_point.first; + touching_edge[EdgePosition::RIGHT] = + initial_rec.x_max == starting_step_point.first; + touching_edge[EdgePosition::TOP] = + initial_rec.y_max == starting_step_point.second; + touching_edge[EdgePosition::BOTTOM] = + initial_rec.y_min == starting_step_point.second; + + EdgePosition next_direction; + if (touching_edge[EdgePosition::LEFT]) { + if (touching_edge[EdgePosition::TOP]) { + next_direction = EdgePosition::RIGHT; + } else { + next_direction = EdgePosition::TOP; + } + } else if (touching_edge[EdgePosition::RIGHT]) { + if (touching_edge[EdgePosition::BOTTOM]) { + next_direction = EdgePosition::LEFT; + } else { + next_direction = EdgePosition::BOTTOM; + } + } else if (touching_edge[EdgePosition::TOP]) { + next_direction = EdgePosition::LEFT; + } else if (touching_edge[EdgePosition::BOTTOM]) { + next_direction = EdgePosition::RIGHT; + } else { + LOG(FATAL) + << "TraceBoundary() got a `starting_step_point` that is not in an edge " + "of the rectangle of `starting_box_index`. This is not allowed."; + } + const ContourPoint starting_point = {.x = starting_step_point.first, + .y = starting_step_point.second, + .next_box_index = starting_box_index, + .next_direction = next_direction}; + ShapePath result; + for (ContourPoint point = starting_point; true; + point = NextByClockwiseOrder(point, rectangles, neighbours)) { + if (result.step_points.size() > 3 && + result.step_points.back() == result.step_points.front() && + point.x == result.step_points[1].first && + point.y == result.step_points[1].second) { + break; + } + if (!result.step_points.empty() && + point.x == result.step_points.back().first && + point.y == result.step_points.back().second) { + // There is a special corner-case of the algorithm using the neighbours. + // Consider the following set-up: + // + // ******** | + // ******** | + // ******** +----> + // ########++++++++ + // ########++++++++ + // ########++++++++ + // + // In this case, the only way the algorithm could reach the "+" box is via + // the "#" box, but which is doesn't contribute to the path. The algorithm + // returns a technically correct zero-size interval, which might be useful + // for callers that want to count the "#" box as visited, but this is not + // our case. + result.touching_box_index.back() = point.next_box_index; + } else { + result.touching_box_index.push_back(point.next_box_index); + result.step_points.push_back({point.x, point.y}); + } + } + return result; +} + +std::string RenderShapes(std::optional bb, + absl::Span rectangles, + const std::vector& shapes) { + const std::vector colors = {"black", "white", "orange", + "cyan", "yellow", "purple"}; + std::stringstream ss; + ss << " edge[headclip=false, tailclip=false, penwidth=40];\n"; + int count = 0; + for (int i = 0; i < shapes.size(); ++i) { + std::string_view shape_color = colors[i % colors.size()]; + for (int j = 0; j < shapes[i].boundary.step_points.size(); ++j) { + std::pair p = + shapes[i].boundary.step_points[j]; + ss << " p" << count << "[pos=\"" << 2 * p.first << "," << 2 * p.second + << "!\" shape=point]\n"; + if (j != shapes[i].boundary.step_points.size() - 1) { + ss << " p" << count << "->p" << count + 1 << " [color=\"" + << shape_color << "\"];\n"; + } + ++count; + } + for (const ShapePath& hole : shapes[i].holes) { + for (int j = 0; j < hole.step_points.size(); ++j) { + std::pair p = hole.step_points[j]; + ss << " p" << count << "[pos=\"" << 2 * p.first << "," << 2 * p.second + << "!\" shape=point]\n"; + if (j != hole.step_points.size() - 1) { + ss << " p" << count << "->p" << count + 1 << " [color=\"" + << shape_color << "\", penwidth=20];\n"; + } + ++count; + } + } + } + return RenderDot(bb, rectangles, ss.str()); +} + +TEST(ContourTest, Random) { + constexpr int kNumRuns = 100; + absl::BitGen bit_gen; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector input = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, bit_gen); + std::shuffle(input.begin(), input.end(), bit_gen); + const int num_fixed_rectangles = input.size() * 2 / 3; + const absl::Span fixed_rectangles = + absl::MakeConstSpan(input).subspan(0, num_fixed_rectangles); + const absl::Span other_rectangles = + absl::MakeSpan(input).subspan(num_fixed_rectangles); + const std::vector input_in_range = + MakeItemsFromRectangles(other_rectangles, 0.6, bit_gen); + + const Neighbours neighbours = BuildNeighboursGraph(fixed_rectangles); + const auto components = SplitInConnectedComponents(neighbours); + const Rectangle bb = {.x_min = 0, .x_max = 100, .y_min = 0, .y_max = 100}; + int min_index = -1; + std::pair min_coord = { + std::numeric_limits::max(), + std::numeric_limits::max()}; + for (const int box_index : components[0]) { + const Rectangle& rectangle = fixed_rectangles[box_index]; + if (std::make_pair(rectangle.x_min, rectangle.y_min) < min_coord) { + min_coord = {rectangle.x_min, rectangle.y_min}; + min_index = box_index; + } + } + + const std::vector shapes = + BoxesToShapes(fixed_rectangles, neighbours); + for (const SingleShape& shape : shapes) { + const ShapePath& boundary = shape.boundary; + const ShapePath expected_shape = + TraceBoundary(boundary.step_points[0], boundary.touching_box_index[0], + fixed_rectangles, neighbours); + if (boundary.step_points != expected_shape.step_points) { + LOG(ERROR) << "Fast algo:\n" + << RenderContour(bb, fixed_rectangles, boundary); + LOG(ERROR) << "Naive algo:\n" + << RenderContour(bb, fixed_rectangles, expected_shape); + LOG(FATAL) << "Found different solutions between naive and fast algo!"; + } + EXPECT_EQ(boundary.step_points, expected_shape.step_points); + EXPECT_EQ(boundary.touching_box_index, expected_shape.touching_box_index); + } + + if (run == 0) { + LOG(INFO) << RenderShapes(bb, fixed_rectangles, shapes); + } + } +} + +TEST(ContourTest, SimpleShapes) { + std::vector rectangles = { + {.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, + {.x_min = 3, .x_max = 8, .y_min = 0, .y_max = 10}}; + ShapePath shape = + TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 0, 1, 1, 1, 0, 0, 0)); + EXPECT_THAT(shape.step_points, + ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), + std::make_pair(10, 10), std::make_pair(8, 10), + std::make_pair(8, 0), std::make_pair(3, 0), + std::make_pair(3, 10), std::make_pair(0, 10), + std::make_pair(0, 20))); + + rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, + {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}}; + shape = + TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 0, 0)); + EXPECT_THAT(shape.step_points, + ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), + std::make_pair(10, 10), std::make_pair(10, 0), + std::make_pair(0, 0), std::make_pair(0, 10), + std::make_pair(0, 20))); + + rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, + {.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 10}}; + shape = + TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 1, 0, 0)); + EXPECT_THAT(shape.step_points, + ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), + std::make_pair(10, 10), std::make_pair(15, 10), + std::make_pair(15, 0), std::make_pair(0, 0), + std::make_pair(0, 10), std::make_pair(0, 20))); + + rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, + {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, + {.x_min = 10, .x_max = 20, .y_min = 0, .y_max = 10}}; + shape = + TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 2, 2, 2, 1, 1, 0, 0)); + EXPECT_THAT(shape.step_points, + ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), + std::make_pair(10, 10), std::make_pair(20, 10), + std::make_pair(20, 0), std::make_pair(10, 0), + std::make_pair(0, 0), std::make_pair(0, 10), + std::make_pair(0, 20))); +} + +TEST(ContourTest, ExampleFromPaper) { + const std::vector input = BuildFromAsciiArt(R"( + ******************* + ******************* + ********** ******************* + ********** ******************* + *************************************** + *************************************** + *************************************** + *************************************** + *********** ************** **** + *********** ************** **** + *********** ******* *** **** + *********** ******* *** **** + *********** ************** **** + *********** ************** **** + *********** ************** **** + *************************************** + *************************************** + *************************************** + ************************************** + ************************************** + ************************************** + ******************************* + *************************************** + *************************************** + **************** **************** + **************** **************** + ****** *** + ****** *** + ****** *** + ****** + )"); + const Neighbours neighbours = BuildNeighboursGraph(input); + auto s = BoxesToShapes(input, neighbours); + LOG(INFO) << RenderDot(std::nullopt, input); + const std::vector output = CutShapeIntoRectangles(s[0]); + LOG(INFO) << RenderDot(std::nullopt, output); + EXPECT_THAT(output.size(), 16); +} + +bool RectanglesCoverSameArea(absl::Span a, + absl::Span b) { + return RegionIncludesOther(a, b) && RegionIncludesOther(b, a); +} + +TEST(ReduceNumberOfBoxes, RandomTestNoOptional) { + constexpr int kNumRuns = 1000; + absl::BitGen bit_gen; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector input = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, bit_gen); + std::shuffle(input.begin(), input.end(), bit_gen); + + std::vector output = input; + std::vector optional_rectangles_empty; + ReduceNumberOfBoxesExactMandatory(&output, &optional_rectangles_empty); + if (run == 0) { + LOG(INFO) << "Presolved:\n" << RenderDot(std::nullopt, input); + LOG(INFO) << "To:\n" << RenderDot(std::nullopt, output); + } + + if (output.size() > input.size()) { + LOG(INFO) << "Presolved:\n" << RenderDot(std::nullopt, input); + LOG(INFO) << "To:\n" << RenderDot(std::nullopt, output); + ADD_FAILURE() << "ReduceNumberofBoxes() increased the number of boxes, " + "but it should be optimal in reducing them!"; + } + CHECK(RectanglesCoverSameArea(output, input)); + } +} + +TEST(ReduceNumberOfBoxes, Problematic) { + // This example shows that we must consider diagonals that touches only on its + // extremities as "intersecting" for the bipartite graph. + const std::vector input = { + {.x_min = 26, .x_max = 51, .y_min = 54, .y_max = 81}, + {.x_min = 51, .x_max = 78, .y_min = 44, .y_max = 67}, + {.x_min = 51, .x_max = 62, .y_min = 67, .y_max = 92}, + {.x_min = 78, .x_max = 98, .y_min = 24, .y_max = 54}, + }; + std::vector output = input; + std::vector optional_rectangles_empty; + ReduceNumberOfBoxesExactMandatory(&output, &optional_rectangles_empty); + LOG(INFO) << "Presolved:\n" << RenderDot(std::nullopt, input); + LOG(INFO) << "To:\n" << RenderDot(std::nullopt, output); +} + +// This example shows that sometimes the best solution with respect to minimum +// number of boxes requires *not* filling a hole. Actually this follows from the +// formula that the minimum number of rectangles in a partition of a polygon +// with n vertices and h holes is n/2 + h − g − 1, where g is the number of +// non-intersecting good diagonals. This test-case shows a polygon with 4 +// internal vertices, 1 hole and 4 non-intersecting good diagonals that includes +// the hole. Removing the hole reduces the n/2 term by 2, decrease the h term by +// 1, but decrease the g term by 4. +// +// *********************** +// *********************** +// ***********************..................... +// ***********************..................... +// ***********************..................... +// ***********************..................... +// ***********************..................... +// ++++++++++++++++++++++ ..................... +// ++++++++++++++++++++++ ..................... +// ++++++++++++++++++++++ ..................... +// ++++++++++++++++++++++000000000000000000000000 +// ++++++++++++++++++++++000000000000000000000000 +// ++++++++++++++++++++++000000000000000000000000 +// 000000000000000000000000 +// 000000000000000000000000 +// 000000000000000000000000 +// 000000000000000000000000 +// +TEST(ReduceNumberOfBoxes, Problematic2) { + const std::vector input = { + {.x_min = 64, .x_max = 82, .y_min = 76, .y_max = 98}, + {.x_min = 39, .x_max = 59, .y_min = 63, .y_max = 82}, + {.x_min = 59, .x_max = 78, .y_min = 61, .y_max = 76}, + {.x_min = 44, .x_max = 64, .y_min = 82, .y_max = 100}, + }; + std::vector output = input; + std::vector optional_rectangles = { + {.x_min = 59, .x_max = 64, .y_min = 76, .y_max = 82}, + }; + ReduceNumberOfBoxesExactMandatory(&output, &optional_rectangles); + LOG(INFO) << "Presolving:\n" << RenderDot(std::nullopt, input); + + // Presolve will refuse to do anything since removing the hole will increase + // the number of boxes. + CHECK(input == output); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/2d_try_edge_propagator.cc b/ortools/sat/2d_try_edge_propagator.cc new file mode 100644 index 00000000000..0d847090b8b --- /dev/null +++ b/ortools/sat/2d_try_edge_propagator.cc @@ -0,0 +1,307 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/2d_try_edge_propagator.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "ortools/base/logging.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/synchronization.h" + +namespace operations_research { +namespace sat { + +int TryEdgeRectanglePropagator::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + x_.WatchAllTasks(id); + y_.WatchAllTasks(id); + watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; +} + +TryEdgeRectanglePropagator::~TryEdgeRectanglePropagator() { + if (!VLOG_IS_ON(1)) return; + std::vector> stats; + stats.push_back({"TryEdgeRectanglePropagator/called", num_calls_}); + stats.push_back({"TryEdgeRectanglePropagator/conflicts", num_conflicts_}); + stats.push_back( + {"TryEdgeRectanglePropagator/propagations", num_propagations_}); + stats.push_back({"TryEdgeRectanglePropagator/cache_hits", num_cache_hits_}); + stats.push_back( + {"TryEdgeRectanglePropagator/cache_misses", num_cache_misses_}); + + shared_stats_->AddStats(stats); +} + +void TryEdgeRectanglePropagator::PopulateActiveBoxRanges() { + const int num_boxes = x_.NumTasks(); + active_box_ranges_.clear(); + active_box_ranges_.reserve(num_boxes); + for (int box = 0; box < num_boxes; ++box) { + if (x_.SizeMin(box) == 0 || y_.SizeMin(box) == 0) continue; + if (!x_.IsPresent(box) || !y_.IsPresent(box)) continue; + + active_box_ranges_.push_back(RectangleInRange{ + .box_index = box, + .bounding_area = {.x_min = x_.StartMin(box), + .x_max = x_.StartMax(box) + x_.SizeMin(box), + .y_min = y_.StartMin(box), + .y_max = y_.StartMax(box) + y_.SizeMin(box)}, + .x_size = x_.SizeMin(box), + .y_size = y_.SizeMin(box)}); + } + max_box_index_ = num_boxes - 1; +} + +bool TryEdgeRectanglePropagator::CanPlace( + int box_index, + const std::pair& position) const { + const Rectangle placed_box = { + .x_min = position.first, + .x_max = position.first + active_box_ranges_[box_index].x_size, + .y_min = position.second, + .y_max = position.second + active_box_ranges_[box_index].y_size}; + for (int i = 0; i < active_box_ranges_.size(); ++i) { + if (i == box_index) continue; + const RectangleInRange& box_reason = active_box_ranges_[i]; + const Rectangle mandatory_region = box_reason.GetMandatoryRegion(); + if (mandatory_region != Rectangle::GetEmpty() && + !mandatory_region.IsDisjoint(placed_box)) { + return false; + } + } + return true; +} + +bool TryEdgeRectanglePropagator::Propagate() { + if (!x_.SynchronizeAndSetTimeDirection(x_is_forward_)) return false; + if (!y_.SynchronizeAndSetTimeDirection(y_is_forward_)) return false; + + num_calls_++; + + PopulateActiveBoxRanges(); + + if (cached_y_hint_.size() <= max_box_index_) { + cached_y_hint_.resize(max_box_index_ + 1, + std::numeric_limits::max()); + } + + if (active_box_ranges_.size() < 2) { + return true; + } + + // Our algo is quadratic, so we don't want to run it on really large problems. + if (active_box_ranges_.size() > 1000) { + return true; + } + + potential_x_positions_.clear(); + potential_y_positions_.clear(); + std::vector>> found_propagations; + for (const RectangleInRange& box : active_box_ranges_) { + const Rectangle mandatory_region = box.GetMandatoryRegion(); + if (mandatory_region == Rectangle::GetEmpty()) { + continue; + } + potential_x_positions_.push_back(mandatory_region.x_max); + potential_y_positions_.push_back(mandatory_region.y_max); + } + std::sort(potential_x_positions_.begin(), potential_x_positions_.end()); + std::sort(potential_y_positions_.begin(), potential_y_positions_.end()); + + for (int i = 0; i < active_box_ranges_.size(); ++i) { + const RectangleInRange& box = active_box_ranges_[i]; + + // For each box, we need to answer whether there exist some y for which + // (x_min, y) is not in conflict with any other box. If there is no such y, + // we can propagate a larger lower bound on x. Now, for the most majority of + // cases there is nothing to propagate, so we want to find the y that makes + // (x_min, y) a valid placement as fast as possible. Now, since things don't + // change that often we try the last y value that was a valid placement for + // this box. This is just a hint: if it is not a valid placement, we will + // try all "interesting" y values before concluding that no such y exist. + const IntegerValue cached_y_hint = cached_y_hint_[box.box_index]; + if (cached_y_hint >= box.bounding_area.y_min && + cached_y_hint <= box.bounding_area.y_max - box.y_size) { + if (CanPlace(i, {box.bounding_area.x_min, cached_y_hint})) { + num_cache_hits_++; + continue; + } + } + num_cache_misses_++; + if (CanPlace(i, {box.bounding_area.x_min, box.bounding_area.y_min})) { + cached_y_hint_[box.box_index] = box.bounding_area.y_min; + continue; + } + + bool placed_at_x_min = false; + const int y_start = + absl::c_lower_bound(potential_y_positions_, box.bounding_area.y_min) - + potential_y_positions_.begin(); + for (int j = y_start; j < potential_y_positions_.size(); ++j) { + if (potential_y_positions_[j] > box.bounding_area.y_max - box.y_size) { + // potential_y_positions is sorted, so we can stop here. + break; + } + if (CanPlace(i, {box.bounding_area.x_min, potential_y_positions_[j]})) { + placed_at_x_min = true; + cached_y_hint_[box.box_index] = potential_y_positions_[j]; + break; + } + } + if (placed_at_x_min) continue; + + // We could not find any placement of the box at its current lower bound! + // Thus, we are sure we have something to propagate. Let's find the new + // lower bound (or a conflict). Note that the code below is much less + // performance critical than the code above, since it only triggers on + // propagations. + std::optional new_x_min = std::nullopt; + for (int j = 0; j < potential_x_positions_.size(); ++j) { + if (potential_x_positions_[j] < box.bounding_area.x_min) { + continue; + } + if (potential_x_positions_[j] > box.bounding_area.x_max - box.x_size) { + continue; + } + if (CanPlace(i, {potential_x_positions_[j], box.bounding_area.y_min})) { + new_x_min = potential_x_positions_[j]; + break; + } + for (int k = y_start; k < potential_y_positions_.size(); ++k) { + const IntegerValue potential_y_position = potential_y_positions_[k]; + if (potential_y_position > box.bounding_area.y_max - box.y_size) { + break; + } + if (CanPlace(i, {potential_x_positions_[j], potential_y_position})) { + // potential_x_positions is sorted, so the first we found is the + // lowest one. + new_x_min = potential_x_positions_[j]; + break; + } + } + if (new_x_min.has_value()) { + break; + } + } + found_propagations.push_back({i, new_x_min}); + } + return ExplainAndPropagate(found_propagations); +} + +bool TryEdgeRectanglePropagator::ExplainAndPropagate( + const std::vector>>& + found_propagations) { + for (const auto& [box_index, new_x_min] : found_propagations) { + const RectangleInRange& box = active_box_ranges_[box_index]; + x_.ClearReason(); + y_.ClearReason(); + for (int j = 0; j < active_box_ranges_.size(); ++j) { + // Important: we also add to the reason the actual box we are changing the + // x_min. This is important, since we don't check if there are any + // feasible placement before its current x_min, so it needs to be part of + // the reason. + const RectangleInRange& box_reason = active_box_ranges_[j]; + if (j != box_index) { + const Rectangle mandatory_region = box_reason.GetMandatoryRegion(); + if (mandatory_region == Rectangle::GetEmpty()) { + continue; + } + // Don't add to the reason any box that was not participating in the + // placement decision. Ie., anything before the old x_min or after the + // new x_max. + if (new_x_min.has_value() && + mandatory_region.x_min > *new_x_min + box_reason.x_size) { + continue; + } + if (new_x_min.has_value() && + mandatory_region.x_max < box.bounding_area.x_min) { + continue; + } + if (mandatory_region.y_min > box.bounding_area.y_max || + mandatory_region.y_max < box.bounding_area.y_min) { + continue; + } + } + + const int b = box_reason.box_index; + + x_.AddStartMinReason(b, box_reason.bounding_area.x_min); + y_.AddStartMinReason(b, box_reason.bounding_area.y_min); + + x_.AddStartMaxReason(b, + box_reason.bounding_area.x_max - box_reason.x_size); + y_.AddStartMaxReason(b, + box_reason.bounding_area.y_max - box_reason.y_size); + + x_.AddSizeMinReason(b); + y_.AddSizeMinReason(b); + + x_.AddPresenceReason(b); + y_.AddPresenceReason(b); + } + x_.ImportOtherReasons(y_); + if (new_x_min.has_value()) { + num_propagations_++; + if (!x_.IncreaseStartMin(box.box_index, *new_x_min)) { + return false; + } + } else { + num_conflicts_++; + return x_.ReportConflict(); + } + } + return true; +} + +void CreateAndRegisterTryEdgePropagator(SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, + Model* model, + GenericLiteralWatcher* watcher) { + TryEdgeRectanglePropagator* try_edge_propagator = + new TryEdgeRectanglePropagator(true, true, x, y, model); + watcher->SetPropagatorPriority(try_edge_propagator->RegisterWith(watcher), 5); + model->TakeOwnership(try_edge_propagator); + + TryEdgeRectanglePropagator* try_edge_propagator_mirrored = + new TryEdgeRectanglePropagator(false, true, x, y, model); + watcher->SetPropagatorPriority( + try_edge_propagator_mirrored->RegisterWith(watcher), 5); + model->TakeOwnership(try_edge_propagator_mirrored); + + TryEdgeRectanglePropagator* try_edge_propagator_swap = + new TryEdgeRectanglePropagator(true, true, y, x, model); + watcher->SetPropagatorPriority( + try_edge_propagator_swap->RegisterWith(watcher), 5); + model->TakeOwnership(try_edge_propagator_swap); + + TryEdgeRectanglePropagator* try_edge_propagator_swap_mirrored = + new TryEdgeRectanglePropagator(false, true, y, x, model); + watcher->SetPropagatorPriority( + try_edge_propagator_swap_mirrored->RegisterWith(watcher), 5); + model->TakeOwnership(try_edge_propagator_swap_mirrored); +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/2d_try_edge_propagator.h b/ortools/sat/2d_try_edge_propagator.h new file mode 100644 index 00000000000..526ab040b5a --- /dev/null +++ b/ortools/sat/2d_try_edge_propagator.h @@ -0,0 +1,96 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_2D_TRY_EDGE_PROPAGATOR_H_ +#define OR_TOOLS_SAT_2D_TRY_EDGE_PROPAGATOR_H_ + +#include +#include +#include +#include + +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/synchronization.h" + +namespace operations_research { +namespace sat { + +// Propagator that for each boxes participating in a no_overlap_2d constraint +// try to find the leftmost valid position that is compatible with all the +// other boxes. If none is found, it will propagate a conflict. Otherwise, if +// it is different from the current x_min, it will propagate the new x_min. +void CreateAndRegisterTryEdgePropagator(SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, + Model* model, + GenericLiteralWatcher* watcher); + +// Exposed for testing. +class TryEdgeRectanglePropagator : public PropagatorInterface { + public: + TryEdgeRectanglePropagator(bool x_is_forward, bool y_is_forward, + SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, Model* model) + : x_(*x), + y_(*y), + shared_stats_(model->GetOrCreate()), + x_is_forward_(x_is_forward), + y_is_forward_(y_is_forward) {} + + ~TryEdgeRectanglePropagator() override; + + bool Propagate() final; + int RegisterWith(GenericLiteralWatcher* watcher); + + protected: + std::vector active_box_ranges_; + int max_box_index_ = 0; + + // Must also populate max_box_index_. + virtual void PopulateActiveBoxRanges(); + + virtual bool ExplainAndPropagate( + const std::vector>>& + found_propagations); + + private: + SchedulingConstraintHelper& x_; + SchedulingConstraintHelper& y_; + SharedStatistics* shared_stats_; + bool x_is_forward_; + bool y_is_forward_; + std::vector cached_y_hint_; + + std::vector potential_x_positions_; + std::vector potential_y_positions_; + + int64_t num_conflicts_ = 0; + int64_t num_propagations_ = 0; + int64_t num_calls_ = 0; + int64_t num_cache_hits_ = 0; + int64_t num_cache_misses_ = 0; + + bool CanPlace(int box_index, + const std::pair& position) const; + + TryEdgeRectanglePropagator(const TryEdgeRectanglePropagator&) = delete; + TryEdgeRectanglePropagator& operator=(const TryEdgeRectanglePropagator&) = + delete; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_2D_TRY_EDGE_PROPAGATOR_H_ diff --git a/ortools/sat/2d_try_edge_propagator_test.cc b/ortools/sat/2d_try_edge_propagator_test.cc new file mode 100644 index 00000000000..f200431553c --- /dev/null +++ b/ortools/sat/2d_try_edge_propagator_test.cc @@ -0,0 +1,152 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/2d_try_edge_propagator.h" + +#include +#include +#include +#include + +#include "absl/random/random.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/2d_orthogonal_packing_testing.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::_; +using ::testing::Each; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class TryEdgeRectanglePropagatorForTest : public TryEdgeRectanglePropagator { + public: + explicit TryEdgeRectanglePropagatorForTest( + Model* model, std::vector active_box_ranges) + : TryEdgeRectanglePropagator(true, true, GetHelperFromModel(model), + GetHelperFromModel(model), model) { + active_box_ranges_ = std::move(active_box_ranges); + } + + void PopulateActiveBoxRanges() override { + max_box_index_ = 0; + for (const RectangleInRange& range : active_box_ranges_) { + if (range.box_index > max_box_index_) { + max_box_index_ = range.box_index; + } + } + } + + bool ExplainAndPropagate( + const std::vector>>& + found_propagations) override { + propagations_ = found_propagations; + return false; + } + + const std::vector>>& propagations() + const { + return propagations_; + } + + private: + static SchedulingConstraintHelper* GetHelperFromModel(Model* model) { + return model->GetOrCreate()->GetOrCreateHelper({}); + } + + Model model_; + IntervalsRepository* repository_ = model_.GetOrCreate(); + + std::vector>> propagations_; +}; + +TEST(TryEdgeRectanglePropagatorTest, Simple) { + // ********** + // ********** To place: + // ********** ++++++++ + // ********** ++++++++ + // ++++++++ + // ++++++++++ ++++++++ + // ++++++++++ + // ++++++++++ + // ++++++++++ + // + // The object to place can only be on the right of the two placed ones. + std::vector active_box_ranges = { + {.box_index = 0, + .bounding_area = {.x_min = 0, .x_max = 5, .y_min = 0, .y_max = 5}, + .x_size = 5, + .y_size = 5}, + {.box_index = 1, + .bounding_area = {.x_min = 0, .x_max = 5, .y_min = 6, .y_max = 11}, + .x_size = 5, + .y_size = 5}, + {.box_index = 2, + .bounding_area = {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, + .x_size = 5, + .y_size = 5}, + }; + Model model; + TryEdgeRectanglePropagatorForTest propagator(&model, active_box_ranges); + propagator.Propagate(); + EXPECT_THAT(propagator.propagations(), + UnorderedElementsAre(Pair(2, IntegerValue(5)))); + + // Now the same thing, but makes it a conflict + active_box_ranges[2].bounding_area.x_min = 0; + active_box_ranges[2].bounding_area.x_max = 5; + TryEdgeRectanglePropagatorForTest propagator2(&model, active_box_ranges); + propagator2.Propagate(); + EXPECT_THAT(propagator2.propagations(), + UnorderedElementsAre(Pair(2, std::nullopt))); +} + +TEST(TryEdgeRectanglePropagatorTest, NoConflictForFeasible) { + constexpr int kNumRuns = 100; + absl::BitGen bit_gen; + Model model; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector rectangles = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, bit_gen); + std::shuffle(rectangles.begin(), rectangles.end(), bit_gen); + const std::vector input_in_range = + MakeItemsFromRectangles(rectangles, 0.6, bit_gen); + + TryEdgeRectanglePropagatorForTest propagator(&model, input_in_range); + propagator.Propagate(); + EXPECT_THAT(propagator.propagations(), + Each(Pair(_, Not(Eq(std::nullopt))))); + + // Now check that the propagations are not in conflict with the initial + // solution. + for (const auto& [box_index, new_x_min] : propagator.propagations()) { + EXPECT_LE(*new_x_min, rectangles[box_index].x_min); + } + } +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index a6781b4ccdf..c21059e05e9 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -13,8 +13,8 @@ # Home of CP/SAT solver (which includes SAT, max-SAT and PB problems). -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library") load("@rules_java//java:defs.bzl", "java_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_python//python:proto.bzl", "py_proto_library") @@ -52,6 +52,16 @@ cc_library( ], ) +cc_test( + name = "model_test", + size = "small", + srcs = ["model_test.cc"], + deps = [ + ":model", + "//ortools/base:gmock_main", + ], +) + proto_library( name = "sat_parameters_proto", srcs = ["sat_parameters.proto"], @@ -62,12 +72,6 @@ cc_proto_library( deps = [":sat_parameters_proto"], ) -go_proto_library( - name = "sat_parameters_go_proto", - proto = ":sat_parameters_proto", - importpath = "github.com/google/or-tools/ortools/sat/proto/satparameters" -) - py_proto_library( name = "sat_parameters_py_pb2", deps = [":sat_parameters_proto"], @@ -78,6 +82,12 @@ java_proto_library( deps = [":sat_parameters_proto"], ) +go_proto_library( + name = "sat_parameters_go_proto", + importpath = "github.com/google/or-tools/ortools/sat/proto/satparameters", + protos = [":sat_parameters_proto"], +) + proto_library( name = "cp_model_proto", srcs = ["cp_model.proto"], @@ -88,12 +98,6 @@ cc_proto_library( deps = [":cp_model_proto"], ) -go_proto_library( - name = "cp_model_go_proto", - importpath = "github.com/google/or-tools/ortools/sat/proto/cpmodel", - proto = ":cp_model_proto", -) - py_proto_library( name = "cp_model_py_pb2", deps = [":cp_model_proto"], @@ -104,6 +108,12 @@ java_proto_library( deps = [":cp_model_proto"], ) +go_proto_library( + name = "cp_model_go_proto", + importpath = "github.com/google/or-tools/ortools/sat/proto/cpmodel", + protos = [":cp_model_proto"], +) + cc_library( name = "cp_model_utils", srcs = ["cp_model_utils.cc"], @@ -252,6 +262,15 @@ cc_library( ], ) +cc_test( + name = "feasibility_jump_test", + srcs = ["feasibility_jump_test.cc"], + deps = [ + ":feasibility_jump", + "//ortools/base:gmock_main", + ], +) + cc_library( name = "linear_model", srcs = ["linear_model.cc"], @@ -278,6 +297,18 @@ cc_library( ], ) +cc_test( + name = "parameters_validation_test", + size = "small", + srcs = ["parameters_validation_test.cc"], + deps = [ + ":parameters_validation", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cp_model_search", srcs = ["cp_model_search.cc"], @@ -626,7 +657,6 @@ cc_library( hdrs = ["presolve_context.h"], deps = [ ":cp_model_cc_proto", - ":cp_model_checker", ":cp_model_loader", ":cp_model_mapping", ":cp_model_utils", @@ -637,6 +667,7 @@ cc_library( ":sat_parameters_cc_proto", ":sat_solver", ":util", + "//ortools/algorithms:sparse_permutation", "//ortools/base", "//ortools/base:mathutil", "//ortools/port:proto_utils", @@ -756,10 +787,8 @@ cc_library( ":util", "//ortools/base", "//ortools/base:stl_util", - "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:logging", - "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -768,6 +797,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) @@ -790,6 +820,17 @@ cc_library( ], ) +cc_test( + name = "sat_base_test", + size = "small", + srcs = ["sat_base_test.cc"], + deps = [ + ":sat_base", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + ], +) + # Enable a warning to check for floating point to integer conversions. # In GCC-4.8, this was "-Wreal-conversion", but was removed in 4.9 # In Clang, this warning is "-Wfloat-conversion" @@ -824,6 +865,7 @@ cc_library( "//ortools/util:stats", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -853,6 +895,19 @@ cc_library( ], ) +cc_test( + name = "restart_test", + size = "small", + srcs = ["restart_test.cc"], + deps = [ + ":model", + ":restart", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + "@com_google_absl//absl/base:core_headers", + ], +) + cc_library( name = "probing", srcs = ["probing.cc"], @@ -883,6 +938,22 @@ cc_library( ], ) +cc_test( + name = "probing_test", + size = "small", + srcs = ["probing_test.cc"], + deps = [ + ":integer", + ":model", + ":probing", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "sat_inprocessing", srcs = ["sat_inprocessing.cc"], @@ -916,6 +987,23 @@ cc_library( ], ) +cc_test( + name = "sat_inprocessing_test", + size = "small", + srcs = ["sat_inprocessing_test.cc"], + deps = [ + ":clause", + ":model", + ":sat_base", + ":sat_inprocessing", + ":sat_solver", + "//ortools/base:gmock_main", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "sat_decision", srcs = ["sat_decision.cc"], @@ -951,6 +1039,7 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/base:timer", + "//ortools/graph:cliques", "//ortools/graph:strongly_connected_components", "//ortools/util:bitset", "//ortools/util:stats", @@ -1021,6 +1110,23 @@ cc_library( ], ) +cc_test( + name = "pb_constraint_test", + size = "small", + srcs = ["pb_constraint_test.cc"], + deps = [ + ":model", + ":pb_constraint", + ":sat_base", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/util:strong_integers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "symmetry", srcs = ["symmetry.cc"], @@ -1036,6 +1142,19 @@ cc_library( ], ) +cc_test( + name = "symmetry_test", + size = "small", + srcs = ["symmetry_test.cc"], + deps = [ + ":sat_base", + ":symmetry", + "//ortools/algorithms:sparse_permutation", + "//ortools/base:gmock_main", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "symmetry_util", srcs = ["symmetry_util.cc"], @@ -1044,6 +1163,7 @@ cc_library( "//ortools/algorithms:dynamic_partition", "//ortools/algorithms:sparse_permutation", "//ortools/base", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], @@ -1057,6 +1177,7 @@ cc_test( ":symmetry_util", "//ortools/algorithms:sparse_permutation", "//ortools/base:gmock_main", + "@com_google_absl//absl/types:span", ], ) @@ -1207,6 +1328,22 @@ cc_library( ], ) +cc_test( + name = "pseudo_costs_test", + size = "small", + srcs = ["pseudo_costs_test.cc"], + deps = [ + ":integer", + ":model", + ":pseudo_costs", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + ], +) + cc_library( name = "intervals", srcs = ["intervals.cc"], @@ -1236,6 +1373,22 @@ cc_library( ], ) +cc_test( + name = "intervals_test", + size = "small", + srcs = ["intervals_test.cc"], + deps = [ + ":integer", + ":intervals", + ":linear_constraint", + ":model", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + ], +) + cc_library( name = "precedences", srcs = ["precedences.cc"], @@ -1271,6 +1424,46 @@ cc_library( ], ) +cc_test( + name = "precedences_test", + size = "small", + srcs = ["precedences_test.cc"], + deps = [ + ":integer", + ":integer_search", + ":model", + ":precedences", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/base:types", + "//ortools/util:sorted_interval_list", + "//ortools/util:strong_integers", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "integer_test", + size = "small", + srcs = ["integer_test.cc"], + deps = [ + ":integer", + ":integer_search", + ":model", + ":sat_base", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/base:types", + "//ortools/util:sorted_interval_list", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "integer_expr", srcs = ["integer_expr.cc"], @@ -1330,6 +1523,23 @@ cc_library( ], ) +cc_test( + name = "linear_propagation_test", + size = "small", + srcs = ["linear_propagation_test.cc"], + deps = [ + ":integer", + ":linear_propagation", + ":model", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "all_different", srcs = ["all_different.cc"], @@ -1351,6 +1561,22 @@ cc_library( ], ) +cc_test( + name = "all_different_test", + srcs = ["all_different_test.cc"], + deps = [ + ":all_different", + ":integer", + ":integer_search", + ":model", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/base:types", + "//ortools/util:sorted_interval_list", + ], +) + cc_library( name = "theta_tree", srcs = ["theta_tree.cc"], @@ -1362,6 +1588,20 @@ cc_library( ], ) +cc_test( + name = "theta_tree_test", + size = "small", + srcs = ["theta_tree_test.cc"], + deps = [ + ":integer", + ":theta_tree", + "//ortools/base:gmock_main", + "//ortools/util:random_engine", + "//ortools/util:strong_integers", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "disjunctive", srcs = ["disjunctive.cc"], @@ -1388,32 +1628,80 @@ cc_library( ], ) -cc_library( - name = "timetable", - srcs = ["timetable.cc"], - hdrs = ["timetable.h"], +cc_test( + name = "disjunctive_test", + size = "small", + srcs = ["disjunctive_test.cc"], deps = [ + ":disjunctive", ":integer", + ":integer_search", ":intervals", ":model", + ":precedences", ":sat_base", - "//ortools/util:rev", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", "//ortools/util:strong_integers", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_library( - name = "timetable_edgefinding", - srcs = ["timetable_edgefinding.cc"], - hdrs = ["timetable_edgefinding.h"], + name = "timetable", + srcs = ["timetable.cc"], + hdrs = ["timetable.h"], deps = [ ":integer", ":intervals", ":model", ":sat_base", - "//ortools/base:iterator_adaptors", + "//ortools/util:rev", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "timetable_test", + size = "medium", + srcs = ["timetable_test.cc"], + deps = [ + ":all_different", + ":cumulative", + ":integer", + ":integer_search", + ":intervals", + ":model", + ":precedences", + ":sat_base", + ":sat_solver", + ":timetable", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "timetable_edgefinding", + srcs = ["timetable_edgefinding.cc"], + hdrs = ["timetable_edgefinding.h"], + deps = [ + ":integer", + ":intervals", + ":model", + ":sat_base", + "//ortools/base:iterator_adaptors", "//ortools/util:strong_integers", "@com_google_absl//absl/log:check", ], @@ -1445,6 +1733,29 @@ cc_library( ], ) +cc_test( + name = "cumulative_test", + size = "large", + srcs = ["cumulative_test.cc"], + shard_count = 32, + deps = [ + ":cumulative", + ":integer", + ":integer_search", + ":intervals", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "cumulative_energy", srcs = ["cumulative_energy.cc"], @@ -1467,6 +1778,37 @@ cc_library( ], ) +cc_test( + name = "cumulative_energy_test", + size = "medium", + srcs = ["cumulative_energy_test.cc"], + deps = [ + ":2d_orthogonal_packing_testing", + ":cp_model_solver", + ":cumulative", + ":cumulative_energy", + ":diffn_util", + ":integer", + ":integer_search", + ":intervals", + ":linear_constraint", + ":model", + ":precedences", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "boolean_problem", srcs = ["boolean_problem.cc"], @@ -1564,6 +1906,21 @@ cc_library( ], ) +cc_test( + name = "linear_constraint_test", + srcs = ["linear_constraint_test.cc"], + deps = [ + ":integer", + ":linear_constraint", + ":model", + ":sat_base", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/util:strong_integers", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "linear_programming_constraint", srcs = ["linear_programming_constraint.cc"], @@ -1640,6 +1997,24 @@ cc_library( ], ) +cc_test( + name = "linear_constraint_manager_test", + srcs = ["linear_constraint_manager_test.cc"], + deps = [ + ":integer", + ":linear_constraint", + ":linear_constraint_manager", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/glop:variables_info", + "//ortools/lp_data:base", + "//ortools/util:strong_integers", + ], +) + cc_library( name = "cuts", srcs = ["cuts.cc"], @@ -1673,6 +2048,29 @@ cc_library( ], ) +cc_test( + name = "cuts_test", + srcs = ["cuts_test.cc"], + deps = [ + ":cuts", + ":implied_bounds", + ":integer", + ":linear_constraint", + ":linear_constraint_manager", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/util:fp_utils", + "//ortools/util:sorted_interval_list", + "//ortools/util:strong_integers", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "routing_cuts", srcs = ["routing_cuts.cc"], @@ -1700,6 +2098,25 @@ cc_library( ], ) +cc_test( + name = "routing_cuts_test", + srcs = ["routing_cuts_test.cc"], + deps = [ + ":cuts", + ":integer", + ":linear_constraint", + ":linear_constraint_manager", + ":model", + ":routing_cuts", + ":sat_base", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/graph:max_flow", + "//ortools/util:strong_integers", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "scheduling_cuts", srcs = ["scheduling_cuts.cc"], @@ -1730,6 +2147,30 @@ cc_library( ], ) +cc_test( + name = "scheduling_cuts_test", + srcs = ["scheduling_cuts_test.cc"], + deps = [ + ":cp_model", + ":cp_model_cc_proto", + ":cp_model_solver", + ":cuts", + ":integer", + ":intervals", + ":linear_constraint", + ":linear_constraint_manager", + ":model", + ":sat_base", + ":scheduling_cuts", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/util:strong_integers", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "diffn_cuts", srcs = ["diffn_cuts.cc"], @@ -1776,6 +2217,17 @@ cc_library( ], ) +cc_test( + name = "zero_half_cuts_test", + srcs = ["zero_half_cuts_test.cc"], + deps = [ + ":integer", + ":zero_half_cuts", + "//ortools/base:gmock_main", + "//ortools/lp_data:base", + ], +) + cc_library( name = "lp_utils", srcs = ["lp_utils.cc"], @@ -1891,6 +2343,29 @@ cc_library( ], ) +cc_test( + name = "optimization_test", + srcs = ["optimization_test.cc"], + deps = [ + ":boolean_problem", + ":boolean_problem_cc_proto", + ":integer", + ":integer_search", + ":model", + ":optimization", + ":pb_constraint", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/util:strong_integers", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "util", srcs = ["util.cc"], @@ -1979,6 +2454,23 @@ cc_library( ], ) +cc_test( + name = "cp_constraints_test", + srcs = ["cp_constraints_test.cc"], + deps = [ + ":cp_constraints", + ":integer", + ":integer_search", + ":model", + ":precedences", + ":sat_base", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "diffn_util", srcs = ["diffn_util.cc"], @@ -2043,8 +2535,10 @@ cc_library( ":diffn_util", ":integer", "//ortools/base:stl_util", + "//ortools/graph:max_flow", "//ortools/graph:strongly_connected_components", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -2055,6 +2549,27 @@ cc_library( ], ) +cc_test( + name = "2d_rectangle_presolve_test", + srcs = ["2d_rectangle_presolve_test.cc"], + deps = [ + ":2d_orthogonal_packing_testing", + ":2d_rectangle_presolve", + ":diffn_util", + ":integer", + "//ortools/base:gmock_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "2d_orthogonal_packing_testing", testonly = 1, @@ -2070,12 +2585,66 @@ cc_library( ], ) +cc_library( + name = "2d_try_edge_propagator", + srcs = ["2d_try_edge_propagator.cc"], + hdrs = ["2d_try_edge_propagator.h"], + deps = [ + ":diffn_util", + ":integer", + ":intervals", + ":model", + ":synchronization", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + ], +) + +cc_test( + name = "2d_try_edge_propagator_test", + srcs = ["2d_try_edge_propagator_test.cc"], + deps = [ + ":2d_orthogonal_packing_testing", + ":2d_try_edge_propagator", + ":diffn_util", + ":integer", + ":intervals", + ":model", + "//ortools/base:gmock_main", + "@com_google_absl//absl/random", + ], +) + +cc_test( + name = "diffn_util_test", + size = "small", + srcs = ["diffn_util_test.cc"], + deps = [ + ":2d_orthogonal_packing_testing", + ":diffn_util", + ":integer", + "//ortools/base", + "//ortools/base:gmock_main", + "//ortools/graph:connected_components", + "//ortools/util:strong_integers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_benchmark//:benchmark", + ], +) + cc_library( name = "diffn", srcs = ["diffn.cc"], hdrs = ["diffn.h"], deps = [ ":2d_orthogonal_packing", + ":2d_try_edge_propagator", ":cumulative_energy", ":diffn_util", ":disjunctive", @@ -2099,6 +2668,28 @@ cc_library( ], ) +cc_test( + name = "diffn_test", + size = "small", + srcs = ["diffn_test.cc"], + deps = [ + ":cp_model", + ":cp_model_cc_proto", + ":cp_model_solver", + ":diffn", + ":integer", + ":integer_search", + ":intervals", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base", + "//ortools/base:gmock_main", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "circuit", srcs = ["circuit.cc"], @@ -2121,6 +2712,23 @@ cc_library( ], ) +cc_test( + name = "circuit_test", + srcs = ["circuit_test.cc"], + deps = [ + ":circuit", + ":integer", + ":integer_search", + ":model", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/graph:strongly_connected_components", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "encoding", srcs = ["encoding.cc"], @@ -2141,6 +2749,19 @@ cc_library( ], ) +cc_test( + name = "encoding_test", + srcs = ["encoding_test.cc"], + deps = [ + ":encoding", + ":pb_constraint", + ":sat_base", + ":sat_solver", + "//ortools/base:gmock_main", + "@com_google_absl//absl/random:distributions", + ], +) + cc_library( name = "cp_model_lns", srcs = ["cp_model_lns.cc"], @@ -2255,6 +2876,17 @@ cc_library( ], ) +cc_test( + name = "subsolver_test", + size = "small", + srcs = ["subsolver_test.cc"], + deps = [ + ":subsolver", + "//ortools/base:gmock_main", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "drat_proof_handler", srcs = ["drat_proof_handler.cc"], @@ -2407,6 +3039,7 @@ cc_library( ":cp_model_utils", ":model", ":sat_parameters_cc_proto", + ":util", "//ortools/util:logging", "//ortools/util:sorted_interval_list", "//ortools/util:time_limit", @@ -2444,6 +3077,27 @@ cc_library( ], ) +cc_test( + name = "implied_bounds_test", + size = "small", + srcs = ["implied_bounds_test.cc"], + deps = [ + ":implied_bounds", + ":integer", + ":linear_constraint", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + "//ortools/base:gmock_main", + "//ortools/base:strong_vector", + "//ortools/lp_data:base", + "//ortools/util:strong_integers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "inclusion", hdrs = ["inclusion.h"], @@ -2497,3 +3151,16 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_test( + name = "inclusion_test", + size = "small", + srcs = ["inclusion_test.cc"], + deps = [ + ":inclusion", + ":util", + "//ortools/base:gmock_main", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + ], +) diff --git a/ortools/sat/CMakeLists.txt b/ortools/sat/CMakeLists.txt index 99f2113edb4..c2a1d46fa2a 100644 --- a/ortools/sat/CMakeLists.txt +++ b/ortools/sat/CMakeLists.txt @@ -47,6 +47,7 @@ if(BUILD_TESTING) FILE_NAME ${FILE_NAME} DEPS + benchmark::benchmark GTest::gmock GTest::gtest_main ) diff --git a/ortools/sat/all_different_test.cc b/ortools/sat/all_different_test.cc new file mode 100644 index 00000000000..bdc91b60900 --- /dev/null +++ b/ortools/sat/all_different_test.cc @@ -0,0 +1,159 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/all_different.h" + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/base/types.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace { + +class AllDifferentTest : public ::testing::TestWithParam { + public: + std::function AllDifferent( + const std::vector& vars) { + return [=](Model* model) { + if (GetParam() == "binary") { + model->Add(AllDifferentBinary(vars)); + } else if (GetParam() == "ac") { + model->Add(AllDifferentBinary(vars)); + model->Add(AllDifferentAC(vars)); + } else if (GetParam() == "bounds") { + model->Add(AllDifferentOnBounds(vars)); + } else { + LOG(FATAL) << "Unknown implementation " << GetParam(); + } + }; + } +}; + +INSTANTIATE_TEST_SUITE_P(All, AllDifferentTest, + ::testing::Values("binary", "ac", "bounds")); + +TEST_P(AllDifferentTest, BasicBehavior) { + Model model; + std::vector vars; + vars.push_back(model.Add(NewIntegerVariable(1, 3))); + vars.push_back(model.Add(NewIntegerVariable(0, 2))); + vars.push_back(model.Add(NewIntegerVariable(1, 3))); + vars.push_back(model.Add(NewIntegerVariable(0, 2))); + model.Add(AllDifferent(vars)); + EXPECT_EQ(SatSolver::FEASIBLE, SolveIntegerProblemWithLazyEncoding(&model)); + + std::vector value_seen(5, false); + for (const IntegerVariable var : vars) { + const int64_t value = model.Get(Value(var)); + EXPECT_FALSE(value_seen[value]); + value_seen[value] = true; + } +} + +TEST_P(AllDifferentTest, PerfectMatching) { + Model model; + std::vector vars; + for (int i = 0; i < 4; ++i) { + vars.push_back(model.Add(NewIntegerVariable(0, 10))); + } + IntegerTrail* integer_trail = model.GetOrCreate(); + integer_trail->UpdateInitialDomain(vars[0], Domain::FromValues({3, 9})); + integer_trail->UpdateInitialDomain(vars[1], Domain::FromValues({3, 8})); + integer_trail->UpdateInitialDomain(vars[2], Domain::FromValues({1, 8})); + integer_trail->UpdateInitialDomain(vars[3], Domain(1)); + model.Add(AllDifferent(vars)); + EXPECT_EQ(SatSolver::FEASIBLE, SolveIntegerProblemWithLazyEncoding(&model)); + EXPECT_EQ(1, model.Get(Value(vars[3]))); + EXPECT_EQ(8, model.Get(Value(vars[2]))); + EXPECT_EQ(3, model.Get(Value(vars[1]))); + EXPECT_EQ(9, model.Get(Value(vars[0]))); +} + +TEST_P(AllDifferentTest, EnumerateAllPermutations) { + const int n = 6; + Model model; + std::vector vars; + for (int i = 0; i < n; ++i) { + vars.push_back(model.Add(NewIntegerVariable(0, n - 1))); + } + model.Add(AllDifferent(vars)); + + std::vector> solutions; + while (true) { + const auto status = SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + solutions.emplace_back(n); + for (int i = 0; i < n; ++i) solutions.back()[i] = model.Get(Value(vars[i])); + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // Test that we do have all the permutations (but in a random order). + std::sort(solutions.begin(), solutions.end()); + std::vector expected(n); + std::iota(expected.begin(), expected.end(), 0); + for (int i = 0; i < solutions.size(); ++i) { + EXPECT_EQ(expected, solutions[i]); + if (i + 1 < solutions.size()) { + EXPECT_TRUE(std::next_permutation(expected.begin(), expected.end())); + } else { + // We enumerated all the permutations. + EXPECT_FALSE(std::next_permutation(expected.begin(), expected.end())); + } + } +} + +int Factorial(int n) { return n ? n * Factorial(n - 1) : 1; } + +TEST_P(AllDifferentTest, EnumerateAllInjections) { + const int n = 5; + const int m = n + 2; + Model model; + std::vector vars; + for (int i = 0; i < n; ++i) { + vars.push_back(model.Add(NewIntegerVariable(0, m - 1))); + } + model.Add(AllDifferent(vars)); + + std::vector solution(n); + int num_solutions = 0; + while (true) { + const auto status = SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + for (int i = 0; i < n; i++) solution[i] = model.Get(Value(vars[i])); + std::sort(solution.begin(), solution.end()); + for (int i = 1; i < n; i++) { + EXPECT_LT(solution[i - 1], solution[i]); + } + num_solutions++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + // Count the number of solutions, it should be m!/(m-n)!. + EXPECT_EQ(num_solutions, Factorial(m) / Factorial(m - n)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/circuit_test.cc b/ortools/sat/circuit_test.cc new file mode 100644 index 00000000000..d6d3cf98f48 --- /dev/null +++ b/ortools/sat/circuit_test.cc @@ -0,0 +1,334 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/circuit.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/graph/strongly_connected_components.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +std::function DenseCircuitConstraint( + int num_nodes, bool allow_subcircuit, + bool allow_multiple_subcircuit_through_zero) { + return [=](Model* model) { + std::vector tails; + std::vector heads; + std::vector literals; + for (int tail = 0; tail < num_nodes; ++tail) { + for (int head = 0; head < num_nodes; ++head) { + if (!allow_subcircuit && tail == head) continue; + tails.push_back(tail); + heads.push_back(head); + literals.push_back(Literal(model->Add(NewBooleanVariable()), true)); + } + } + LoadSubcircuitConstraint(num_nodes, tails, heads, literals, model, + allow_multiple_subcircuit_through_zero); + }; +} + +int CountSolutions(Model* model) { + int num_solutions = 0; + while (true) { + const SatSolver::Status status = SolveIntegerProblemWithLazyEncoding(model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + ++num_solutions; + + // Loop to the next solution. + model->Add(ExcludeCurrentSolutionAndBacktrack()); + } + return num_solutions; +} + +int Factorial(int n) { return n ? n * Factorial(n - 1) : 1; } + +TEST(ReindexArcTest, BasicCase) { + const int num_nodes = 1000; + std::vector tails(num_nodes); + std::vector heads(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + tails[i] = 100 * i; + heads[i] = 100 * i; + } + ReindexArcs(&tails, &heads); + for (int i = 0; i < num_nodes; ++i) { + EXPECT_EQ(i, tails[i]); + EXPECT_EQ(i, heads[i]); + } +} + +TEST(ReindexArcTest, NegativeNumbering) { + const int num_nodes = 1000; + std::vector tails(num_nodes); + std::vector heads(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + tails[i] = -100 * i; + heads[i] = -100 * i; + } + ReindexArcs(&tails, &heads); + for (int i = 0; i < num_nodes; ++i) { + EXPECT_EQ(i, tails[num_nodes - 1 - i]); + EXPECT_EQ(i, heads[num_nodes - 1 - i]); + } +} + +TEST(CircuitConstraintTest, NodeWithNoArcsIsUnsat) { + static const int kNumNodes = 2; + Model model; + std::vector tails; + std::vector heads; + std::vector literals; + tails.push_back(0); + heads.push_back(1); + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + LoadSubcircuitConstraint(kNumNodes, tails, heads, literals, &model); + EXPECT_TRUE(model.GetOrCreate()->ModelIsUnsat()); +} + +TEST(CircuitConstraintTest, AllCircuits) { + static const int kNumNodes = 4; + Model model; + model.Add( + DenseCircuitConstraint(kNumNodes, /*allow_subcircuit=*/false, + /*allow_multiple_subcircuit_through_zero=*/false)); + + const int num_solutions = CountSolutions(&model); + EXPECT_EQ(num_solutions, Factorial(kNumNodes - 1)); +} + +TEST(CircuitConstraintTest, AllSubCircuits) { + static const int kNumNodes = 4; + + Model model; + model.Add( + DenseCircuitConstraint(kNumNodes, /*allow_subcircuit=*/true, + /*allow_multiple_subcircuit_through_zero=*/false)); + + const int num_solutions = CountSolutions(&model); + int expected = 1; // No circuit at all. + for (int circuit_size = 2; circuit_size <= kNumNodes; ++circuit_size) { + // The number of circuit of a given size is: + // - n for the first element + // - times (n-1) for the second + // - ... + // - times (n - (circuit_size - 1)) for the last. + // That is n! / (n - circuit_size)!, and like this we count circuit_size + // times the same circuit, so we have to divide by circuit_size in the end. + expected += Factorial(kNumNodes) / + (circuit_size * Factorial(kNumNodes - circuit_size)); + } + EXPECT_EQ(num_solutions, expected); +} + +TEST(CircuitConstraintTest, AllVehiculeRoutes) { + static const int kNumNodes = 4; + Model model; + + model.Add( + DenseCircuitConstraint(kNumNodes, /*allow_subcircuit=*/false, + /*allow_multiple_subcircuit_through_zero=*/true)); + + const int num_solutions = CountSolutions(&model); + int expected = 1; // 3 outgoing arcs from zero. + expected += 2 * 3; // 2 outgoing arcs from zero. 3 pairs, 2 direction. + expected += 6; // full circuit. + EXPECT_EQ(num_solutions, expected); +} + +TEST(CircuitConstraintTest, AllCircuitCoverings) { + // This test counts the number of circuit coverings of the clique on + // num_nodes with num_distinguished distinguished nodes, i.e. graphs that are + // vertex-disjoint circuits where every circuit must contain exactly one + // distinguished node. + // + // When writing n the number of nodes and k the number of distinguished nodes, + // and the number of such coverings T(n, k), we have: + // T(n,1) = (n-1)!, T(k,k) = 1, T(n,k) = (n-1)!/(k-1)! for n >= k >= 1. + // Indeed, we can enumerate canonical representations, e.g. [1]64[2]35, + // by starting with [1][2]...[k], and place every node in turn at its final + // place w.r.t. existing neighbours. To generate the above example, we go + // though [1][2], [1][2]3, [1]4[2]3, [1]4[2]35, [1]64[2]35. + // At the first iteration, there are k choices, then k+1 ... n-1. + for (int num_nodes = 1; num_nodes <= 6; num_nodes++) { + for (int num_distinguished = 1; num_distinguished <= num_nodes; + num_distinguished++) { + Model model; + std::vector distinguished(num_distinguished); + std::iota(distinguished.begin(), distinguished.end(), 0); + std::vector> graph(num_nodes); + std::vector arcs; + for (int i = 0; i < num_nodes; i++) { + graph[i].resize(num_nodes); + for (int j = 0; j < num_nodes; j++) { + const auto var = model.Add(NewBooleanVariable()); + graph[i][j] = Literal(var, true); + arcs.emplace_back(graph[i][j]); + } + if (i >= num_distinguished) { + model.Add(ClauseConstraint({graph[i][i].Negated()})); + } + } + model.Add(ExactlyOnePerRowAndPerColumn(graph)); + model.Add(CircuitCovering(graph, distinguished)); + const int64_t num_solutions = CountSolutions(&model); + EXPECT_EQ(num_solutions * Factorial(num_distinguished - 1), + Factorial(num_nodes - 1)); + } + } +} + +TEST(CircuitConstraintTest, InfeasibleBecauseOfMissingArcs) { + Model model; + std::vector tails; + std::vector heads; + std::vector literals; + for (const auto arcs : + std::vector>{{0, 1}, {1, 1}, {0, 2}, {2, 2}}) { + tails.push_back(arcs.first); + heads.push_back(arcs.second); + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + } + LoadSubcircuitConstraint(3, tails, heads, literals, &model, false); + const SatSolver::Status status = SolveIntegerProblemWithLazyEncoding(&model); + EXPECT_EQ(status, SatSolver::Status::INFEASIBLE); +} + +// The graph look like this with a self-loop at 2. If 2 is not selected +// (self-loop) then there is one solution (0,1,3,0) and (0,3,5,0). Otherwise, +// there is 2 more solutions with 2 inserteed in one of the two routes. +// +// 0 ---> 1 ---> 4 ------------- +// | | ^ | +// | -----> 2* --> 5 ---> 0 +// | ^ ^ +// | | | +// -------------> 3 ------ +// +TEST(CircuitConstraintTest, RouteConstraint) { + Model model; + std::vector tails; + std::vector heads; + std::vector literals; + for (const auto arcs : std::vector>{{0, 1}, + {0, 3}, + {1, 2}, + {1, 4}, + {2, 2}, + {2, 4}, + {2, 5}, + {3, 2}, + {3, 5}, + {4, 0}, + {5, 0}}) { + tails.push_back(arcs.first); + heads.push_back(arcs.second); + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + } + LoadSubcircuitConstraint(6, tails, heads, literals, &model, true); + const int64_t num_solutions = CountSolutions(&model); + EXPECT_EQ(num_solutions, 3); +} + +TEST(NoCyclePropagatorTest, CountAllSolutions) { + // We create a 2 * 2 grid with diagonal arcs. + Model model; + int num_nodes = 0; + const int num_x = 2; + const int num_y = 2; + const auto get_index = [&num_nodes](int x, int y) { + const int index = x * num_y + y; + num_nodes = std::max(num_nodes, index + 1); + return index; + }; + + std::vector tails; + std::vector heads; + std::vector literals; + for (int x = 0; x < num_x; ++x) { + for (int y = 0; y < num_y; ++y) { + for (const int x_dir : {-1, 0, 1}) { + for (const int y_dir : {-1, 0, 1}) { + const int head_x = x + x_dir; + const int head_y = y + y_dir; + if (x_dir == 0 && y_dir == 0) continue; + if (head_x < 0 || head_x >= num_x) continue; + if (head_y < 0 || head_y >= num_y) continue; + tails.push_back(get_index(x, y)); + heads.push_back(get_index(head_x, head_y)); + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + } + } + } + } + model.TakeOwnership( + new NoCyclePropagator(num_nodes, tails, heads, literals, &model)); + + // Graph is small enough. + CHECK_EQ(num_nodes, 4); + CHECK_EQ(tails.size(), 12); + + // Counts solution with brute-force algo. + int num_expected_solutions = 0; + std::vector> subgraph(num_nodes); + std::vector> components; + const int num_cases = 1 << tails.size(); + for (int mask = 0; mask < num_cases; ++mask) { + for (int n = 0; n < num_nodes; ++n) { + subgraph[n].clear(); + } + for (int a = 0; a < tails.size(); ++a) { + if ((1 << a) & mask) { + subgraph[tails[a]].push_back(heads[a]); + } + } + components.clear(); + FindStronglyConnectedComponents(num_nodes, subgraph, &components); + bool has_cycle = false; + for (const std::vector compo : components) { + if (compo.size() > 1) { + has_cycle = true; + break; + } + } + if (!has_cycle) ++num_expected_solutions; + } + EXPECT_EQ(num_expected_solutions, 543); + + // There is 12 arcs. + // So out of 2^12 solution, we have to exclude all the one with cycles. + EXPECT_EQ(CountSolutions(&model), 543); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 1b952d62bb7..8241b153292 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -500,6 +500,7 @@ void ClauseManager::DeleteRemovedClauses() { void BinaryImplicationGraph::Resize(int num_variables) { SCOPED_TIME_STAT(&stats_); + bfs_stack_.resize(num_variables << 1); implications_.resize(num_variables << 1); implies_something_.resize(num_variables << 1); might_have_dups_.resize(num_variables << 1); @@ -948,7 +949,7 @@ void BinaryImplicationGraph::MinimizeConflictFirst( for (const LiteralIndex i : is_marked_.PositionsSetAtLeastOnce()) { // TODO(user): if this is false, then we actually have a conflict of size 2. // This can only happen if the binary clause was not propagated properly - // if for instance we do chronological bactracking without re-enqueing the + // if for instance we do chronological bactracking without re-enqueuing the // consequence of a binary clause. if (trail.Assignment().LiteralIsTrue(Literal(i))) { marked->Set(Literal(i).Variable()); @@ -1534,11 +1535,12 @@ bool BinaryImplicationGraph::ComputeTransitiveReduction(bool log_info) { // Also mark all the ones reachable through the root AMOs. if (root < at_most_ones_.size()) { + auto is_marked = is_marked_.BitsetView(); for (const int start : at_most_ones_[root]) { for (const Literal l : AtMostOne(start)) { if (l.Index() == root) continue; - if (!is_marked_[l.Negated()] && !is_redundant_[l.Negated()]) { - is_marked_.SetUnsafe(l.Negated()); + if (!is_marked[l.Negated()] && !is_redundant_[l.Negated()]) { + is_marked_.SetUnsafe(is_marked, l.Negated()); MarkDescendants(l.Negated()); } } @@ -1805,13 +1807,7 @@ std::vector BinaryImplicationGraph::ExpandAtMostOneWithWeight( const util_intops::StrongVector& expanded_lp_values) { std::vector clique(at_most_one.begin(), at_most_one.end()); std::vector intersection; - double clique_weight = 0.0; const int64_t old_work = work_done_in_mark_descendants_; - if (use_weight) { - for (const Literal l : clique) { - clique_weight += expanded_lp_values[l]; - } - } for (int i = 0; i < clique.size(); ++i) { // Do not spend too much time here. if (work_done_in_mark_descendants_ - old_work > 1e8) break; @@ -1828,26 +1824,15 @@ std::vector BinaryImplicationGraph::ExpandAtMostOneWithWeight( } int new_size = 0; - double intersection_weight = 0.0; is_marked_.Clear(clique[i]); is_marked_.Clear(clique[i].NegatedIndex()); for (const LiteralIndex index : intersection) { if (!is_marked_[index]) continue; intersection[new_size++] = index; - if (use_weight) { - intersection_weight += expanded_lp_values[index]; - } } intersection.resize(new_size); if (intersection.empty()) break; - // We can't generate a violated cut this way. This is because intersection - // contains all the possible ways to extend the current clique. - if (use_weight && clique_weight + intersection_weight <= 1.0) { - clique.clear(); - return clique; - } - // Expand? The negation of any literal in the intersection is a valid way // to extend the clique. if (i + 1 == clique.size()) { @@ -1857,9 +1842,10 @@ std::vector BinaryImplicationGraph::ExpandAtMostOneWithWeight( for (int j = 0; j < intersection.size(); ++j) { // If we don't use weight, we prefer variable that comes first. const double lp = - use_weight ? 1.0 - expanded_lp_values[intersection[j]] + - absl::Uniform(*random_, 0.0, 1e-4) - : can_be_included.size() - intersection[j].value(); + use_weight + ? expanded_lp_values[Literal(intersection[j]).NegatedIndex()] + + absl::Uniform(*random_, 0.0, 1e-4) + : can_be_included.size() - intersection[j].value(); if (index == -1 || lp > max_lp) { index = j; max_lp = lp; @@ -1869,9 +1855,6 @@ std::vector BinaryImplicationGraph::ExpandAtMostOneWithWeight( clique.push_back(Literal(intersection[index]).Negated()); std::swap(intersection.back(), intersection[index]); intersection.pop_back(); - if (use_weight) { - clique_weight += expanded_lp_values[clique.back()]; - } } } } @@ -1890,16 +1873,26 @@ BinaryImplicationGraph::ExpandAtMostOneWithWeight( const util_intops::StrongVector& can_be_included, const util_intops::StrongVector& expanded_lp_values); +// This function and the generated cut serves two purpose: +// 1/ If a new clause of size 2 was discovered and not included in the LP, we +// will add it. +// 2/ The more classical clique cut separation algorithm +// +// Note that once 1/ Is performed, any literal close to 1.0 in the lp shouldn't +// be in the same clique as a literal with positive weight. So for step 2, we +// only really need to consider fractional variables. const std::vector>& BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( - const std::vector& literals, - const std::vector& lp_values) { + absl::Span literals, absl::Span lp_values, + absl::Span reduced_costs) { // We only want to generate a cut with literals from the LP, not extra ones. const int num_literals = implications_.size(); util_intops::StrongVector can_be_included(num_literals, false); util_intops::StrongVector expanded_lp_values( num_literals, 0.0); + util_intops::StrongVector heuristic_weights( + num_literals, 0.0); const int size = literals.size(); for (int i = 0; i < size; ++i) { const Literal l = literals[i]; @@ -1909,6 +1902,23 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( const double value = lp_values[i]; expanded_lp_values[l] = value; expanded_lp_values[l.NegatedIndex()] = 1.0 - value; + + // This is used for extending clique-cuts. + // In most situation, we will only extend the cuts with literal at zero, + // and we prefer "low" reduced cost first, so we negate it. Variable with + // high reduced costs will likely stay that way and are of less interest in + // a clique cut. At least that is my interpretation. + // + // TODO(user): For large problems or when we base the clique from a newly + // added and violated 2-clique, we might consider only a subset of + // fractional variables, so we might need to include fractional variable + // first, but then their rc should be zero, so it should be already kind of + // doing that. + // + // Remark: This seems to have a huge impact to the cut performance! + const double rc = reduced_costs[i]; + heuristic_weights[l] = -rc; + heuristic_weights[l.NegatedIndex()] = rc; } // We want highest sum first. @@ -1926,6 +1936,7 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( // currently still statically add the initial implications, this will only add // cut based on newly learned binary clause. Or the one that were not added // to the relaxation in the first place. + std::vector fractional_literals; for (int i = 0; i < size; ++i) { Literal current_literal = literals[i]; double current_value = lp_values[i]; @@ -1937,6 +1948,10 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( current_value = 1.0 - current_value; } + if (current_value < 0.99) { + fractional_literals.push_back(current_literal); + } + // We consider only one candidate for each current_literal. LiteralIndex best = kNoLiteralIndex; double best_value = 0.0; @@ -1967,12 +1982,101 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( // Expand to a maximal at most one each candidates before returning them. // Note that we only expand using literal from the LP. tmp_cuts_.clear(); - std::vector at_most_one; for (const Candidate& candidate : candidates) { - at_most_one = ExpandAtMostOneWithWeight( - {candidate.a, candidate.b}, can_be_included, expanded_lp_values); - if (!at_most_one.empty()) tmp_cuts_.push_back(at_most_one); + tmp_cuts_.push_back(ExpandAtMostOneWithWeight( + {candidate.a, candidate.b}, can_be_included, heuristic_weights)); } + + // Once we processed new implications, also add "proper" clique cuts. + // We can generate a small graph and separate cut efficiently there. + if (fractional_literals.size() > 1) { + // Lets permute this randomly and truncate if we have too many variables. + // Since we use bitset it is good to have a multiple of 64 there. + // + // TODO(user): Prefer more fractional variables. + const int max_graph_size = 1024; + if (fractional_literals.size() > max_graph_size) { + std::shuffle(fractional_literals.begin(), fractional_literals.end(), + *random_); + fractional_literals.resize(max_graph_size); + } + + bron_kerbosch_.Initialize(fractional_literals.size() * 2); + + // Prepare a dense mapping. + int i = 0; + tmp_mapping_.resize(implications_.size(), -1); + for (const Literal l : fractional_literals) { + bron_kerbosch_.SetWeight(i, expanded_lp_values[l]); + tmp_mapping_[l] = i++; + bron_kerbosch_.SetWeight(i, expanded_lp_values[l.Negated()]); + tmp_mapping_[l.Negated()] = i++; + } + + // Copy the implication subgraph and remap it to a dense indexing. + // + // TODO(user): Treat at_most_one more efficiently. We can collect them + // and scan each of them just once. + for (const Literal base : fractional_literals) { + for (const Literal l : {base, base.Negated()}) { + const int from = tmp_mapping_[l]; + for (const Literal next : DirectImplications(l)) { + // l => next so (l + not(next) <= 1). + const int to = tmp_mapping_[next.Negated()]; + if (to != -1) { + bron_kerbosch_.AddEdge(from, to); + } + } + } + } + + // Before running the algo, compute the transitive closure. + // The graph shouldn't be too large, so this should be fast enough. + bron_kerbosch_.TakeTransitiveClosureOfImplicationGraph(); + + bron_kerbosch_.SetWorkLimit(1e8); + bron_kerbosch_.SetMinimumWeight(1.001); + std::vector> cliques = bron_kerbosch_.Run(); + + // If we have many candidates, we will only expand the first few with + // maximum weights. + const int max_num_per_batch = 5; + std::vector> with_weight = + bron_kerbosch_.GetMutableIndexAndWeight(); + if (with_weight.size() > max_num_per_batch) { + std::sort( + with_weight.begin(), with_weight.end(), + [](const std::pair& a, const std::pair& b) { + return a.second > b.second; + }); + with_weight.resize(max_num_per_batch); + } + + std::vector at_most_one; + for (const auto [index, weight] : with_weight) { + // Convert. + at_most_one.clear(); + for (const int i : cliques[index]) { + const Literal l = fractional_literals[i / 2]; + at_most_one.push_back(i % 2 == 1 ? l.Negated() : l); + } + + // Expand and add clique. + // + // TODO(user): Expansion is pretty slow. Given that the base clique can + // share literal beeing part of the same amo, we should be able to speed + // that up, we don't want to scan an amo twice basically. + tmp_cuts_.push_back(ExpandAtMostOneWithWeight( + at_most_one, can_be_included, heuristic_weights)); + } + + // Clear the dense mapping + for (const Literal l : fractional_literals) { + tmp_mapping_[l] = -1; + tmp_mapping_[l.Negated()] = -1; + } + } + return tmp_cuts_; } @@ -2053,23 +2157,23 @@ BinaryImplicationGraph::HeuristicAmoPartition(std::vector* literals) { } void BinaryImplicationGraph::MarkDescendants(Literal root) { - bfs_stack_.resize(implications_.size()); auto* const stack = bfs_stack_.data(); - const int amo_size = static_cast(at_most_ones_.size()); - auto is_marked = is_marked_.const_view(); + auto is_marked = is_marked_.BitsetView(); auto is_redundant = is_redundant_.const_view(); if (is_redundant[root]) return; int stack_size = 1; stack[0] = root; is_marked_.Set(root); + const int amo_size = static_cast(at_most_ones_.size()); + auto implies_something = implies_something_.const_view(); for (int j = 0; j < stack_size; ++j) { const Literal current = stack[j]; - if (!implies_something_[current]) continue; + if (!implies_something[current]) continue; for (const Literal l : implications_[current]) { if (!is_marked[l] && !is_redundant[l]) { - is_marked_.SetUnsafe(l); + is_marked_.SetUnsafe(is_marked, l); stack[stack_size++] = l; } } @@ -2079,7 +2183,7 @@ void BinaryImplicationGraph::MarkDescendants(Literal root) { for (const Literal l : AtMostOne(start)) { if (l == current) continue; if (!is_marked[l.NegatedIndex()] && !is_redundant[l.NegatedIndex()]) { - is_marked_.SetUnsafe(l.NegatedIndex()); + is_marked_.SetUnsafe(is_marked, l.NegatedIndex()); stack[stack_size++] = l.Negated(); } } @@ -2094,8 +2198,8 @@ std::vector BinaryImplicationGraph::ExpandAtMostOne( std::vector clique(at_most_one.begin(), at_most_one.end()); // Optim. - for (int i = 0; i < clique.size(); ++i) { - if (implications_[clique[i]].empty() || is_redundant_[clique[i]]) { + for (const Literal l : clique) { + if (implications_[l].empty() || is_redundant_[l]) { return clique; } } diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index 4a30915cdce..3f1c0c8f84c 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -33,6 +33,7 @@ #include "absl/random/bit_gen_ref.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/graph/cliques.h" #include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -652,8 +653,8 @@ class BinaryImplicationGraph : public SatPropagator { // // TODO(user): Refine the heuristic and unit test! const std::vector>& GenerateAtMostOnesWithLargeWeight( - const std::vector& literals, - const std::vector& lp_values); + absl::Span literals, absl::Span lp_values, + absl::Span reduced_costs); // Heuristically identify "at most one" between the given literals, swap // them around and return these amo as span inside the literals vector. @@ -920,6 +921,10 @@ class BinaryImplicationGraph : public SatPropagator { int64_t work_done_in_mark_descendants_ = 0; std::vector bfs_stack_; + // For clique cuts. + util_intops::StrongVector tmp_mapping_; + WeightedBronKerboschBitsetAlgorithm bron_kerbosch_; + // Used by ComputeTransitiveReduction() in case we abort early to maintain // the invariant checked by InvariantsAreOk(). Some of our algo // relies on this to be always true. diff --git a/ortools/sat/cp_constraints_test.cc b/ortools/sat/cp_constraints_test.cc new file mode 100644 index 00000000000..bceee075ad2 --- /dev/null +++ b/ortools/sat/cp_constraints_test.cc @@ -0,0 +1,120 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_constraints.h" + +#include + +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(LiteralXorIsTest, OneVariable) { + Model model; + const BooleanVariable a = model.Add(NewBooleanVariable()); + const BooleanVariable b = model.Add(NewBooleanVariable()); + model.Add(LiteralXorIs({Literal(a, true)}, true)); + model.Add(LiteralXorIs({Literal(b, true)}, false)); + SatSolver* solver = model.GetOrCreate(); + EXPECT_TRUE(solver->Propagate()); + EXPECT_TRUE(solver->Assignment().LiteralIsTrue(Literal(a, true))); + EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(b, true))); +} + +// A simple macro to make the code more readable. +#define EXPECT_BOUNDS_EQ(var, lb, ub) \ + EXPECT_EQ(model.Get(LowerBound(var)), lb); \ + EXPECT_EQ(model.Get(UpperBound(var)), ub) + +TEST(PartialIsOneOfVarTest, MinMaxPropagation) { + Model model; + const IntegerVariable target_var = model.Add(NewIntegerVariable(-10, 20)); + std::vector vars; + std::vector selectors; + for (int i = 0; i < 10; ++i) { + vars.push_back(model.Add(ConstantIntegerVariable(i))); + selectors.push_back(Literal(model.Add(NewBooleanVariable()), true)); + } + model.Add(PartialIsOneOfVar(target_var, vars, selectors)); + + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_BOUNDS_EQ(target_var, 0, 9); + + model.Add(ClauseConstraint({selectors[0].Negated()})); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_BOUNDS_EQ(target_var, 1, 9); + + model.Add(ClauseConstraint({selectors[8].Negated()})); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_BOUNDS_EQ(target_var, 1, 9); + + model.Add(ClauseConstraint({selectors[9].Negated()})); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_BOUNDS_EQ(target_var, 1, 7); +} + +TEST(GreaterThanAtLeastOneOfPropagatorTest, BasicTest) { + for (int i = 0; i < 2; ++i) { + Model model; + + // We create a simple model with 3 variables and 2 conditional precedences. + // We only add the GreaterThanAtLeastOneOfPropagator() for i == 1. + const IntegerVariable a = model.Add(NewIntegerVariable(0, 3)); + const IntegerVariable b = model.Add(NewIntegerVariable(0, 3)); + const IntegerVariable c = model.Add(NewIntegerVariable(0, 3)); + const Literal ac = Literal(model.Add(NewBooleanVariable()), true); + const Literal bc = Literal(model.Add(NewBooleanVariable()), true); + model.Add(ConditionalLowerOrEqualWithOffset(a, c, 3, ac)); + model.Add(ConditionalLowerOrEqualWithOffset(b, c, 2, bc)); + model.Add(ClauseConstraint({ac, bc})); + if (i == 1) { + model.Add(GreaterThanAtLeastOneOf( + c, {a, b}, {IntegerValue(3), IntegerValue(2)}, {ac, bc}, {})); + } + + // Test that we do propagate more with the extra propagator. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_EQ(model.Get(LowerBound(c)), i == 0 ? 0 : 2); + + // Test that we find all solutions. + int num_solutions = 0; + while (true) { + const auto status = SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + ++num_solutions; + VLOG(1) << model.Get(Value(a)) << " " << model.Get(Value(b)) << " " + << model.Get(Value(c)); + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + EXPECT_EQ(num_solutions, 18); + } +} + +#undef EXPECT_BOUNDS_EQ + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model.cc b/ortools/sat/cp_model.cc index 94e097d4f8f..392db732eaa 100644 --- a/ortools/sat/cp_model.cc +++ b/ortools/sat/cp_model.cc @@ -763,9 +763,8 @@ void CpModelBuilder::FixVariable(BoolVar var, bool value) { Constraint CpModelBuilder::AddBoolOr(absl::Span literals) { ConstraintProto* const proto = cp_model_.add_constraints(); - for (const BoolVar& lit : literals) { - proto->mutable_bool_or()->add_literals(lit.index_); - } + BoolArgumentProto* const bool_or = proto->mutable_bool_or(); + for (const BoolVar& lit : literals) bool_or->add_literals(lit.index_); return Constraint(proto); } @@ -783,9 +782,8 @@ Constraint CpModelBuilder::AddAtMostOne(absl::Span literals) { Constraint CpModelBuilder::AddExactlyOne(absl::Span literals) { ConstraintProto* const proto = cp_model_.add_constraints(); - for (const BoolVar& lit : literals) { - proto->mutable_exactly_one()->add_literals(lit.index_); - } + BoolArgumentProto* const exactly_one = proto->mutable_exactly_one(); + for (const BoolVar& lit : literals) exactly_one->add_literals(lit.index_); return Constraint(proto); } diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index b76dc95aa1a..b57d2056ba1 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -18,6 +18,7 @@ syntax = "proto3"; package operations_research.sat; option csharp_namespace = "Google.OrTools.Sat"; +option go_package = "github.com/google/or-tools/ortools/sat/proto/cpmodel"; option java_package = "com.google.ortools.sat"; option java_multiple_files = true; option java_outer_classname = "CpModelProtobuf"; diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 53d0b3f60e2..e9cceb9d4dd 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -29,6 +29,7 @@ #include "absl/log/check.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "google/protobuf/message.h" #include "ortools/base/logging.h" #include "ortools/base/stl_util.h" @@ -126,7 +127,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, { circuit->add_tails(num_events); circuit->add_heads(num_events); - circuit->add_literals(context->NewBoolVar()); + circuit->add_literals(context->NewBoolVar("reservoir expansion")); } for (int i = 0; i < num_events; ++i) { @@ -141,7 +142,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // We use the available index 'num_events'. { // Circuit starts at i, level_vars[i] == demand_expr[i]. - const int start_var = context->NewBoolVar(); + const int start_var = context->NewBoolVar("reservoir expansion"); circuit->add_tails(num_events); circuit->add_heads(i); circuit->add_literals(start_var); @@ -163,7 +164,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // Circuit ends at i, no extra constraint there. circuit->add_tails(i); circuit->add_heads(num_events); - circuit->add_literals(context->NewBoolVar()); + circuit->add_literals(context->NewBoolVar("reservoir expansion")); } for (int j = 0; j < num_events; ++j) { @@ -179,7 +180,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // reservoir except if the set of time point is exactly the same! // otherwise if we miss one, then A "after" B in one circuit do not // implies that there is no C in between in another! - const int arc_i_j = context->NewBoolVar(); + const int arc_i_j = context->NewBoolVar("reservoir expansion"); circuit->add_tails(i); circuit->add_heads(j); circuit->add_literals(arc_i_j); @@ -755,13 +756,13 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { std::vector enforcement_literals; enforcement_literals.reserve(num_exprs); if (num_exprs == 2) { - const int new_bool = context->NewBoolVar(); + const int new_bool = context->NewBoolVar("lin max expansion"); enforcement_literals.push_back(new_bool); enforcement_literals.push_back(NegatedRef(new_bool)); } else { ConstraintProto* exactly_one = context->working_model->add_constraints(); for (int i = 0; i < num_exprs; ++i) { - const int new_bool = context->NewBoolVar(); + const int new_bool = context->NewBoolVar("lin max expansion"); exactly_one->mutable_exactly_one()->add_literals(new_bool); enforcement_literals.push_back(new_bool); } @@ -1194,7 +1195,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { out_encoding.clear(); if (states.size() == 2) { - const int var = context->NewBoolVar(); + const int var = context->NewBoolVar("automaton expansion"); out_encoding[states[0]] = var; out_encoding[states[1]] = NegatedRef(var); } else if (states.size() > 2) { @@ -1243,7 +1244,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { } } - out_encoding[state] = context->NewBoolVar(); + out_encoding[state] = context->NewBoolVar("automaton expansion"); } } } @@ -1302,7 +1303,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { // expand this small table with 3 columns (i.e. compress, negate, etc...). std::vector tuple_literals; if (num_tuples == 2) { - const int bool_var = context->NewBoolVar(); + const int bool_var = context->NewBoolVar("automaton expansion"); tuple_literals.push_back(bool_var); tuple_literals.push_back(NegatedRef(bool_var)); } else { @@ -1320,7 +1321,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { } else if (out_count[out_states[i]] == 1 && !out_encoding.empty()) { tuple_literal = out_encoding[out_states[i]]; } else { - tuple_literal = context->NewBoolVar(); + tuple_literal = context->NewBoolVar("automaton expansion"); } tuple_literals.push_back(tuple_literal); @@ -1818,7 +1819,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, if (ct->enforcement_literal().size() == 1) { table_is_active_literal = ct->enforcement_literal(0); } else if (ct->enforcement_literal().size() > 1) { - table_is_active_literal = context->NewBoolVar(); + table_is_active_literal = context->NewBoolVar("table expansion"); // Adds table_is_active <=> and(enforcement_literals). BoolArgumentProto* bool_or = @@ -1850,7 +1851,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, break; } if (create_new_var) { - tuple_literals[i] = context->NewBoolVar(); + tuple_literals[i] = context->NewBoolVar("table expansion"); } exactly_one->add_literals(tuple_literals[i]); } @@ -2134,6 +2135,36 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, if (ct->linear().domain().size() <= 2) return; if (ct->linear().vars().size() == 1) return; + // If we have a hint for all variables of this linear constraint, finds in + // which bucket it fall. + int hint_bucket = -1; + bool set_hint_of_bucket_variables = false; + if (context->HintIsLoaded()) { + set_hint_of_bucket_variables = true; + int64_t hint_activity = 0; + const int num_terms = ct->linear().vars().size(); + const absl::Span hint = context->SolutionHint(); + for (int i = 0; i < num_terms; ++i) { + const int var = ct->linear().vars(i); + DCHECK_LT(var, hint.size()); + if (!context->VarHasSolutionHint(var)) { + set_hint_of_bucket_variables = false; + break; + } + hint_activity += ct->linear().coeffs(i) * hint[var]; + } + if (set_hint_of_bucket_variables) { + for (int i = 0; i < ct->linear().domain_size(); i += 2) { + const int64_t lb = ct->linear().domain(i); + const int64_t ub = ct->linear().domain(i + 1); + if (hint_activity >= lb && hint_activity <= ub) { + hint_bucket = i; + break; + } + } + } + } + const SatParameters& params = context->params(); if (params.encode_complex_linear_constraint_with_integer()) { // Integer encoding. @@ -2155,7 +2186,10 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, if (ct->enforcement_literal().empty() && ct->linear().domain_size() == 4) { // We cover the special case of no enforcement and two choices by creating // a single Boolean. - single_bool = context->NewBoolVar(); + single_bool = context->NewBoolVar("complex linear expansion"); + if (set_hint_of_bucket_variables) { + context->SetNewVariableHint(single_bool, hint_bucket == 0); + } } else { clause = context->working_model->add_constraints()->mutable_bool_or(); for (const int ref : ct->enforcement_literal()) { @@ -2173,7 +2207,10 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, int subdomain_literal; if (clause != nullptr) { - subdomain_literal = context->NewBoolVar(); + subdomain_literal = context->NewBoolVar("complex linear expansion"); + if (set_hint_of_bucket_variables) { + context->SetNewVariableHint(subdomain_literal, hint_bucket == i); + } clause->add_literals(subdomain_literal); domain_literals.push_back(subdomain_literal); } else { @@ -2196,7 +2233,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, if (enforcement_literals.size() == 1) { linear_is_enforced = enforcement_literals[0]; } else { - linear_is_enforced = context->NewBoolVar(); + linear_is_enforced = context->NewBoolVar("complex linear expansion"); BoolArgumentProto* maintain_linear_is_enforced = context->working_model->add_constraints()->mutable_bool_or(); for (const int e_lit : enforcement_literals) { diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index f4d3a3728b5..5cd81bfc2c2 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -1212,6 +1211,7 @@ void SplitAndLoadIntermediateConstraints(bool lb_required, bool ub_required, void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); + if (ct.linear().vars().empty()) { const Domain rhs = ReadDomainFromProto(ct.linear()); if (rhs.Contains(0)) return; @@ -1429,22 +1429,27 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { // We have a linear with a complex Domain, we need to create extra Booleans. - // In this case, we can create just one Boolean instead of two since one - // is the negation of the other. - const bool special_case = - ct.enforcement_literal().empty() && ct.linear().domain_size() == 4; - // For enforcement => var \in domain, we can potentially reuse the encoding // literal directly rather than creating new ones. - const bool is_linear1 = !special_case && vars.size() == 1 && coeffs[0] == 1; + const bool is_linear1 = vars.size() == 1 && coeffs[0] == 1; + bool special_case = false; std::vector clause; std::vector for_enumeration; auto* encoding = m->GetOrCreate(); - for (int i = 0; i < ct.linear().domain_size(); i += 2) { + const int domain_size = ct.linear().domain_size(); + for (int i = 0; i < domain_size; i += 2) { const int64_t lb = ct.linear().domain(i); const int64_t ub = ct.linear().domain(i + 1); + // Skip non-reachable intervals. + if (min_sum > ub) continue; + if (max_sum < lb) continue; + + // Skip trivial constraint. Note that when this happens, all the intervals + // before where non-reachable. + if (min_sum >= lb && max_sum <= ub) return; + if (is_linear1) { if (lb == ub) { clause.push_back( @@ -1461,9 +1466,17 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { } } + // If there is just two terms and no enforcement, we don't need to create an + // extra boolean as the second case can be controlled by the negation of the + // first. + if (ct.enforcement_literal().empty() && clause.size() == 1 && + i + 1 == domain_size) { + special_case = true; + } + const Literal subdomain_literal( - special_case && i > 0 ? clause.back().Negated() - : Literal(m->Add(NewBooleanVariable()), true)); + special_case ? clause.back().Negated() + : Literal(m->Add(NewBooleanVariable()), true)); clause.push_back(subdomain_literal); for_enumeration.push_back(subdomain_literal); diff --git a/ortools/sat/cp_model_mapping.h b/ortools/sat/cp_model_mapping.h index 58a6849f74e..530a0a21b55 100644 --- a/ortools/sat/cp_model_mapping.h +++ b/ortools/sat/cp_model_mapping.h @@ -172,6 +172,13 @@ class CpModelMapping { return reverse_integer_map_[var]; } + // This one should only be used when we have a mapping. + int GetProtoLiteralFromLiteral(sat::Literal lit) const { + const int proto_var = GetProtoVariableFromBooleanVariable(lit.Variable()); + DCHECK_NE(proto_var, -1); + return lit.IsPositive() ? proto_var : NegatedRef(proto_var); + } + const std::vector& GetVariableMapping() const { return integers_; } diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 754fcc16713..4b6720e1064 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -2192,7 +2192,7 @@ bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) { if (ct->enforcement_literal().size() == 1) { indicator = ct->enforcement_literal(0); } else { - indicator = context_->NewBoolVar(); + indicator = context_->NewBoolVar("indicator"); auto* new_ct = context_->working_model->add_constraints(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->mutable_bool_or()->add_literals(indicator); @@ -2508,7 +2508,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { context_->UpdateRuleStats("linear1: infeasible"); return MarkConstraintAsFalse(ct); } - if (rhs == context_->DomainOf(var)) { + if (rhs == var_domain) { context_->UpdateRuleStats("linear1: always true"); return RemoveConstraint(ct); } @@ -2544,16 +2544,28 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } // Detect encoding. + bool changed = false; if (ct->enforcement_literal().size() == 1) { // If we already have an encoding literal, this constraint is really // an implication. - const int lit = ct->enforcement_literal(0); + int lit = ct->enforcement_literal(0); + + // For correctness below, it is important lit is the canonical literal, + // otherwise we might remove the constraint even though it is the one + // defining an encoding literal. + const int representative = context_->GetLiteralRepresentative(lit); + if (lit != representative) { + lit = representative; + ct->set_enforcement_literal(0, lit); + context_->UpdateRuleStats("linear1: remapped enforcement literal"); + changed = true; + } if (rhs.IsFixed()) { const int64_t value = rhs.FixedValue(); int encoding_lit; if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { - if (lit == encoding_lit) return false; + if (lit == encoding_lit) return changed; context_->AddImplication(lit, encoding_lit); context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); @@ -2567,7 +2579,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } context_->UpdateNewConstraintsVariableUsage(); } - return false; + return changed; } const Domain complement = rhs.Complement().IntersectionWith(var_domain); @@ -2575,7 +2587,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { const int64_t value = complement.FixedValue(); int encoding_lit; if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { - if (NegatedRef(lit) == encoding_lit) return false; + if (NegatedRef(lit) == encoding_lit) return changed; context_->AddImplication(lit, NegatedRef(encoding_lit)); context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); @@ -2589,11 +2601,11 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } context_->UpdateNewConstraintsVariableUsage(); } - return false; + return changed; } } - return false; + return changed; } bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { @@ -7110,9 +7122,6 @@ void CpModelPresolver::Probe() { } probing_timer->AddCounter("fixed_bools", num_fixed); - DetectDuplicateConstraintsWithDifferentEnforcements( - mapping, implication_graph, model.GetOrCreate()); - int num_equiv = 0; int num_changed_bounds = 0; const int num_variables = context_->working_model->variables().size(); @@ -7148,6 +7157,12 @@ void CpModelPresolver::Probe() { probing_timer->AddCounter("new_binary_clauses", prober->num_new_binary_clauses()); + // Note that we prefer to run this after we exported all equivalence to the + // context, so that our enforcement list can be presolved to the best of our + // knowledge. + DetectDuplicateConstraintsWithDifferentEnforcements( + mapping, implication_graph, model.GetOrCreate()); + // Stop probing timer now and display info. probing_timer.reset(); @@ -8758,7 +8773,7 @@ bool CpModelPresolver::ProcessEncodingFromLinear( // All false means associated_lit is false too. // But not for the rhs case if we are not in exactly one. if (in_exactly_one || value != rhs) { - // TODO(user): Insted of bool_or + implications, we could add an + // TODO(user): Instead of bool_or + implications, we could add an // exactly one! Experiment with this. In particular it might capture // more structure for later heuristic to add the exactly one instead. // This also applies to automata/table/element expansion. @@ -8888,37 +8903,20 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( for (const auto& [dup, rep] : duplicates_without_enforcement) { auto* dup_ct = context_->working_model->mutable_constraints(dup); auto* rep_ct = context_->working_model->mutable_constraints(rep); - if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { - continue; + + // Make sure our enforcement list are up to date: nothing fixed and that + // its uses the literal representatives. + if (PresolveEnforcementLiteral(dup_ct)) { + context_->UpdateConstraintVariableUsage(dup); + } + if (PresolveEnforcementLiteral(rep_ct)) { + context_->UpdateConstraintVariableUsage(rep); } - // If we have a trail, we can check if any variable of the enforcement is - // fixed to false. This is useful for what follows since calling - // implication_graph->DirectImplications() is invalid for fixed variables. - if (trail != nullptr) { - bool found_false_enforcement = false; - for (const int c : {dup, rep}) { - for (const int l : - context_->working_model->constraints(c).enforcement_literal()) { - if (trail->Assignment().LiteralIsFalse(mapping->Literal(l))) { - found_false_enforcement = true; - break; - } - } - if (found_false_enforcement) { - context_->UpdateRuleStats("enforcement: false literal"); - if (c == rep) { - rep_ct->Swap(dup_ct); - context_->UpdateConstraintVariableUsage(rep); - } - dup_ct->Clear(); - context_->UpdateConstraintVariableUsage(dup); - break; - } - } - if (found_false_enforcement) { - continue; - } + // Skip this pair if one of the constraint was simplified + if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET || + dup_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { + continue; } // If one of them has no enforcement, then the other can be ignored. @@ -8936,10 +8934,7 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // Special case. This looks specific but users might reify with a cost // a duplicate constraint. In this case, no need to have two variables, // we can make them equal by duality argument. - const int a = rep_ct->enforcement_literal(0); - const int b = dup_ct->enforcement_literal(0); - if (context_->IsFixed(a) || context_->IsFixed(b)) continue; - + // // TODO(user): Deal with more general situation? Note that we already // do something similar in dual_bound_strengthening.Strengthen() were we // are more general as we just require an unique blocking constraint rather @@ -8949,6 +8944,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // we can also add the equality. Alternatively, we can just introduce a new // variable and merge all duplicate constraint into 1 + bunch of boolean // constraints liking enforcements. + const int a = rep_ct->enforcement_literal(0); + const int b = dup_ct->enforcement_literal(0); if (context_->VariableWithCostIsUniqueAndRemovable(a) && context_->VariableWithCostIsUniqueAndRemovable(b)) { // Both these case should be presolved before, but it is easy to deal with @@ -9007,19 +9004,19 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // B, then constraint A is redundant and we can remove it. const int c_a = i == 0 ? dup : rep; const int c_b = i == 0 ? rep : dup; + const auto& ct_a = context_->working_model->constraints(c_a); + const auto& ct_b = context_->working_model->constraints(c_b); enforcement_vars.clear(); implications_used.clear(); - for (const int proto_lit : - context_->working_model->constraints(c_b).enforcement_literal()) { + for (const int proto_lit : ct_b.enforcement_literal()) { const Literal lit = mapping->Literal(proto_lit); - if (trail->Assignment().LiteralIsTrue(lit)) continue; + DCHECK(!trail->Assignment().LiteralIsAssigned(lit)); enforcement_vars.insert(lit); } - for (const int proto_lit : - context_->working_model->constraints(c_a).enforcement_literal()) { + for (const int proto_lit : ct_a.enforcement_literal()) { const Literal lit = mapping->Literal(proto_lit); - if (trail->Assignment().LiteralIsTrue(lit)) continue; + DCHECK(!trail->Assignment().LiteralIsAssigned(lit)); for (const Literal implication_lit : implication_graph->DirectImplications(lit)) { auto extracted = enforcement_vars.extract(implication_lit); @@ -9029,6 +9026,71 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( } } if (enforcement_vars.empty()) { + // Tricky: Because we keep track of literal <=> var == value, we + // cannot easily simplify linear1 here. This is because a scenario + // like this can happen: + // + // We have registered the fact that a <=> X=1 because we saw two + // constraints a => X=1 and not(a) => X!= 1 + // + // Now, we are here and we have: + // a => X=1, b => X=1, a => b + // So we rewrite this as + // a => b, b => X=1 + // + // But later, the PresolveLinearOfSizeOne() see + // b => X=1 and just rewrite this as b => a since (a <=> X=1). + // This is wrong because the constraint "b => X=1" is needed for the + // equivalence (a <=> X=1), but we lost that fact. + // + // Note(user): In the scenario above we can see that a <=> b, and if + // we know that fact, then the transformation is correctly handled. + // The bug was triggered when the Probing finished early due to time + // limit and we never detected that equivalence. + // + // TODO(user): Try to find a cleaner way to handle this. We could + // query our HasVarValueEncoding() directly here and directly detect a + // <=> b. However we also need to figure the case of + // half-implications. + { + if (ct_a.constraint_case() == ConstraintProto::kLinear && + ct_a.linear().vars().size() == 1 && + ct_a.enforcement_literal().size() == 1) { + const int var = ct_a.linear().vars(0); + const Domain var_domain = context_->DomainOf(var); + const Domain rhs = + ReadDomainFromProto(ct_a.linear()) + .InverseMultiplicationBy(ct_a.linear().coeffs(0)) + .IntersectionWith(var_domain); + + // IsFixed() do not work on empty domain. + if (rhs.IsEmpty()) { + context_->UpdateRuleStats("duplicate: linear1 infeasible"); + if (!MarkConstraintAsFalse(rep_ct)) return; + if (!MarkConstraintAsFalse(dup_ct)) return; + context_->UpdateConstraintVariableUsage(rep); + context_->UpdateConstraintVariableUsage(dup); + continue; + } + if (rhs == var_domain) { + context_->UpdateRuleStats("duplicate: linear1 always true"); + rep_ct->Clear(); + dup_ct->Clear(); + context_->UpdateConstraintVariableUsage(rep); + context_->UpdateConstraintVariableUsage(dup); + continue; + } + + // We skip if it is a var == value or var != value constraint. + if (rhs.IsFixed() || + rhs.Complement().IntersectionWith(var_domain).IsFixed()) { + context_->UpdateRuleStats( + "TODO duplicate: skipped identical encoding constraints"); + continue; + } + } + } + context_->UpdateRuleStats( "duplicate: identical constraint with implied enforcements"); if (c_a == rep) { @@ -9043,12 +9105,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // graph. This is because in some case the implications are only true // in the presence of the "duplicated" constraints. for (const auto& [a, b] : implications_used) { - const int var_a = - mapping->GetProtoVariableFromBooleanVariable(a.Variable()); - const int proto_lit_a = a.IsPositive() ? var_a : NegatedRef(var_a); - const int var_b = - mapping->GetProtoVariableFromBooleanVariable(b.Variable()); - const int proto_lit_b = b.IsPositive() ? var_b : NegatedRef(var_b); + const int proto_lit_a = mapping->GetProtoLiteralFromLiteral(a); + const int proto_lit_b = mapping->GetProtoLiteralFromLiteral(b); context_->AddImplication(proto_lit_a, proto_lit_b); } context_->UpdateNewConstraintsVariableUsage(); @@ -12910,6 +12968,26 @@ CpSolverStatus CpModelPresolver::Presolve() { context_->WriteObjectiveToProto(); } + // Now that everything that could possibly be fixed was fixed, make sure we + // don't leave any linear constraint with fixed variables. + for (int c = 0; c < context_->working_model->constraints_size(); ++c) { + ConstraintProto& ct = *context_->working_model->mutable_constraints(c); + bool need_canonicalize = false; + if (ct.constraint_case() == ConstraintProto::kLinear) { + for (const int v : ct.linear().vars()) { + if (context_->IsFixed(v)) { + need_canonicalize = true; + break; + } + } + } + if (need_canonicalize) { + if (CanonicalizeLinear(&ct)) { + context_->UpdateConstraintVariableUsage(c); + } + } + } + // Take care of linear constraint with a complex rhs. FinalExpansionForLinearConstraint(context_); diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 03df0ebc0d1..f47ec0aa4da 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -118,6 +118,10 @@ ABSL_FLAG(bool, cp_model_ignore_hints, false, "If true, ignore any supplied hints."); ABSL_FLAG(bool, cp_model_fingerprint_model, true, "Fingerprint the model."); +ABSL_FLAG(bool, cp_model_check_intermediate_solutions, false, + "When true, all intermediate solutions found by the solver will be " + "checked. This can be expensive, therefore it is off by default."); + namespace operations_research { namespace sat { @@ -1611,7 +1615,9 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { helper, name_filter.LastName()), lns_params, helper, shared)); } - if (params.use_lb_relax_lns() && name_filter.Keep("lb_relax_lns")) { + if (params.use_lb_relax_lns() && + params.num_workers() >= params.lb_relax_num_workers_threshold() && + name_filter.Keep("lb_relax_lns")) { reentrant_interleaved_subsolvers.push_back(std::make_unique( std::make_unique( helper, name_filter.LastName(), shared->time_limit), @@ -1727,14 +1733,14 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { // Compared to LNS, these are not re-entrant, so we need to schedule the // correct number for parallelism. if (shared->model_proto.has_objective()) { - // If not forced by the parameters, we want one LS every two threads that + // If not forced by the parameters, we want one LS every 3 threads that // work on interleaved stuff. Note that by default they are many LNS, so // that shouldn't be too many. const int num_thread_for_interleaved_workers = params.num_workers() - full_worker_subsolvers.size(); int num_violation_ls = params.has_num_violation_ls() ? params.num_violation_ls() - : (num_thread_for_interleaved_workers + 1) / 2; + : (num_thread_for_interleaved_workers + 2) / 3; // If there is no rentrant solver, maybe increase the number to reach max // parallelism. @@ -1749,7 +1755,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { const absl::string_view lin_ls_name = "ls_lin"; const int num_ls_lin = - name_filter.Keep(lin_ls_name) ? num_violation_ls / 3 : 0; + name_filter.Keep(lin_ls_name) ? (num_violation_ls + 1) / 3 : 0; const int num_ls_default = name_filter.Keep(ls_name) ? num_violation_ls - num_ls_lin : 0; @@ -2405,24 +2411,41 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // We either check all solutions, or only the last one. // Checking all solution might be expensive if we creates many. auto check_solution = [&model_proto, ¶ms, mapping_proto, - &postsolve_mapping](CpSolverResponse* response) { - if (response->solution().empty()) return; + &postsolve_mapping](const CpSolverResponse& response) { + if (response.solution().empty()) return; + + bool solution_is_feasible = true; if (params.cp_model_presolve()) { // We pass presolve data for more informative message in case the solution // is not feasible. - CHECK(SolutionIsFeasible(model_proto, response->solution(), mapping_proto, - &postsolve_mapping)); + solution_is_feasible = SolutionIsFeasible( + model_proto, response.solution(), mapping_proto, &postsolve_mapping); } else { - CHECK(SolutionIsFeasible(model_proto, response->solution())); + solution_is_feasible = + SolutionIsFeasible(model_proto, response.solution()); + } + + // We dump the response when infeasible, this might help debugging. + if (!solution_is_feasible) { + const std::string file = absl::StrCat( + absl::GetFlag(FLAGS_cp_model_dump_prefix), "wrong_response.pb.txt"); + LOG(INFO) << "Dumping infeasible response proto to '" << file << "'."; + CHECK(WriteModelProtoToFile(response, file)); + + // Crash. + LOG(FATAL) << "Infeasible solution!" + << " source': " << response.solution_info() << "'" + << " dumped CpSolverResponse to '" << file << "'."; } }; if (DEBUG_MODE || absl::GetFlag(FLAGS_cp_model_check_intermediate_solutions)) { - shared_response_manager->AddResponsePostprocessor( - std::move(check_solution)); + shared_response_manager->AddSolutionCallback(std::move(check_solution)); } else { shared_response_manager->AddFinalResponsePostprocessor( - std::move(check_solution)); + [checker = std::move(check_solution)](CpSolverResponse* response) { + checker(*response); + }); } // Solution postsolving. diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 9e76433e666..3e51e066ece 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -108,10 +108,6 @@ ABSL_FLAG( "we will interpret this as an internal solution which can be used for " "debugging. For instance we use it to identify wrong cuts/reasons."); -ABSL_FLAG(bool, cp_model_check_intermediate_solutions, false, - "When true, all intermediate solutions found by the solver will be " - "checked. This can be expensive, therefore it is off by default."); - namespace operations_research { namespace sat { @@ -307,12 +303,6 @@ std::vector GetSolutionValues(const CpModelProto& model_proto, } } } - - if (DEBUG_MODE || - absl::GetFlag(FLAGS_cp_model_check_intermediate_solutions)) { - // TODO(user): Checks against initial model. - CHECK(SolutionIsFeasible(model_proto, solution)); - } return solution; } @@ -1006,15 +996,21 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { VLOG(3) << num_ignored_constraints << " constraints were skipped."; } if (!unsupported_types.empty()) { - VLOG(1) << "There is unsupported constraints types in this model: "; + auto* logger = model->GetOrCreate(); + SOLVER_LOG(logger, + "There is unsupported constraints types in this model: "); std::vector names; for (const ConstraintProto::ConstraintCase type : unsupported_types) { names.push_back(ConstraintCaseName(type)); } std::sort(names.begin(), names.end()); for (const absl::string_view name : names) { - VLOG(1) << " - " << name; + SOLVER_LOG(logger, " - ", name); } + + // TODO(user): This is wrong. We should support a MODEL_INVALID end of solve + // in the SharedResponseManager. + SOLVER_LOG(logger, "BUG: We will wrongly report INFEASIBLE now."); return unsat(); } diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index 14e34ab310b..a2220f3b956 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -32,7 +32,6 @@ #include "ortools/util/logging.h" ABSL_DECLARE_FLAG(bool, cp_model_dump_models); -ABSL_DECLARE_FLAG(bool, cp_model_check_intermediate_solutions); ABSL_DECLARE_FLAG(std::string, cp_model_dump_prefix); ABSL_DECLARE_FLAG(bool, cp_model_dump_submodels); diff --git a/ortools/sat/cp_model_symmetries.cc b/ortools/sat/cp_model_symmetries.cc index 1c53e17e9e2..b489d655b6d 100644 --- a/ortools/sat/cp_model_symmetries.cc +++ b/ortools/sat/cp_model_symmetries.cc @@ -895,6 +895,47 @@ std::vector BuildInequalityCoeffsForOrbitope( return out; } +void UpdateHintAfterFixingBoolToBreakSymmetry( + PresolveContext* context, int var, bool fixed_value, + const std::vector>& generators) { + if (!context->VarHasSolutionHint(var)) { + return; + } + const int64_t hinted_value = context->SolutionHint(var); + if (hinted_value == static_cast(fixed_value)) { + return; + } + + std::vector schrier_vector; + std::vector orbit; + GetSchreierVectorAndOrbit(var, generators, &schrier_vector, &orbit); + + bool found_target = false; + int target_var; + for (int v : orbit) { + if (context->VarHasSolutionHint(v) && + context->SolutionHint(v) == static_cast(fixed_value)) { + found_target = true; + target_var = v; + break; + } + } + if (!found_target) { + context->UpdateRuleStats( + "hint: couldn't transform infeasible hint properly"); + return; + } + + const std::vector generator_idx = + TracePoint(target_var, schrier_vector, generators); + for (const int i : generator_idx) { + context->PermuteHintValues(*generators[i]); + } + + DCHECK(context->VarHasSolutionHint(var)); + DCHECK_EQ(context->SolutionHint(var), fixed_value); +} + } // namespace bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { @@ -1010,6 +1051,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { // fixing do not exploit the full structure of these symmeteries. Note // however that the fixing via propagation above close cod105 even more // efficiently. + std::vector var_can_be_true_per_orbit(num_vars, -1); { std::vector tmp_to_clear; std::vector tmp_sizes(num_vars, 0); @@ -1050,7 +1092,11 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { } // We push all but the first one in each orbit. - if (tmp_sizes[rep] == 0) can_be_fixed_to_false.push_back(var); + if (tmp_sizes[rep] == 0) { + can_be_fixed_to_false.push_back(var); + } else { + var_can_be_true_per_orbit[rep] = var; + } tmp_sizes[rep] = 0; } } else { @@ -1131,7 +1177,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { } } - // Supper simple heuristic to use the orbitope or not. + // Super simple heuristic to use the orbitope or not. // // In an orbitope with an at most one on each row, we can fix the upper right // triangle. We could use a formula, but the loop is fast enough. @@ -1153,6 +1199,19 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const int var = can_be_fixed_to_false[i]; if (orbits[var] == orbit_index) ++num_in_orbit; context->UpdateRuleStats("symmetry: fixed to false in general orbit"); + if (context->VarHasSolutionHint(var) && context->SolutionHint(var) == 1 && + var_can_be_true_per_orbit[orbits[var]] != -1) { + // We are breaking the symmetry in a way that makes the hint invalid. + // We want `var` to be false, so we would naively pick a symmetry to + // enforce that. But that will be wrong if we do this twice: after we + // permute the hint to fix the first one we would look for a symmetry + // group element that fixes the second one to false. But there are many + // of those, and picking the wrong one would risk making the first one + // true again. Since this is a AMO, fixing the one that is true doesn't + // have this problem. + UpdateHintAfterFixingBoolToBreakSymmetry( + context, var_can_be_true_per_orbit[orbits[var]], true, generators); + } if (!context->SetLiteralToFalse(var)) return false; } diff --git a/ortools/sat/cumulative_energy_test.cc b/ortools/sat/cumulative_energy_test.cc new file mode 100644 index 00000000000..8e58b53a28a --- /dev/null +++ b/ortools/sat/cumulative_energy_test.cc @@ -0,0 +1,562 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cumulative_energy.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/2d_orthogonal_packing_testing.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/cumulative.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +// An instance is a set of energy tasks and a capacity. +struct EnergyTask { + int start_min; + int end_max; + int energy_min; + int energy_max; + int duration_min; + int duration_max; + bool is_optional; +}; + +struct EnergyInstance { + std::vector tasks; + int capacity; +}; + +std::string InstanceDebugString(const EnergyInstance& instance) { + std::string result; + absl::StrAppend(&result, "Instance capacity:", instance.capacity, "\n"); + for (const EnergyTask& task : instance.tasks) { + absl::StrAppend(&result, "[", task.start_min, ", ", task.end_max, + "] duration:", task.duration_min, "..", task.duration_max, + " energy:", task.energy_min, "..", task.energy_max, + " is_optional:", task.is_optional, "\n"); + } + return result; +} + +// Satisfiability using the constraint. +bool SolveUsingConstraint(const EnergyInstance& instance) { + Model model; + std::vector intervals; + std::vector energies; + for (const auto& task : instance.tasks) { + LinearExpression energy; + energy.vars.push_back( + model.Add(NewIntegerVariable(task.energy_min, task.energy_max))); + energy.coeffs.push_back(IntegerValue(1)); + energies.push_back(energy); + if (task.is_optional) { + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable start = + model.Add(NewIntegerVariable(task.start_min, task.end_max)); + const IntegerVariable end = + model.Add(NewIntegerVariable(task.start_min, task.end_max)); + const IntegerVariable duration = + model.Add(NewIntegerVariable(task.duration_min, task.duration_max)); + intervals.push_back( + model.Add(NewOptionalInterval(start, end, duration, is_present))); + } else { + intervals.push_back(model.Add(NewIntervalWithVariableSize( + task.start_min, task.end_max, task.duration_min, task.duration_max))); + } + } + + const AffineExpression capacity( + model.Add(ConstantIntegerVariable(instance.capacity))); + + SchedulingConstraintHelper* helper = + new SchedulingConstraintHelper(intervals, &model); + model.TakeOwnership(helper); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({}, helper, &model); + demands_helper->OverrideLinearizedEnergies(energies); + model.TakeOwnership(demands_helper); + + AddCumulativeOverloadChecker(capacity, helper, demands_helper, &model); + + return SolveIntegerProblemWithLazyEncoding(&model) == + SatSolver::Status::FEASIBLE; +} + +// One task by itself is infeasible. +TEST(CumulativeEnergyTest, UnfeasibleFixedCharacteristics) { + EnergyInstance instance = {{{0, 100, 11, 11, 2, 2, false}}, 5}; + EXPECT_FALSE(SolveUsingConstraint(instance)) << InstanceDebugString(instance); +} + +// Tasks are feasible iff all are at energy min. +TEST(CumulativeEnergyTest, FeasibleEnergyMin) { + EnergyInstance instance = {{ + {-10, 10, 10, 15, 0, 20, false}, + {-10, 10, 15, 20, 0, 20, false}, + {-10, 10, 5, 10, 0, 20, false}, + }, + 3}; + EXPECT_TRUE(SolveUsingConstraint(instance)) << InstanceDebugString(instance); +} + +// Tasks are feasible iff optionals tasks are removed. +TEST(CumulativeEnergyTest, FeasibleRemoveOptionals) { + EnergyInstance instance = {{ + {-10, 10, 1, 1, 1, 1, true}, + {-10, 10, 5, 10, 7, 7, true}, + {-10, 10, 10, 15, 0, 20, false}, + {-10, 10, 15, 20, 0, 20, false}, + {-10, 10, 5, 10, 0, 20, false}, + }, + 3}; + EXPECT_TRUE(SolveUsingConstraint(instance)) << InstanceDebugString(instance); +} + +// This instance was problematic. +TEST(CumulativeEnergyTest, Problematic1) { + EnergyInstance instance = {{ + {2, 18, 6, 7, 5, 10, false}, + {2, 25, 6, 9, 14, 17, false}, + {-4, 19, 6, 9, 10, 20, false}, + {-9, 7, 6, 15, 9, 16, false}, + {-1, 19, 6, 12, 6, 14, false}, + }, + 1}; + EXPECT_TRUE(SolveUsingConstraint(instance)) << InstanceDebugString(instance); +} + +// Satisfiability using a naive model: one task per unit of energy. +// Force energy-based reasoning in Cumulative() and add symmetry breaking, +// or the solver has a much harder time. +bool SolveUsingNaiveModel(const EnergyInstance& instance) { + Model model; + std::vector intervals; + std::vector consumptions; + IntegerVariable one = model.Add(ConstantIntegerVariable(1)); + IntervalsRepository* intervals_repository = + model.GetOrCreate(); + + for (const auto& task : instance.tasks) { + if (task.is_optional) { + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + for (int i = 0; i < task.energy_min; i++) { + const IntegerVariable start = + model.Add(NewIntegerVariable(task.start_min, task.end_max)); + const IntegerVariable end = + model.Add(NewIntegerVariable(task.start_min, task.end_max)); + + intervals.push_back( + model.Add(NewOptionalInterval(start, end, one, is_present))); + consumptions.push_back(AffineExpression(IntegerValue(1))); + } + } else { + IntegerVariable first_start = kNoIntegerVariable; + IntegerVariable previous_start = kNoIntegerVariable; + for (int i = 0; i < task.energy_min; i++) { + IntervalVariable interval = + model.Add(NewInterval(task.start_min, task.end_max, 1)); + intervals.push_back(interval); + consumptions.push_back(AffineExpression(IntegerValue(1))); + const AffineExpression start_expr = + intervals_repository->Start(interval); + CHECK_EQ(start_expr.coeff, 1); + CHECK_EQ(start_expr.constant, 0); + CHECK_NE(start_expr.var, kNoIntegerVariable); + const IntegerVariable start = start_expr.var; + if (previous_start != kNoIntegerVariable) { + model.Add(LowerOrEqual(previous_start, start)); + } else { + first_start = start; + } + previous_start = start; + } + // start[last] <= start[0] + duration_max - 1 + if (previous_start != kNoIntegerVariable) { + model.Add(LowerOrEqualWithOffset(previous_start, first_start, + -task.duration_max + 1)); + } + } + } + + SatParameters params = + model.Add(NewSatParameters("use_overload_checker_in_cumulative:true")); + model.Add(Cumulative(intervals, consumptions, + AffineExpression(IntegerValue(instance.capacity)))); + + return SolveIntegerProblemWithLazyEncoding(&model) == + SatSolver::Status::FEASIBLE; +} + +// Generates random instances, fill the schedule to try and make a tricky case. +EnergyInstance GenerateRandomInstance(int num_tasks, + absl::BitGenRef randomizer) { + const int capacity = absl::Uniform(randomizer, 1, 12); + std::vector tasks; + for (int i = 0; i < num_tasks; i++) { + int start_min = absl::Uniform(randomizer, -10, 10); + int duration_min = absl::Uniform(randomizer, 1, 21); + int duration_max = absl::Uniform(randomizer, 1, 21); + if (duration_min > duration_max) std::swap(duration_min, duration_max); + int end_max = start_min + duration_max + absl::Uniform(randomizer, 0, 10); + int energy_min = (capacity * 30) / num_tasks; + int energy_max = energy_min + absl::Uniform(randomizer, 1, 10); + tasks.push_back({start_min, end_max, energy_min, energy_max, duration_min, + duration_max, false}); + } + + return {tasks, capacity}; +} + +// Compare constraint to naive model. +TEST(CumulativeEnergyTest, CompareToNaiveModel) { + const int num_tests = 10; + std::mt19937 randomizer(12345); + for (int test = 0; test < num_tests; test++) { + EnergyInstance instance = + GenerateRandomInstance(absl::Uniform(randomizer, 4, 7), randomizer); + bool result_constraint = SolveUsingConstraint(instance); + bool result_naive = SolveUsingNaiveModel(instance); + EXPECT_EQ(result_naive, result_constraint) << InstanceDebugString(instance); + LOG(INFO) << result_constraint; + } +} + +struct CumulativeTasks { + int64_t duration; + int64_t demand; + int64_t min_start; + int64_t max_end; +}; + +enum class PropagatorChoice { + OVERLOAD, + OVERLOAD_DFF, +}; +bool TestOverloadCheckerPropagation( + absl::Span tasks, int capacity_min_before, + int capacity_min_after, int capacity_max, + PropagatorChoice propagator_choice = PropagatorChoice::OVERLOAD) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* precedences = + model.GetOrCreate(); + + const int num_tasks = tasks.size(); + std::vector interval_vars(num_tasks); + std::vector demands(num_tasks); + const AffineExpression capacity = + AffineExpression(integer_trail->AddIntegerVariable( + IntegerValue(capacity_min_before), IntegerValue(capacity_max))); + + // Build the task variables. + for (int t = 0; t < num_tasks; ++t) { + interval_vars[t] = model.Add( + NewInterval(tasks[t].min_start, tasks[t].max_end, tasks[t].duration)); + demands[t] = AffineExpression(IntegerValue(tasks[t].demand)); + } + + // Propagate properly the other bounds of the intervals. + EXPECT_TRUE(precedences->Propagate()); + + // Propagator responsible for filtering the capacity variable. + SchedulingConstraintHelper* helper = + new SchedulingConstraintHelper(interval_vars, &model); + model.TakeOwnership(helper); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper(demands, helper, &model); + model.TakeOwnership(demands_helper); + + if (propagator_choice == PropagatorChoice::OVERLOAD) { + AddCumulativeOverloadChecker(capacity, helper, demands_helper, &model); + } else if (propagator_choice == PropagatorChoice::OVERLOAD_DFF) { + AddCumulativeOverloadCheckerDff(capacity, helper, demands_helper, &model); + } else { + LOG(FATAL) << "Unknown propagator choice!"; + } + + // Check initial satisfiability. + auto* sat_solver = model.GetOrCreate(); + if (!sat_solver->Propagate()) return false; + + // Check capacity. + EXPECT_EQ(capacity_min_after, integer_trail->LowerBound(capacity)); + return true; +} + +// This is a trivially infeasible instance. +TEST(OverloadCheckerTest, UNSAT1) { + EXPECT_FALSE( + TestOverloadCheckerPropagation({{4, 2, 0, 7}, {4, 2, 0, 7}}, 2, 2, 2)); +} + +// This is an infeasible instance on which timetabling finds nothing. The +// overload checker finds the contradiction. +TEST(OverloadCheckerTest, UNSAT2) { + EXPECT_FALSE(TestOverloadCheckerPropagation( + {{4, 2, 0, 8}, {4, 2, 0, 8}, {4, 2, 0, 8}}, 2, 2, 2)); +} + +// This is the same instance as in UNSAT1 but here the capacity can increase. +TEST(OverloadCheckerTest, IncreaseCapa1) { + EXPECT_TRUE( + TestOverloadCheckerPropagation({{4, 2, 2, 9}, {4, 2, 2, 9}}, 2, 3, 10)); +} + +// This is an instance in which tasks can be perfectly packed in a rectangle of +// size 5 to 6. OverloadChecker increases the capacity from 3 to 5. +TEST(OverloadCheckerTest, IncreaseCapa2) { + EXPECT_TRUE(TestOverloadCheckerPropagation({{5, 2, 2, 8}, + {2, 3, 2, 8}, + {2, 1, 2, 8}, + {1, 3, 2, 8}, + {1, 3, 2, 8}, + {3, 2, 2, 8}}, + 3, 5, 10)); +} + +// This is an instance in which OverloadChecker increases the capacity. +TEST(OverloadCheckerTest, IncreaseCapa3) { + EXPECT_TRUE(TestOverloadCheckerPropagation( + {{1, 3, 3, 6}, {1, 3, 3, 6}, {1, 1, 3, 8}}, 0, 2, 10)); +} + +// This is a trivially infeasible instance with negative times. +TEST(OverloadCheckerTest, UNSATNeg1) { + EXPECT_FALSE( + TestOverloadCheckerPropagation({{4, 2, -7, 0}, {4, 2, -7, 0}}, 2, 2, 2)); +} + +// This is an infeasible instance with negative times on which timetabling finds +// nothing. The overload checker finds the contradiction. +TEST(OverloadCheckerTest, UNSATNeg2) { + EXPECT_FALSE(TestOverloadCheckerPropagation( + {{4, 2, -4, 4}, {4, 2, -4, 4}, {4, 2, -4, 4}}, 2, 2, 2)); +} + +// This is the same instance as in UNSATNeg1 but here the capacity can increase. +TEST(OverloadCheckerTest, IncreaseCapaNeg1) { + EXPECT_TRUE(TestOverloadCheckerPropagation({{4, 2, -10, -3}, {4, 2, -10, -3}}, + 2, 3, 10)); +} + +// This is an instance with negative times in which tasks can be perfectly +// packed in a rectangle of size 5 to 6. OverloadChecker increases the capacity +// from 3 to 5. +TEST(OverloadCheckerTest, IncreaseCapaNeg2) { + EXPECT_TRUE(TestOverloadCheckerPropagation({{5, 2, -2, 4}, + {2, 3, -2, 4}, + {2, 1, -2, 4}, + {1, 3, -2, 4}, + {1, 3, -2, 4}, + {3, 2, -2, 4}}, + 3, 5, 10)); +} + +// This is an instance with negative times in which OverloadChecker increases +// the capacity. +TEST(OverloadCheckerTest, IncreaseCapaNeg3) { + EXPECT_TRUE(TestOverloadCheckerPropagation( + {{1, 3, -3, 0}, {1, 3, -3, 0}, {1, 1, -3, 2}}, 0, 2, 10)); +} + +TEST(OverloadCheckerTest, OptionalTaskPropagatedToAbsent) { + Model model; + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + + // TODO(user): Fix the code! the propagation is dependent on the order of + // tasks. If we use the proper theta-lambda tree, this will be fixed. + const IntervalVariable i2 = model.Add(NewInterval(0, 10, /*size=*/8)); + const IntervalVariable i1 = + model.Add(NewOptionalInterval(0, 10, /*size=*/8, is_present)); + + SchedulingConstraintHelper* helper = + new SchedulingConstraintHelper({i1, i2}, &model); + model.TakeOwnership(helper); + const AffineExpression cte(IntegerValue(2)); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({cte, cte}, helper, &model); + model.TakeOwnership(demands_helper); + + AddCumulativeOverloadChecker(cte, helper, demands_helper, &model); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.Get(Value(is_present))); +} + +TEST(OverloadCheckerTest, OptionalTaskMissedPropagationCase) { + Model model; + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + const IntervalVariable i1 = + model.Add(NewOptionalInterval(0, 10, /*size=*/8, is_present)); + const IntervalVariable i2 = + model.Add(NewOptionalInterval(0, 10, /*size=*/8, is_present)); + + SchedulingConstraintHelper* helper = + new SchedulingConstraintHelper({i1, i2}, &model); + model.TakeOwnership(helper); + const AffineExpression cte(IntegerValue(2)); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({cte, cte}, helper, &model); + model.TakeOwnership(demands_helper); + + AddCumulativeOverloadChecker(cte, helper, demands_helper, &model); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().VariableIsAssigned( + is_present.Variable())); +} + +TEST(OverloadCheckerDffTest, DffIsNeeded) { + const std::vector tasks = { + {.duration = 10, .demand = 5, .min_start = 0, .max_end = 22}, + {.duration = 10, .demand = 5, .min_start = 0, .max_end = 22}, + {.duration = 10, .demand = 5, .min_start = 0, .max_end = 22}, + {.duration = 10, .demand = 5, .min_start = 0, .max_end = 22}, + }; + EXPECT_FALSE(TestOverloadCheckerPropagation(tasks, /*capacity_min_before=*/9, + /*capacity_min_after=*/9, + /*capacity_max=*/9, + PropagatorChoice::OVERLOAD_DFF)); +} + +TEST(OverloadCheckerDffTest, NoConflictRandomFeasibleProblem) { + absl::BitGen random; + for (int i = 0; i < 100; ++i) { + const std::vector rectangles = GenerateNonConflictingRectangles( + absl::Uniform(random, 6, 20), random); + Rectangle bounding_box; + for (const auto& item : rectangles) { + bounding_box.x_min = std::min(bounding_box.x_min, item.x_min); + bounding_box.x_max = std::max(bounding_box.x_max, item.x_max); + bounding_box.y_min = std::min(bounding_box.y_min, item.y_min); + bounding_box.y_max = std::max(bounding_box.y_max, item.y_max); + } + const std::vector range_items = + MakeItemsFromRectangles(rectangles, 0.3, random); + std::vector tasks(range_items.size()); + + for (int i = 0; i < range_items.size(); ++i) { + tasks[i] = {.duration = range_items[i].x_size.value(), + .demand = range_items[i].y_size.value(), + .min_start = range_items[i].bounding_area.x_min.value(), + .max_end = range_items[i].bounding_area.x_max.value()}; + } + EXPECT_TRUE(TestOverloadCheckerPropagation( + tasks, /*capacity_min_before=*/bounding_box.SizeY().value(), + /*capacity_min_after=*/bounding_box.SizeY().value(), + /*capacity_max=*/bounding_box.SizeY().value(), + PropagatorChoice::OVERLOAD_DFF)); + } +} + +bool TestIsAfterCumulative(absl::Span tasks, + int capacity_max, int expected_end_min) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* precedences = + model.GetOrCreate(); + + const int num_tasks = tasks.size(); + std::vector interval_vars(num_tasks); + std::vector demands(num_tasks); + const AffineExpression capacity = + AffineExpression(integer_trail->AddIntegerVariable( + IntegerValue(capacity_max), IntegerValue(capacity_max))); + + // Build the task variables. + std::vector subtasks; + for (int t = 0; t < num_tasks; ++t) { + interval_vars[t] = model.Add( + NewInterval(tasks[t].min_start, tasks[t].max_end, tasks[t].duration)); + demands[t] = AffineExpression(IntegerValue(tasks[t].demand)); + subtasks.push_back(t); + } + + // Propagate properly the other bounds of the intervals. + EXPECT_TRUE(precedences->Propagate()); + + // Propagator responsible for filtering the capacity variable. + SchedulingConstraintHelper* helper = + new SchedulingConstraintHelper(interval_vars, &model); + model.TakeOwnership(helper); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper(demands, helper, &model); + model.TakeOwnership(demands_helper); + + const IntegerVariable var = + integer_trail->AddIntegerVariable(IntegerValue(0), IntegerValue(100)); + + std::vector offsets(subtasks.size(), IntegerValue(0)); + CumulativeIsAfterSubsetConstraint* propag = + new CumulativeIsAfterSubsetConstraint(var, capacity, subtasks, offsets, + helper, demands_helper, &model); + propag->RegisterWith(model.GetOrCreate()); + model.TakeOwnership(propag); + + // Check initial satisfiability. + auto* sat_solver = model.GetOrCreate(); + if (!sat_solver->Propagate()) return false; + + // Check bound + EXPECT_EQ(expected_end_min, integer_trail->LowerBound(var)); + return true; +} + +// We detect that the interval cannot overlap. +TEST(IsAfterCumulativeTest, BasicCase1) { + // duration, demand, start_min, end_max + EXPECT_TRUE(TestIsAfterCumulative({{4, 2, 0, 8}, {4, 2, 0, 10}}, + /*capacity_max=*/3, + /*expected_end_min=*/8)); +} + +// Now, one interval can overlap. It is also after the other, so the best bound +// we get is not that great: energy = 2 + 8 + 8 = 18, with capa = 3, we get 6. +// +// TODO(user): Maybe we can do more advanced reasoning to recover the 8 here. +TEST(IsAfterCumulativeTest, BasicCase2) { + // duration, demand, start_min, end_max. + EXPECT_TRUE(TestIsAfterCumulative({{2, 1, 3, 8}, {4, 2, 0, 8}, {4, 2, 0, 10}}, + /*capacity_max=*/3, + /*expected_end_min=*/6)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cumulative_test.cc b/ortools/sat/cumulative_test.cc new file mode 100644 index 00000000000..e14f2c63ce6 --- /dev/null +++ b/ortools/sat/cumulative_test.cc @@ -0,0 +1,421 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cumulative.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +// RcpspInstance contains the data to define an instance of the Resource +// Constrained Project Scheduling Problem (RCPSP). We only consider a restricted +// variant of the RCPSP which is the problem of scheduling a set of +// non-premptive tasks that consume a given quantity of a resource without +// exceeding the resource's capacity. We assume that the duration of a task, its +// demand, and the resource capacity are fixed. +struct RcpspInstance { + RcpspInstance() : capacity(0), min_start(0), max_end(0) {} + std::vector durations; + std::vector optional; + std::vector demands; + int64_t capacity; + int64_t min_start; + int64_t max_end; + std::string DebugString() const { + std::string result = "RcpspInstance {\n"; + result += " demands: {" + absl::StrJoin(demands, ", ") + "}\n"; + result += " durations: {" + absl::StrJoin(durations, ", ") + "}\n"; + result += " optional: {" + absl::StrJoin(optional, ", ") + "}\n"; + result += " min_start: " + absl::StrCat(min_start) + "\n"; + result += " max_end: " + absl::StrCat(max_end) + "\n"; + result += " capacity: " + absl::StrCat(capacity) + "\n}"; + return result; + } +}; + +// Generates a random RcpspInstance with num_tasks tasks such that: +// - the duration of a task is a fixed random number in +// [min_duration, max_durations]; +// - tasks can be optional if enable_optional is true; +// - the demand of a task is a fixed random number in [min_demand, max_demand]; +// - the resource capacity is a fixed random number in +// [max_demand - 1, max_capacity]. This allows the capacity to be lower than +// the highest demand to generate trivially unfeasible instances. +// - the energy (i.e. surface) of the resource is 120% of the total energy of +// the tasks. This allows the generation of infeasible instances. +RcpspInstance GenerateRandomInstance(int num_tasks, int min_duration, + int max_duration, int min_demand, + int max_demand, int max_capacity, + int min_start, bool enable_optional) { + absl::BitGen random; + RcpspInstance instance; + int energy = 0; + + // Generate task demands and durations. + int max_of_all_durations = 0; + for (int t = 0; t < num_tasks; ++t) { + const int duration = absl::Uniform(random, min_duration, max_duration + 1); + const int demand = absl::Uniform(random, min_demand, max_demand + 1); + energy += duration * demand; + max_of_all_durations = std::max(max_of_all_durations, duration); + instance.durations.push_back(duration); + instance.demands.push_back(demand); + instance.optional.push_back(enable_optional && + absl::Bernoulli(random, 0.5)); + } + + // Generate the resource capacity. + instance.capacity = absl::Uniform(random, max_demand, max_capacity + 1); + + // Generate the time window. + instance.min_start = min_start; + instance.max_end = + min_start + + std::max(static_cast(std::round(energy * 1.2 / instance.capacity)), + max_of_all_durations); + return instance; +} + +template +int CountAllSolutions(const RcpspInstance& instance, SatParameters parameters, + const Cumulative& cumulative) { + Model model; + parameters.set_use_disjunctive_constraint_in_cumulative(false); + model.GetOrCreate()->SetParameters(parameters); + + DCHECK_EQ(instance.demands.size(), instance.durations.size()); + DCHECK_LE(instance.min_start, instance.max_end); + + const int num_tasks = instance.demands.size(); + std::vector intervals(num_tasks); + std::vector demands(num_tasks); + const AffineExpression capacity = IntegerValue(instance.capacity); + + for (int t = 0; t < num_tasks; ++t) { + if (instance.optional[t]) { + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + intervals[t] = + model.Add(NewOptionalInterval(instance.min_start, instance.max_end, + instance.durations[t], is_present)); + } else { + intervals[t] = model.Add(NewInterval(instance.min_start, instance.max_end, + instance.durations[t])); + } + demands[t] = IntegerValue(instance.demands[t]); + } + + model.Add(cumulative(intervals, demands, capacity, nullptr)); + + // Make sure that every Boolean variable is considered as a decision variable + // to be fixed. + if (parameters.search_branching() == SatParameters::FIXED_SEARCH) { + SatSolver* sat_solver = model.GetOrCreate(); + for (int i = 0; i < sat_solver->NumVariables(); ++i) { + model.Add( + NewIntegerVariableFromLiteral(Literal(BooleanVariable(i), true))); + } + } + + int num_solutions_found = 0; + // Loop until there is no remaining solution to find. + while (true) { + // Try to find a solution. + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + // Leave the loop if there is no solution left. + if (status != SatSolver::Status::FEASIBLE) break; + num_solutions_found++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + return num_solutions_found; +} + +TEST(CumulativeTimeDecompositionTest, AllPermutations) { + RcpspInstance instance; + instance.demands = {1, 1, 1, 1, 1}; + instance.durations = {1, 1, 1, 1, 1}; + instance.optional = {false, false, false, false, false}; + instance.capacity = 1; + instance.min_start = 0; + instance.max_end = 5; + ASSERT_EQ(120, CountAllSolutions(instance, {}, CumulativeTimeDecomposition)); +} + +TEST(CumulativeTimeDecompositionTest, FindAll) { + RcpspInstance instance; + instance.demands = {1, 1, 1, 1, 4, 4}; + instance.durations = {1, 2, 3, 3, 3, 3}; + instance.optional = {false, false, false, false, false, false}; + instance.capacity = 4; + instance.min_start = 0; + instance.max_end = 11; + ASSERT_EQ(2040, CountAllSolutions(instance, {}, CumulativeTimeDecomposition)); + ASSERT_EQ(2040, CountAllSolutions(instance, {}, CumulativeUsingReservoir)); +} + +TEST(CumulativeTimeDecompositionTest, OptionalTasks1) { + RcpspInstance instance; + instance.demands = {3, 3, 3}; + instance.durations = {1, 1, 1}; + instance.optional = {true, true, true}; + instance.capacity = 7; + instance.min_start = 0; + instance.max_end = 2; + ASSERT_EQ(25, CountAllSolutions(instance, {}, Cumulative)); + ASSERT_EQ(25, CountAllSolutions(instance, {}, CumulativeUsingReservoir)); +} + +// Up to two tasks can be scheduled at the same time. +TEST(CumulativeTimeDecompositionTest, OptionalTasks2) { + RcpspInstance instance; + instance.demands = {3, 3, 3}; + instance.durations = {3, 3, 3}; + instance.optional = {true, true, true}; + instance.capacity = 7; + instance.min_start = 0; + instance.max_end = 3; + ASSERT_EQ(7, CountAllSolutions(instance, {}, CumulativeTimeDecomposition)); + ASSERT_EQ(7, CountAllSolutions(instance, {}, CumulativeUsingReservoir)); +} + +TEST(CumulativeTimeDecompositionTest, RegressionTest1) { + RcpspInstance instance; + instance.demands = {5, 4, 1}; + instance.durations = {1, 1, 2}; + instance.optional = {false, false, false}; + instance.capacity = 5; + instance.min_start = 0; + instance.max_end = 2; + ASSERT_EQ(0, CountAllSolutions(instance, {}, CumulativeTimeDecomposition)); +} + +// Cumulative was pruning too many solutions on that instance. +TEST(CumulativeTimeDecompositionTest, RegressionTest2) { + SatParameters parameters; + parameters.set_use_overload_checker_in_cumulative(false); + parameters.set_use_timetable_edge_finding_in_cumulative(false); + RcpspInstance instance; + instance.demands = {4, 4, 3}; + instance.durations = {2, 2, 3}; + instance.optional = {true, true, true}; + instance.capacity = 6; + instance.min_start = 0; + instance.max_end = 5; + ASSERT_EQ( + 22, CountAllSolutions(instance, parameters, CumulativeTimeDecomposition)); +} + +bool CheckCumulative(const SatParameters& parameters, + const RcpspInstance& instance) { + const int64_t num_solutions_ref = + CountAllSolutions(instance, parameters, CumulativeTimeDecomposition); + const int64_t num_solutions_test = + CountAllSolutions(instance, parameters, Cumulative); + if (num_solutions_ref != num_solutions_test) { + LOG(INFO) << "Want: " << num_solutions_ref + << " solutions, got: " << num_solutions_test << " solutions."; + LOG(INFO) << instance.DebugString(); + return false; + } + const int64_t num_solutions_reservoir = + CountAllSolutions(instance, parameters, CumulativeUsingReservoir); + if (num_solutions_ref != num_solutions_reservoir) { + LOG(INFO) << "Want: " << num_solutions_ref + << " solutions, got: " << num_solutions_reservoir + << " solutions."; + LOG(INFO) << instance.DebugString(); + return false; + } + return true; +} + +// Checks that the cumulative constraint performs trivial propagation by +// updating the capacity and demand variables. +TEST(CumulativeTest, CapacityAndDemand) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + const IntervalVariable interval = model.Add(NewInterval(-1000, 1000, 1)); + const IntegerVariable demand = model.Add(NewIntegerVariable(5, 15)); + const IntegerVariable capacity = model.Add(NewIntegerVariable(0, 10)); + const IntegerTrail* integer_trail = model.GetOrCreate(); + model.Add(Cumulative({interval}, {AffineExpression(demand)}, + AffineExpression(capacity))); + ASSERT_TRUE(sat_solver->Propagate()); + ASSERT_EQ(integer_trail->LowerBound(capacity), 5); + ASSERT_EQ(integer_trail->UpperBound(capacity), 10); + ASSERT_EQ(integer_trail->LowerBound(demand), 5); + ASSERT_EQ(integer_trail->UpperBound(demand), 10); +} + +// Checks that the cumulative constraint adpats the demand of the task to +// prevent the capacity overload. +TEST(CumulativeTest, CapacityAndZeroDemand) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + const IntegerVariable start = model.Add(NewIntegerVariable(-1000, 1000)); + const IntegerVariable size = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable end = model.Add(NewIntegerVariable(-1000, 1000)); + const IntervalVariable interval = model.Add(NewInterval(start, end, size)); + const IntegerVariable demand = model.Add(NewIntegerVariable(11, 15)); + const IntegerVariable capacity = model.Add(NewIntegerVariable(0, 10)); + const IntegerTrail* integer_trail = model.GetOrCreate(); + model.Add(Cumulative({interval}, {AffineExpression(demand)}, + AffineExpression(capacity))); + ASSERT_TRUE(sat_solver->Propagate()); + ASSERT_EQ(integer_trail->LowerBound(capacity), 0); + ASSERT_EQ(integer_trail->UpperBound(capacity), 10); + ASSERT_EQ(integer_trail->LowerBound(demand), 11); + ASSERT_EQ(integer_trail->UpperBound(demand), 15); + ASSERT_EQ(integer_trail->UpperBound(size), 0); +} + +// Checks that the cumulative constraint removes the task to prevent the +// capacity overload. +TEST(CumulativeTest, CapacityAndOptionalTask) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + const Literal l = Literal(model.Add(NewBooleanVariable()), true); + const IntervalVariable interval = + model.Add(NewOptionalInterval(-1000, 1000, 1, l)); + const IntegerVariable demand = model.Add(ConstantIntegerVariable(15)); + const IntegerVariable capacity = model.Add(ConstantIntegerVariable(10)); + model.Add(Cumulative({interval}, {AffineExpression(demand)}, + AffineExpression(capacity))); + ASSERT_TRUE(sat_solver->Propagate()); + ASSERT_FALSE(model.Get(Value(l))); +} + +// Cumulative was pruning too many solutions on that instance. +TEST(CumulativeTest, RegressionTest1) { + SatParameters parameters; + parameters.set_use_overload_checker_in_cumulative(false); + parameters.set_use_timetable_edge_finding_in_cumulative(false); + RcpspInstance instance; + instance.demands = {4, 4, 3}; + instance.durations = {2, 2, 3}; + instance.optional = {true, true, true}; + instance.capacity = 6; + instance.min_start = 0; + instance.max_end = 5; + ASSERT_EQ(22, CountAllSolutions(instance, parameters, Cumulative)); +} + +// Cumulative was pruning too many solutions on that instance. +TEST(CumulativeTest, RegressionTest2) { + SatParameters parameters; + parameters.set_use_overload_checker_in_cumulative(false); + parameters.set_use_timetable_edge_finding_in_cumulative(false); + RcpspInstance instance; + instance.demands = {5, 4}; + instance.durations = {4, 4}; + instance.optional = {true, true}; + instance.capacity = 6; + instance.min_start = 0; + instance.max_end = 7; + ASSERT_EQ(9, CountAllSolutions(instance, parameters, Cumulative)); +} + +// ======================================================================== +// All the test belows check that the cumulative propagator finds the exact +// same number of solutions than its time point decomposition. +// ======================================================================== + +// Param1: Number of tasks. +// Param3: Enable overload checking. +// Param4: Enable timetable edge finding. +typedef ::testing::tuple CumulativeTestParams; + +class RandomCumulativeTest + : public ::testing::TestWithParam { + protected: + int GetNumTasks() { return ::testing::get<0>(GetParam()); } + + SatParameters GetSatParameters() { + SatParameters parameters; + parameters.set_use_disjunctive_constraint_in_cumulative(false); + parameters.set_use_overload_checker_in_cumulative( + ::testing::get<1>(GetParam())); + parameters.set_use_timetable_edge_finding_in_cumulative( + ::testing::get<2>(GetParam())); + return parameters; + } +}; + +class FastRandomCumulativeTest : public RandomCumulativeTest {}; +class SlowRandomCumulativeTest : public RandomCumulativeTest {}; + +TEST_P(FastRandomCumulativeTest, FindAll) { + ASSERT_TRUE(CheckCumulative( + GetSatParameters(), + GenerateRandomInstance(GetNumTasks(), 1, 4, 1, 5, 7, 0, false))); +} + +TEST_P(FastRandomCumulativeTest, FindAllNegativeTime) { + ASSERT_TRUE(CheckCumulative( + GetSatParameters(), + GenerateRandomInstance(GetNumTasks(), 1, 4, 1, 5, 7, -100, false))); +} + +TEST_P(SlowRandomCumulativeTest, FindAllZeroDuration) { + ASSERT_TRUE(CheckCumulative( + GetSatParameters(), + GenerateRandomInstance(GetNumTasks(), 0, 4, 1, 5, 7, 0, false))); +} + +TEST_P(SlowRandomCumulativeTest, FindAllZeroDemand) { + ASSERT_TRUE(CheckCumulative( + GetSatParameters(), + GenerateRandomInstance(GetNumTasks(), 1, 4, 0, 5, 7, 0, false))); +} + +TEST_P(SlowRandomCumulativeTest, FindAllOptionalTasks) { + ASSERT_TRUE(CheckCumulative( + GetSatParameters(), + GenerateRandomInstance(GetNumTasks(), 1, 4, 0, 5, 7, 0, true))); +} + +INSTANTIATE_TEST_SUITE_P( + All, FastRandomCumulativeTest, + ::testing::Combine(::testing::Range(3, DEBUG_MODE ? 4 : 6), + ::testing::Bool(), ::testing::Bool())); + +INSTANTIATE_TEST_SUITE_P( + All, SlowRandomCumulativeTest, + ::testing::Combine(::testing::Range(3, DEBUG_MODE ? 4 : 5), + ::testing::Bool(), ::testing::Bool())); + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cuts.cc b/ortools/sat/cuts.cc index da89753b9e2..e65af77517f 100644 --- a/ortools/sat/cuts.cc +++ b/ortools/sat/cuts.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -232,11 +233,10 @@ bool CutData::AllCoefficientsArePositive() const { return true; } -void CutData::Canonicalize() { +void CutData::SortRelevantEntries() { num_relevant_entries = 0; max_magnitude = 0; - for (int i = 0; i < terms.size(); ++i) { - CutTerm& entry = terms[i]; + for (CutTerm& entry : terms) { max_magnitude = std::max(max_magnitude, IntTypeAbs(entry.coeff)); if (entry.HasRelevantLpValue()) { std::swap(terms[num_relevant_entries], entry); @@ -270,92 +270,78 @@ double CutData::ComputeEfficacy() const { return violation / std::sqrt(norm); } -void CutDataBuilder::ClearIndices() { - num_merges_ = 0; - constraint_is_indexed_ = false; - bool_index_.clear(); - secondary_bool_index_.clear(); -} - -void CutDataBuilder::RegisterAllBooleanTerms(const CutData& cut) { - constraint_is_indexed_ = true; - const int size = cut.terms.size(); - for (int i = 0; i < size; ++i) { - const CutTerm& term = cut.terms[i]; - if (term.bound_diff != 1) continue; - if (!term.IsSimple()) continue; - - // Initially we shouldn't have duplicate bools and (1 - bools). - // So we just fill bool_index_. - bool_index_[term.expr_vars[0]] = i; +// We can only merge the term if term.coeff + old_coeff do not overflow and +// if t * new_coeff do not overflow. +// +// If we cannot merge the term, we will keep them separate. The produced cut +// will be less strong, but can still be used. +bool CutDataBuilder::MergeIfPossible(IntegerValue t, CutTerm& to_add, + CutTerm& target) { + DCHECK_EQ(to_add.expr_vars[0], target.expr_vars[0]); + DCHECK_EQ(to_add.expr_coeffs[0], target.expr_coeffs[0]); + + const IntegerValue new_coeff = CapAddI(to_add.coeff, target.coeff); + if (AtMinOrMaxInt64I(new_coeff) || ProdOverflow(t, new_coeff)) { + return false; } -} -void CutDataBuilder::AddOrMergeTerm(const CutTerm& term, IntegerValue t, - CutData* cut) { - if (!constraint_is_indexed_) { - RegisterAllBooleanTerms(*cut); - } + to_add.coeff = 0; // Clear since we merge it. + target.coeff = new_coeff; + return true; +} - DCHECK(term.IsSimple()); - const IntegerVariable var = term.expr_vars[0]; - const bool is_positive = (term.expr_coeffs[0] > 0); - const int new_index = cut->terms.size(); - const auto [it, inserted] = bool_index_.insert({var, new_index}); - if (inserted) { - cut->terms.push_back(term); - return; - } +// We only deal with coeff * Bool or coeff * (1 - Bool) +// +// TODO(user): Because of merges, we might have entry with a coefficient of +// zero than are not useful. Remove them? +int CutDataBuilder::AddOrMergeBooleanTerms(absl::Span new_terms, + IntegerValue t, CutData* cut) { + if (new_terms.empty()) return 0; - // If the referred var is not right, replace the entry. - int entry_index = it->second; - if (entry_index >= new_index || cut->terms[entry_index].expr_vars[0] != var) { - it->second = new_index; - cut->terms.push_back(term); - return; + bool_index_.clear(); + secondary_bool_index_.clear(); + int num_merges = 0; + + // Fill the maps. + int i = 0; + for (CutTerm& term : new_terms) { + const IntegerVariable var = term.expr_vars[0]; + auto& map = term.expr_coeffs[0] > 0 ? bool_index_ : secondary_bool_index_; + const auto [it, inserted] = map.insert({var, i}); + if (!inserted) { + if (MergeIfPossible(t, term, new_terms[it->second])) { + ++num_merges; + } + } + ++i; } - // If the sign is not right, look into secondary hash_map for opposite sign. - if ((cut->terms[entry_index].expr_coeffs[0] > 0) != is_positive) { - const auto [it, inserted] = secondary_bool_index_.insert({var, new_index}); - if (inserted) { - cut->terms.push_back(term); - return; - } + // Loop over the cut now. Note that we loop with indices as we might add new + // terms in the middle of the loop. + for (CutTerm& term : cut->terms) { + if (term.bound_diff != 1) continue; + if (!term.IsSimple()) continue; - // If the referred var is not right, replace the entry. - entry_index = it->second; - if (entry_index >= new_index || - cut->terms[entry_index].expr_vars[0] != var) { - it->second = new_index; - cut->terms.push_back(term); - return; - } + const IntegerVariable var = term.expr_vars[0]; + auto& map = term.expr_coeffs[0] > 0 ? bool_index_ : secondary_bool_index_; + auto it = map.find(var); + if (it == map.end()) continue; - // If the sign is not right, replace the entry. - if ((cut->terms[entry_index].expr_coeffs[0] > 0) != is_positive) { - it->second = new_index; - cut->terms.push_back(term); - return; + // We found a match, try to merge the map entry into the cut. + // Note that we don't waste time erasing this entry from the map since + // we should have no duplicates in the original cut. + if (MergeIfPossible(t, new_terms[it->second], term)) { + ++num_merges; } } - DCHECK_EQ(cut->terms[entry_index].expr_vars[0], var); - DCHECK_EQ((cut->terms[entry_index].expr_coeffs[0] > 0), is_positive); - // We can only merge the term if term.coeff + old_coeff do not overflow and - // if t * new_coeff do not overflow. - // - // If we cannot merge the term, we will keep them separate. The produced cut - // will be less strong, but can still be used. - const IntegerValue new_coeff = - CapAddI(cut->terms[entry_index].coeff, term.coeff); - if (AtMinOrMaxInt64I(new_coeff) || ProdOverflow(t, new_coeff)) { - // If we cannot merge the term, we keep them separate. + // Finally add the terms we couldn't merge. + for (const CutTerm& term : new_terms) { + if (term.coeff == 0) continue; cut->terms.push_back(term); - } else { - ++num_merges_; - cut->terms[entry_index].coeff = new_coeff; } + + return num_merges; } // TODO(user): Divide by gcd first to avoid possible overflow in the @@ -789,40 +775,38 @@ bool IntegerRoundingCutHelper::ComputeCut( // This should be better except it can mess up the norm and the divisors. cut_ = base_ct; if (options.use_ib_before_heuristic && ib_processor != nullptr) { - ib_processor->BaseCutBuilder()->ClearNumMerges(); - const int old_size = static_cast(cut_.terms.size()); - bool abort = true; - for (int i = 0; i < old_size; ++i) { - if (cut_.terms[i].bound_diff <= 1) continue; - if (!cut_.terms[i].HasRelevantLpValue()) continue; - - if (options.prefer_positive_ib && cut_.terms[i].coeff < 0) { + std::vector* new_bool_terms = + ib_processor->ClearedMutableTempTerms(); + for (CutTerm& term : cut_.terms) { + if (term.bound_diff <= 1) continue; + if (!term.HasRelevantLpValue()) continue; + + if (options.prefer_positive_ib && term.coeff < 0) { // We complement the term before trying the implied bound. - cut_.terms[i].Complement(&cut_.rhs); + term.Complement(&cut_.rhs); if (ib_processor->TryToExpandWithLowerImpliedbound( - IntegerValue(1), i, - /*complement=*/true, &cut_, ib_processor->BaseCutBuilder())) { + IntegerValue(1), + /*complement=*/true, &term, &cut_.rhs, new_bool_terms)) { ++total_num_initial_ibs_; - abort = false; continue; } - cut_.terms[i].Complement(&cut_.rhs); + term.Complement(&cut_.rhs); } if (ib_processor->TryToExpandWithLowerImpliedbound( - IntegerValue(1), i, - /*complement=*/true, &cut_, ib_processor->BaseCutBuilder())) { - abort = false; + IntegerValue(1), + /*complement=*/true, &term, &cut_.rhs, new_bool_terms)) { ++total_num_initial_ibs_; } } - total_num_initial_merges_ += - ib_processor->BaseCutBuilder()->NumMergesSinceLastClear(); // TODO(user): We assume that this is called with and without the option // use_ib_before_heuristic, so that we can abort if no IB has been applied // since then we will redo the computation. This is not really clean. - if (abort) return false; + if (new_bool_terms->empty()) return false; + total_num_initial_merges_ += + ib_processor->MutableCutBuilder()->AddOrMergeBooleanTerms( + absl::MakeSpan(*new_bool_terms), IntegerValue(1), &cut_); } // Our heuristic will try to generate a few different cuts, and we will keep @@ -842,7 +826,7 @@ bool IntegerRoundingCutHelper::ComputeCut( // // TODO(user): If the rhs is small and close to zero, we might want to // consider different way of complementing the variables. - cut_.Canonicalize(); + cut_.SortRelevantEntries(); const IntegerValue remainder_threshold( std::max(IntegerValue(1), cut_.max_magnitude / 1000)); if (cut_.rhs >= 0 && cut_.rhs < remainder_threshold.value()) { @@ -997,11 +981,11 @@ bool IntegerRoundingCutHelper::ComputeCut( // This should lead to stronger cuts even if the norms might be worse. num_ib_used_ = 0; if (ib_processor != nullptr) { - const auto [num_lb, num_ub] = ib_processor->PostprocessWithImpliedBound( - f, factor_t, &cut_, &cut_builder_); + const auto [num_lb, num_ub, num_merges] = + ib_processor->PostprocessWithImpliedBound(f, factor_t, &cut_); total_num_pos_lifts_ += num_lb; total_num_neg_lifts_ += num_ub; - total_num_merges_ += cut_builder_.NumMergesSinceLastClear(); + total_num_merges_ += num_merges; num_ib_used_ = num_lb + num_ub; } @@ -1297,21 +1281,23 @@ bool CoverCutHelper::TrySimpleKnapsack(const CutData& input_ct, // Tricky: This only work because the cut absl128 rhs is not changed by these // operations. if (ib_processor != nullptr) { - ib_processor->BaseCutBuilder()->ClearNumMerges(); - const int old_size = static_cast(cut_.terms.size()); - for (int i = 0; i < old_size; ++i) { + std::vector* new_bool_terms = + ib_processor->ClearedMutableTempTerms(); + for (CutTerm& term : cut_.terms) { // We only look at non-Boolean with an lp value not close to the upper // bound. - const CutTerm& term = cut_.terms[i]; if (term.bound_diff <= 1) continue; if (term.lp_value + 1e-4 > AsDouble(term.bound_diff)) continue; if (ib_processor->TryToExpandWithLowerImpliedbound( - IntegerValue(1), i, - /*complement=*/false, &cut_, ib_processor->BaseCutBuilder())) { + IntegerValue(1), + /*complement=*/false, &term, &cut_.rhs, new_bool_terms)) { ++cover_stats_.num_initial_ibs; } } + + ib_processor->MutableCutBuilder()->AddOrMergeBooleanTerms( + absl::MakeSpan(*new_bool_terms), IntegerValue(1), &cut_); } bool has_relevant_int = false; @@ -1387,11 +1373,11 @@ bool CoverCutHelper::TrySimpleKnapsack(const CutData& input_ct, } if (ib_processor != nullptr) { - const auto [num_lb, num_ub] = ib_processor->PostprocessWithImpliedBound( - f, /*factor_t=*/1, &cut_, &cut_builder_); + const auto [num_lb, num_ub, num_merges] = + ib_processor->PostprocessWithImpliedBound(f, /*factor_t=*/1, &cut_); cover_stats_.num_lb_ibs += num_lb; cover_stats_.num_ub_ibs += num_ub; - cover_stats_.num_merges += cut_builder_.NumMergesSinceLastClear(); + cover_stats_.num_merges += num_merges; } cover_stats_.num_bumps += ApplyWithPotentialBump(f, best_coeff, &cut_); @@ -1467,11 +1453,11 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct, min_magnitude); if (ib_processor != nullptr) { - const auto [num_lb, num_ub] = ib_processor->PostprocessWithImpliedBound( - f, /*factor_t=*/1, &cut_, &cut_builder_); + const auto [num_lb, num_ub, num_merges] = + ib_processor->PostprocessWithImpliedBound(f, /*factor_t=*/1, &cut_); flow_stats_.num_lb_ibs += num_lb; flow_stats_.num_ub_ibs += num_ub; - flow_stats_.num_merges += cut_builder_.NumMergesSinceLastClear(); + flow_stats_.num_merges += num_merges; } // Lifting. @@ -1526,16 +1512,19 @@ bool CoverCutHelper::TryWithLetchfordSouliLifting( // // TODO(user): Merge Boolean terms that are complement of each other. if (ib_processor != nullptr) { - ib_processor->BaseCutBuilder()->ClearNumMerges(); - const int old_size = static_cast(cut_.terms.size()); - for (int i = 0; i < old_size; ++i) { - if (cut_.terms[i].bound_diff <= 1) continue; + std::vector* new_bool_terms = + ib_processor->ClearedMutableTempTerms(); + for (CutTerm& term : cut_.terms) { + if (term.bound_diff <= 1) continue; if (ib_processor->TryToExpandWithLowerImpliedbound( - IntegerValue(1), i, - /*complement=*/false, &cut_, ib_processor->BaseCutBuilder())) { + IntegerValue(1), + /*complement=*/false, &term, &cut_.rhs, new_bool_terms)) { ++ls_stats_.num_initial_ibs; } } + + ib_processor->MutableCutBuilder()->AddOrMergeBooleanTerms( + absl::MakeSpan(*new_bool_terms), IntegerValue(1), &cut_); } // TODO(user): we currently only deal with Boolean in the cover. Fix. @@ -2192,9 +2181,9 @@ bool ImpliedBoundsProcessor::DecomposeWithImpliedUpperBound( return true; } -std::pair ImpliedBoundsProcessor::PostprocessWithImpliedBound( +std::tuple ImpliedBoundsProcessor::PostprocessWithImpliedBound( const std::function& f, IntegerValue factor_t, - CutData* cut, CutDataBuilder* builder) { + CutData* cut) { int num_applied_lb = 0; int num_applied_ub = 0; @@ -2202,10 +2191,9 @@ std::pair ImpliedBoundsProcessor::PostprocessWithImpliedBound( CutTerm slack_term; CutTerm ub_bool_term; CutTerm ub_slack_term; - builder->ClearIndices(); - const int initial_size = cut->terms.size(); - for (int i = 0; i < initial_size; ++i) { - CutTerm& term = cut->terms[i]; + + tmp_terms_.clear(); + for (CutTerm& term : cut->terms) { if (term.bound_diff <= 1) continue; if (!term.IsSimple()) continue; @@ -2255,30 +2243,31 @@ std::pair ImpliedBoundsProcessor::PostprocessWithImpliedBound( // loose more, so we prefer to be a bit defensive. if (score > base_score + 1e-2) { ++num_applied_ub; - term = ub_slack_term; // Override first before push_back() ! - builder->AddOrMergeTerm(ub_bool_term, factor_t, cut); + term = ub_slack_term; + tmp_terms_.push_back(ub_bool_term); continue; } } if (expand) { ++num_applied_lb; - term = slack_term; // Override first before push_back() ! - builder->AddOrMergeTerm(bool_term, factor_t, cut); + term = slack_term; + tmp_terms_.push_back(bool_term); } } - return {num_applied_lb, num_applied_ub}; + + const int num_merges = cut_builder_.AddOrMergeBooleanTerms( + absl::MakeSpan(tmp_terms_), factor_t, cut); + + return {num_applied_lb, num_applied_ub, num_merges}; } -// Important: The cut_builder_ must have been reset. bool ImpliedBoundsProcessor::TryToExpandWithLowerImpliedbound( - IntegerValue factor_t, int i, bool complement, CutData* cut, - CutDataBuilder* builder) { - CutTerm& term = cut->terms[i]; - + IntegerValue factor_t, bool complement, CutTerm* term, absl::int128* rhs, + std::vector* new_bool_terms) { CutTerm bool_term; CutTerm slack_term; - if (!DecomposeWithImpliedLowerBound(term, factor_t, bool_term, slack_term)) { + if (!DecomposeWithImpliedLowerBound(*term, factor_t, bool_term, slack_term)) { return false; } @@ -2287,26 +2276,22 @@ bool ImpliedBoundsProcessor::TryToExpandWithLowerImpliedbound( // It is always good to complement such variable. // // Note that here we do more and just complement anything closer to UB. - // - // TODO(user): Because of merges, we might have entry with a coefficient of - // zero than are not useful. Remove them. if (complement) { if (bool_term.lp_value > 0.5) { - bool_term.Complement(&cut->rhs); + bool_term.Complement(rhs); } if (slack_term.lp_value > 0.5 * AsDouble(slack_term.bound_diff)) { - slack_term.Complement(&cut->rhs); + slack_term.Complement(rhs); } } - term = slack_term; - builder->AddOrMergeTerm(bool_term, factor_t, cut); + *term = slack_term; + new_bool_terms->push_back(bool_term); return true; } bool ImpliedBoundsProcessor::CacheDataForCut(IntegerVariable first_slack, CutData* cut) { - base_cut_builder_.ClearIndices(); cached_data_.clear(); const int size = cut->terms.size(); @@ -2746,17 +2731,21 @@ CutGenerator CreateCliqueCutGenerator( CutGenerator result; result.vars = variables; auto* implication_graph = model->GetOrCreate(); + result.only_run_at_level_zero = true; result.generate_cuts = [variables, literals, implication_graph, positive_map, negative_map, model](LinearConstraintManager* manager) { std::vector packed_values; + std::vector packed_reduced_costs; const auto& lp_values = manager->LpValues(); + const auto& reduced_costs = manager->ReducedCosts(); for (int i = 0; i < literals.size(); ++i) { packed_values.push_back(lp_values[variables[i]]); + packed_reduced_costs.push_back(reduced_costs[variables[i]]); } const std::vector> at_most_ones = - implication_graph->GenerateAtMostOnesWithLargeWeight(literals, - packed_values); + implication_graph->GenerateAtMostOnesWithLargeWeight( + literals, packed_values, packed_reduced_costs); for (const std::vector& at_most_one : at_most_ones) { // We need to express such "at most one" in term of the initial diff --git a/ortools/sat/cuts.h b/ortools/sat/cuts.h index 7e975cc6d4e..245a440c87a 100644 --- a/ortools/sat/cuts.h +++ b/ortools/sat/cuts.h @@ -134,6 +134,10 @@ struct CutData { double ComputeViolation() const; double ComputeEfficacy() const; + // This sorts terms by decreasing lp values and fills both + // num_relevant_entries and max_magnitude. + void SortRelevantEntries(); + std::string DebugString() const; // Note that we use a 128 bit rhs so we can freely complement variable without @@ -141,8 +145,7 @@ struct CutData { absl::int128 rhs; std::vector terms; - // This sorts terms and fill both num_relevant_entries and max_magnitude. - void Canonicalize(); + // Only filled after SortRelevantEntries(). IntegerValue max_magnitude; int num_relevant_entries; }; @@ -150,24 +153,21 @@ struct CutData { // Stores temporaries used to build or manipulate a CutData. class CutDataBuilder { public: + // Returns false if we encounter an integer overflow. + bool ConvertToLinearConstraint(const CutData& cut, LinearConstraint* output); + // These function allow to merges entries corresponding to the same variable // and complementation. That is (X - lb) and (ub - X) are NOT merged and kept // as separate terms. Note that we currently only merge Booleans since this // is the only case we need. - void ClearIndices(); - void AddOrMergeTerm(const CutTerm& term, IntegerValue t, CutData* cut); - - void ClearNumMerges() { num_merges_ = 0; } - int NumMergesSinceLastClear() const { return num_merges_; } - - // Returns false if we encounter an integer overflow. - bool ConvertToLinearConstraint(const CutData& cut, LinearConstraint* output); + // + // Return num_merges. + int AddOrMergeBooleanTerms(absl::Span terms, IntegerValue t, + CutData* cut); private: - void RegisterAllBooleanTerms(const CutData& cut); + bool MergeIfPossible(IntegerValue t, CutTerm& to_add, CutTerm& target); - int num_merges_ = 0; - bool constraint_is_indexed_ = false; absl::flat_hash_map bool_index_; absl::flat_hash_map secondary_bool_index_; absl::btree_map tmp_map_; @@ -219,27 +219,31 @@ class ImpliedBoundsProcessor { // We are about to apply the super-additive function f() to the CutData. Use // implied bound information to eventually substitute and make the cut - // stronger. Returns the number of {lb_ib, ub_ib} applied. + // stronger. Returns the number of {lb_ib, ub_ib, merges} applied. // // This should lead to stronger cuts even if the norms migth be worse. - std::pair PostprocessWithImpliedBound( + std::tuple PostprocessWithImpliedBound( const std::function& f, IntegerValue factor_t, - CutData* cut, CutDataBuilder* builder); + CutData* cut); // Precomputes quantities used by all cut generation. // This allows to do that once rather than 6 times. // Return false if there are no exploitable implied bounds. bool CacheDataForCut(IntegerVariable first_slack, CutData* cut); - // All our cut code use the same base cut (modulo complement), so we reuse the - // hash-map of where boolean are in the cut. Note that even if we add new - // entry that are no longer there for another cut algo, we can still reuse the - // same hash-map. - CutDataBuilder* BaseCutBuilder() { return &base_cut_builder_; } + bool TryToExpandWithLowerImpliedbound(IntegerValue factor_t, bool complement, + CutTerm* term, absl::int128* rhs, + std::vector* new_bool_terms); - bool TryToExpandWithLowerImpliedbound(IntegerValue factor_t, int i, - bool complement, CutData* cut, - CutDataBuilder* builder); + // This can be used to share the hash-map memory. + CutDataBuilder* MutableCutBuilder() { return &cut_builder_; } + + // This can be used as a temporary storage for + // TryToExpandWithLowerImpliedbound(). + std::vector* ClearedMutableTempTerms() { + tmp_terms_.clear(); + return &tmp_terms_; + } // Add a new variable that could be used in the new cuts. // Note that the cache must be computed to take this into account. @@ -283,7 +287,8 @@ class ImpliedBoundsProcessor { mutable absl::flat_hash_map cache_; // Temporary data used by CacheDataForCut(). - CutDataBuilder base_cut_builder_; + std::vector tmp_terms_; + CutDataBuilder cut_builder_; std::vector cached_data_; TopNCuts ib_cut_pool_ = TopNCuts(50); @@ -431,7 +436,6 @@ class IntegerRoundingCutHelper { std::vector best_rs_; int64_t num_ib_used_ = 0; - CutDataBuilder cut_builder_; CutData cut_; std::vector> adjusted_coeffs_; @@ -531,7 +535,6 @@ class CoverCutHelper { // Here to reuse memory, cut_ is both the input and the output. CutData cut_; CutData temp_cut_; - CutDataBuilder cut_builder_; // Hack to not sort twice. bool has_bool_base_ct_ = false; diff --git a/ortools/sat/cuts_test.cc b/ortools/sat/cuts_test.cc new file mode 100644 index 00000000000..d8d78d153c1 --- /dev/null +++ b/ortools/sat/cuts_test.cc @@ -0,0 +1,1171 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cuts.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/implied_bounds.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/linear_constraint_manager.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/fp_utils.h" +#include "ortools/util/sorted_interval_list.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::EndsWith; +using ::testing::StartsWith; + +std::vector IntegerValueVector(absl::Span values) { + std::vector result; + for (const int v : values) result.push_back(IntegerValue(v)); + return result; +} + +TEST(GetSuperAdditiveRoundingFunctionTest, AllSmallValues) { + const int max_divisor = 25; + for (IntegerValue max_t(1); max_t <= 9; ++max_t) { + for (IntegerValue max_scaling(1); max_scaling <= 9; max_scaling++) { + for (IntegerValue divisor(1); divisor <= max_divisor; ++divisor) { + for (IntegerValue rhs_remainder(1); rhs_remainder < divisor; + ++rhs_remainder) { + const std::string info = absl::StrCat( + " rhs_remainder = ", rhs_remainder.value(), + " divisor = ", divisor.value(), " max_t = ", max_t.value(), + " max_scaling = ", max_scaling.value()); + const auto f = GetSuperAdditiveRoundingFunction( + rhs_remainder, divisor, + std::min(max_t, + GetFactorT(rhs_remainder, divisor, IntegerValue(100))), + max_scaling); + ASSERT_EQ(f(IntegerValue(0)), 0) << info; + ASSERT_GE(f(divisor), 1) << info; + ASSERT_LE(f(divisor), max_scaling * max_t) << info; + for (IntegerValue a(0); a < divisor; ++a) { + IntegerValue min_diff = kMaxIntegerValue; + for (IntegerValue b(1); b < divisor; ++b) { + min_diff = std::min(min_diff, f(a + b) - f(a) - f(b)); + ASSERT_GE(min_diff, 0) + << info << ", f(" << a << ")=" << f(a) << " + f(" << b + << ")=" << f(b) << " <= f(" << a + b << ")=" << f(a + b); + } + + // TODO(user): Our discretized "mir" function is not always + // maximal. Try to fix it? + if (a <= rhs_remainder || max_scaling != 2) continue; + if (rhs_remainder * max_t < divisor / 2) continue; + + // min_diff > 0 shows that our function is dominated (i.e. not + // maximal) since f(a) could be increased by 1/2. + ASSERT_EQ(min_diff, 0) + << "Not maximal at " << info << " f(" << a << ") = " << f(a) + << " min_diff:" << min_diff; + } + } + } + } + } +} + +TEST(GetSuperAdditiveStrengtheningFunction, AllSmallValues) { + for (const int64_t rhs : {13, 14}) { // test odd/even + for (int64_t min_magnitude = 1; min_magnitude <= rhs; ++min_magnitude) { + const auto f = GetSuperAdditiveStrengtheningFunction(rhs, min_magnitude); + + // Check super additivity in -[50, 50] + for (int a = -50; a <= 50; ++a) { + for (int b = -50; b <= 50; ++b) { + ASSERT_LE(f(a) + f(b), f(a + b)) + << " a=" << a << " b=" << b << " min=" << min_magnitude + << " rhs=" << rhs; + } + } + } + } +} + +TEST(GetSuperAdditiveStrengtheningMirFunction, AllSmallValues) { + for (const int64_t rhs : {13, 14}) { // test odd/even + for (int64_t scaling = 1; scaling <= rhs; ++scaling) { + const auto f = GetSuperAdditiveStrengtheningMirFunction(rhs, scaling); + + // Check super additivity in -[50, 50] + for (int a = -50; a <= 50; ++a) { + for (int b = -50; b <= 50; ++b) { + ASSERT_LE(f(a) + f(b), f(a + b)) + << " a=" << a << " b=" << b << " scaling=" << scaling + << " rhs=" << rhs; + } + } + } + } +} + +TEST(CutDataTest, ComputeViolation) { + CutData cut; + cut.rhs = 2; + cut.terms.push_back({.lp_value = 1.2, .coeff = 1}); + cut.terms.push_back({.lp_value = 0.5, .coeff = 2}); + EXPECT_COMPARABLE(cut.ComputeViolation(), 0.2, 1e-10); +} + +template +std::string GetCutString(const Helper& helper) { + LinearConstraint ct; + CutDataBuilder builder; + EXPECT_TRUE(builder.ConvertToLinearConstraint(helper.cut(), &ct)); + return ct.DebugString(); +} + +TEST(CoverCutHelperTest, SimpleExample) { + // 6x0 + 4x1 + 10x2 <= 9. + std::vector vars = {IntegerVariable(0), IntegerVariable(2), + IntegerVariable(4)}; + std::vector coeffs = IntegerValueVector({6, 4, 10}); + std::vector lbs = IntegerValueVector({0, 0, 0}); + std::vector lp_values{1.0, 0.5, 0.1}; // Tight. + + // Note(user): the ub of the last variable is not used. But the first two + // are even though only the second one is required for the validity of the + // cut. + std::vector ubs = IntegerValueVector({1, 1, 10}); + + CutData data; + data.FillFromParallelVectors(IntegerValue(9), vars, coeffs, lp_values, lbs, + ubs); + data.ComplementForPositiveCoefficients(); + CoverCutHelper helper; + EXPECT_TRUE(helper.TrySimpleKnapsack(data)); + EXPECT_EQ(GetCutString(helper), "1*X0 1*X1 1*X2 <= 1"); + EXPECT_EQ(helper.Info(), "lift=1"); +} + +// I tried to reproduce bug 169094958, but if the base constraint is tight, +// the bug was triggered only due to numerical imprecision. A simple way to +// trigger it is like with this test if the given LP value just violate the +// initial constraint. +TEST(CoverCutHelperTest, WeirdExampleWithViolatedConstraint) { + // x0 + x1 <= 9. + std::vector vars = {IntegerVariable(0), IntegerVariable(2)}; + std::vector coeffs = IntegerValueVector({1, 1}); + std::vector lbs = IntegerValueVector({ + 0, + 0, + }); + std::vector ubs = IntegerValueVector({10, 13}); + std::vector lp_values{0.0, 12.6}; // violated. + + CutData data; + data.FillFromParallelVectors(IntegerValue(9), vars, coeffs, lp_values, lbs, + ubs); + data.ComplementForPositiveCoefficients(); + CoverCutHelper helper; + EXPECT_TRUE(helper.TrySimpleKnapsack(data)); + EXPECT_EQ(GetCutString(helper), "1*X0 1*X1 <= 9"); + EXPECT_EQ(helper.Info(), "lift=1"); +} + +TEST(CoverCutHelperTest, LetchfordSouliLifting) { + const int n = 10; + const IntegerValue rhs = IntegerValue(16); + std::vector vars; + std::vector coeffs = + IntegerValueVector({5, 5, 5, 5, 15, 13, 9, 8, 8, 8}); + for (int i = 0; i < n; ++i) { + vars.push_back(IntegerVariable(2 * i)); + } + std::vector lbs(n, IntegerValue(0)); + std::vector ubs(n, IntegerValue(1)); + std::vector lps(n, 0.0); + for (int i = 0; i < 4; ++i) { + lps[i] = 0.9; + } + + CutData data; + data.FillFromParallelVectors(rhs, vars, coeffs, lps, lbs, ubs); + data.ComplementForPositiveCoefficients(); + + CoverCutHelper helper; + EXPECT_TRUE(helper.TryWithLetchfordSouliLifting(data)); + EXPECT_EQ(GetCutString(helper), + "1*X0 1*X1 1*X2 1*X3 3*X4 3*X5 2*X6 1*X7 1*X8 1*X9 <= 3"); + + // For now, we only support Booleans in the cover. + // Note that we don't care for variable not in the cover though. + data.terms[3].bound_diff = IntegerValue(2); + EXPECT_FALSE(helper.TryWithLetchfordSouliLifting(data)); +} + +LinearConstraint IntegerRoundingCutWithBoundsFromTrail( + const RoundingOptions& options, IntegerValue rhs, + absl::Span vars, + absl::Span coeffs, absl::Span lp_values, + const Model& model) { + std::vector lbs; + std::vector ubs; + auto* integer_trail = model.Get(); + for (int i = 0; i < vars.size(); ++i) { + lbs.push_back(integer_trail->LowerBound(vars[i])); + ubs.push_back(integer_trail->UpperBound(vars[i])); + } + + CutData data; + data.FillFromParallelVectors(rhs, vars, coeffs, lp_values, lbs, ubs); + data.ComplementForSmallerLpValues(); + + IntegerRoundingCutHelper helper; + EXPECT_TRUE(helper.ComputeCut(options, data, nullptr)); + + CutDataBuilder builder; + LinearConstraint constraint; + EXPECT_TRUE(builder.ConvertToLinearConstraint(helper.cut(), &constraint)); + return constraint; +} + +TEST(IntegerRoundingCutTest, LetchfordLodiExample1) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 10)); + + // 6x0 + 4x1 <= 9. + const IntegerValue rhs = IntegerValue(9); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(6), IntegerValue(4)}; + + std::vector lp_values{1.5, 0.0}; + RoundingOptions options; + options.max_scaling = 2; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + options, rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "2*X0 1*X1 <= 2"); +} + +TEST(IntegerRoundingCutTest, LetchfordLodiExample1Modified) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 1)); + + // 6x0 + 4x1 <= 9. + const IntegerValue rhs = IntegerValue(9); + + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(6), IntegerValue(4)}; + + // x1 is at its upper bound here. + std::vector lp_values{5.0 / 6.0, 1.0}; + + // Note that the cut is only valid because the bound of x1 is one here. + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "1*X0 1*X1 <= 1"); +} + +TEST(IntegerRoundingCutTest, LetchfordLodiExample2) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 10)); + + // 6x0 + 4x1 <= 9. + const IntegerValue rhs = IntegerValue(9); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(6), IntegerValue(4)}; + + std::vector lp_values{0.0, 2.25}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "3*X0 2*X1 <= 4"); +} + +TEST(IntegerRoundingCutTest, LetchfordLodiExample2WithNegatedCoeff) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(-3, 0)); + + // 6x0 - 4x1 <= 9. + const IntegerValue rhs = IntegerValue(9); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(6), IntegerValue(-4)}; + + std::vector lp_values{0.0, -2.25}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + + // We actually do not return like in the example "3*X0 -2*X1 <= 4" + // But the simpler X0 - X1 <= 2 which has the same violation (0.25) but a + // better norm. + EXPECT_EQ(constraint.DebugString(), "1*X0 -1*X1 <= 2"); +} + +// This used to trigger a failure with a wrong implied bound code path. +TEST(IntegerRoundingCutTest, TestCaseUsedForDebugging) { + Model model; + // Variable values are in comment. + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 3)); // 1 + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 4)); // 0 + const IntegerVariable x2 = model.Add(NewIntegerVariable(0, 2)); // 1 + const IntegerVariable x3 = model.Add(NewIntegerVariable(0, 1)); // 0 + const IntegerVariable x4 = model.Add(NewIntegerVariable(0, 3)); // 1 + + // The constraint is tight with value above (-5 - 4 + 7 == -2). + const IntegerValue rhs = IntegerValue(-2); + std::vector vars = {x0, x1, x2, x3, x4}; + std::vector coeffs = IntegerValueVector({-5, -1, -4, -7, 7}); + + // The constraint is tight under LP (-5 * 0.4 == -2). + std::vector lp_values{0.4, 0.0, -1e-16, 0.0, 0.0}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + + EXPECT_EQ(constraint.DebugString(), "-2*X0 -1*X1 -2*X2 -2*X3 2*X4 <= -2"); +} + +// The algo should find a "divisor" 2 when it lead to a good cut. +// +// TODO(user): Double check that such divisor will always be found? Of course, +// if the initial constraint coefficient are too high, then it will not, but +// that is okay since such cut efficacity will be bad anyway. +TEST(IntegerRoundingCutTest, ZeroHalfCut) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x2 = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x3 = model.Add(NewIntegerVariable(0, 10)); + + // 6x0 + 4x1 + 8x2 + 7x3 <= 9. + const IntegerValue rhs = IntegerValue(9); + std::vector vars = {x0, x1, x2, x3}; + std::vector coeffs = {IntegerValue(6), IntegerValue(4), + IntegerValue(8), IntegerValue(7)}; + + std::vector lp_values{0.25, 1.25, 0.3125, 0.0}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "3*X0 2*X1 4*X2 3*X3 <= 4"); +} + +TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 5)); + + // 1e6 x0 - x1 <= 1.5e6. + const IntegerValue rhs = IntegerValue(1.5e6); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(1e6), IntegerValue(-1)}; + + // Note thate without adjustement, this returns 2 * X0 - X1 <= 2. + // TODO(user): expose parameters so this can be verified other than manually? + std::vector lp_values{1.5, 0.1}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "1*X0 <= 1"); +} + +TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision2) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(0, 5)); + + // 1e6 x0 + 999999 * x1 <= 1.5e6. + const IntegerValue rhs = IntegerValue(1.5e6); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(1e6), IntegerValue(999999)}; + + // Note thate without adjustement, this returns 2 * X0 + X1 <= 2. + // TODO(user): expose parameters so this can be verified other than manually? + std::vector lp_values{1.49, 0.1}; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + RoundingOptions(), rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "1*X0 1*X1 <= 1"); +} + +TEST(IntegerRoundingCutTest, MirOnLargerConstraint) { + Model model; + std::vector vars(10); + for (int i = 0; i < 10; ++i) { + vars[i] = model.Add(NewIntegerVariable(0, 5)); + } + + // sum (i + 1) x_i <= 16. + const IntegerValue rhs = IntegerValue(16); + std::vector coeffs; + for (int i = 0; i < vars.size(); ++i) { + coeffs.push_back(IntegerValue(i + 1)); + } + + std::vector lp_values(vars.size(), 0.0); + lp_values[9] = 1.6; // 10 * 1.6 == 16 + + RoundingOptions options; + options.max_scaling = 4; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + options, rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), "1*X6 2*X7 3*X8 4*X9 <= 4"); +} + +TEST(IntegerRoundingCutTest, MirOnLargerConstraint2) { + Model model; + std::vector vars(10); + for (int i = 0; i < 10; ++i) vars[i] = model.Add(NewIntegerVariable(0, 5)); + + // sum (i + 1) x_i <= 16. + const IntegerValue rhs = IntegerValue(16); + std::vector coeffs; + for (int i = 0; i < vars.size(); ++i) { + coeffs.push_back(IntegerValue(i + 1)); + } + + std::vector lp_values(vars.size(), 0.0); + lp_values[4] = 5.5 / 5.0; + lp_values[9] = 1.05; + + RoundingOptions options; + options.max_scaling = 4; + LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( + options, rhs, vars, coeffs, lp_values, model); + EXPECT_EQ(constraint.DebugString(), + "2*X1 3*X2 4*X3 6*X4 6*X5 8*X6 9*X7 10*X8 12*X9 <= 18"); +} + +std::vector ToIntegerValues(const std::vector input) { + std::vector output; + for (const int64_t v : input) output.push_back(IntegerValue(v)); + return output; +} + +std::vector ToIntegerVariables( + const std::vector input) { + std::vector output; + for (const int64_t v : input) output.push_back(IntegerVariable(v)); + return output; +} + +// This used to fail as I was coding the CL when I was trying to force t==1 +// in the GetSuperAdditiveRoundingFunction() code. +TEST(IntegerRoundingCutTest, RegressionTest) { + RoundingOptions options; + options.max_scaling = 4; + + const IntegerValue rhs = int64_t{7469520585651099083}; + std::vector vars = ToIntegerVariables( + {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30, 32, 34, 36, 38, 42, 44, 46, 48, 50, 52, 54, 56}); + std::vector coeffs = ToIntegerValues( + {22242929208935956LL, 128795791007031270LL, 64522773588815932LL, + 106805487542181976LL, 136903984044996548LL, 177476314670499137LL, + 364043443034395LL, 28002509947960647LL, 310965596097558939LL, + 103949088324014599LL, 41400520193055115LL, 50111468002532494LL, + 53821870865384327LL, 68690238549704032LL, 75189534851923882LL, + 136250652059774801LL, 169776580612315087LL, 172493907306536826LL, + 13772608007357656LL, 74052819842959090LL, 134400722410234077LL, + 5625133860678171LL, 299572729577293761LL, 81099235700461109LL, + 178989907222373586LL, 16642124499479353LL, 110378717916671350LL, + 41703587448036910LL}); + std::vector lp_values = { + 0, 0, 2.51046, 0.0741114, 0.380072, 5.17238, 0, + 0, 13.2214, 0, 0.635977, 0, 0, 3.39859, + 1.15936, 0.165207, 2.29673, 2.19505, 0, 0, 2.31191, + 0, 0.785149, 0.258119, 2.26978, 0, 0.970046, 0}; + std::vector lbs(28, IntegerValue(0)); + std::vector ubs(28, IntegerValue(99)); + ubs[8] = 17; + std::vector solution = + ToIntegerValues({0, 3, 0, 2, 2, 2, 0, 1, 5, 1, 1, 1, 1, 2, + 0, 2, 1, 3, 1, 1, 4, 1, 6, 2, 3, 0, 1, 1}); + + EXPECT_EQ(coeffs.size(), vars.size()); + EXPECT_EQ(lp_values.size(), vars.size()); + EXPECT_EQ(lbs.size(), vars.size()); + EXPECT_EQ(ubs.size(), vars.size()); + EXPECT_EQ(solution.size(), vars.size()); + + // The solution is a valid integer solution of the inequality. + { + IntegerValue activity(0); + for (int i = 0; i < vars.size(); ++i) { + activity += solution[i] * coeffs[i]; + } + EXPECT_LE(activity, rhs); + } + + CutData data; + data.FillFromParallelVectors(rhs, vars, coeffs, lp_values, lbs, ubs); + IntegerRoundingCutHelper helper; + + // TODO(user): Actually this fail, so we don't compute a cut here. + EXPECT_FALSE(helper.ComputeCut(options, data, nullptr)); +} + +void InitializeLpValues(absl::Span values, Model* model) { + auto* lp_values = model->GetOrCreate(); + lp_values->resize(2 * values.size()); + for (int i = 0; i < values.size(); ++i) { + (*lp_values)[IntegerVariable(2 * i)] = values[i]; + (*lp_values)[IntegerVariable(2 * i + 1)] = -values[i]; + } +} + +TEST(SquareCutGeneratorTest, TestBelowCut) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(0, 25)); + InitializeLpValues({2.0, 12.0}, &model); + + CutGenerator square = CreateSquareCutGenerator(y, x, 1, &model); + auto* manager = model.GetOrCreate(); + square.generate_cuts(manager); + EXPECT_EQ(1, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("-5*X0 1*X1 <= 0")); +} + +TEST(SquareCutGeneratorTest, TestBelowCutWithOffset) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(1, 25)); + InitializeLpValues({2.0, 12.0}, &model); + + CutGenerator square = CreateSquareCutGenerator(y, x, 1, &model); + auto* manager = model.GetOrCreate(); + square.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("-6*X0 1*X1 <= -5")); +} + +TEST(SquareCutGeneratorTest, TestNoBelowCut) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(1, 25)); + InitializeLpValues({2.0, 6.0}, &model); + + CutGenerator square = CreateSquareCutGenerator(y, x, 1, &model); + auto* manager = model.GetOrCreate(); + square.generate_cuts(manager); + ASSERT_EQ(0, manager->num_cuts()); +} + +TEST(SquareCutGeneratorTest, TestAboveCut) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(1, 25)); + InitializeLpValues({2.5, 6.25}, &model); + + CutGenerator square = CreateSquareCutGenerator(y, x, 1, &model); + auto* manager = model.GetOrCreate(); + square.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + StartsWith("-6 <= -5*X0 1*X1")); +} + +TEST(SquareCutGeneratorTest, TestNearlyAboveCut) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(1, 25)); + InitializeLpValues({2.4, 5.99999}, &model); + + CutGenerator square = CreateSquareCutGenerator(y, x, 1, &model); + auto* manager = model.GetOrCreate(); + square.generate_cuts(manager); + ASSERT_EQ(0, manager->num_cuts()); +} + +TEST(MultiplicationCutGeneratorTest, TestCut1) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(2, 3)); + IntegerVariable z = model.Add(NewIntegerVariable(1, 15)); + InitializeLpValues({1.2, 2.1, 2.1}, &model); + + CutGenerator mult = + CreatePositiveMultiplicationCutGenerator(z, x, y, 1, &model); + auto* manager = model.GetOrCreate(); + mult.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("2*X0 1*X1 -1*X2 <= 2")); +} + +TEST(MultiplicationCutGeneratorTest, TestCut2) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(2, 3)); + IntegerVariable z = model.Add(NewIntegerVariable(1, 15)); + InitializeLpValues({4.9, 2.8, 12.0}, &model); + + CutGenerator mult = + CreatePositiveMultiplicationCutGenerator(z, x, y, 1, &model); + auto* manager = model.GetOrCreate(); + mult.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("3*X0 5*X1 -1*X2 <= 15")); +} + +TEST(MultiplicationCutGeneratorTest, TestCut3) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 5)); + IntegerVariable y = model.Add(NewIntegerVariable(2, 3)); + IntegerVariable z = model.Add(NewIntegerVariable(1, 15)); + InitializeLpValues({1.2, 2.1, 4.4}, &model); + + CutGenerator mult = + CreatePositiveMultiplicationCutGenerator(z, x, y, 1, &model); + auto* manager = model.GetOrCreate(); + mult.generate_cuts(manager); + ASSERT_EQ(2, manager->num_cuts()); + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + StartsWith("3 <= 3*X0 1*X1 -1*X2")); + EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), + StartsWith("10 <= 2*X0 5*X1 -1*X2")); +} + +TEST(MultiplicationCutGeneratorTest, TestNoCut1) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 50)); + IntegerVariable y = model.Add(NewIntegerVariable(2, 30)); + IntegerVariable z = model.Add(NewIntegerVariable(1, 1500)); + InitializeLpValues({40.0, 20.0, 799.0}, &model); + + CutGenerator mult = + CreatePositiveMultiplicationCutGenerator(z, x, y, 1, &model); + auto* manager = model.GetOrCreate(); + mult.generate_cuts(manager); + ASSERT_EQ(0, manager->num_cuts()); +} + +TEST(MultiplicationCutGeneratorTest, TestNoCut2) { + Model model; + IntegerVariable x = model.Add(NewIntegerVariable(1, 50)); + IntegerVariable y = model.Add(NewIntegerVariable(2, 30)); + IntegerVariable z = model.Add(NewIntegerVariable(1, 1500)); + InitializeLpValues({40.0, 20.0, 801.0}, &model); + + CutGenerator mult = + CreatePositiveMultiplicationCutGenerator(z, x, y, 1, &model); + auto* manager = model.GetOrCreate(); + mult.generate_cuts(manager); + ASSERT_EQ(0, manager->num_cuts()); +} + +TEST(AllDiffCutGeneratorTest, TestCut) { + Model model; + Domain domain(10); + domain = domain.UnionWith(Domain(15)); + domain = domain.UnionWith(Domain(25)); + IntegerVariable x = model.Add(NewIntegerVariable(domain)); + IntegerVariable y = model.Add(NewIntegerVariable(domain)); + IntegerVariable z = model.Add(NewIntegerVariable(domain)); + InitializeLpValues({15.0, 15.0, 15.0}, &model); + + CutGenerator all_diff = CreateAllDifferentCutGenerator({x, y, z}, &model); + auto* manager = model.GetOrCreate(); + all_diff.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + EXPECT_EQ(manager->AllConstraints().front().constraint.DebugString(), + "50 <= 1*X0 1*X1 1*X2 <= 50"); +} + +TEST(AllDiffCutGeneratorTest, TestCut2) { + Model model; + Domain domain(10); + domain = domain.UnionWith(Domain(15)); + domain = domain.UnionWith(Domain(25)); + IntegerVariable x = model.Add(NewIntegerVariable(domain)); + IntegerVariable y = model.Add(NewIntegerVariable(domain)); + IntegerVariable z = model.Add(NewIntegerVariable(domain)); + InitializeLpValues({13.0, 10.0, 12.0}, &model); + + CutGenerator all_diff = CreateAllDifferentCutGenerator({x, y, z}, &model); + auto* manager = model.GetOrCreate(); + all_diff.generate_cuts(manager); + ASSERT_EQ(2, manager->num_cuts()); + EXPECT_EQ(manager->AllConstraints().front().constraint.DebugString(), + "25 <= 1*X1 1*X2 <= 40"); + EXPECT_EQ(manager->AllConstraints().back().constraint.DebugString(), + "50 <= 1*X0 1*X1 1*X2 <= 50"); +} + +// We model the maximum of 3 affine functions: +// f0(x) = 1 +// f1(x) = -x0 - 2x1 +// f2(x) = -x0 + x1 +// over the box domain -1 <= x0, x1 <= 1. For this data, there are 9 possible +// maximum corner cuts. I denote each by noting which function f^i each input +// variable x_j gets assigned: +// (1) x0 -> f0, x1 -> f0: y <= 0x0 + 0x1 + 1z_0 + 3z_1 + 2z_2 +// (2) x0 -> f0, x1 -> f1: y <= 0x0 - 2x1 + 3z_0 + 1z_1 + 4z_2 +// (3) x0 -> f0, x1 -> f2: y <= 0x0 + x1 + 2z_0 + 4z_1 + 1z_2 +// (4) x0 -> f1, x1 -> f0: y <= -x0 + 0x1 + 2z_0 + 2z_1 + 1z_2 +// (5) x0 -> f1, x1 -> f1: y <= -x0 - 2x1 + 4z_0 + 0z_1 + 3z_2 +// (6) x0 -> f1, x1 -> f2: y <= -x0 + x1 + 3z_0 + 3z_1 + 0z_2 +// (7) x0 -> f2, x1 -> f0: y <= -x0 + 0x1 + 2z_0 + 2z_1 + 1z_2 +// (8) x0 -> f2, x1 -> f1: y <= -x0 - 2x1 + 4z_0 + 0z_1 + 3z_2 +// (9) x0 -> f2, x1 -> f2: y <= -x0 + x1 + 3z_0 + 3z_1 + 0z_2 +TEST(LinMaxCutsTest, BasicCuts1) { + Model model; + IntegerVariable x0 = model.Add(NewIntegerVariable(-1, 1)); + IntegerVariable x1 = model.Add(NewIntegerVariable(-1, 1)); + IntegerVariable target = model.Add(NewIntegerVariable(-100, 100)); + LinearExpression f0; + f0.offset = IntegerValue(1); + LinearExpression f1; + f1.vars = {x0, x1}; + f1.coeffs = {IntegerValue(-1), IntegerValue(-2)}; + LinearExpression f2; + f2.vars = {x0, x1}; + f2.coeffs = {IntegerValue(-1), IntegerValue(1)}; + + std::vector exprs = {f0, f1, f2}; + std::vector z_vars; + for (int i = 0; i < exprs.size(); ++i) { + IntegerVariable z = model.Add(NewIntegerVariable(0, 1)); + z_vars.push_back(z); + } + + CutGenerator max_cuts = + CreateLinMaxCutGenerator(target, exprs, z_vars, &model); + + auto* manager = model.GetOrCreate(); + InitializeLpValues({-1.0, 1.0, 2.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0}, &model); + + max_cuts.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + + // x vars are X0,X1 respectively, target is X2, z_vars are X3,X4,X5 + // respectively. + // Most violated inequality is 2. + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + StartsWith("0 <= -2*X1 -1*X2 3*X3 1*X4 4*X5")); + + InitializeLpValues({-1.0, -1.0, 2.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0}, + &model); + max_cuts.generate_cuts(manager); + ASSERT_EQ(2, manager->num_cuts()); + // Most violated inequality is 3. + EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), + StartsWith("0 <= 1*X1 -1*X2 2*X3 4*X4 1*X5")); +} + +// We model the maximum of 3 affine functions: +// f0(x) = 1 +// f1(x) = x +// f2(x) = -x +// target = max(f0, f1, f2) +// x in [-10, 10] +TEST(LinMaxCutsTest, AffineCuts1) { + Model model; + const IntegerValue zero(0); + const IntegerValue one(1); + IntegerVariable x = model.Add(NewIntegerVariable(-10, 10)); + IntegerVariable target = model.Add(NewIntegerVariable(1, 100)); + LinearExpression target_expr; + target_expr.vars.push_back(target); + target_expr.coeffs.push_back(one); + + std::vector> affines = { + {zero, one}, {one, zero}, {-one, zero}}; + + LinearConstraintBuilder builder(&model); + ASSERT_TRUE( + BuildMaxAffineUpConstraint(target_expr, x, affines, &model, &builder)); + + // Note, the cut is not normalized. + EXPECT_EQ(builder.Build().DebugString(), "20*X1 <= 200"); +} + +// We model the maximum of 3 affine functions: +// f0(x) = 1 +// f1(x) = x +// f2(x) = -x +// target = max(f0, f1, f2) +// x in [-1, 10] +TEST(LinMaxCutsTest, AffineCuts2) { + Model model; + const IntegerValue zero(0); + const IntegerValue one(1); + IntegerVariable x = model.Add(NewIntegerVariable(-1, 10)); + IntegerVariable target = model.Add(NewIntegerVariable(1, 100)); + LinearExpression target_expr; + target_expr.vars.push_back(target); + target_expr.coeffs.push_back(one); + + std::vector> affines = { + {zero, one}, {one, zero}, {-one, zero}}; + + LinearConstraintBuilder builder(&model); + ASSERT_TRUE( + BuildMaxAffineUpConstraint(target_expr, x, affines, &model, &builder)); + + EXPECT_EQ(builder.Build().DebugString(), "-9*X0 11*X1 <= 20"); +} + +// We model the maximum of 3 affine functions: +// f0(x) = 1 +// f1(x) = x +// f2(x) = -x +// target = max(f0, f1, f2) +// x fixed +TEST(LinMaxCutsTest, AffineCutsFixedVar) { + Model model; + const IntegerValue zero(0); + const IntegerValue one(1); + IntegerVariable x = model.Add(NewIntegerVariable(2, 2)); + IntegerVariable target = model.Add(NewIntegerVariable(0, 100)); + LinearExpression target_expr; + target_expr.vars.push_back(target); + target_expr.coeffs.push_back(one); + + std::vector> affines = { + {zero, one}, {one, zero}, {-one, zero}}; + + CutGenerator max_cuts = + CreateMaxAffineCutGenerator(target_expr, x, affines, "test", &model); + + auto* manager = model.GetOrCreate(); + InitializeLpValues({2.0, 8.0}, &model); + max_cuts.generate_cuts(manager); + EXPECT_EQ(0, manager->num_cuts()); +} + +TEST(ImpliedBoundsProcessorTest, PositiveBasicTest) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const BooleanVariable b = model.Add(NewBooleanVariable()); + const IntegerVariable b_view = model.Add(NewIntegerVariable(0, 1)); + const IntegerVariable x = model.Add(NewIntegerVariable(2, 9)); + + auto* integer_encoder = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + auto* implied_bounds = model.GetOrCreate(); + + integer_encoder->AssociateToIntegerEqualValue(Literal(b, true), b_view, + IntegerValue(1)); + implied_bounds->Add(Literal(b, true), + IntegerLiteral::GreaterOrEqual(x, IntegerValue(5))); + + // Lp solution. + ImpliedBoundsProcessor processor({x, b_view}, integer_trail, implied_bounds); + + util_intops::StrongVector lp_values(1000); + lp_values[x] = 4.0; + lp_values[b_view] = 2.0 / 3.0; // 2.0 + b_view_value * (5-2) == 4.0 + processor.RecomputeCacheAndSeparateSomeImpliedBoundCuts(lp_values); + + // Lets look at the term X. + CutData data; + + CutTerm X; + X.coeff = 1; + X.lp_value = 2.0; + X.bound_diff = 7; + X.expr_vars[0] = x; + X.expr_coeffs[0] = 1; + X.expr_coeffs[1] = 0; + X.expr_offset = -2; + data.terms.push_back(X); + + processor.CacheDataForCut(IntegerVariable(100), &data); + const IntegerValue t(1); + std::vector new_terms; + EXPECT_TRUE(processor.TryToExpandWithLowerImpliedbound( + t, /*complement=*/false, &data.terms[0], &data.rhs, &new_terms)); + + EXPECT_EQ(0, processor.MutableCutBuilder()->AddOrMergeBooleanTerms( + absl::MakeSpan(new_terms), t, &data)); + + EXPECT_EQ(data.terms.size(), 2); + EXPECT_THAT(data.terms[0].DebugString(), + ::testing::StartsWith("coeff=1 lp=0 range=7")); + EXPECT_THAT(data.terms[1].DebugString(), + ::testing::StartsWith("coeff=3 lp=0.666667 range=1")); + EXPECT_EQ(data.terms[1].expr_offset, 0); +} + +// Same as above but with b.Negated() +TEST(ImpliedBoundsProcessorTest, NegativeBasicTest) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const BooleanVariable b = model.Add(NewBooleanVariable()); + const IntegerVariable b_view = model.Add(NewIntegerVariable(0, 1)); + const IntegerVariable x = model.Add(NewIntegerVariable(2, 9)); + + auto* integer_encoder = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + auto* implied_bounds = model.GetOrCreate(); + + integer_encoder->AssociateToIntegerEqualValue(Literal(b, true), b_view, + IntegerValue(1)); + implied_bounds->Add(Literal(b, false), // False here. + IntegerLiteral::GreaterOrEqual(x, IntegerValue(5))); + + // Lp solution. + ImpliedBoundsProcessor processor({x, b_view}, integer_trail, implied_bounds); + + util_intops::StrongVector lp_values(1000); + lp_values[x] = 4.0; + lp_values[b_view] = 1.0 - 2.0 / 3.0; // 1 - value above. + processor.RecomputeCacheAndSeparateSomeImpliedBoundCuts(lp_values); + + // Lets look at the term X. + CutData data; + + CutTerm X; + X.coeff = 1; + X.lp_value = 2.0; + X.bound_diff = 7; + X.expr_vars[0] = x; + X.expr_coeffs[0] = 1; + X.expr_coeffs[1] = 0; + X.expr_offset = -2; + data.terms.push_back(X); + + processor.CacheDataForCut(IntegerVariable(100), &data); + + const IntegerValue t(1); + std::vector new_terms; + EXPECT_TRUE(processor.TryToExpandWithLowerImpliedbound( + t, /*complement=*/false, &data.terms[0], &data.rhs, &new_terms)); + EXPECT_EQ(0, processor.MutableCutBuilder()->AddOrMergeBooleanTerms( + absl::MakeSpan(new_terms), t, &data)); + + EXPECT_EQ(data.terms.size(), 2); + EXPECT_THAT(data.terms[0].DebugString(), + ::testing::StartsWith("coeff=1 lp=0 range=7")); + EXPECT_THAT(data.terms[1].DebugString(), + ::testing::StartsWith("coeff=3 lp=0.666667 range=1")); + + // This is the only change, we have 1 - bool there actually. + EXPECT_EQ(data.terms[1].expr_offset, 1); + EXPECT_EQ(data.terms[1].expr_coeffs[0], -1); + EXPECT_EQ(data.terms[1].expr_vars[0], b_view); +} + +TEST(ImpliedBoundsProcessorTest, DecompositionTest) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const BooleanVariable b = model.Add(NewBooleanVariable()); + const IntegerVariable b_view = model.Add(NewIntegerVariable(0, 1)); + const BooleanVariable c = model.Add(NewBooleanVariable()); + const IntegerVariable c_view = model.Add(NewIntegerVariable(0, 1)); + const IntegerVariable x = model.Add(NewIntegerVariable(2, 9)); + + auto* integer_encoder = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + auto* implied_bounds = model.GetOrCreate(); + + integer_encoder->AssociateToIntegerEqualValue(Literal(b, true), b_view, + IntegerValue(1)); + integer_encoder->AssociateToIntegerEqualValue(Literal(c, true), c_view, + IntegerValue(1)); + implied_bounds->Add(Literal(b, true), + IntegerLiteral::GreaterOrEqual(x, IntegerValue(5))); + implied_bounds->Add(Literal(c, true), + IntegerLiteral::LowerOrEqual(x, IntegerValue(2))); + + // Lp solution. + ImpliedBoundsProcessor processor({x, b_view, c_view}, integer_trail, + implied_bounds); + + util_intops::StrongVector lp_values(1000); + lp_values[x] = 4.0; + lp_values[NegationOf(x)] = -4.0; + lp_values[b_view] = 2.0 / 3.0; // 2.0 + b_view_value * (5-2) == 4.0 + lp_values[c_view] = 0.5; + processor.RecomputeCacheAndSeparateSomeImpliedBoundCuts(lp_values); + + // Lets look at the term X. + CutTerm X; + X.coeff = 1; + X.lp_value = 2.0; + X.bound_diff = 7; + X.expr_vars[0] = x; + X.expr_coeffs[0] = 1; + X.expr_coeffs[1] = 0; + X.expr_offset = -2; + + CutData data; + data.terms.push_back(X); + processor.CacheDataForCut(IntegerVariable(100), &data); + X = data.terms[0]; + + // X - 2 = 3 * B + slack; + CutTerm bool_term; + CutTerm slack_term; + EXPECT_TRUE(processor.DecomposeWithImpliedLowerBound(X, IntegerValue(1), + bool_term, slack_term)); + EXPECT_THAT(bool_term.DebugString(), + ::testing::StartsWith("coeff=3 lp=0.666667 range=1")); + EXPECT_THAT(slack_term.DebugString(), + ::testing::StartsWith("coeff=1 lp=0 range=7")); + + // (9 - X) = 7 * C + slack; + CutTerm Y = X; + absl::int128 unused; + Y.Complement(&unused); + Y.coeff = -Y.coeff; + EXPECT_TRUE(processor.DecomposeWithImpliedLowerBound(Y, IntegerValue(1), + bool_term, slack_term)); + EXPECT_THAT(bool_term.DebugString(), + ::testing::StartsWith("coeff=7 lp=0.5 range=1")); + EXPECT_THAT(slack_term.DebugString(), + ::testing::StartsWith("coeff=1 lp=1.5 range=7")); + + // X - 2 = 7 * (1 - C) - slack; + EXPECT_TRUE(processor.DecomposeWithImpliedUpperBound(X, IntegerValue(1), + bool_term, slack_term)); + EXPECT_THAT(bool_term.DebugString(), + ::testing::StartsWith("coeff=7 lp=0.5 range=1")); + EXPECT_THAT(slack_term.DebugString(), + ::testing::StartsWith("coeff=-1 lp=1.5 range=7")); +} + +TEST(CutDataTest, SimpleExample) { + Model model; + const IntegerVariable x0 = model.Add(NewIntegerVariable(7, 10)); + const IntegerVariable x1 = model.Add(NewIntegerVariable(-3, 20)); + + // 6x0 - 4x1 <= 9. + const IntegerValue rhs = IntegerValue(9); + std::vector vars = {x0, x1}; + std::vector coeffs = {IntegerValue(6), IntegerValue(-4)}; + std::vector lp_values = {7.5, 4.5}; + + CutData cut; + std::vector lbs; + std::vector ubs; + auto* integer_trail = model.Get(); + for (int i = 0; i < vars.size(); ++i) { + lbs.push_back(integer_trail->LowerBound(vars[i])); + ubs.push_back(integer_trail->UpperBound(vars[i])); + } + cut.FillFromParallelVectors(rhs, vars, coeffs, lp_values, lbs, ubs); + cut.ComplementForSmallerLpValues(); + + // 6 (X0' + 7) - 4 (X1' - 3) <= 9 + ASSERT_EQ(cut.terms.size(), 2); + EXPECT_EQ(cut.rhs, 9 - 4 * 3 - 6 * 7); + EXPECT_EQ(cut.terms[0].coeff, 6); + EXPECT_EQ(cut.terms[0].lp_value, 0.5); + EXPECT_EQ(cut.terms[0].bound_diff, 3); + EXPECT_EQ(cut.terms[1].coeff, -4); + EXPECT_EQ(cut.terms[1].lp_value, 7.5); + EXPECT_EQ(cut.terms[1].bound_diff, 23); + + // Lets complement. + const absl::int128 old_rhs = cut.rhs; + cut.terms[0].Complement(&cut.rhs); + EXPECT_EQ(cut.rhs, old_rhs - 3 * 6); + EXPECT_EQ(cut.terms[0].coeff, -6); + EXPECT_EQ(cut.terms[0].lp_value, 3 - 0.5); + EXPECT_EQ(cut.terms[0].bound_diff, 3); + + // Encode back. + LinearConstraint new_constraint; + CutDataBuilder builder; + EXPECT_TRUE(builder.ConvertToLinearConstraint(cut, &new_constraint)); + + // We have a division by GCD in there. + const IntegerValue gcd = 2; + EXPECT_EQ(vars.size(), new_constraint.num_terms); + for (int i = 0; i < new_constraint.num_terms; ++i) { + EXPECT_EQ(vars[i], new_constraint.vars[i]); + EXPECT_EQ(coeffs[i] / gcd, new_constraint.coeffs[i]); + } +} + +TEST(SumOfAllDiffLowerBounderTest, ContinuousVariables) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x1 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x2 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x3 = model.Add(NewIntegerVariable(1, 10)); + + SumOfAllDiffLowerBounder helper; + helper.Add(x1, 3, *integer_trail); + helper.Add(x2, 3, *integer_trail); + helper.Add(x3, 3, *integer_trail); + EXPECT_EQ(3, helper.size()); + EXPECT_EQ(6, helper.SumOfMinDomainValues()); + EXPECT_EQ(6, helper.SumOfDifferentMins()); + std::string suffix; + EXPECT_EQ(6, helper.GetBestLowerBound(suffix)); + EXPECT_EQ("e", suffix); + helper.Clear(); + EXPECT_EQ(0, helper.size()); +} + +TEST(SumOfAllDiffLowerBounderTest, DisjointVariables) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x1 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x2 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x3 = model.Add(NewIntegerVariable(1, 10)); + + SumOfAllDiffLowerBounder helper; + helper.Add(x1, 3, *integer_trail); + helper.Add(x2, 3, *integer_trail); + helper.Add(AffineExpression(x3, 1, 10), 3, *integer_trail); + EXPECT_EQ(3, helper.size()); + EXPECT_EQ(6, helper.SumOfMinDomainValues()); + EXPECT_EQ(14, helper.SumOfDifferentMins()); + std::string suffix; + EXPECT_EQ(14, helper.GetBestLowerBound(suffix)); + EXPECT_EQ("a", suffix); +} + +TEST(SumOfAllDiffLowerBounderTest, DiscreteDomains) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x1 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x2 = model.Add(NewIntegerVariable(1, 10)); + IntegerVariable x3 = model.Add(NewIntegerVariable(1, 10)); + + SumOfAllDiffLowerBounder helper; + helper.Add(AffineExpression(x1, 3, 0), 3, *integer_trail); + helper.Add(AffineExpression(x2, 3, 0), 3, *integer_trail); + helper.Add(AffineExpression(x3, 3, 0), 3, *integer_trail); + EXPECT_EQ(3, helper.size()); + EXPECT_EQ(18, helper.SumOfMinDomainValues()); + EXPECT_EQ(12, helper.SumOfDifferentMins()); + std::string suffix; + EXPECT_EQ(18, helper.GetBestLowerBound(suffix)); + EXPECT_EQ("d", suffix); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 7faf54896ab..d4617abad7e 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -31,6 +31,7 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/sat/2d_orthogonal_packing.h" +#include "ortools/sat/2d_try_edge_propagator.h" #include "ortools/sat/cumulative_energy.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/disjunctive.h" @@ -233,6 +234,10 @@ void AddNonOverlappingRectangles(const std::vector& x, watcher->SetPropagatorPriority(energy_constraint->RegisterWith(watcher), 5); model->TakeOwnership(energy_constraint); } + + if (params.use_try_edge_reasoning_in_no_overlap_2d()) { + CreateAndRegisterTryEdgePropagator(x_helper, y_helper, model, watcher); + } } #define RETURN_IF_FALSE(f) \ diff --git a/ortools/sat/diffn_test.cc b/ortools/sat/diffn_test.cc new file mode 100644 index 00000000000..a46d78d29f1 --- /dev/null +++ b/ortools/sat/diffn_test.cc @@ -0,0 +1,176 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/diffn.h" + +#include + +#include + +#include "absl/strings/str_join.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/cp_model.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +// Counts how many ways we can put two square of minimal size 1 in an n x n +// square. +// +// For n = 1, infeasible. +// For n = 2, should be 4 * 3. +// For n = 3: +// - 9 * 8 for two size 1. +// - 4 * 5 for size 2 + size 1. Times 2 for the permutation. +int CountAllTwoBoxesSolutions(int n) { + Model model; + std::vector x; + std::vector y; + for (int i = 0; i < 2; ++i) { + // Create a square shaped box of minimum size 1. + const IntegerVariable size = model.Add(NewIntegerVariable(1, n)); + x.push_back( + model.Add(NewInterval(model.Add(NewIntegerVariable(0, n)), + model.Add(NewIntegerVariable(0, n)), size))); + y.push_back( + model.Add(NewInterval(model.Add(NewIntegerVariable(0, n)), + model.Add(NewIntegerVariable(0, n)), size))); + } + + // The cumulative relaxation adds extra variables that are not complextly + // fixed. So to not count too many solution with our code here, we disable + // that. Note that alternativelly, we could have used the cp_model.proto API + // to do the same, and that should works even with this on. + AddNonOverlappingRectangles(x, y, &model); + + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); + auto start_value = [repository, integer_trail](IntervalVariable i) { + return integer_trail->LowerBound(repository->Start(i)).value(); + }; + auto end_value = [repository, integer_trail](IntervalVariable i) { + return integer_trail->LowerBound(repository->End(i)).value(); + }; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Display the first few solutions. + if (num_solutions_found < 30) { + LOG(INFO) << "R1: " << start_value(x[0]) << "," << start_value(y[0]) + << " " << end_value(x[0]) << "," << end_value(y[0]) + << " R2: " << start_value(x[1]) << "," << start_value(y[1]) + << " " << end_value(x[1]) << "," << end_value(y[1]); + } + + num_solutions_found++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + return num_solutions_found; +} + +TEST(NonOverlappingRectanglesTest, SimpleCounting) { + EXPECT_EQ(CountAllTwoBoxesSolutions(1), 0); + EXPECT_EQ(CountAllTwoBoxesSolutions(2), 3 * 4); + EXPECT_EQ(CountAllTwoBoxesSolutions(3), 9 * 8 + 4 * 5 * 2); + EXPECT_EQ(CountAllTwoBoxesSolutions(4), + /*2 1x1 square*/ 16 * 15 + + /*2 2x2 square*/ 2 * (5 + 3 + 4 + 4) + + /*3x3 and 1x1*/ 2 * 4 * 7 + + /*2x2 amd 1x1*/ 2 * 9 * 12); +} + +TEST(NonOverlappingRectanglesTest, SimpleCountingWithOptional) { + Model model; + IntervalsRepository* interval_repository = + model.GetOrCreate(); + std::vector x; + std::vector y; + const Literal l1(model.Add(NewBooleanVariable()), true); + x.push_back(interval_repository->CreateInterval( + IntegerValue(0), IntegerValue(5), IntegerValue(5), l1.Index(), false)); + y.push_back(interval_repository->CreateInterval( + IntegerValue(0), IntegerValue(2), IntegerValue(2), l1.Index(), false)); + + const Literal l2(model.Add(NewBooleanVariable()), true); + x.push_back(interval_repository->CreateInterval( + IntegerValue(4), IntegerValue(6), IntegerValue(2), l2.Index(), false)); + y.push_back(interval_repository->CreateInterval( + IntegerValue(3), IntegerValue(4), IntegerValue(1), l2.Index(), false)); + + // The cumulative relaxation adds extra variables that are not completely + // fixed. So to not count too many solution with our code here, we disable + // that. Note that alternatively, we could have used the cp_model.proto API + // to do the same, and that should works even with this on. + // TODO(user): Fix and run with add_cumulative_relaxation = true. + AddNonOverlappingRectangles(x, y, &model); + + int num_solutions_found = 0; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Display the first few solutions. + if (num_solutions_found < 30) { + LOG(INFO) << "R1: " << interval_repository->IsPresent(x[0]) << " " + << " R2: " << interval_repository->IsPresent(x[1]) << " "; + } + + num_solutions_found++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + EXPECT_EQ(4, num_solutions_found); +} + +TEST(NonOverlappingRectanglesTest, CountSolutionsWithZeroAreaBoxes) { + CpModelBuilder cp_model; + IntVar v1 = cp_model.NewIntVar({1, 2}); + IntVar v2 = cp_model.NewIntVar({0, 1}); + IntervalVar x1 = cp_model.NewIntervalVar(2, v2, 2 + v2); + IntervalVar x2 = cp_model.NewFixedSizeIntervalVar(1, 2); + IntervalVar y1 = cp_model.NewIntervalVar(1, v1, v1 + 1); + IntervalVar y2 = cp_model.NewFixedSizeIntervalVar(2, 0); + NoOverlap2DConstraint diffn = cp_model.AddNoOverlap2D(); + diffn.AddRectangle(x1, y1); + diffn.AddRectangle(x2, y2); + + Model model; + model.Add(NewSatParameters("enumerate_all_solutions:true")); + int count = 0; + model.Add( + NewFeasibleSolutionObserver([&count](const CpSolverResponse& response) { + LOG(INFO) << absl::StrJoin(response.solution(), " "); + ++count; + })); + const CpSolverResponse response = SolveCpModel(cp_model.Build(), &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(count, 2); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index ade9c12322d..2c7d93d3659 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -51,7 +51,7 @@ bool Rectangle::IsDisjoint(const Rectangle& other) const { other.y_min >= y_max; } -absl::InlinedVector Rectangle::SetDifference( +absl::InlinedVector Rectangle::RegionDifference( const Rectangle& other) const { const Rectangle intersect = Intersect(other); if (intersect.SizeX() == 0) { @@ -155,8 +155,8 @@ bool ReportEnergyConflict(Rectangle bounding_box, absl::Span boxes, return x->ReportConflict(); } -bool BoxesAreInEnergyConflict(const std::vector& rectangles, - const std::vector& energies, +bool BoxesAreInEnergyConflict(absl::Span rectangles, + absl::Span energies, absl::Span boxes, Rectangle* conflict) { // First consider all relevant intervals along the x axis. @@ -412,11 +412,6 @@ absl::Span FilterBoxesThatAreTooLarge( return boxes.subspan(0, new_size); } -std::ostream& operator<<(std::ostream& out, const IndexedInterval& interval) { - return out << "[" << interval.start << ".." << interval.end << " (#" - << interval.index << ")]"; -} - void ConstructOverlappingSets(bool already_sorted, std::vector* intervals, std::vector>* result) { @@ -1547,18 +1542,17 @@ std::string RenderDot(std::optional bb, std::stringstream ss; ss << "digraph {\n"; ss << " graph [ bgcolor=lightgray ]\n"; - ss << " node [style=filled]\n"; + ss << " node [style=filled shape=box]\n"; if (bb.has_value()) { ss << " bb [fillcolor=\"grey\" pos=\"" << 2 * bb->x_min + bb->SizeX() - << "," << 2 * bb->y_min + bb->SizeY() - << "!\" shape=box width=" << 2 * bb->SizeX() + << "," << 2 * bb->y_min + bb->SizeY() << "!\" width=" << 2 * bb->SizeX() << " height=" << 2 * bb->SizeY() << "]\n"; } for (int i = 0; i < solution.size(); ++i) { ss << " " << i << " [fillcolor=\"" << colors[i % colors.size()] << "\" pos=\"" << 2 * solution[i].x_min + solution[i].SizeX() << "," << 2 * solution[i].y_min + solution[i].SizeY() - << "!\" shape=box width=" << 2 * solution[i].SizeX() + << "!\" width=" << 2 * solution[i].SizeX() << " height=" << 2 * solution[i].SizeY() << "]\n"; } ss << extra_dot_payload; @@ -1568,27 +1562,30 @@ std::string RenderDot(std::optional bb, std::vector FindEmptySpaces( const Rectangle& bounding_box, std::vector ocupied_rectangles) { - std::vector empty_spaces = {bounding_box}; - std::vector new_empty_spaces; // Sorting is not necessary for correctness but makes it faster. std::sort(ocupied_rectangles.begin(), ocupied_rectangles.end(), [](const Rectangle& a, const Rectangle& b) { return std::tuple(a.x_min, -a.x_max, a.y_min) < std::tuple(b.x_min, -b.x_max, b.y_min); }); - for (const Rectangle& ocupied_rectangle : ocupied_rectangles) { - new_empty_spaces.clear(); - for (const auto& empty_space : empty_spaces) { - for (Rectangle& r : empty_space.SetDifference(ocupied_rectangle)) { - new_empty_spaces.push_back(std::move(r)); - } - } - empty_spaces.swap(new_empty_spaces); - if (empty_spaces.empty()) { - break; + return PavedRegionDifference({bounding_box}, ocupied_rectangles); +} + +std::vector PavedRegionDifference( + std::vector original_region, + absl::Span area_to_remove) { + std::vector new_area_to_cover; + for (const Rectangle& rectangle : area_to_remove) { + new_area_to_cover.clear(); + for (const Rectangle& r : original_region) { + const auto& new_rectangles = r.RegionDifference(rectangle); + new_area_to_cover.insert(new_area_to_cover.end(), new_rectangles.begin(), + new_rectangles.end()); } + original_region.swap(new_area_to_cover); + if (original_region.empty()) break; } - return empty_spaces; + return original_region; } } // namespace sat diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index 5b93c82621d..0fc5df8ee2b 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -63,7 +64,8 @@ struct Rectangle { // Returns `this \ other` as a set of disjoint rectangles of non-empty area. // The resulting vector will have at most four elements. - absl::InlinedVector SetDifference(const Rectangle& other) const; + absl::InlinedVector RegionDifference( + const Rectangle& other) const; template friend void AbslStringify(Sink& sink, const Rectangle& r) { @@ -76,6 +78,8 @@ struct Rectangle { std::tie(other.x_min, other.x_max, other.y_min, other.y_max); } + bool operator!=(const Rectangle& other) const { return !(other == *this); } + static Rectangle GetEmpty() { return Rectangle{.x_min = IntegerValue(0), .x_max = IntegerValue(0), @@ -123,8 +127,8 @@ std::vector> GetOverlappingRectangleComponents( // Visible for testing. The algo is in O(n^4) so shouldn't be used directly. // Returns true if there exist a bounding box with too much energy. -bool BoxesAreInEnergyConflict(const std::vector& rectangles, - const std::vector& energies, +bool BoxesAreInEnergyConflict(absl::Span rectangles, + absl::Span energies, absl::Span boxes, Rectangle* conflict = nullptr); @@ -190,8 +194,13 @@ struct IndexedInterval { return a.start < b.start; } }; + + template + friend void AbslStringify(Sink& sink, const IndexedInterval& interval) { + absl::Format(&sink, "[%v..%v] (#%v)", interval.start, interval.end, + interval.index); + } }; -std::ostream& operator<<(std::ostream& out, const IndexedInterval& interval); // Given n fixed intervals, returns the subsets of intervals that overlap during // at least one time unit. Note that we only return "maximal" subset and filter @@ -435,6 +444,18 @@ struct RectangleInRange { containing_area.y_max); } + Rectangle GetMandatoryRegion() const { + // Weird math to avoid overflow. + if (bounding_area.SizeX() - x_size >= x_size || + bounding_area.SizeY() - y_size >= y_size) { + return Rectangle::GetEmpty(); + } + return Rectangle{.x_min = bounding_area.x_max - x_size, + .x_max = bounding_area.x_min + x_size, + .y_min = bounding_area.y_max - y_size, + .y_max = bounding_area.y_min + y_size}; + } + static RectangleInRange BiggestWithMinIntersection( const Rectangle& containing_area, const RectangleInRange& original, const IntegerValue& min_intersect_x, @@ -608,6 +629,20 @@ std::string RenderDot(std::optional bb, std::vector FindEmptySpaces( const Rectangle& bounding_box, std::vector ocupied_rectangles); +// Given two regions, each one of them defined by a vector of non-overlapping +// rectangles paving them, returns a vector of non-overlapping rectangles that +// paves the points that were part of the first region but not of the second. +// This can also be seen as the set difference of the points of the regions. +std::vector PavedRegionDifference( + std::vector original_region, + absl::Span area_to_remove); + +// The two regions must be defined by non-overlapping rectangles. +inline bool RegionIncludesOther(absl::Span region, + absl::Span other) { + return PavedRegionDifference({other.begin(), other.end()}, region).empty(); +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc new file mode 100644 index 00000000000..b0d9a112797 --- /dev/null +++ b/ortools/sat/diffn_util_test.cc @@ -0,0 +1,988 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/diffn_util.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/graph/connected_components.h" +#include "ortools/sat/2d_orthogonal_packing_testing.h" +#include "ortools/sat/integer.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; + +TEST(GetOverlappingRectangleComponentsTest, NoComponents) { + EXPECT_TRUE(GetOverlappingRectangleComponents({}, {}).empty()); + IntegerValue zero(0); + IntegerValue two(2); + IntegerValue four(4); + EXPECT_TRUE(GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, two, four}}, {}) + .empty()); + std::vector first = {0}; + EXPECT_TRUE(GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, two, four}}, + absl::MakeSpan(first)) + .empty()); + std::vector both = {0, 1}; + EXPECT_TRUE(GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, two, four}}, + absl::MakeSpan(both)) + .empty()); + EXPECT_TRUE(GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, zero, two}}, + absl::MakeSpan(both)) + .empty()); + EXPECT_TRUE(GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {zero, two, two, four}}, + absl::MakeSpan(both)) + .empty()); +} + +TEST(GetOverlappingRectangleComponentsTest, ComponentAndActive) { + EXPECT_TRUE(GetOverlappingRectangleComponents({}, {}).empty()); + IntegerValue zero(0); + IntegerValue one(1); + IntegerValue two(2); + IntegerValue three(3); + IntegerValue four(4); + + std::vector all = {0, 1, 2}; + const auto& components = GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {zero, two, one, three}, {zero, two, two, four}}, + absl::MakeSpan(all)); + ASSERT_EQ(1, components.size()); + EXPECT_EQ(3, components[0].size()); + + std::vector only_two = {0, 2}; + EXPECT_TRUE(GetOverlappingRectangleComponents({{zero, two, zero, two}, + {zero, two, one, three}, + {zero, two, two, four}}, + absl::MakeSpan(only_two)) + .empty()); +} + +TEST(AnalyzeIntervalsTest, Random) { + // Generate a random set of intervals until the first conflict. We are in n^5! + absl::BitGen random; + const int64_t size = 20; + std::vector rectangles; + std::vector energies; + std::vector boxes; + for (int i = 0; i < 40; ++i) { + Rectangle box; + box.x_min = IntegerValue(absl::Uniform(random, 0, size)); + box.x_max = + IntegerValue(absl::Uniform(random, box.x_min.value() + 1, size + 1)); + box.y_min = IntegerValue(absl::Uniform(random, 0, size)); + box.y_max = + IntegerValue(absl::Uniform(random, box.y_min.value() + 1, size + 1)); + rectangles.push_back(box); + boxes.push_back(i); + energies.push_back(IntegerValue(absl::Uniform( + random, 1, (box.x_max - box.x_min + 1).value())) * + IntegerValue(absl::Uniform( + random, 1, (box.y_max - box.y_min + 1).value()))); + + LOG(INFO) << i << " " << box << " energy:" << energies.back(); + Rectangle conflict; + if (!BoxesAreInEnergyConflict(rectangles, energies, boxes, &conflict)) { + continue; + } + + LOG(INFO) << "Conflict! " << conflict; + + // Make sure whatever filter we do, we do not remove the conflict. + absl::Span s = absl::MakeSpan(boxes); + IntegerValue threshold_x = kMaxIntegerValue; + IntegerValue threshold_y = kMaxIntegerValue; + for (int i = 0; i < 4; ++i) { + if (!AnalyzeIntervals(/*transpose=*/i % 2 == 1, s, rectangles, energies, + &threshold_x, &threshold_y)) { + LOG(INFO) << "Detected by analyse."; + return; + } + s = FilterBoxesAndRandomize(rectangles, s, threshold_x, threshold_y, + random); + LOG(INFO) << "Filtered size: " << s.size() << " x<=" << threshold_x + << " y<=" << threshold_y; + ASSERT_TRUE(BoxesAreInEnergyConflict(rectangles, energies, s)); + } + + break; + } +} + +TEST(FilterBoxesThatAreTooLargeTest, Empty) { + std::vector r; + std::vector energies; + std::vector boxes; + EXPECT_TRUE( + FilterBoxesThatAreTooLarge(r, energies, absl::MakeSpan(boxes)).empty()); +} + +TEST(FilterBoxesThatAreTooLargeTest, BasicTest) { + int num_boxes(3); + std::vector r(num_boxes); + std::vector energies(num_boxes, IntegerValue(25)); + std::vector boxes{0, 1, 2}; + + r[0] = {IntegerValue(0), IntegerValue(5), IntegerValue(0), IntegerValue(5)}; + r[1] = {IntegerValue(0), IntegerValue(10), IntegerValue(0), IntegerValue(10)}; + r[2] = {IntegerValue(0), IntegerValue(6), IntegerValue(0), IntegerValue(6)}; + + EXPECT_THAT(FilterBoxesThatAreTooLarge(r, energies, absl::MakeSpan(boxes)), + ElementsAre(0, 2)); +} + +TEST(ConstructOverlappingSetsTest, BasicTest) { + std::vector> result{{3}}; // To be sure we clear. + + // --------------------0 + // --------1 --------2 + // ------------3 + // ------4 + std::vector intervals{{0, IntegerValue(0), IntegerValue(10)}, + {1, IntegerValue(0), IntegerValue(4)}, + {2, IntegerValue(6), IntegerValue(10)}, + {3, IntegerValue(2), IntegerValue(8)}, + {4, IntegerValue(3), IntegerValue(6)}}; + + // Note that the order is deterministic, but not sorted. + ConstructOverlappingSets(/*already_sorted=*/false, &intervals, &result); + EXPECT_THAT(result, ElementsAre(UnorderedElementsAre(0, 1, 3, 4), + UnorderedElementsAre(3, 0, 2))); +} + +TEST(ConstructOverlappingSetsTest, OneSet) { + std::vector> result{{3}}; // To be sure we clear. + + std::vector intervals{ + {0, IntegerValue(0), IntegerValue(10)}, + {1, IntegerValue(1), IntegerValue(10)}, + {2, IntegerValue(2), IntegerValue(10)}, + {3, IntegerValue(3), IntegerValue(10)}, + {4, IntegerValue(4), IntegerValue(10)}}; + + ConstructOverlappingSets(/*already_sorted=*/false, &intervals, &result); + EXPECT_THAT(result, ElementsAre(ElementsAre(0, 1, 2, 3, 4))); +} + +TEST(GetOverlappingIntervalComponentsTest, BasicTest) { + std::vector> components{{3}}; // To be sure we clear. + + std::vector intervals{{0, IntegerValue(0), IntegerValue(3)}, + {1, IntegerValue(2), IntegerValue(4)}, + {2, IntegerValue(4), IntegerValue(7)}, + {3, IntegerValue(8), IntegerValue(10)}, + {4, IntegerValue(5), IntegerValue(9)}}; + + GetOverlappingIntervalComponents(&intervals, &components); + EXPECT_THAT(components, ElementsAre(ElementsAre(0, 1), ElementsAre(2, 4, 3))); +} + +TEST(GetOverlappingIntervalComponentsAndArticulationPointsTest, + WithWeirdIndicesAndSomeCornerCases) { + // Here are our intervals: 2======5 7====9 + // They are indexed from top to 0===2 4=====7 8======11 + // bottom, from left to right, 1===3 5=6 7=8 + // starting at 10. + std::vector intervals{ + {10, IntegerValue(2), IntegerValue(5)}, + {11, IntegerValue(7), IntegerValue(9)}, + {12, IntegerValue(0), IntegerValue(2)}, + {13, IntegerValue(4), IntegerValue(7)}, + {14, IntegerValue(8), IntegerValue(11)}, + {15, IntegerValue(1), IntegerValue(3)}, + {16, IntegerValue(5), IntegerValue(6)}, + {17, IntegerValue(7), IntegerValue(8)}, + }; + + std::vector> components; + GetOverlappingIntervalComponents(&intervals, &components); + EXPECT_THAT(components, ElementsAre(ElementsAre(12, 15, 10, 13, 16), + ElementsAre(17, 11, 14))); + + EXPECT_THAT(GetIntervalArticulationPoints(&intervals), + ElementsAre(15, 10, 13, 11)); +} + +std::vector GenerateRandomIntervalVector( + absl::BitGenRef random, int num_intervals) { + std::vector intervals; + intervals.reserve(num_intervals); + const int64_t interval_domain = + absl::LogUniform(random, 1, std::numeric_limits::max()); + const int64_t max_interval_length = absl::Uniform( + random, std::max(1, interval_domain / (2 * num_intervals + 1)), + interval_domain); + for (int i = 0; i < num_intervals; ++i) { + const int64_t start = absl::Uniform(random, 0, interval_domain); + const int64_t max_length = + std::min(interval_domain - start, max_interval_length); + const int64_t end = + start + absl::Uniform(absl::IntervalClosed, random, 1, max_length); + intervals.push_back( + IndexedInterval{i, IntegerValue(start), IntegerValue(end)}); + } + return intervals; +} + +std::vector> GetOverlappingIntervalComponentsBruteForce( + const std::vector& intervals) { + // Build the adjacency list. + std::vector> adj(intervals.size()); + for (int i = 1; i < intervals.size(); ++i) { + for (int j = 0; j < i; ++j) { + if (std::max(intervals[i].start, intervals[j].start) < + std::min(intervals[i].end, intervals[j].end)) { + adj[i].push_back(j); + adj[j].push_back(i); + } + } + } + std::vector component_indices = + util::GetConnectedComponents(intervals.size(), adj); + if (component_indices.empty()) return {}; + // Transform that into the expected output: a vector of components. + std::vector> components( + *absl::c_max_element(component_indices) + 1); + for (int i = 0; i < intervals.size(); ++i) { + components[component_indices[i]].push_back(i); + } + // Sort the components by start, like GetOverlappingIntervalComponents(). + absl::c_sort(components, [&intervals](const std::vector& c1, + const std::vector& c2) { + CHECK(!c1.empty() && !c2.empty()); + return intervals[c1[0]].start < intervals[c2[0]].start; + }); + // Inside each component, the intervals should be sorted, too. + // Moreover, we need to convert our indices to IntervalIndex.index. + for (std::vector& component : components) { + absl::c_sort(component, [&intervals](int i, int j) { + return IndexedInterval::ComparatorByStartThenEndThenIndex()(intervals[i], + intervals[j]); + }); + for (int& index : component) index = intervals[index].index; + } + return components; +} + +TEST(GetOverlappingIntervalComponentsTest, RandomizedStressTest) { + // Test duration as of 2021-06: .6s in fastbuild, .3s in opt. + constexpr int kNumTests = 10000; + absl::BitGen random; + for (int test = 0; test < kNumTests; ++test) { + const int num_intervals = absl::Uniform(random, 0, 16); + std::vector intervals = + GenerateRandomIntervalVector(random, num_intervals); + const std::vector intervals_copy = intervals; + std::vector> components; + GetOverlappingIntervalComponents(&intervals, &components); + ASSERT_THAT( + components, + ElementsAreArray(GetOverlappingIntervalComponentsBruteForce(intervals))) + << test << " " << absl::StrJoin(intervals_copy, ","); + // Also verify that the function only altered the order of "intervals". + EXPECT_THAT(intervals, UnorderedElementsAreArray(intervals_copy)); + ASSERT_FALSE(HasFailure()) + << test << " " << absl::StrJoin(intervals_copy, ","); + } +} + +TEST(GetIntervalArticulationPointsTest, RandomizedStressTest) { + // THIS TEST ASSUMES THAT GetOverlappingIntervalComponents() IS CORRECT. + // -> don't look at it if GetOverlappingIntervalComponentsTest.StressTest + // fails, and rather investigate that other test first. + + auto get_num_components = [](const std::vector& intervals) { + std::vector mutable_intervals = intervals; + std::vector> components; + GetOverlappingIntervalComponents(&mutable_intervals, &components); + return components.size(); + }; + // Test duration as of 2021-06: 1s in fastbuild, .4s in opt. + constexpr int kNumTests = 10000; + absl::BitGen random; + for (int test = 0; test < kNumTests; ++test) { + const int num_intervals = absl::Uniform(random, 0, 16); + const std::vector intervals = + GenerateRandomIntervalVector(random, num_intervals); + const int baseline_num_components = get_num_components(intervals); + + // Compute the expected articulation points: try removing each interval + // individually and check whether there are more components if we do. + std::vector expected_articulation_points; + for (int i = 0; i < num_intervals; ++i) { + std::vector tmp_intervals = intervals; + tmp_intervals.erase(tmp_intervals.begin() + i); + if (get_num_components(tmp_intervals) > baseline_num_components) { + expected_articulation_points.push_back(i); + } + } + // Sort the articulation points by start, and replace them by their + // corresponding IndexedInterval.index. + absl::c_sort(expected_articulation_points, [&intervals](int i, int j) { + return intervals[i].start < intervals[j].start; + }); + for (int& idx : expected_articulation_points) idx = intervals[idx].index; + + // Compare our function with the expected values. + std::vector mutable_intervals = intervals; + EXPECT_THAT(GetIntervalArticulationPoints(&mutable_intervals), + ElementsAreArray(expected_articulation_points)); + + // Also verify that the function only altered the order of "intervals". + EXPECT_THAT(mutable_intervals, UnorderedElementsAreArray(intervals)); + ASSERT_FALSE(HasFailure()) << test << " " << absl::StrJoin(intervals, ","); + } +} + +TEST(CapacityProfileTest, BasicApi) { + CapacityProfile profile; + profile.AddRectangle(IntegerValue(2), IntegerValue(6), IntegerValue(0), + IntegerValue(2)); + profile.AddRectangle(IntegerValue(4), IntegerValue(12), IntegerValue(0), + IntegerValue(1)); + profile.AddRectangle(IntegerValue(4), IntegerValue(8), IntegerValue(0), + IntegerValue(5)); + std::vector result; + profile.BuildResidualCapacityProfile(&result); + EXPECT_THAT( + result, + ElementsAre( + CapacityProfile::Rectangle(kMinIntegerValue, IntegerValue(0)), + CapacityProfile::Rectangle(IntegerValue(2), IntegerValue(2)), + CapacityProfile::Rectangle(IntegerValue(4), IntegerValue(5)), + CapacityProfile::Rectangle(IntegerValue(8), IntegerValue(1)), + CapacityProfile::Rectangle(IntegerValue(12), IntegerValue(0)))); + + // We query it twice to test that it can be done and that the result is not + // messed up. + profile.BuildResidualCapacityProfile(&result); + EXPECT_THAT( + result, + ElementsAre( + CapacityProfile::Rectangle(kMinIntegerValue, IntegerValue(0)), + CapacityProfile::Rectangle(IntegerValue(2), IntegerValue(2)), + CapacityProfile::Rectangle(IntegerValue(4), IntegerValue(5)), + CapacityProfile::Rectangle(IntegerValue(8), IntegerValue(1)), + CapacityProfile::Rectangle(IntegerValue(12), IntegerValue(0)))); + EXPECT_EQ(IntegerValue(2 * 2 + 4 * 5 + 4 * 1), profile.GetBoundingArea()); +} + +TEST(CapacityProfileTest, ProfileWithMandatoryPart) { + CapacityProfile profile; + profile.AddRectangle(IntegerValue(2), IntegerValue(6), IntegerValue(0), + IntegerValue(2)); + profile.AddRectangle(IntegerValue(4), IntegerValue(12), IntegerValue(0), + IntegerValue(1)); + profile.AddRectangle(IntegerValue(4), IntegerValue(8), IntegerValue(0), + IntegerValue(5)); + profile.AddMandatoryConsumption(IntegerValue(5), IntegerValue(10), + IntegerValue(1)); + std::vector result; + + // Add a dummy rectangle to test the result is cleared. result.push_bask(..); + result.push_back( + CapacityProfile::Rectangle(IntegerValue(2), IntegerValue(3))); + + profile.BuildResidualCapacityProfile(&result); + EXPECT_THAT( + result, + ElementsAre( + CapacityProfile::Rectangle(kMinIntegerValue, IntegerValue(0)), + CapacityProfile::Rectangle(IntegerValue(2), IntegerValue(2)), + CapacityProfile::Rectangle(IntegerValue(4), IntegerValue(5)), + CapacityProfile::Rectangle(IntegerValue(5), IntegerValue(4)), + CapacityProfile::Rectangle(IntegerValue(8), IntegerValue(0)), + CapacityProfile::Rectangle(IntegerValue(10), IntegerValue(1)), + CapacityProfile::Rectangle(IntegerValue(12), IntegerValue(0)))); + + // The bounding area should not be impacted by the mandatory consumption. + EXPECT_EQ(IntegerValue(2 * 2 + 4 * 5 + 4 * 1), profile.GetBoundingArea()); +} + +IntegerValue NaiveSmallest1DIntersection(IntegerValue range_min, + IntegerValue range_max, + IntegerValue size, + IntegerValue interval_min, + IntegerValue interval_max) { + IntegerValue min_intersection = std::numeric_limits::max(); + for (IntegerValue start = range_min; start + size <= range_max; ++start) { + // Interval is [start, start + size] + const IntegerValue intersection_start = std::max(start, interval_min); + const IntegerValue intersection_end = std::min(start + size, interval_max); + const IntegerValue intersection_length = + std::max(IntegerValue(0), intersection_end - intersection_start); + min_intersection = std::min(min_intersection, intersection_length); + } + return min_intersection; +} + +TEST(Smallest1DIntersectionTest, BasicTest) { + absl::BitGen random; + const int64_t max_size = 20; + constexpr int num_runs = 400; + for (int k = 0; k < num_runs; k++) { + const IntegerValue range_min = + IntegerValue(absl::Uniform(random, 0, max_size - 1)); + const IntegerValue range_max = + IntegerValue(absl::Uniform(random, range_min.value() + 1, max_size)); + const IntegerValue size = + absl::Uniform(random, 1, range_max.value() - range_min.value()); + + const IntegerValue interval_min = + IntegerValue(absl::Uniform(random, 0, max_size - 1)); + const IntegerValue interval_max = + IntegerValue(absl::Uniform(random, interval_min.value() + 1, max_size)); + EXPECT_EQ(NaiveSmallest1DIntersection(range_min, range_max, size, + interval_min, interval_max), + Smallest1DIntersection(range_min, range_max, size, interval_min, + interval_max)); + } +} + +TEST(RectangleTest, BasicTest) { + Rectangle r1 = {.x_min = 0, .x_max = 2, .y_min = 0, .y_max = 2}; + Rectangle r2 = {.x_min = 1, .x_max = 3, .y_min = 1, .y_max = 3}; + EXPECT_EQ(r1.Intersect(r2), + Rectangle({.x_min = 1, .x_max = 2, .y_min = 1, .y_max = 2})); +} + +TEST(RectangleTest, RandomRegionDifferenceTest) { + absl::BitGen random; + const int64_t size = 20; + constexpr int num_runs = 400; + for (int k = 0; k < num_runs; k++) { + Rectangle ret[2]; + for (int i = 0; i < 2; ++i) { + ret[i].x_min = IntegerValue(absl::Uniform(random, 0, size - 1)); + ret[i].x_max = + ret[i].x_min + IntegerValue(absl::Uniform(random, 1, size - 1)); + ret[i].y_min = IntegerValue(absl::Uniform(random, 0, size - 1)); + ret[i].y_max = + ret[i].y_min + IntegerValue(absl::Uniform(random, 1, size - 1)); + } + auto set_diff = ret[0].RegionDifference(ret[1]); + EXPECT_EQ(set_diff.empty(), ret[0].Intersect(ret[1]) == ret[0]); + IntegerValue diff_area = 0; + for (int i = 0; i < set_diff.size(); ++i) { + for (int j = i + 1; j < set_diff.size(); ++j) { + EXPECT_TRUE(set_diff[i].IsDisjoint(set_diff[j])); + } + EXPECT_NE(set_diff[i].Intersect(ret[0]), Rectangle::GetEmpty()); + EXPECT_EQ(set_diff[i].Intersect(ret[1]), Rectangle::GetEmpty()); + IntegerValue area = set_diff[i].Area(); + EXPECT_GT(area, 0); + diff_area += area; + } + EXPECT_EQ(ret[0].IntersectArea(ret[1]) + diff_area, ret[0].Area()); + } +} + +TEST(RectangleTest, RandomPavedRegionDifferenceTest) { + absl::BitGen random; + constexpr int num_runs = 100; + for (int k = 0; k < num_runs; k++) { + const std::vector set1 = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, random); + const std::vector set2 = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, random); + + const std::vector diff = PavedRegionDifference(set1, set2); + for (int i = 0; i < diff.size(); ++i) { + for (int j = i + 1; j < diff.size(); ++j) { + EXPECT_TRUE(diff[i].IsDisjoint(diff[j])); + } + } + for (const auto& r_diff : diff) { + for (const auto& r2 : set2) { + EXPECT_TRUE(r_diff.IsDisjoint(r2)); + } + IntegerValue area = 0; + for (const auto& r1 : set1) { + area += r_diff.IntersectArea(r1); + } + EXPECT_EQ(area, r_diff.Area()); + } + } +} + +TEST(GetMinimumOverlapTest, BasicTest) { + RectangleInRange range_ret = { + .bounding_area = {.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 15}, + .x_size = 10, + .y_size = 10}; + + // Minimum intersection is when the item is in the bottom-left corner of the + // allowed space. + Rectangle r = {.x_min = 3, .x_max = 30, .y_min = 3, .y_max = 30}; + EXPECT_EQ(range_ret.GetMinimumIntersection(r).Area(), 7 * 7); + EXPECT_EQ(range_ret.GetAtCorner(RectangleInRange::Corner::BOTTOM_LEFT), + Rectangle({.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10})); + EXPECT_EQ(range_ret.GetAtCorner(RectangleInRange::Corner::BOTTOM_LEFT) + .Intersect(r) + .Area(), + 7 * 7); + EXPECT_EQ(r.Intersect( + Rectangle({.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10})), + Rectangle({.x_min = 3, .x_max = 10, .y_min = 3, .y_max = 10})); + + RectangleInRange bigger = + RectangleInRange::BiggestWithMinIntersection(r, range_ret, 7, 7); + // This should be a broader range but don't increase the minimum intersection. + EXPECT_EQ(bigger.GetMinimumIntersection(r).Area(), 7 * 7); + for (const auto& pos : + {RectangleInRange::Corner::BOTTOM_LEFT, + RectangleInRange::Corner::TOP_LEFT, RectangleInRange::Corner::TOP_RIGHT, + RectangleInRange::Corner::BOTTOM_RIGHT}) { + EXPECT_EQ(bigger.GetAtCorner(pos).Intersect(r).Area(), 7 * 7); + } + EXPECT_EQ(bigger.bounding_area.x_min, 0); + EXPECT_EQ(bigger.bounding_area.x_max, 33); + EXPECT_EQ(bigger.bounding_area.y_min, 0); + EXPECT_EQ(bigger.bounding_area.y_max, 33); + EXPECT_EQ(r.Intersect(Rectangle( + {.x_min = 23, .x_max = 33, .y_min = 23, .y_max = 33})), + Rectangle({.x_min = 23, .x_max = 30, .y_min = 23, .y_max = 30})); + + RectangleInRange range_ret2 = { + .bounding_area = {.x_min = 0, .x_max = 105, .y_min = 0, .y_max = 120}, + .x_size = 100, + .y_size = 100}; + Rectangle r2 = {.x_min = 2, .x_max = 4, .y_min = 0, .y_max = 99}; + EXPECT_EQ(range_ret2.GetMinimumIntersection(r2), Rectangle::GetEmpty()); +} + +IntegerValue RecomputeEnergy(const Rectangle& rectangle, + const std::vector& intervals) { + IntegerValue ret = 0; + for (const RectangleInRange& range : intervals) { + const Rectangle min_intersect = range.GetMinimumIntersection(rectangle); + EXPECT_LE(min_intersect.SizeX(), range.x_size); + EXPECT_LE(min_intersect.SizeY(), range.y_size); + ret += min_intersect.Area(); + } + return ret; +} + +IntegerValue RecomputeEnergy(const ProbingRectangle& ranges) { + return RecomputeEnergy(ranges.GetCurrentRectangle(), ranges.Intervals()); +} + +void MoveAndCheck(ProbingRectangle& ranges, ProbingRectangle::Edge type) { + EXPECT_TRUE(ranges.CanShrink(type)); + const IntegerValue expected_area = + ranges.GetCurrentRectangle().Area() - ranges.GetShrinkDeltaArea(type); + const IntegerValue expected_min_energy = + ranges.GetMinimumEnergy() - ranges.GetShrinkDeltaEnergy(type); + ranges.Shrink(type); + EXPECT_EQ(ranges.GetMinimumEnergy(), RecomputeEnergy(ranges)); + EXPECT_EQ(ranges.GetMinimumEnergy(), expected_min_energy); + EXPECT_EQ(ranges.GetCurrentRectangle().Area(), expected_area); + ranges.ValidateInvariants(); +} + +TEST(ProbingRectangleTest, BasicTest) { + RectangleInRange range_ret = { + .bounding_area = {.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 13}, + .x_size = 10, + .y_size = 8}; + RectangleInRange range_ret2 = { + .bounding_area = {.x_min = 1, .x_max = 8, .y_min = 7, .y_max = 14}, + .x_size = 5, + .y_size = 5}; + + std::vector ranges_vec = {range_ret, range_ret2}; + ProbingRectangle ranges(ranges_vec); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 14})); + + // Start with the full bounding box, thus both are fully inside. + EXPECT_EQ(ranges.GetMinimumEnergy(), 10 * 8 + 5 * 5); + + EXPECT_EQ(ranges.GetMinimumEnergy(), RecomputeEnergy(ranges)); + + MoveAndCheck(ranges, ProbingRectangle::Edge::LEFT); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 1, .x_max = 15, .y_min = 0, .y_max = 14})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::LEFT); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 3, .x_max = 15, .y_min = 0, .y_max = 14})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::LEFT); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 5, .x_max = 15, .y_min = 0, .y_max = 14})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::LEFT); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 6, .x_max = 15, .y_min = 0, .y_max = 14})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::TOP); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 6, .x_max = 15, .y_min = 0, .y_max = 13})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::TOP); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 6, .x_max = 15, .y_min = 0, .y_max = 8})); + + MoveAndCheck(ranges, ProbingRectangle::Edge::TOP); + EXPECT_EQ(ranges.GetCurrentRectangle(), + Rectangle({.x_min = 6, .x_max = 15, .y_min = 0, .y_max = 5})); +} + +void ReduceUntilDone(ProbingRectangle& ranges, absl::BitGen& random) { + static constexpr ProbingRectangle::Edge kAllEdgesArr[] = { + ProbingRectangle::Edge::LEFT, + ProbingRectangle::Edge::TOP, + ProbingRectangle::Edge::RIGHT, + ProbingRectangle::Edge::BOTTOM, + }; + static constexpr absl::Span kAllMoveTypes( + kAllEdgesArr); + while (!ranges.IsMinimal()) { + ProbingRectangle::Edge type = + kAllMoveTypes.at(absl::Uniform(random, 0, (int)kAllMoveTypes.size())); + if (!ranges.CanShrink(type)) continue; + MoveAndCheck(ranges, type); + } +} + +// This function will find the conflicts for rectangles that have as coordinates +// for the edges one of {min, min + size, max - size, max} for every possible +// item that is at least partially inside the rectangle. Note that we might not +// detect a conflict even if there is one by looking only at those rectangles, +// see the ProbingRectangleTest.CounterExample unit test for a concrete example. +std::optional FindRectangleWithEnergyTooLargeExhaustive( + const std::vector& box_ranges) { + int num_boxes = box_ranges.size(); + std::vector x; + x.reserve(num_boxes * 4); + std::vector y; + y.reserve(num_boxes * 4); + for (const auto& box : box_ranges) { + x.push_back(box.bounding_area.x_min); + x.push_back(box.bounding_area.x_min + box.x_size); + x.push_back(box.bounding_area.x_max - box.x_size); + x.push_back(box.bounding_area.x_max); + y.push_back(box.bounding_area.y_min); + y.push_back(box.bounding_area.y_min + box.y_size); + y.push_back(box.bounding_area.y_max - box.y_size); + y.push_back(box.bounding_area.y_max); + } + std::sort(x.begin(), x.end()); + std::sort(y.begin(), y.end()); + x.erase(std::unique(x.begin(), x.end()), x.end()); + y.erase(std::unique(y.begin(), y.end()), y.end()); + for (int i = 0; i < x.size(); ++i) { + for (int j = i + 1; j < x.size(); ++j) { + for (int k = 0; k < y.size(); ++k) { + for (int l = k + 1; l < y.size(); ++l) { + IntegerValue used_energy = 0; + Rectangle rect = { + .x_min = x[i], .x_max = x[j], .y_min = y[k], .y_max = y[l]}; + for (const auto& box : box_ranges) { + auto intersection = box.GetMinimumIntersection(rect); + used_energy += intersection.Area(); + } + if (used_energy > rect.Area()) { + std::vector items_inside; + for (const auto& box : box_ranges) { + if (box.GetMinimumIntersectionArea(rect) > 0) { + items_inside.push_back(box); + } + } + if (items_inside.size() == num_boxes) { + return rect; + } else { + // Call it again after removing items that are outside. + auto try2 = + FindRectangleWithEnergyTooLargeExhaustive(items_inside); + if (try2.has_value()) { + return try2; + } + } + } + } + } + } + } + return std::nullopt; +} + +// This function should give exactly the same result as the +// `FindRectangleWithEnergyTooLargeExhaustive` above, but exercising the +// `ProbingRectangle` class. +std::optional FindRectangleWithEnergyTooLargeWithProbingRectangle( + std::vector& box_ranges) { + int left_shrinks = 0; + int right_shrinks = 0; + int top_shrinks = 0; + + ProbingRectangle ranges(box_ranges); + + while (true) { + // We want to do the equivalent of what + // `FindRectangleWithEnergyTooLargeExhaustive` does: for every + // left/right/top coordinates, try all possible bottom for conflicts. But + // since we cannot fix the coordinates with ProbingRectangle, we fix the + // number of shrinks instead. + ranges.Reset(); + for (int i = 0; i < left_shrinks; i++) { + CHECK(ranges.CanShrink(ProbingRectangle::Edge::LEFT)); + ranges.Shrink(ProbingRectangle::Edge::LEFT); + } + const bool left_end = !ranges.CanShrink(ProbingRectangle::Edge::LEFT); + for (int i = 0; i < top_shrinks; i++) { + CHECK(ranges.CanShrink(ProbingRectangle::Edge::TOP)); + ranges.Shrink(ProbingRectangle::Edge::TOP); + } + const bool top_end = !ranges.CanShrink(ProbingRectangle::Edge::TOP); + for (int i = 0; i < right_shrinks; i++) { + CHECK(ranges.CanShrink(ProbingRectangle::Edge::RIGHT)); + ranges.Shrink(ProbingRectangle::Edge::RIGHT); + } + const bool right_end = !ranges.CanShrink(ProbingRectangle::Edge::RIGHT); + if (ranges.GetMinimumEnergy() > ranges.GetCurrentRectangleArea()) { + return ranges.GetCurrentRectangle(); + } + while (ranges.CanShrink(ProbingRectangle::Edge::BOTTOM)) { + ranges.Shrink(ProbingRectangle::Edge::BOTTOM); + if (ranges.GetMinimumEnergy() > ranges.GetCurrentRectangleArea()) { + return ranges.GetCurrentRectangle(); + } + } + if (!right_end) { + right_shrinks++; + } else if (!top_end) { + top_shrinks++; + right_shrinks = 0; + } else if (!left_end) { + left_shrinks++; + top_shrinks = 0; + right_shrinks = 0; + } else { + break; + } + } + return std::nullopt; +} + +TEST(ProbingRectangleTest, Random) { + absl::BitGen random; + const int64_t size = 20; + std::vector rectangles; + int count = 0; + int comprehensive_count = 0; + constexpr int num_runs = 400; + for (int k = 0; k < num_runs; k++) { + const int num_intervals = absl::Uniform(random, 1, 20); + IntegerValue total_area = 0; + rectangles.clear(); + for (int i = 0; i < num_intervals; ++i) { + RectangleInRange& range = rectangles.emplace_back(); + range.bounding_area.x_min = IntegerValue(absl::Uniform(random, 0, size)); + range.bounding_area.x_max = IntegerValue( + absl::Uniform(random, range.bounding_area.x_min.value() + 1, size)); + range.x_size = absl::Uniform(random, 1, + range.bounding_area.x_max.value() - + range.bounding_area.x_min.value()); + + range.bounding_area.y_min = IntegerValue(absl::Uniform(random, 0, size)); + range.bounding_area.y_max = IntegerValue( + absl::Uniform(random, range.bounding_area.y_min.value() + 1, size)); + range.y_size = absl::Uniform(random, 1, + range.bounding_area.y_max.value() - + range.bounding_area.y_min.value()); + total_area += range.x_size * range.y_size; + } + auto ret = FindRectanglesWithEnergyConflictMC(rectangles, random, 1.0, 0.8); + count += !ret.conflicts.empty(); + ProbingRectangle ranges(rectangles); + EXPECT_EQ(total_area, ranges.GetMinimumEnergy()); + const bool has_possible_conflict = + FindRectangleWithEnergyTooLargeExhaustive(rectangles).has_value(); + if (has_possible_conflict) { + EXPECT_TRUE( + FindRectangleWithEnergyTooLargeWithProbingRectangle(rectangles) + .has_value()); + } + ReduceUntilDone(ranges, random); + comprehensive_count += has_possible_conflict; + } + LOG(INFO) << count << "/" << num_runs << " had an heuristic (out of " + << comprehensive_count << " possible)."; +} + +// Counterexample for proposition 5.4 of Clautiaux, François, et al. "A new +// constraint programming approach for the orthogonal packing problem." +// Computers & Operations Research 35.3 (2008): 944-959. +TEST(ProbingRectangleTest, CounterExample) { + const std::vector rectangles = { + {.bounding_area = {.x_min = 6, .x_max = 10, .y_min = 11, .y_max = 16}, + .x_size = 3, + .y_size = 2}, + {.bounding_area = {.x_min = 5, .x_max = 17, .y_min = 12, .y_max = 13}, + .x_size = 2, + .y_size = 1}, + {.bounding_area = {.x_min = 15, .x_max = 18, .y_min = 11, .y_max = 14}, + .x_size = 1, + .y_size = 1}, + {.bounding_area = {.x_min = 4, .x_max = 14, .y_min = 4, .y_max = 19}, + .x_size = 8, + .y_size = 7}, + {.bounding_area = {.x_min = 0, .x_max = 16, .y_min = 5, .y_max = 18}, + .x_size = 8, + .y_size = 9}, + {.bounding_area = {.x_min = 4, .x_max = 14, .y_min = 12, .y_max = 16}, + .x_size = 5, + .y_size = 1}, + {.bounding_area = {.x_min = 1, .x_max = 16, .y_min = 12, .y_max = 18}, + .x_size = 6, + .y_size = 1}, + {.bounding_area = {.x_min = 5, .x_max = 19, .y_min = 14, .y_max = 15}, + .x_size = 2, + .y_size = 1}}; + const Rectangle rect = {.x_min = 6, .x_max = 10, .y_min = 7, .y_max = 16}; + // The only other possible rectangle with a conflict is x(7..9), y(7..16), + // but none of {y_min, y_min + y_size, y_max - y_size, y_max} is equal to 7. + const IntegerValue energy = RecomputeEnergy(rect, rectangles); + EXPECT_GT(energy, rect.Area()); + EXPECT_FALSE( + FindRectangleWithEnergyTooLargeExhaustive(rectangles).has_value()); +} + +void BM_FindRectangles(benchmark::State& state) { + absl::BitGen random; + std::vector> problems; + static constexpr int kNumProblems = 20; + for (int i = 0; i < kNumProblems; i++) { + problems.push_back(MakeItemsFromRectangles( + GenerateNonConflictingRectangles(state.range(0), random), + state.range(1) / 100.0, random)); + } + int idx = 0; + for (auto s : state) { + CHECK(FindRectanglesWithEnergyConflictMC(problems[idx], random, 1.0, 0.8) + .conflicts.empty()); + ++idx; + if (idx == kNumProblems) idx = 0; + } +} + +BENCHMARK(BM_FindRectangles) + ->ArgPair(5, 1) + ->ArgPair(10, 1) + ->ArgPair(20, 1) + ->ArgPair(30, 1) + ->ArgPair(40, 1) + ->ArgPair(80, 1) + ->ArgPair(100, 1) + ->ArgPair(200, 1) + ->ArgPair(1000, 1) + ->ArgPair(10000, 1) + ->ArgPair(5, 100) + ->ArgPair(10, 100) + ->ArgPair(20, 100) + ->ArgPair(30, 100) + ->ArgPair(40, 100) + ->ArgPair(80, 100) + ->ArgPair(100, 100) + ->ArgPair(200, 100) + ->ArgPair(1000, 100) + ->ArgPair(10000, 100); + +TEST(FindPairwiseRestrictionsTest, Random) { + absl::BitGen random; + constexpr int num_runs = 400; + for (int k = 0; k < num_runs; k++) { + const int num_rectangles = absl::Uniform(random, 1, 20); + const std::vector rectangles = + GenerateNonConflictingRectangles(num_rectangles, random); + const std::vector items = + GenerateItemsRectanglesWithNoPairwiseConflict( + rectangles, absl::Uniform(random, 0, 1.0), random); + std::vector results; + AppendPairwiseRestrictions(items, &results); + for (const PairwiseRestriction& result : results) { + EXPECT_NE(result.type, + PairwiseRestriction::PairwiseRestrictionType::CONFLICT); + } + } +} + +void BM_FindPairwiseRestrictions(benchmark::State& state) { + absl::BitGen random; + // In the vast majority of the cases the propagator doesn't find any pairwise + // condition to propagate. Thus we choose to benchmark for this particular + // case. + const std::vector items = + GenerateItemsRectanglesWithNoPairwisePropagation( + state.range(0), state.range(1) / 100.0, random); + std::vector results; + for (auto s : state) { + AppendPairwiseRestrictions(items, &results); + CHECK(results.empty()); + } +} + +BENCHMARK(BM_FindPairwiseRestrictions) + ->ArgPair(5, 1) + ->ArgPair(10, 1) + ->ArgPair(20, 1) + ->ArgPair(30, 1) + ->ArgPair(40, 1) + ->ArgPair(80, 1) + ->ArgPair(100, 1) + ->ArgPair(200, 1) + ->ArgPair(1000, 1) + ->ArgPair(10000, 1) + ->ArgPair(5, 100) + ->ArgPair(10, 100) + ->ArgPair(20, 100) + ->ArgPair(30, 100) + ->ArgPair(40, 100) + ->ArgPair(80, 100) + ->ArgPair(100, 100) + ->ArgPair(200, 100) + ->ArgPair(1000, 100) + ->ArgPair(10000, 100); + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/disjunctive_test.cc b/ortools/sat/disjunctive_test.cc new file mode 100644 index 00000000000..bf58e2b915a --- /dev/null +++ b/ortools/sat/disjunctive_test.cc @@ -0,0 +1,527 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/disjunctive.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +// TODO(user): Add tests for variable duration intervals! The code is trickier +// to get right in this case. + +// Macros to improve the test readability below. +#define MIN_START(v) IntegerValue(v) +#define MIN_DURATION(v) IntegerValue(v) + +TEST(TaskSetTest, AddEntry) { + TaskSet tasks(1000); + std::mt19937 random(12345); + for (int i = 0; i < 1000; ++i) { + tasks.AddEntry({i, MIN_START(absl::Uniform(random, 0, 1000)), + MIN_DURATION(absl::Uniform(random, 0, 100))}); + } + EXPECT_TRUE( + std::is_sorted(tasks.SortedTasks().begin(), tasks.SortedTasks().end())); +} + +TEST(TaskSetTest, EndMinOnEmptySet) { + TaskSet tasks(0); + int critical_index; + EXPECT_EQ(kMinIntegerValue, + tasks.ComputeEndMin(/*task_to_ignore=*/-1, &critical_index)); + EXPECT_EQ(kMinIntegerValue, tasks.ComputeEndMin()); +} + +TEST(TaskSetTest, EndMinBasicTest) { + TaskSet tasks(3); + int critical_index; + tasks.AddEntry({0, MIN_START(2), MIN_DURATION(3)}); + tasks.AddEntry({1, MIN_START(2), MIN_DURATION(3)}); + tasks.AddEntry({2, MIN_START(2), MIN_DURATION(3)}); + EXPECT_EQ(11, tasks.ComputeEndMin(/*task_to_ignore=*/-1, &critical_index)); + EXPECT_EQ(11, tasks.ComputeEndMin()); + EXPECT_EQ(0, critical_index); +} + +TEST(TaskSetTest, EndMinWithNegativeValue) { + TaskSet tasks(3); + int critical_index; + tasks.AddEntry({0, MIN_START(-5), MIN_DURATION(1)}); + tasks.AddEntry({1, MIN_START(-6), MIN_DURATION(2)}); + tasks.AddEntry({2, MIN_START(-7), MIN_DURATION(3)}); + EXPECT_EQ(-1, tasks.ComputeEndMin(/*task_to_ignore=*/-1, &critical_index)); + EXPECT_EQ(-1, tasks.ComputeEndMin()); + EXPECT_EQ(0, critical_index); +} + +TEST(TaskSetTest, EndMinLimitCase) { + TaskSet tasks(3); + int critical_index; + tasks.AddEntry({0, MIN_START(2), MIN_DURATION(3)}); + tasks.AddEntry({1, MIN_START(2), MIN_DURATION(3)}); + tasks.AddEntry({2, MIN_START(8), MIN_DURATION(5)}); + EXPECT_EQ(8, tasks.ComputeEndMin(/*task_to_ignore=*/2, &critical_index)); + EXPECT_EQ(0, critical_index); + EXPECT_EQ(13, tasks.ComputeEndMin(/*task_to_ignore=*/-1, &critical_index)); + EXPECT_EQ(2, critical_index); +} + +TEST(TaskSetTest, IgnoringTheLastEntry) { + TaskSet tasks(3); + int critical_index; + tasks.AddEntry({0, MIN_START(2), MIN_DURATION(3)}); + tasks.AddEntry({1, MIN_START(7), MIN_DURATION(3)}); + EXPECT_EQ(10, tasks.ComputeEndMin(/*task_to_ignore=*/-1, &critical_index)); + EXPECT_EQ(5, tasks.ComputeEndMin(/*task_to_ignore=*/1, &critical_index)); +} + +#define MIN_START(v) IntegerValue(v) +#define MIN_DURATION(v) IntegerValue(v) + +// Tests that the DisjunctiveConstraint propagate how expected on the +// given input. Returns false if a conflict is detected (i.e. no feasible +// solution). +struct TaskWithDuration { + int min_start; + int max_end; + int min_duration; +}; +struct Task { + int min_start; + int max_end; +}; +bool TestDisjunctivePropagation(absl::Span input, + absl::Span expected, + int expected_num_enqueues) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + IntervalsRepository* intervals = model.GetOrCreate(); + + const int kStart(0); + const int kHorizon(10000); + + std::vector ids; + for (const TaskWithDuration& task : input) { + const IntervalVariable i = + model.Add(NewInterval(kStart, kHorizon, task.min_duration)); + ids.push_back(i); + std::vector no_literal_reason; + std::vector no_integer_reason; + EXPECT_TRUE(integer_trail->Enqueue( + intervals->Start(i).GreaterOrEqual(IntegerValue(task.min_start)), + no_literal_reason, no_integer_reason)); + EXPECT_TRUE( + integer_trail->Enqueue(intervals->End(i).LowerOrEqual(task.max_end), + no_literal_reason, no_integer_reason)); + } + + // Propagate properly the other bounds of the intervals. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + + const int initial_num_enqueues = integer_trail->num_enqueues(); + AddDisjunctive(ids, &model); + if (!model.GetOrCreate()->Propagate()) return false; + CHECK_EQ(input.size(), expected.size()); + for (int i = 0; i < input.size(); ++i) { + EXPECT_EQ(expected[i].min_start, + integer_trail->LowerBound(intervals->Start(ids[i]))) + << "task #" << i; + EXPECT_EQ(expected[i].max_end, + integer_trail->UpperBound(intervals->End(ids[i]))) + << "task #" << i; + } + + // The *2 is because there is one Enqueue() for the start and end variable. + EXPECT_EQ(expected_num_enqueues + initial_num_enqueues, + integer_trail->num_enqueues()); + return true; +} + +// 01234567890 +// (---- ) +// ( ------) +TEST(DisjunctiveConstraintTest, NoPropagation) { + EXPECT_TRUE(TestDisjunctivePropagation({{0, 10, 4}, {0, 10, 6}}, + {{0, 10}, {0, 10}}, 0)); +} + +// 01234567890 +// (---- ) +// ( -------) +TEST(DisjunctiveConstraintTest, Overload) { + EXPECT_FALSE(TestDisjunctivePropagation({{0, 10, 4}, {0, 10, 7}}, {}, 0)); +} + +// 01234567890123456789 +// (----- ) +// ( -----) +// ( ------ ) +TEST(DisjunctiveConstraintTest, OverloadFromVilimPhd) { + EXPECT_FALSE( + TestDisjunctivePropagation({{0, 13, 5}, {1, 14, 5}, {2, 12, 6}}, {}, 0)); +} + +// 0123456789012345678901234567890123456789 +// ( [---- ) +// (--- ) +// ( ---) +// (-----) +// +// TODO(user): The problem with this test is that the other propagators do +// propagate the same bound, but in 2 steps, whereas the edge finding do that in +// one. To properly test this, we need to add options to deactivate some of +// the propagations. +TEST(DisjunctiveConstraintTest, EdgeFindingFromVilimPhd) { + EXPECT_TRUE(TestDisjunctivePropagation( + {{4, 30, 4}, {5, 13, 3}, {5, 13, 3}, {13, 18, 5}}, + {{18, 30}, {5, 13}, {5, 13}, {13, 18}}, /*expected_num_enqueues=*/2)); +} + +// 0123456789012345678901234567890123456789 +// (----------- ) +// ( ----------) +// ( -- ] ) +TEST(DisjunctiveConstraintTest, NotLastFromVilimPhd) { + EXPECT_TRUE(TestDisjunctivePropagation({{0, 25, 11}, {1, 27, 10}, {4, 20, 2}}, + {{0, 25}, {1, 27}, {4, 17}}, 1)); +} + +// 0123456789012345678901234567890123456789 +// (----- ) +// ( -----) +// (--- ) +// [ <- the new bound for the third task. +TEST(DisjunctiveConstraintTest, DetectablePrecedenceFromVilimPhd) { + EXPECT_TRUE(TestDisjunctivePropagation({{0, 13, 5}, {1, 14, 5}, {7, 17, 3}}, + {{0, 13}, {1, 14}, {10, 17}}, 1)); +} + +TEST(DisjunctiveConstraintTest, Precedences) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + auto* precedences = model.GetOrCreate(); + auto* relations = model.GetOrCreate(); + auto* intervals = model.GetOrCreate(); + + const auto add_affine_coeff_one_precedence = [&](const AffineExpression e1, + const AffineExpression& e2) { + CHECK_NE(e1.var, kNoIntegerVariable); + CHECK_EQ(e1.coeff, 1); + CHECK_NE(e2.var, kNoIntegerVariable); + CHECK_EQ(e2.coeff, 1); + precedences->AddPrecedenceWithOffset(e1.var, e2.var, + e1.constant - e2.constant); + relations->Add(e1.var, e2.var, e1.constant - e2.constant); + }; + + const int kStart(0); + const int kHorizon(10000); + + std::vector ids; + ids.push_back(model.Add(NewInterval(kStart, kHorizon, 10))); + ids.push_back(model.Add(NewInterval(kStart, kHorizon, 10))); + ids.push_back(model.Add(NewInterval(kStart, kHorizon, 10))); + AddDisjunctive(ids, &model); + + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + for (const IntervalVariable i : ids) { + EXPECT_EQ(0, integer_trail->LowerBound(intervals->Start(i))); + } + + // Now with the precedences. + add_affine_coeff_one_precedence(intervals->End(ids[0]), + intervals->Start(ids[2])); + add_affine_coeff_one_precedence(intervals->End(ids[1]), + intervals->Start(ids[2])); + EXPECT_TRUE(precedences->Propagate(trail)); + EXPECT_EQ(10, integer_trail->LowerBound(intervals->Start(ids[2]))); + + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_EQ(20, integer_trail->LowerBound(intervals->Start(ids[2]))); +} + +// This test should enumerate all the permutation of kNumIntervals elements. +// It used to fail before CL 134067105. +TEST(SchedulingTest, Permutations) { + static const int kNumIntervals = 4; + Model model; + std::vector intervals; + for (int i = 0; i < kNumIntervals; ++i) { + const IntervalVariable interval = + model.Add(NewInterval(0, kNumIntervals, 1)); + intervals.push_back(interval); + } + AddDisjunctive(intervals, &model); + + IntegerTrail* integer_trail = model.GetOrCreate(); + IntervalsRepository* repository = model.GetOrCreate(); + std::vector> solutions; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::vector solution(kNumIntervals, -1); + for (int i = 0; i < intervals.size(); ++i) { + const IntervalVariable interval = intervals[i]; + const int64_t start_time = + integer_trail->LowerBound(repository->Start(interval)).value(); + DCHECK_GE(start_time, 0); + DCHECK_LT(start_time, kNumIntervals); + solution[start_time] = i; + } + solutions.push_back(solution); + LOG(INFO) << "Found solution: {" << absl::StrJoin(solution, ", ") << "}."; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // Test that we do have all the permutations (but in a random order). + std::sort(solutions.begin(), solutions.end()); + std::vector expected(kNumIntervals); + std::iota(expected.begin(), expected.end(), 0); + for (int i = 0; i < solutions.size(); ++i) { + EXPECT_EQ(expected, solutions[i]); + if (i + 1 < solutions.size()) { + EXPECT_TRUE(std::next_permutation(expected.begin(), expected.end())); + } else { + // We enumerated all the permutations. + EXPECT_FALSE(std::next_permutation(expected.begin(), expected.end())); + } + } +} + +// ============================================================================ +// Random tests with comparison with a simple time-decomposition encoding. +// ============================================================================ + +void AddDisjunctiveTimeDecomposition(absl::Span vars, + Model* model) { + const int num_tasks = vars.size(); + IntegerTrail* integer_trail = model->GetOrCreate(); + IntegerEncoder* encoder = model->GetOrCreate(); + IntervalsRepository* repository = model->GetOrCreate(); + + // Compute time range. + IntegerValue min_start = kMaxIntegerValue; + IntegerValue max_end = kMinIntegerValue; + for (int t = 0; t < num_tasks; ++t) { + const AffineExpression start = repository->Start(vars[t]); + const AffineExpression end = repository->End(vars[t]); + min_start = std::min(min_start, integer_trail->LowerBound(start)); + max_end = std::max(max_end, integer_trail->UpperBound(end)); + } + + // Add a constraint for each point of time. + for (IntegerValue time = min_start; time <= max_end; ++time) { + std::vector presence_at_time; + for (const IntervalVariable var : vars) { + const AffineExpression start = repository->Start(var); + const AffineExpression end = repository->End(var); + + const IntegerValue start_min = integer_trail->LowerBound(start); + const IntegerValue end_max = integer_trail->UpperBound(end); + if (end_max <= time || time < start_min) continue; + + // This will be true iff interval is present at time. + // TODO(user): we actually only need one direction of the equivalence. + presence_at_time.push_back( + Literal(model->Add(NewBooleanVariable()), true)); + + std::vector presence_condition; + presence_condition.push_back(encoder->GetOrCreateAssociatedLiteral( + start.LowerOrEqual(IntegerValue(time)))); + presence_condition.push_back(encoder->GetOrCreateAssociatedLiteral( + end.GreaterOrEqual(IntegerValue(time + 1)))); + if (repository->IsOptional(var)) { + presence_condition.push_back(repository->PresenceLiteral(var)); + } + model->Add(ReifiedBoolAnd(presence_condition, presence_at_time.back())); + } + model->Add(AtMostOneConstraint(presence_at_time)); + + // Abort if UNSAT. + if (model->GetOrCreate()->ModelIsUnsat()) return; + } +} + +struct OptionalTasksWithDuration { + int min_start; + int max_end; + int duration; + bool is_optional; +}; + +// TODO(user): we never generate zero duration for now. +std::vector GenerateRandomInstance( + int num_tasks, absl::BitGenRef randomizer) { + std::vector instance; + for (int i = 0; i < num_tasks; ++i) { + OptionalTasksWithDuration task; + task.min_start = absl::Uniform(randomizer, 0, 10); + task.max_end = absl::Uniform(randomizer, 0, 10); + if (task.min_start > task.max_end) std::swap(task.min_start, task.max_end); + if (task.min_start == task.max_end) ++task.max_end; + task.duration = + 1 + absl::Uniform(randomizer, 0, task.max_end - task.min_start - 1); + task.is_optional = absl::Bernoulli(randomizer, 1.0 / 2); + instance.push_back(task); + } + return instance; +} + +int CountAllSolutions( + absl::Span instance, + const std::function&, Model*)>& + add_disjunctive) { + Model model; + std::vector intervals; + for (const OptionalTasksWithDuration& task : instance) { + if (task.is_optional) { + const Literal is_present = Literal(model.Add(NewBooleanVariable()), true); + intervals.push_back(model.Add(NewOptionalInterval( + task.min_start, task.max_end, task.duration, is_present))); + } else { + intervals.push_back( + model.Add(NewInterval(task.min_start, task.max_end, task.duration))); + } + } + add_disjunctive(intervals, &model); + + int num_solutions_found = 0; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + num_solutions_found++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + return num_solutions_found; +} + +std::string InstanceDebugString( + absl::Span instance) { + std::string result; + for (const OptionalTasksWithDuration& task : instance) { + absl::StrAppend(&result, "[", task.min_start, ", ", task.max_end, + "] duration:", task.duration, + " is_optional:", task.is_optional, "\n"); + } + return result; +} + +TEST(DisjunctiveTest, RandomComparisonWithSimpleEncoding) { + std::mt19937 randomizer(12345); + const int num_tests = DEBUG_MODE ? 100 : 1000; + for (int test = 0; test < num_tests; ++test) { + const int num_tasks = absl::Uniform(randomizer, 1, 6); + const std::vector instance = + GenerateRandomInstance(num_tasks, randomizer); + EXPECT_EQ(CountAllSolutions(instance, AddDisjunctiveTimeDecomposition), + CountAllSolutions(instance, AddDisjunctive)) + << InstanceDebugString(instance); + EXPECT_EQ( + CountAllSolutions(instance, AddDisjunctive), + CountAllSolutions(instance, AddDisjunctiveWithBooleanPrecedencesOnly)) + << InstanceDebugString(instance); + } +} + +TEST(DisjunctiveTest, TwoIntervalsTest) { + // All the way to put 2 intervals of size 4 and 3 in [0,9]. There is just + // two non-busy unit interval, so: + // - 2 possibilities with 1 hole of size 2 at beginning + // - 2 possibilities with 1 hole of size 2 at the end. + // - 2 possibilities with 1 hole of size 2 in the middle. + // - 2 possibilities with 2 holes around the interval of size 3. + // - 2 possibilities with 2 holes around the interval of size 4. + // - 2 possibilities with 2 holes on both extremities. + std::vector instance; + instance.push_back({0, 9, 4, false}); + instance.push_back({0, 9, 3, false}); + EXPECT_EQ(12, CountAllSolutions(instance, AddDisjunctive)); +} + +TEST(DisjunctiveTest, Precedences) { + Model model; + + std::vector ids; + ids.push_back(model.Add(NewInterval(0, 7, 3))); + ids.push_back(model.Add(NewInterval(0, 7, 2))); + AddDisjunctive(ids, &model); + + const IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); + IntervalsRepository* intervals = model.GetOrCreate(); + model.Add( + AffineCoeffOneLowerOrEqualWithOffset(intervals->End(ids[0]), var, 5)); + model.Add( + AffineCoeffOneLowerOrEqualWithOffset(intervals->End(ids[1]), var, 4)); + + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_EQ(model.Get(LowerBound(var)), (3 + 2) + std::min(4, 5)); +} + +TEST(DisjunctiveTest, OptionalIntervalsWithLinkedPresence) { + Model model; + const Literal alternative = Literal(model.Add(NewBooleanVariable()), true); + + std::vector intervals; + intervals.push_back(model.Add(NewOptionalInterval(0, 6, 3, alternative))); + intervals.push_back(model.Add(NewOptionalInterval(0, 6, 2, alternative))); + intervals.push_back( + model.Add(NewOptionalInterval(0, 6, 4, alternative.Negated()))); + AddDisjunctive(intervals, &model); + + int num_solutions_found = 0; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + num_solutions_found++; + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + EXPECT_EQ(num_solutions_found, /*alternative*/ 6 + /*!alternative*/ 3); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/docs/README.md b/ortools/sat/docs/README.md index c00f1d65a52..d5f61110982 100644 --- a/ortools/sat/docs/README.md +++ b/ortools/sat/docs/README.md @@ -226,9 +226,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func simpleSatProgram() error { @@ -264,7 +264,7 @@ func simpleSatProgram() error { func main() { if err := simpleSatProgram(); err != nil { - glog.Exitf("simpleSatProgram returned with error: %v", err) + log.Exitf("simpleSatProgram returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/boolean_logic.md b/ortools/sat/docs/boolean_logic.md index 9f0d10c8fba..bef7d6a7759 100644 --- a/ortools/sat/docs/boolean_logic.md +++ b/ortools/sat/docs/boolean_logic.md @@ -113,8 +113,8 @@ public class LiteralSampleSat package main import ( - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func literalSampleSat() { @@ -123,7 +123,7 @@ func literalSampleSat() { x := model.NewBoolVar().WithName("x") notX := x.Not() - glog.Infof("x = %d, x.Not() = %d", x.Index(), notX.Index()) + log.Infof("x = %d, x.Not() = %d", x.Index(), notX.Index()) } func main() { @@ -248,7 +248,7 @@ public class BoolOrSampleSat package main import ( - "ortools/sat/go/cpmodel" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func boolOrSampleSat() { @@ -434,7 +434,7 @@ public class ReifiedSampleSat package main import ( - "ortools/sat/go/cpmodel" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func reifiedSampleSat() { @@ -525,10 +525,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func booleanProductSample() error { @@ -552,11 +552,11 @@ func booleanProductSample() error { } // Set `fill_additional_solutions_in_response` and `enumerate_all_solutions` to true so // the solver returns all solutions found. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(4), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -576,7 +576,7 @@ func booleanProductSample() error { func main() { err := booleanProductSample() if err != nil { - glog.Exitf("booleanProductSample returned with error: %v", err) + log.Exitf("booleanProductSample returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/channeling.md b/ortools/sat/docs/channeling.md index cf16b8960d2..6feed79c202 100644 --- a/ortools/sat/docs/channeling.md +++ b/ortools/sat/docs/channeling.md @@ -308,11 +308,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func channelingSampleSat() error { @@ -344,12 +344,12 @@ func channelingSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(11), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -367,7 +367,7 @@ func channelingSampleSat() error { func main() { if err := channelingSampleSat(); err != nil { - glog.Exitf("channelingSampleSat returned with error: %v", err) + log.Exitf("channelingSampleSat returned with error: %v", err) } } ``` @@ -895,8 +895,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const ( @@ -907,7 +907,7 @@ const ( ) type item struct { - Cost, Copies int64_t + Cost, Copies int64 } func binpackingProblemSat() error { @@ -993,7 +993,7 @@ func binpackingProblemSat() error { func main() { if err := binpackingProblemSat(); err != nil { - glog.Exitf("binpackingProblemSat returned with error: %v", err) + log.Exitf("binpackingProblemSat returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/integer_arithmetic.md b/ortools/sat/docs/integer_arithmetic.md index 2fcd3c5629b..7fac0f7d1ed 100644 --- a/ortools/sat/docs/integer_arithmetic.md +++ b/ortools/sat/docs/integer_arithmetic.md @@ -275,9 +275,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const numAnimals = 20 @@ -315,7 +315,7 @@ func rabbitsAndPheasants() error { func main() { if err := rabbitsAndPheasants(); err != nil { - glog.Exitf("rabbitsAndPheasants returned with error: %v", err) + log.Exitf("rabbitsAndPheasants returned with error: %v", err) } } ``` @@ -675,11 +675,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) const ( @@ -719,12 +719,12 @@ func earlinessTardinessCostSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -742,7 +742,7 @@ func earlinessTardinessCostSampleSat() error { func main() { if err := earlinessTardinessCostSampleSat(); err != nil { - glog.Exitf("earlinessTardinessCostSampleSat returned with error: %v", err) + log.Exitf("earlinessTardinessCostSampleSat returned with error: %v", err) } } ``` @@ -1131,11 +1131,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func stepFunctionSampleSat() error { @@ -1185,12 +1185,12 @@ func stepFunctionSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -1208,7 +1208,7 @@ func stepFunctionSampleSat() error { func main() { if err := stepFunctionSampleSat(); err != nil { - glog.Exitf("stepFunctionSampleSat returned with error: %v", err) + log.Exitf("stepFunctionSampleSat returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/model.md b/ortools/sat/docs/model.md index 3949f8ce2bc..1bc3f3ce509 100644 --- a/ortools/sat/docs/model.md +++ b/ortools/sat/docs/model.md @@ -310,9 +310,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func solutionHintingSampleSat() error { @@ -353,7 +353,7 @@ func solutionHintingSampleSat() error { func main() { if err := solutionHintingSampleSat(); err != nil { - glog.Exitf("solutionHintingSampleSat returned with error: %v", err) + log.Exitf("solutionHintingSampleSat returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/scheduling.md b/ortools/sat/docs/scheduling.md index 19a8b5264c9..091ab8891f9 100644 --- a/ortools/sat/docs/scheduling.md +++ b/ortools/sat/docs/scheduling.md @@ -195,8 +195,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const horizon = 100 @@ -231,7 +231,7 @@ func intervalSampleSat() error { func main() { if err := intervalSampleSat(); err != nil { - glog.Exitf("intervalSampleSat returned with error: %v", err) + log.Exitf("intervalSampleSat returned with error: %v", err) } } ``` @@ -421,8 +421,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const horizon = 100 @@ -454,7 +454,7 @@ func optionalIntervalSampleSat() error { func main() { if err := optionalIntervalSampleSat(); err != nil { - glog.Exitf("optionalIntervalSampleSat returned with error: %v", err) + log.Exitf("optionalIntervalSampleSat returned with error: %v", err) } } ``` @@ -841,9 +841,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const horizon = 21 // 3 weeks @@ -909,7 +909,7 @@ func noOverlapSampleSat() error { func main() { if err := noOverlapSampleSat(); err != nil { - glog.Exitf("noOverlapSampleSat returned with error: %v", err) + log.Exitf("noOverlapSampleSat returned with error: %v", err) } } ``` @@ -1865,9 +1865,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const ( @@ -1937,7 +1937,7 @@ func rankingSampleSat() error { for t := 0; t < numTasks; t++ { start := model.NewIntVarFromDomain(horizon) - duration := cpmodel.NewConstant(int64_t(t + 1)) + duration := cpmodel.NewConstant(int64(t + 1)) end := model.NewIntVarFromDomain(horizon) var presence cpmodel.BoolVar if t < numTasks/2 { @@ -2008,7 +2008,7 @@ func rankingSampleSat() error { func main() { if err := rankingSampleSat(); err != nil { - glog.Exitf("rankingSampleSat returned with error: %v", err) + log.Exitf("rankingSampleSat returned with error: %v", err) } } ``` diff --git a/ortools/sat/docs/solver.md b/ortools/sat/docs/solver.md index 5881234d438..6bff2619c08 100644 --- a/ortools/sat/docs/solver.md +++ b/ortools/sat/docs/solver.md @@ -194,11 +194,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func solveWithTimeLimitSampleSat() error { @@ -217,9 +217,9 @@ func solveWithTimeLimitSampleSat() error { } // Sets a time limit of 10 seconds. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ MaxTimeInSeconds: proto.Float64(10.0), - }.Build() + } // Solve. response, err := cpmodel.SolveCpModelWithParameters(m, params) @@ -240,7 +240,7 @@ func solveWithTimeLimitSampleSat() error { func main() { if err := solveWithTimeLimitSampleSat(); err != nil { - glog.Exitf("solveWithTimeLimitSampleSat returned with error: %v", err) + log.Exitf("solveWithTimeLimitSampleSat returned with error: %v", err) } } ``` @@ -535,10 +535,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func solveAndPrintIntermediateSolutionsSampleSat() error { @@ -562,10 +562,10 @@ func solveAndPrintIntermediateSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so intermediate solutions // cannot be printed while solving. However, the CP-SAT solver does allow for returning // the intermediate solutions found while solving in the response. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(10), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -583,7 +583,7 @@ func solveAndPrintIntermediateSolutionsSampleSat() error { func main() { if err := solveAndPrintIntermediateSolutionsSampleSat(); err != nil { - glog.Exitf("solveAndPrintIntermediateSolutionsSampleSat returned with error: %v", err) + log.Exitf("solveAndPrintIntermediateSolutionsSampleSat returned with error: %v", err) } } ``` @@ -872,10 +872,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func searchForAllSolutionsSampleSat() error { @@ -895,11 +895,11 @@ func searchForAllSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so each feasible solution cannot // be printed while solving. However, the CP Solver can return all of the enumerated solutions // in the response by setting the following parameters. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ EnumerateAllSolutions: proto.Bool(true), FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(27), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -917,7 +917,7 @@ func searchForAllSolutionsSampleSat() error { func main() { if err := searchForAllSolutionsSampleSat(); err != nil { - glog.Exitf("searchForAllSolutionsSampleSat returned with error: %v", err) + log.Exitf("searchForAllSolutionsSampleSat returned with error: %v", err) } } ``` diff --git a/ortools/sat/encoding_test.cc b/ortools/sat/encoding_test.cc new file mode 100644 index 00000000000..29608b3fbb5 --- /dev/null +++ b/ortools/sat/encoding_test.cc @@ -0,0 +1,106 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/encoding.h" + +#include +#include +#include +#include + +#include "absl/random/distributions.h" +#include "gtest/gtest.h" +#include "ortools/sat/pb_constraint.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(MergeAllNodesWithDequeTest, BasicPropagation) { + // We start with a sat solver and n Boolean variables. + std::mt19937 random(12345); + const int n = 456; + SatSolver solver; + solver.SetNumVariables(n); + + // We encode the full cardinality constraint on the n variables. + std::deque repository; + std::vector nodes; + for (int i = 0; i < n; ++i) { + repository.push_back(EncodingNode::LiteralNode( + Literal(BooleanVariable(i), true), Coefficient(0))); + nodes.push_back(&repository.back()); + } + const Coefficient an_upper_bound(1000); + EncodingNode* root = + MergeAllNodesWithDeque(an_upper_bound, nodes, &solver, &repository); + EXPECT_EQ(root->lb(), 0); + EXPECT_EQ(root->ub(), n); + EXPECT_EQ(root->size(), n); + EXPECT_EQ(root->depth(), 9); // 2^9 = 512 which is the first value >= n. + + // We fix some of the n variables randomly, and check some property of the + // Encoding nodes. + for (int run = 0; run < 10; ++run) { + const float density = run / 10; + int exact_count = 0; + solver.Backtrack(0); + for (int i = 0; i < n; ++i) { + const bool value = absl::Bernoulli(random, density); + exact_count += value ? 1 : 0; + EXPECT_TRUE(solver.EnqueueDecisionIfNotConflicting( + Literal(BooleanVariable(i), value))); + } + EXPECT_EQ(solver.Solve(), SatSolver::FEASIBLE); + + // We use an exact encoding, so the number of affected variables at the root + // level of the encoding should be exactly exact_count. + if (exact_count > 0) { + EXPECT_TRUE(solver.Assignment().LiteralIsTrue( + root->GreaterThan(exact_count - 1))); + } + if (exact_count < n) { + EXPECT_FALSE( + solver.Assignment().LiteralIsTrue(root->GreaterThan(exact_count))); + } + } +} + +TEST(LazyMergeAllNodeWithPQAndIncreaseLbTest, CorrectDepth) { + // We start with a sat solver and n Boolean variables. + std::mt19937 random(12345); + const int n = 456; + SatSolver solver; + solver.SetNumVariables(n); + + // We encode the full cardinality constraint on the n variables. + std::deque repository; + std::vector nodes; + for (int i = 0; i < n; ++i) { + repository.push_back(EncodingNode::LiteralNode( + Literal(BooleanVariable(i), true), Coefficient(0))); + nodes.push_back(&repository.back()); + } + EncodingNode* root = + LazyMergeAllNodeWithPQAndIncreaseLb(1, nodes, &solver, &repository); + EXPECT_EQ(root->lb(), 1); + EXPECT_EQ(root->ub(), n); + EXPECT_EQ(root->size(), 0); + EXPECT_EQ(root->depth(), 9); // 2^9 = 512 which is the first value >= n. +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/feasibility_jump_test.cc b/ortools/sat/feasibility_jump_test.cc new file mode 100644 index 00000000000..0e848d03ff7 --- /dev/null +++ b/ortools/sat/feasibility_jump_test.cc @@ -0,0 +1,93 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/feasibility_jump.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace operations_research::sat { +namespace { + +TEST(JumpTableTest, TestCachesCalls) { + int num_calls = 0; + JumpTable jumps; + jumps.SetComputeFunction( + [&](int) { return std::make_pair(++num_calls, -1.0); }); + jumps.RecomputeAll(1); + + EXPECT_EQ(jumps.GetJump(0), std::make_pair(int64_t{1}, -1.0)); + EXPECT_EQ(jumps.GetJump(0), std::make_pair(int64_t{1}, -1.0)); + EXPECT_EQ(num_calls, 1); +} + +TEST(JumpTableTest, TestNeedsRecomputationOneVar) { + int num_calls = 0; + JumpTable jumps; + jumps.SetComputeFunction( + [&](int) { return std::make_pair(++num_calls, -1.0); }); + jumps.RecomputeAll(1); + + jumps.GetJump(0); + jumps.Recompute(0); + + EXPECT_EQ(jumps.GetJump(0), std::make_pair(int64_t{2}, -1.0)); + EXPECT_EQ(num_calls, 2); +} + +TEST(JumpTableTest, TestNeedsRecomputationMultiVar) { + int num_calls = 0; + JumpTable jumps; + jumps.SetComputeFunction( + [&](int v) { return std::make_pair(++num_calls, v); }); + jumps.RecomputeAll(2); + + jumps.GetJump(0); + jumps.GetJump(1); + jumps.Recompute(0); + + EXPECT_EQ(jumps.GetJump(0), std::make_pair(int64_t{3}, 0.0)); + EXPECT_EQ(jumps.GetJump(1), std::make_pair(int64_t{2}, 1.0)); + EXPECT_EQ(num_calls, 3); +} + +TEST(JumpTableTest, TestVarsNeedingRecomputePossiblyGood) { + int num_calls = 0; + JumpTable jumps; + jumps.SetComputeFunction( + [&](int) { return std::make_pair(++num_calls, 1.0); }); + jumps.RecomputeAll(1); + + EXPECT_TRUE(jumps.NeedRecomputation(0)); + EXPECT_EQ(num_calls, 0); +} + +TEST(JumpTableTest, TestSetJump) { + int num_calls = 0; + JumpTable jumps; + jumps.SetComputeFunction( + [&](int) { return std::make_pair(++num_calls, -1.0); }); + jumps.RecomputeAll(1); + + jumps.SetJump(0, 1, 1.0); + + EXPECT_FALSE(jumps.NeedRecomputation(0)); + EXPECT_GE(jumps.Score(0), 0); + EXPECT_EQ(jumps.GetJump(0), std::make_pair(int64_t{1}, 1.0)); + EXPECT_EQ(num_calls, 0); +} + +} // namespace +} // namespace operations_research::sat diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index e1577030490..5ff2c7dd856 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -40,6 +40,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/strong_integers.h" @@ -610,11 +611,11 @@ bool FeasibilityPump::PropagationRounding() { } const int64_t rounded_value = - static_cast(std::round(lp_solution_[var_index])); + SafeDoubleToInt64(std::round(lp_solution_[var_index])); const int64_t floor_value = - static_cast(std::floor(lp_solution_[var_index])); + SafeDoubleToInt64(std::floor(lp_solution_[var_index])); const int64_t ceil_value = - static_cast(std::ceil(lp_solution_[var_index])); + SafeDoubleToInt64(std::ceil(lp_solution_[var_index])); const bool floor_is_in_domain = (domain.Contains(floor_value) && lb.value() <= floor_value); diff --git a/ortools/sat/go/cpmodel/BUILD.bazel b/ortools/sat/go/cpmodel/BUILD.bazel index 1aac72eae48..7a7e86aaff7 100644 --- a/ortools/sat/go/cpmodel/BUILD.bazel +++ b/ortools/sat/go/cpmodel/BUILD.bazel @@ -1,3 +1,16 @@ +# Copyright 2010-2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( @@ -5,8 +18,6 @@ go_library( srcs = [ "cp_model.go", "cp_solver.go", - "cp_solver_c.cc", - "cp_solver_c.h", "domain.go", ], cdeps = [":cp_solver_c"], @@ -44,12 +55,12 @@ cc_library( srcs = ["cp_solver_c.cc"], hdrs = ["cp_solver_c.h"], deps = [ + "//ortools/base:memutil", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_solver", + "//ortools/sat:model", "//ortools/sat:sat_parameters_cc_proto", "//ortools/util:time_limit", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/log:check", ], ) diff --git a/ortools/sat/go/cpmodel/cp_model.go b/ortools/sat/go/cpmodel/cp_model.go index 1822f6c55fa..561233aece8 100644 --- a/ortools/sat/go/cpmodel/cp_model.go +++ b/ortools/sat/go/cpmodel/cp_model.go @@ -578,8 +578,8 @@ func (cp *Builder) NewOptionalIntervalVar(start, size, end LinearArgument, prese Start: start.asLinearExpressionProto(), Size: size.asLinearExpressionProto(), End: end.asLinearExpressionProto(), - }, - }}) + }}, + }) return IntervalVar{cpb: cp, ind: ind} } @@ -803,11 +803,10 @@ func (cp *Builder) AddMinEquality(target LinearArgument, exprs ...LinearArgument } return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: asNegatedLinearExpressionProto(target), - Exprs: protos, - }}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: asNegatedLinearExpressionProto(target), + Exprs: protos, + }}, }) } @@ -819,11 +818,10 @@ func (cp *Builder) AddMaxEquality(target LinearArgument, exprs ...LinearArgument } return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: target.asLinearExpressionProto(), - Exprs: protos, - }}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: target.asLinearExpressionProto(), + Exprs: protos, + }}, }) } @@ -835,53 +833,49 @@ func (cp *Builder) AddMultiplicationEquality(target LinearArgument, exprs ...Lin } return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntProd{ - &cmpb.LinearArgumentProto{ - Target: target.asLinearExpressionProto(), - Exprs: protos, - }}, + Constraint: &cmpb.ConstraintProto_IntProd{&cmpb.LinearArgumentProto{ + Target: target.asLinearExpressionProto(), + Exprs: protos, + }}, }) } // AddDivisionEquality adds the constraint: target == num / denom. func (cp *Builder) AddDivisionEquality(target, num, denom LinearArgument) Constraint { return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntDiv{ - &cmpb.LinearArgumentProto{ - Target: target.asLinearExpressionProto(), - Exprs: []*cmpb.LinearExpressionProto{ - num.asLinearExpressionProto(), - denom.asLinearExpressionProto(), - }, - }}, + Constraint: &cmpb.ConstraintProto_IntDiv{&cmpb.LinearArgumentProto{ + Target: target.asLinearExpressionProto(), + Exprs: []*cmpb.LinearExpressionProto{ + num.asLinearExpressionProto(), + denom.asLinearExpressionProto(), + }, + }}, }) } // AddAbsEquality adds the constraint: target == Abs(expr). func (cp *Builder) AddAbsEquality(target, expr LinearArgument) Constraint { return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: target.asLinearExpressionProto(), - Exprs: []*cmpb.LinearExpressionProto{ - expr.asLinearExpressionProto(), - asNegatedLinearExpressionProto(expr), - }, - }}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: target.asLinearExpressionProto(), + Exprs: []*cmpb.LinearExpressionProto{ + expr.asLinearExpressionProto(), + asNegatedLinearExpressionProto(expr), + }, + }}, }) } // AddModuloEquality adds the constraint: target == v % mod. func (cp *Builder) AddModuloEquality(target, v, mod LinearArgument) Constraint { return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntMod{ - &cmpb.LinearArgumentProto{ - Target: target.asLinearExpressionProto(), - Exprs: []*cmpb.LinearExpressionProto{ - v.asLinearExpressionProto(), - mod.asLinearExpressionProto(), - }, - }}, + Constraint: &cmpb.ConstraintProto_IntMod{&cmpb.LinearArgumentProto{ + Target: target.asLinearExpressionProto(), + Exprs: []*cmpb.LinearExpressionProto{ + v.asLinearExpressionProto(), + mod.asLinearExpressionProto(), + }, + }}, }) } @@ -894,37 +888,33 @@ func (cp *Builder) AddNoOverlap(vars ...IntervalVar) Constraint { } return cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_NoOverlap{ - &cmpb.NoOverlapConstraintProto{ - Intervals: intervals, - }}, + Constraint: &cmpb.ConstraintProto_NoOverlap{&cmpb.NoOverlapConstraintProto{ + Intervals: intervals, + }}, }) } // AddNoOverlap2D adds a no_overlap2D constraint that prevents a set of boxes from overlapping. func (cp *Builder) AddNoOverlap2D() NoOverlap2DConstraint { return NoOverlap2DConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_NoOverlap_2D{ - &cmpb.NoOverlap2DConstraintProto{}, - }})} + Constraint: &cmpb.ConstraintProto_NoOverlap_2D{&cmpb.NoOverlap2DConstraintProto{}}, + })} } // AddCircuitConstraint adds a circuit constraint to the model. The circuit constraint is // defined on a graph where the arcs are present if the corresponding literals are set to true. func (cp *Builder) AddCircuitConstraint() CircuitConstraint { return CircuitConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Circuit{ - &cmpb.CircuitConstraintProto{}, - }})} + Constraint: &cmpb.ConstraintProto_Circuit{&cmpb.CircuitConstraintProto{}}, + })} } // AddMultipleCircuitConstraint adds a multiple circuit constraint to the model, aka the "VRP" // (Vehicle Routing Problem) constraint. func (cp *Builder) AddMultipleCircuitConstraint() MultipleCircuitConstraint { return MultipleCircuitConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Routes{ - &cmpb.RoutesConstraintProto{}, - }})} + Constraint: &cmpb.ConstraintProto_Routes{&cmpb.RoutesConstraintProto{}}, + })} } // AddAllowedAssignments adds an allowed assignments constraint to the model. When all variables @@ -937,9 +927,8 @@ func (cp *Builder) AddAllowedAssignments(vars ...IntVar) TableConstraint { } return TableConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Table{ - &cmpb.TableConstraintProto{Vars: varsInd}, - }})} + Constraint: &cmpb.ConstraintProto_Table{&cmpb.TableConstraintProto{Vars: varsInd}}, + })} } // AddReservoirConstraint adds a reservoir constraint with optional refill/emptying events. @@ -951,8 +940,7 @@ func (cp *Builder) AddAllowedAssignments(vars ...IntVar) TableConstraint { // is assigned a value t, then the level of the reservoir changes by // level_change (which is constant) at time t. Therefore, at any time t: // -// sum(level_changes[i] * actives[i] if times[i] <= t) -// in [min_level, max_level] +// sum(level_changes[i] * actives[i] if times[i] <= t) in [min_level, max_level] // // Note that min level must be <= 0, and the max level must be >= 0. // Please use fixed level_changes to simulate an initial state. @@ -963,10 +951,9 @@ func (cp *Builder) AddReservoirConstraint(min, max int64) ReservoirConstraint { return ReservoirConstraint{ cp.appendConstraint( &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Reservoir{ - &cmpb.ReservoirConstraintProto{ - MinLevel: min, MaxLevel: max, - }}}, + Constraint: &cmpb.ConstraintProto_Reservoir{&cmpb.ReservoirConstraintProto{ + MinLevel: min, MaxLevel: max, + }}}, ), cp.NewConstant(1).Index()} } @@ -1001,12 +988,11 @@ func (cp *Builder) AddAutomaton(transitionVars []IntVar, startState int64, final transitions = append(transitions, int32(v.Index())) } return AutomatonConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Automaton{ - &cmpb.AutomatonConstraintProto{ - Vars: transitions, - StartingState: startState, - FinalStates: finalStates, - }}, + Constraint: &cmpb.ConstraintProto_Automaton{&cmpb.AutomatonConstraintProto{ + Vars: transitions, + StartingState: startState, + FinalStates: finalStates, + }}, })} } @@ -1015,11 +1001,10 @@ func (cp *Builder) AddAutomaton(transitionVars []IntVar, startState int64, final // capacity. func (cp *Builder) AddCumulative(capacity LinearArgument) CumulativeConstraint { return CumulativeConstraint{cp.appendConstraint(&cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Cumulative{ - &cmpb.CumulativeConstraintProto{ - Capacity: capacity.asLinearExpressionProto(), - }, - }})} + Constraint: &cmpb.ConstraintProto_Cumulative{&cmpb.CumulativeConstraintProto{ + Capacity: capacity.asLinearExpressionProto(), + }}, + })} } // Minimize adds a linear minimization objective. diff --git a/ortools/sat/go/cpmodel/cp_model_test.go b/ortools/sat/go/cpmodel/cp_model_test.go index ad91fefff05..8fa8c53bc07 100644 --- a/ortools/sat/go/cpmodel/cp_model_test.go +++ b/ortools/sat/go/cpmodel/cp_model_test.go @@ -20,11 +20,10 @@ import ( "sort" "testing" - "github.com/google/go-cmp/cmp" - "google.golang.org/protobuf/testing/protocmp" - log "github.com/golang/glog" + "github.com/google/go-cmp/cmp" cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + "google.golang.org/protobuf/testing/protocmp" ) func Example() { @@ -735,19 +734,17 @@ func TestIntervalVar(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(trueVar.Index())}, - Constraint: &cmpb.ConstraintProto_Interval{ - &cmpb.IntervalConstraintProto{ - Start: &cmpb.LinearExpressionProto{Offset: 1}, - Size: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - }, - End: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, + Constraint: &cmpb.ConstraintProto_Interval{&cmpb.IntervalConstraintProto{ + Start: &cmpb.LinearExpressionProto{Offset: 1}, + Size: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, }, - }, + End: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, + Coeffs: []int64{1}, + }, + }}, }, }, { @@ -759,20 +756,18 @@ func TestIntervalVar(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(trueVar.Index())}, - Constraint: &cmpb.ConstraintProto_Interval{ - &cmpb.IntervalConstraintProto{ - Start: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - }, - Size: &cmpb.LinearExpressionProto{Offset: 5}, - End: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Offset: 5, - }, + Constraint: &cmpb.ConstraintProto_Interval{&cmpb.IntervalConstraintProto{ + Start: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, }, - }, + Size: &cmpb.LinearExpressionProto{Offset: 5}, + End: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Offset: 5, + }, + }}, }, }, { @@ -784,19 +779,17 @@ func TestIntervalVar(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv1.Index())}, - Constraint: &cmpb.ConstraintProto_Interval{ - &cmpb.IntervalConstraintProto{ - Start: &cmpb.LinearExpressionProto{Offset: 1}, - Size: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - }, - End: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, + Constraint: &cmpb.ConstraintProto_Interval{&cmpb.IntervalConstraintProto{ + Start: &cmpb.LinearExpressionProto{Offset: 1}, + Size: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, }, - }, + End: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, + Coeffs: []int64{1}, + }, + }}, }, }, { @@ -808,20 +801,18 @@ func TestIntervalVar(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv1.Index())}, - Constraint: &cmpb.ConstraintProto_Interval{ - &cmpb.IntervalConstraintProto{ - Start: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - }, - Size: &cmpb.LinearExpressionProto{Offset: 5}, - End: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Offset: 5, - }, + Constraint: &cmpb.ConstraintProto_Interval{&cmpb.IntervalConstraintProto{ + Start: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, }, - }, + Size: &cmpb.LinearExpressionProto{Offset: 5}, + End: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Offset: 5, + }, + }}, }, }, } @@ -880,11 +871,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_BoolOr{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_BoolOr{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -896,11 +885,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_BoolAnd{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_BoolAnd{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -912,11 +899,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_BoolXor{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_BoolXor{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -928,11 +913,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_BoolOr{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_BoolOr{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -944,11 +927,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_AtMostOne{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_AtMostOne{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -960,11 +941,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { }, want: &cmpb.ConstraintProto{ EnforcementLiteral: []int32{int32(bv3.Index())}, - Constraint: &cmpb.ConstraintProto_ExactlyOne{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_ExactlyOne{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -975,11 +954,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_BoolOr{ - &cmpb.BoolArgumentProto{ - Literals: []int32{int32(bv1.Not().Index()), int32(bv2.Not().Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_BoolOr{&cmpb.BoolArgumentProto{ + Literals: []int32{int32(bv1.Not().Index()), int32(bv2.Not().Index())}, + }}, }, }, { @@ -991,13 +968,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index()), int32(bv1.Index())}, - Coeffs: []int64{1, 1}, - Domain: []int64{-5, -4, -2, -1, 6, 15}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index()), int32(bv1.Index())}, + Coeffs: []int64{1, 1}, + Domain: []int64{-5, -4, -2, -1, 6, 15}, + }}, }, }, { @@ -1008,13 +983,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index()), int32(bv1.Index())}, - Coeffs: []int64{1, 1}, - Domain: []int64{2, 6}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index()), int32(bv1.Index())}, + Coeffs: []int64{1, 1}, + Domain: []int64{2, 6}, + }}, }, }, { @@ -1025,13 +998,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{10, 10}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{10, 10}, + }}, }, }, { @@ -1042,13 +1013,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{math.MinInt64, 10}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{math.MinInt64, 10}, + }}, }, }, { @@ -1059,13 +1028,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{math.MinInt64, 9}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{math.MinInt64, 9}, + }}, }, }, { @@ -1076,13 +1043,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{10, math.MaxInt64}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{10, math.MaxInt64}, + }}, }, }, { @@ -1093,13 +1058,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{11, math.MaxInt64}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{11, math.MaxInt64}, + }}, }, }, { @@ -1110,13 +1073,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Linear{ - &cmpb.LinearConstraintProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - Domain: []int64{math.MinInt64, 9, 11, math.MaxInt64}, - }, - }, + Constraint: &cmpb.ConstraintProto_Linear{&cmpb.LinearConstraintProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + Domain: []int64{math.MinInt64, 9, 11, math.MaxInt64}, + }}, }, }, { @@ -1127,30 +1088,28 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_AllDiff{ - &cmpb.AllDifferentConstraintProto{ - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(bv1.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(bv2.Index())}, - Coeffs: []int64{-1}, - Offset: 1, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{}, - Coeffs: []int64{}, - Offset: 10, - }, + Constraint: &cmpb.ConstraintProto_AllDiff{&cmpb.AllDifferentConstraintProto{ + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(bv1.Index())}, + Coeffs: []int64{1}, + }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(bv2.Index())}, + Coeffs: []int64{-1}, + Offset: 1, + }, + &cmpb.LinearExpressionProto{ + Vars: []int32{}, + Coeffs: []int64{}, + Offset: 10, }, }, - }, + }}, }, }, { @@ -1161,13 +1120,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Element{ - &cmpb.ElementConstraintProto{ - Index: int32(iv1.Index()), - Target: int32(iv4.Index()), - Vars: []int32{int32(iv2.Index()), int32(iv3.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_Element{&cmpb.ElementConstraintProto{ + Index: int32(iv1.Index()), + Target: int32(iv4.Index()), + Vars: []int32{int32(iv2.Index()), int32(iv3.Index())}, + }}, }, }, { @@ -1178,16 +1135,14 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Element{ - &cmpb.ElementConstraintProto{ - Index: int32(iv1.Index()), - Target: int32(iv4.Index()), - Vars: []int32{ - int32(model.NewConstant(10).Index()), - int32(model.NewConstant(20).Index()), - }, + Constraint: &cmpb.ConstraintProto_Element{&cmpb.ElementConstraintProto{ + Index: int32(iv1.Index()), + Target: int32(iv4.Index()), + Vars: []int32{ + int32(model.NewConstant(10).Index()), + int32(model.NewConstant(20).Index()), }, - }, + }}, }, }, { @@ -1198,12 +1153,10 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Inverse{ - &cmpb.InverseConstraintProto{ - FDirect: []int32{int32(iv1.Index()), int32(iv2.Index())}, - FInverse: []int32{int32(iv3.Index()), int32(iv4.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_Inverse{&cmpb.InverseConstraintProto{ + FDirect: []int32{int32(iv1.Index()), int32(iv2.Index())}, + FInverse: []int32{int32(iv3.Index()), int32(iv4.Index())}, + }}, }, }, { @@ -1214,24 +1167,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{-1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{-1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{-1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv3.Index())}, - Coeffs: []int64{-1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv3.Index())}, + Coeffs: []int64{-1}, }, }, - }, + }}, }, }, { @@ -1242,24 +1193,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv3.Index())}, - Coeffs: []int64{1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv3.Index())}, + Coeffs: []int64{1}, }, }, - }, + }}, }, }, { @@ -1270,24 +1219,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntProd{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_IntProd{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv3.Index())}, - Coeffs: []int64{1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv3.Index())}, + Coeffs: []int64{1}, }, }, - }, + }}, }, }, { @@ -1298,24 +1245,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntDiv{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_IntDiv{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv3.Index())}, - Coeffs: []int64{1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv3.Index())}, + Coeffs: []int64{1}, }, }, - }, + }}, }, }, { @@ -1326,24 +1271,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_LinMax{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_LinMax{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{-1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, + Coeffs: []int64{-1}, }, }, - }, + }}, }, }, { @@ -1354,24 +1297,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_IntMod{ - &cmpb.LinearArgumentProto{ - Target: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_IntMod{&cmpb.LinearArgumentProto{ + Target: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Exprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Exprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv3.Index())}, - Coeffs: []int64{1}, - }, + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv3.Index())}, + Coeffs: []int64{1}, }, }, - }, + }}, }, }, { @@ -1382,11 +1323,9 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_NoOverlap{ - &cmpb.NoOverlapConstraintProto{ - Intervals: []int32{int32(interval1.Index()), int32(interval2.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_NoOverlap{&cmpb.NoOverlapConstraintProto{ + Intervals: []int32{int32(interval1.Index()), int32(interval2.Index())}, + }}, }, }, { @@ -1399,12 +1338,10 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_NoOverlap_2D{ - &cmpb.NoOverlap2DConstraintProto{ - XIntervals: []int32{int32(interval1.Index()), int32(interval3.Index())}, - YIntervals: []int32{int32(interval2.Index()), int32(interval4.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_NoOverlap_2D{&cmpb.NoOverlap2DConstraintProto{ + XIntervals: []int32{int32(interval1.Index()), int32(interval3.Index())}, + YIntervals: []int32{int32(interval2.Index()), int32(interval4.Index())}, + }}, }, }, { @@ -1416,13 +1353,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Circuit{ - &cmpb.CircuitConstraintProto{ - Tails: []int32{0}, - Heads: []int32{1}, - Literals: []int32{int32(bv1.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_Circuit{&cmpb.CircuitConstraintProto{ + Tails: []int32{0}, + Heads: []int32{1}, + Literals: []int32{int32(bv1.Index())}, + }}, }, }, { @@ -1434,13 +1369,11 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Routes{ - &cmpb.RoutesConstraintProto{ - Tails: []int32{0}, - Heads: []int32{1}, - Literals: []int32{int32(bv1.Index())}, - }, - }, + Constraint: &cmpb.ConstraintProto_Routes{&cmpb.RoutesConstraintProto{ + Tails: []int32{0}, + Heads: []int32{1}, + Literals: []int32{int32(bv1.Index())}, + }}, }, }, { @@ -1453,12 +1386,10 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Table{ - &cmpb.TableConstraintProto{ - Vars: []int32{int32(iv1.Index()), int32(iv2.Index())}, - Values: []int64{0, 2, 1, 3}, - }, - }, + Constraint: &cmpb.ConstraintProto_Table{&cmpb.TableConstraintProto{ + Vars: []int32{int32(iv1.Index()), int32(iv2.Index())}, + Values: []int64{0, 2, 1, 3}, + }}, }, }, { @@ -1470,24 +1401,22 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Reservoir{ - &cmpb.ReservoirConstraintProto{ - MinLevel: 10, - MaxLevel: 20, - TimeExprs: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, - Coeffs: []int64{2}, - }, + Constraint: &cmpb.ConstraintProto_Reservoir{&cmpb.ReservoirConstraintProto{ + MinLevel: 10, + MaxLevel: 20, + TimeExprs: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{2}, }, - LevelChanges: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Offset: 15, - }, + }, + LevelChanges: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Offset: 15, }, - ActiveLiterals: []int32{int32(one.Index())}, }, - }, + ActiveLiterals: []int32{int32(one.Index())}, + }}, }, }, { @@ -1500,16 +1429,14 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Automaton{ - &cmpb.AutomatonConstraintProto{ - Vars: []int32{int32(iv1.Index()), int32(iv2.Index())}, - StartingState: 0, - FinalStates: []int64{5, 10}, - TransitionTail: []int64{0, 2}, - TransitionHead: []int64{1, 3}, - TransitionLabel: []int64{10, 15}, - }, - }, + Constraint: &cmpb.ConstraintProto_Automaton{&cmpb.AutomatonConstraintProto{ + Vars: []int32{int32(iv1.Index()), int32(iv2.Index())}, + StartingState: 0, + FinalStates: []int64{5, 10}, + TransitionTail: []int64{0, 2}, + TransitionHead: []int64{1, 3}, + TransitionLabel: []int64{10, 15}, + }}, }, }, { @@ -1521,21 +1448,19 @@ func TestCpModelBuilder_Constraints(t *testing.T) { return m.GetConstraints()[c.Index()] }, want: &cmpb.ConstraintProto{ - Constraint: &cmpb.ConstraintProto_Cumulative{ - &cmpb.CumulativeConstraintProto{ - Capacity: &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv1.Index())}, + Constraint: &cmpb.ConstraintProto_Cumulative{&cmpb.CumulativeConstraintProto{ + Capacity: &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv1.Index())}, + Coeffs: []int64{1}, + }, + Intervals: []int32{int32(interval1.Index())}, + Demands: []*cmpb.LinearExpressionProto{ + &cmpb.LinearExpressionProto{ + Vars: []int32{int32(iv2.Index())}, Coeffs: []int64{1}, }, - Intervals: []int32{int32(interval1.Index())}, - Demands: []*cmpb.LinearExpressionProto{ - &cmpb.LinearExpressionProto{ - Vars: []int32{int32(iv2.Index())}, - Coeffs: []int64{1}, - }, - }, }, - }, + }}, }, }, } diff --git a/ortools/sat/go/cpmodel/cp_solver.go b/ortools/sat/go/cpmodel/cp_solver.go index 752b5d05624..e50a1f2f5b6 100644 --- a/ortools/sat/go/cpmodel/cp_solver.go +++ b/ortools/sat/go/cpmodel/cp_solver.go @@ -18,10 +18,10 @@ import ( "sync" "unsafe" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - - "google.golang.org/protobuf/proto" ) /* diff --git a/ortools/sat/go/cpmodel/cp_solver_c.cc b/ortools/sat/go/cpmodel/cp_solver_c.cc index 1f5b36808f4..e8bba37b13e 100644 --- a/ortools/sat/go/cpmodel/cp_solver_c.cc +++ b/ortools/sat/go/cpmodel/cp_solver_c.cc @@ -14,13 +14,12 @@ #include "ortools/sat/go/cpmodel/cp_solver_c.h" #include -#include -#include "absl/status/status.h" -#include "absl/strings/internal/memutil.h" -#include "ortools/base/logging.h" +#include "absl/log/check.h" +#include "ortools/base/memutil.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/util/time_limit.h" @@ -28,13 +27,6 @@ namespace operations_research::sat { namespace { -char* memdup(const char* s, size_t slen) { - void* copy; - if ((copy = malloc(slen)) == nullptr) return nullptr; - memcpy(copy, s, slen); - return reinterpret_cast(copy); -} - CpSolverResponse solveWithParameters(std::atomic* const limit_reached, const CpModelProto& proto, const SatParameters& params) { @@ -81,7 +73,7 @@ void SolveCpInterruptible(void* const limit_reached, const void* creq, CHECK(res.SerializeToString(&res_str)); *cres_len = static_cast(res_str.size()); - *cres = memdup(res_str.data(), *cres_len); + *cres = strings::memdup(res_str.data(), *cres_len); CHECK(*cres != nullptr); } diff --git a/ortools/sat/implied_bounds_test.cc b/ortools/sat/implied_bounds_test.cc new file mode 100644 index 00000000000..935c30b33df --- /dev/null +++ b/ortools/sat/implied_bounds_test.cc @@ -0,0 +1,706 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/implied_bounds.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/lp_data/lp_types.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +TEST(ImpliedBoundsTest, BasicTest) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + auto* ib = model.GetOrCreate(); + auto* sat_solver = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + + const Literal enforcement(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 10))); + + EXPECT_TRUE(ib->Add(enforcement, + IntegerLiteral::GreaterOrEqual(var, IntegerValue(3)))); + EXPECT_TRUE(ib->Add(enforcement.Negated(), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)))); + + // Here because we are at level-zero everything is propagated right away. + EXPECT_EQ(integer_trail->LowerBound(var), IntegerValue(3)); + EXPECT_EQ(integer_trail->LevelZeroLowerBound(var), IntegerValue(3)); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_EQ(integer_trail->LowerBound(var), IntegerValue(3)); +} + +TEST(ImpliedBoundsTest, BasicTestPositiveLevel) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + auto* ib = model.GetOrCreate(); + auto* sat_solver = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + + const Literal enforcement(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 10))); + + // We can do the same at a positive level. + const Literal to_enqueue(model.Add(NewBooleanVariable()), true); + EXPECT_TRUE(sat_solver->ResetToLevelZero()); + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(to_enqueue)); + EXPECT_GT(sat_solver->CurrentDecisionLevel(), 0); + + EXPECT_TRUE(ib->Add(enforcement, + IntegerLiteral::GreaterOrEqual(var, IntegerValue(3)))); + EXPECT_TRUE(ib->Add(enforcement.Negated(), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)))); + + // Now, only the level zero bound is up to date. + EXPECT_EQ(integer_trail->LowerBound(var), IntegerValue(0)); + EXPECT_EQ(integer_trail->LevelZeroLowerBound(var), IntegerValue(3)); + + // But on the next restart, nothing is lost. + EXPECT_TRUE(sat_solver->ResetToLevelZero()); + EXPECT_EQ(integer_trail->LowerBound(var), IntegerValue(3)); +} + +// Same test as above but no deduction since parameter is false. +TEST(ImpliedBoundsTest, BasicTestWithFalseParameters) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(false); + auto* ib = model.GetOrCreate(); + auto* sat_solver = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + + const Literal enforcement(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 10))); + + EXPECT_TRUE(ib->Add(enforcement, + IntegerLiteral::GreaterOrEqual(var, IntegerValue(3)))); + EXPECT_TRUE(ib->Add(enforcement.Negated(), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)))); + + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_EQ(integer_trail->LowerBound(var), IntegerValue(0)); +} + +TEST(ImpliedBoundsTest, ReadBoundsFromTrail) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const Literal l(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + // Make sure l as a view. + const IntegerVariable view(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + l, view, IntegerValue(1)); + + // So that there is a decision. + auto* sat_solver = model.GetOrCreate(); + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(l)); + EXPECT_TRUE(sat_solver->Propagate()); + + // Enqueue a bunch of fact. + auto* integer_trail = model.GetOrCreate(); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(4)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(8)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(9)), {l.Negated()}, {})); + + // Read from trail. + auto* ib = model.GetOrCreate(); + ib->ProcessIntegerTrail(l); + + std::vector result = ib->GetImpliedBounds(var); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0].literal_view, view); + EXPECT_EQ(result[0].lower_bound, IntegerValue(9)); + EXPECT_TRUE(result[0].is_positive); +} + +TEST(ImpliedBoundsTest, DetectEqualityFromMin) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const Literal literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + auto* ib = model.GetOrCreate(); + ib->Add(literal, IntegerLiteral::LowerOrEqual(var, IntegerValue(0))); + + EXPECT_THAT( + ib->GetImpliedValues(literal), + testing::UnorderedElementsAre(testing::Pair(var, IntegerValue(0)))); +} + +TEST(ImpliedBoundsTest, DetectEqualityFromMax) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const Literal literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + auto* ib = model.GetOrCreate(); + ib->Add(literal, IntegerLiteral::GreaterOrEqual(var, IntegerValue(100))); + + EXPECT_THAT(ib->GetImpliedValues(literal), + UnorderedElementsAre(Pair(var, IntegerValue(100)))); +} + +TEST(ImpliedBoundsTest, DetectEqualityFromBothInequalities) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const Literal literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + auto* ib = model.GetOrCreate(); + ib->Add(literal, IntegerLiteral::LowerOrEqual(var, IntegerValue(7))); + ib->Add(literal, IntegerLiteral::GreaterOrEqual(var, IntegerValue(7))); + + EXPECT_THAT(ib->GetImpliedValues(literal), + UnorderedElementsAre(Pair(var, IntegerValue(7)))); +} + +TEST(ImpliedBoundsTest, NoEqualityDetection) { + Model model; + model.GetOrCreate()->set_use_implied_bounds(true); + + const Literal literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + auto* ib = model.GetOrCreate(); + ib->Add(literal, IntegerLiteral::LowerOrEqual(var, IntegerValue(7))); + ib->Add(literal, IntegerLiteral::GreaterOrEqual(var, IntegerValue(6))); + + EXPECT_TRUE(ib->GetImpliedValues(literal).empty()); +} + +TEST(DetectLinearEncodingOfProductsTest, MatchingElementEncodings) { + Model model; + const Literal l0(model.Add(NewBooleanVariable()), true); + const Literal l1(model.Add(NewBooleanVariable()), true); + const Literal l2(model.Add(NewBooleanVariable()), true); + const Literal l3(model.Add(NewBooleanVariable()), true); + + model.Add(NewIntegerVariableFromLiteral(l0)); + model.Add(NewIntegerVariableFromLiteral(l1)); + model.Add(NewIntegerVariableFromLiteral(l2)); + model.Add(NewIntegerVariableFromLiteral(l3)); + + const IntegerVariable x0(model.Add(NewIntegerVariable(0, 100))); + const IntegerVariable x1(model.Add(NewIntegerVariable(0, 100))); + auto* element_encodings = model.GetOrCreate(); + element_encodings->Add(x0, + {{IntegerValue(2), l0}, + {IntegerValue(4), l1}, + {IntegerValue(2), l2}, + {IntegerValue(10), l3}}, + 2); + element_encodings->Add(x1, + {{IntegerValue(3), l0}, + {IntegerValue(10), l1}, + {IntegerValue(20), l2}, + {IntegerValue(30), l3}}, + 2); + LinearConstraintBuilder builder(&model); + builder.AddConstant(IntegerValue(-1)); // To be cleared. + EXPECT_TRUE( + model.GetOrCreate()->TryToLinearize(x0, x1, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "34*X1 34*X2 294*X3 + 6"); + + builder.Clear(); + EXPECT_TRUE( + model.GetOrCreate()->TryToLinearize(x1, x0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "34*X1 34*X2 294*X3 + 6"); +} + +TEST(DetectLinearEncodingOfProductsTest, MatchingEncodingAndSizeTwoEncoding) { + Model model; + const Literal l0(model.Add(NewBooleanVariable()), true); + const Literal l1(model.Add(NewBooleanVariable()), true); + const Literal l2(model.Add(NewBooleanVariable()), true); + const Literal l3(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(0, 100))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + auto* element_encodings = model.GetOrCreate(); + auto* integer_encoder = model.GetOrCreate(); + element_encodings->Add(x0, + {{IntegerValue(2), l0}, + {IntegerValue(4), l1}, + {IntegerValue(2), l2}, + {IntegerValue(10), l3}}, + 2); + integer_encoder->AssociateToIntegerEqualValue(l2, x1, IntegerValue(7)); + model.Add(NewIntegerVariableFromLiteral(l0)); + model.Add(NewIntegerVariableFromLiteral(l1)); + model.Add(NewIntegerVariableFromLiteral(l2)); + model.Add(NewIntegerVariableFromLiteral(l3)); + + LinearConstraintBuilder builder(&model); + builder.AddConstant(IntegerValue(-1)); // To be cleared. + EXPECT_TRUE( + model.GetOrCreate()->TryToLinearize(x0, x1, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "12*X3 2*X4 48*X5 + 12"); + + EXPECT_TRUE( + model.GetOrCreate()->TryToLinearize(x1, x0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "12*X3 2*X4 48*X5 + 12"); +} + +TEST(DetectLinearEncodingOfProductsTest, BooleanAffinePosPosProduct) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const AffineExpression left(var, IntegerValue(2), IntegerValue(-1)); + const AffineExpression right(var, IntegerValue(3), IntegerValue(1)); + + LinearConstraintBuilder builder(&model); + util_intops::StrongVector lp_values(2, 0.0); + + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } + + builder.Clear(); + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + right, left, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } +} + +TEST(DetectLinearEncodingOfProductsTest, BooleanAffinePosNegProduct) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const AffineExpression left(var, IntegerValue(2), IntegerValue(-1)); + const AffineExpression right(NegationOf(var), IntegerValue(3), + IntegerValue(1)); + + LinearConstraintBuilder builder(&model); + util_intops::StrongVector lp_values(2, 0.0); + + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } + builder.Clear(); + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + right, left, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } +} + +TEST(DetectLinearEncodingOfProductsTest, BooleanAffineNegNegProduct) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const AffineExpression left(NegationOf(var), IntegerValue(2), + IntegerValue(-1)); + const AffineExpression right(NegationOf(var), IntegerValue(3), + IntegerValue(1)); + + LinearConstraintBuilder builder(&model); + util_intops::StrongVector lp_values(2, 0.0); + + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } + + builder.Clear(); + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + right, left, &builder)); + for (int value : {0, 1}) { + lp_values[var] = static_cast(value); + lp_values[NegationOf(var)] = static_cast(-value); + EXPECT_EQ(builder.BuildExpression().LpValue(lp_values), + left.LpValue(lp_values) * right.LpValue(lp_values)); + } +} + +TEST(DetectLinearEncodingOfProductsTest, NoDetectionWhenNotBooleanA) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(0, 2)); + const AffineExpression left(var, IntegerValue(2), IntegerValue(-1)); + const AffineExpression right(var, IntegerValue(3), IntegerValue(1)); + + LinearConstraintBuilder builder(&model); + EXPECT_FALSE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); +} + +TEST(DetectLinearEncodingOfProductsTest, NoDetectionWhenNotBooleanB) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(-1, 1)); + const AffineExpression left(var, IntegerValue(2), IntegerValue(-1)); + const AffineExpression right(var, IntegerValue(3), IntegerValue(1)); + + LinearConstraintBuilder builder(&model); + EXPECT_FALSE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); +} + +TEST(DetectLinearEncodingOfProductsTest, AffineTimesConstant) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(0, 5)); + const AffineExpression left(var, IntegerValue(2), IntegerValue(-1)); + const AffineExpression right = IntegerValue(3); + + LinearConstraintBuilder builder(&model); + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + left, right, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "6*X0 + -3"); + + EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( + right, left, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "6*X0 + -3"); +} + +TEST(DecomposeProductTest, MatchingElementEncodings) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const Literal l1(model.Add(NewBooleanVariable()), true); + const Literal l2(model.Add(NewBooleanVariable()), true); + const Literal l3(model.Add(NewBooleanVariable()), true); + + model.Add(NewIntegerVariableFromLiteral(l0)); + model.Add(NewIntegerVariableFromLiteral(l1)); + model.Add(NewIntegerVariableFromLiteral(l2)); + model.Add(NewIntegerVariableFromLiteral(l3)); + + const IntegerVariable x0(model.Add(NewIntegerVariable(0, 100))); + const IntegerVariable x1(model.Add(NewIntegerVariable(0, 100))); + + auto* element_encodings = model.GetOrCreate(); + element_encodings->Add(x0, + {{IntegerValue(2), l0}, + {IntegerValue(4), l1}, + {IntegerValue(2), l2}, + {IntegerValue(10), l3}}, + 2); + element_encodings->Add(x1, + {{IntegerValue(3), l0}, + {IntegerValue(10), l1}, + {IntegerValue(20), l2}, + {IntegerValue(30), l3}}, + 2); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0, IntegerValue(2), IntegerValue(3)}, + {l1, IntegerValue(4), IntegerValue(10)}, + {l2, IntegerValue(2), IntegerValue(20)}, + {l3, IntegerValue(10), IntegerValue(30)}, + }; + ASSERT_FALSE(terms_a.empty()); + EXPECT_EQ(terms_a, expected_terms_a); + + const std::vector terms_b = + decomposer->TryToDecompose(x1, x0); + const std::vector expected_terms_b = { + {l0, IntegerValue(3), IntegerValue(2)}, + {l1, IntegerValue(10), IntegerValue(4)}, + {l2, IntegerValue(20), IntegerValue(2)}, + {l3, IntegerValue(30), IntegerValue(10)}, + }; + ASSERT_FALSE(terms_b.empty()); + EXPECT_EQ(terms_b, expected_terms_b); +} + +TEST(DecomposeProductTest, MatchingEncodingAndSizeTwoEncoding) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const Literal l1(model.Add(NewBooleanVariable()), true); + const Literal l2(model.Add(NewBooleanVariable()), true); + const Literal l3(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(0, 100))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + + auto* element_encodings = model.GetOrCreate(); + element_encodings->Add(x0, + {{IntegerValue(2), l0}, + {IntegerValue(4), l1}, + {IntegerValue(2), l2}, + {IntegerValue(10), l3}}, + 2); + + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(l2, x1, IntegerValue(7)); + model.Add(NewIntegerVariableFromLiteral(l0)); + model.Add(NewIntegerVariableFromLiteral(l1)); + model.Add(NewIntegerVariableFromLiteral(l2)); + model.Add(NewIntegerVariableFromLiteral(l3)); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0, IntegerValue(2), IntegerValue(6)}, + {l1, IntegerValue(4), IntegerValue(6)}, + {l2, IntegerValue(2), IntegerValue(7)}, + {l3, IntegerValue(10), IntegerValue(6)}, + }; + EXPECT_EQ(terms_a, expected_terms_a); + + const std::vector terms_b = + decomposer->TryToDecompose(x1, x0); + const std::vector expected_terms_b = { + {l0, IntegerValue(6), IntegerValue(2)}, + {l1, IntegerValue(6), IntegerValue(4)}, + {l2, IntegerValue(7), IntegerValue(2)}, + {l3, IntegerValue(6), IntegerValue(10)}, + }; + EXPECT_EQ(terms_b, expected_terms_b); +} + +TEST(DecomposeProductTest, MatchingSizeTwoEncodingsFirstFirst) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(5, 6))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(l0, x0, IntegerValue(5)); + integer_encoder->AssociateToIntegerEqualValue(l0, x1, IntegerValue(6)); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0, IntegerValue(5), IntegerValue(6)}, + {l0.Negated(), IntegerValue(6), IntegerValue(7)}, + }; + EXPECT_EQ(terms_a, expected_terms_a); +} + +TEST(DecomposeProductTest, MatchingSizeTwoEncodingsFirstLast) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(5, 6))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(l0, x0, IntegerValue(5)); + integer_encoder->AssociateToIntegerEqualValue(l0, x1, IntegerValue(7)); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0, IntegerValue(5), IntegerValue(7)}, + {l0.Negated(), IntegerValue(6), IntegerValue(6)}, + }; + EXPECT_EQ(terms_a, expected_terms_a); +} + +TEST(DecomposeProductTest, MatchingSizeTwoEncodingslastFirst) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(5, 6))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(l0, x0, IntegerValue(6)); + integer_encoder->AssociateToIntegerEqualValue(l0, x1, IntegerValue(6)); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0.Negated(), IntegerValue(5), IntegerValue(7)}, + {l0, IntegerValue(6), IntegerValue(6)}, + }; + EXPECT_EQ(terms_a, expected_terms_a); +} + +TEST(DecomposeProductTest, MatchingSizeTwoEncodingsLastLast) { + Model model; + + const Literal l0(model.Add(NewBooleanVariable()), true); + const IntegerVariable x0(model.Add(NewIntegerVariable(5, 6))); + const IntegerVariable x1(model.Add(NewIntegerVariable(6, 7))); + + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(l0, x0, IntegerValue(6)); + integer_encoder->AssociateToIntegerEqualValue(l0, x1, IntegerValue(7)); + + auto* decomposer = model.GetOrCreate(); + const std::vector terms_a = + decomposer->TryToDecompose(x0, x1); + const std::vector expected_terms_a = { + {l0.Negated(), IntegerValue(5), IntegerValue(6)}, + {l0, IntegerValue(6), IntegerValue(7)}, + }; + EXPECT_EQ(terms_a, expected_terms_a); +} + +TEST(ProductDetectorTest, BasicCases) { + Model model; + model.GetOrCreate()->set_detect_linearized_product(true); + model.GetOrCreate()->set_linearization_level(2); + auto* detector = model.GetOrCreate(); + detector->ProcessTernaryClause(Literals({+1, +2, +3})); + detector->ProcessBinaryClause(Literals({-1, -2})); + detector->ProcessBinaryClause(Literals({-1, -3})); + EXPECT_EQ(kNoLiteralIndex, detector->GetProduct(Literal(-1), Literal(-2))); + EXPECT_EQ(kNoLiteralIndex, detector->GetProduct(Literal(-1), Literal(-3))); + EXPECT_EQ(Literal(+1).Index(), + detector->GetProduct(Literal(-2), Literal(-3))); +} + +TEST(ProductDetectorTest, BasicIntCase1) { + Model model; + model.GetOrCreate()->set_detect_linearized_product(true); + model.GetOrCreate()->set_linearization_level(2); + auto* detector = model.GetOrCreate(); + + IntegerVariable x(10); + IntegerVariable y(20); + detector->ProcessConditionalZero(Literal(+1), x); + detector->ProcessConditionalEquality(Literal(-1), x, y); + + EXPECT_EQ(x, detector->GetProduct(Literal(-1), y)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(-1), x)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(1), x)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(1), y)); +} + +TEST(ProductDetectorTest, BasicIntCase2) { + Model model; + model.GetOrCreate()->set_detect_linearized_product(true); + model.GetOrCreate()->set_linearization_level(2); + auto* detector = model.GetOrCreate(); + + IntegerVariable x(10); + IntegerVariable y(20); + detector->ProcessConditionalEquality(Literal(-1), x, y); + detector->ProcessConditionalZero(Literal(+1), x); + + EXPECT_EQ(x, detector->GetProduct(Literal(-1), y)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(-1), x)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(1), x)); + EXPECT_EQ(kNoIntegerVariable, detector->GetProduct(Literal(1), y)); +} + +TEST(ProductDetectorTest, RLT) { + Model model; + model.GetOrCreate()->set_add_rlt_cuts(true); + model.GetOrCreate()->set_linearization_level(2); + auto* detector = model.GetOrCreate(); + auto* integer_encoder = model.GetOrCreate(); + + const Literal l0(model.Add(NewBooleanVariable()), true); + const IntegerVariable x(model.Add(NewIntegerVariable(0, 1))); + integer_encoder->AssociateToIntegerEqualValue(l0, x, IntegerValue(1)); + + const Literal l1(model.Add(NewBooleanVariable()), true); + const IntegerVariable y(model.Add(NewIntegerVariable(0, 1))); + integer_encoder->AssociateToIntegerEqualValue(l1, y, IntegerValue(1)); + + const Literal l2(model.Add(NewBooleanVariable()), true); + const IntegerVariable z(model.Add(NewIntegerVariable(0, 1))); + integer_encoder->AssociateToIntegerEqualValue(l2, z, IntegerValue(1)); + + // X + (1 - Y) + Z >= 1 + detector->ProcessTernaryClause(Literals({+1, -2, +3})); + + // Lets choose value so that X + Z >= Y is tight. + util_intops::StrongVector lp_values(10, 0.0); + lp_values[x] = 0.7; + lp_values[y] = 0.9; + lp_values[z] = 0.2; + const absl::flat_hash_map lp_vars = { + {x, glop::ColIndex(0)}, {y, glop::ColIndex(1)}, {z, glop::ColIndex(2)}}; + detector->InitializeBooleanRLTCuts(lp_vars, lp_values); + + // (1 - X) * Y <= Z, 0.3 * 0.9 == 0.27 <= 0.2, interesting! + // (1 - X) * (1 - Z) <= (1 - Y), 0.3 * 0.8 == 0.24 <= 0.1, interesting ! + // Y * (1 - Z) <= X, 0.9 * 0.8 == 0.72 <= 0.7, interesting ! + EXPECT_EQ(detector->BoolRLTCandidates().size(), 3); + EXPECT_THAT(detector->BoolRLTCandidates().at(NegationOf(x)), + UnorderedElementsAre(y, NegationOf(z))); + EXPECT_THAT(detector->BoolRLTCandidates().at(y), + UnorderedElementsAre(NegationOf(x), NegationOf(z))); + EXPECT_THAT(detector->BoolRLTCandidates().at(NegationOf(z)), + UnorderedElementsAre(y, NegationOf(x))); + + // And we can recover the literal ub. + EXPECT_EQ(detector->LiteralProductUpperBound(NegationOf(x), y), z); + EXPECT_EQ(detector->LiteralProductUpperBound(NegationOf(x), NegationOf(z)), + NegationOf(y)); + EXPECT_EQ(detector->LiteralProductUpperBound(y, NegationOf(z)), x); + + // If we change values, we might get less candidates though + lp_values[x] = 0.0; + lp_values[y] = 0.2; + lp_values[z] = 0.2; + detector->InitializeBooleanRLTCuts(lp_vars, lp_values); + + // (1 - X) * Y <= Z, 1.0 * 0.2 <= 0.2, tight, but not interesting. + // (1 - X) * (1 - Z) <= (1 - Y), 1.0 * 0.8 <= 0.8 tight, but not interesting. + // Y * (1 - Z) <= X, 0.2 * 0.8 <= 0.0, interesting ! + EXPECT_EQ(detector->BoolRLTCandidates().size(), 2); + EXPECT_THAT(detector->BoolRLTCandidates().at(y), + UnorderedElementsAre(NegationOf(z))); + EXPECT_THAT(detector->BoolRLTCandidates().at(NegationOf(z)), + UnorderedElementsAre(y)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/inclusion_test.cc b/ortools/sat/inclusion_test.cc new file mode 100644 index 00000000000..7f5276708f5 --- /dev/null +++ b/ortools/sat/inclusion_test.cc @@ -0,0 +1,177 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/inclusion.h" + +#include +#include + +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/util.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(InclusionDetectorTest, SymmetricExample) { + CompactVectorVector storage; + InclusionDetector detector(storage); + detector.AddPotentialSet(storage.Add({1, 2})); + detector.AddPotentialSet(storage.Add({1, 3})); + detector.AddPotentialSet(storage.Add({1, 2, 3})); + detector.AddPotentialSet(storage.Add({1, 4, 3, 2})); + + std::vector> included; + detector.DetectInclusions([&included](int subset, int superset) { + included.push_back({subset, superset}); + }); + EXPECT_THAT(included, + ::testing::ElementsAre(std::make_pair(0, 2), std::make_pair(1, 2), + std::make_pair(0, 3), std::make_pair(1, 3), + std::make_pair(2, 3))); +} + +// If sets are duplicates, we do not detect both inclusions, but just one. +TEST(InclusionDetectorTest, DuplicateBehavior) { + CompactVectorVector storage; + InclusionDetector detector(storage); + detector.AddPotentialSet(storage.Add({1, 2})); + detector.AddPotentialSet(storage.Add({1, 2})); + detector.AddPotentialSet(storage.Add({1, 2})); + detector.AddPotentialSet(storage.Add({1, 2})); + + std::vector> included; + detector.DetectInclusions([&included](int subset, int superset) { + included.push_back({subset, superset}); + }); + EXPECT_THAT(included, ::testing::ElementsAre( + std::make_pair(0, 1), std::make_pair(0, 2), + std::make_pair(1, 2), std::make_pair(0, 3), + std::make_pair(2, 3), std::make_pair(1, 3))); +} + +TEST(InclusionDetectorTest, NonSymmetricExample) { + CompactVectorVector storage; + InclusionDetector detector(storage); + + // Index 0, 1, 2 + detector.AddPotentialSubset(storage.Add({1, 2})); + detector.AddPotentialSubset(storage.Add({1, 3})); + detector.AddPotentialSubset(storage.Add({1, 2, 3})); + + // Index 3, 4, 5, 6 + detector.AddPotentialSuperset(storage.Add({1, 2})); + detector.AddPotentialSuperset(storage.Add({1, 4, 3})); + detector.AddPotentialSuperset(storage.Add({1, 4, 3})); + detector.AddPotentialSuperset(storage.Add({1, 5, 2, 3})); + + std::vector> included; + detector.DetectInclusions([&included](int subset, int superset) { + included.push_back({subset, superset}); + }); + EXPECT_THAT(included, ::testing::ElementsAre( + std::make_pair(0, 3), std::make_pair(1, 4), + std::make_pair(1, 5), std::make_pair(0, 6), + std::make_pair(2, 6), std::make_pair(1, 6))); + + // Class can be used multiple time. + // Here we test exclude a subset for appearing twice. + included.clear(); + detector.DetectInclusions([&detector, &included](int subset, int superset) { + included.push_back({subset, superset}); + detector.StopProcessingCurrentSubset(); + }); + EXPECT_THAT(included, + ::testing::ElementsAre(std::make_pair(0, 3), std::make_pair(1, 4), + std::make_pair(2, 6))); + + // Here we test exclude a superset for appearing twice. + included.clear(); + detector.DetectInclusions([&detector, &included](int subset, int superset) { + included.push_back({subset, superset}); + detector.StopProcessingCurrentSuperset(); + }); + EXPECT_THAT(included, ::testing::ElementsAre( + std::make_pair(0, 3), std::make_pair(1, 4), + std::make_pair(1, 5), std::make_pair(0, 6))); + + // Here we stop on first match. + included.clear(); + detector.DetectInclusions([&detector, &included](int subset, int superset) { + included.push_back({subset, superset}); + detector.Stop(); + }); + EXPECT_THAT(included, ::testing::ElementsAre(std::make_pair(0, 3))); +} + +TEST(InclusionDetectorTest, InclusionChain) { + CompactVectorVector storage; + InclusionDetector detector(storage); + detector.AddPotentialSet(storage.Add({1})); + detector.AddPotentialSet(storage.Add({1, 2})); + detector.AddPotentialSet(storage.Add({1, 2, 3})); + + std::vector> included; + detector.DetectInclusions([&included](int subset, int superset) { + included.push_back({subset, superset}); + }); + EXPECT_THAT(included, + ::testing::ElementsAre(std::make_pair(0, 1), std::make_pair(0, 2), + std::make_pair(1, 2))); + + // If we stop processing a superset that can also be a subset, it should + // not appear as such. + included.clear(); + detector.DetectInclusions([&](int subset, int superset) { + detector.StopProcessingCurrentSuperset(); + included.push_back({subset, superset}); + }); + EXPECT_THAT(included, ::testing::ElementsAre(std::make_pair(0, 1), + std::make_pair(0, 2))); +} + +// We just check that nothing crashes. +TEST(InclusionDetectorTest, RandomTest) { + absl::BitGen random; + CompactVectorVector storage; + InclusionDetector detector(storage); + + std::vector temp; + for (int i = 0; i < 1000; ++i) { + temp.clear(); + const int size = absl::Uniform(random, 0, 100); + for (int j = 0; j < size; ++j) { + temp.push_back(absl::Uniform(random, 0, 10000)); + } + if (absl::Bernoulli(random, 0.5)) { + detector.AddPotentialSet(storage.Add(temp)); + } else { + if (absl::Bernoulli(random, 0.5)) { + detector.AddPotentialSubset(storage.Add(temp)); + } else { + detector.AddPotentialSuperset(storage.Add(temp)); + } + } + } + + int num_inclusions = 0; + detector.DetectInclusions( + [&num_inclusions](int subset, int superset) { ++num_inclusions; }); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index ea6b4885af0..5803e39202c 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -49,7 +49,7 @@ namespace operations_research { namespace sat { std::vector NegationOf( - const std::vector& vars) { + absl::Span vars) { std::vector result(vars.size()); for (int i = 0; i < vars.size(); ++i) { result[i] = NegationOf(vars[i]); @@ -1686,13 +1686,13 @@ bool IntegerTrail::EnqueueInternal( } const int prev_trail_index = var_trail_index_[i_lit.var]; + var_lbs_[i_lit.var] = i_lit.bound; + var_trail_index_[i_lit.var] = integer_trail_.size(); integer_trail_.push_back({/*bound=*/i_lit.bound, /*var=*/i_lit.var, /*prev_trail_index=*/prev_trail_index, /*reason_index=*/reason_index}); - var_lbs_[i_lit.var] = i_lit.bound; - var_trail_index_[i_lit.var] = integer_trail_.size() - 1; return true; } @@ -1737,13 +1737,13 @@ bool IntegerTrail::EnqueueAssociatedIntegerLiteral(IntegerLiteral i_lit, const int reason_index = AppendReasonToInternalBuffers({literal_reason.Negated()}, {}); const int prev_trail_index = var_trail_index_[i_lit.var]; + var_lbs_[i_lit.var] = i_lit.bound; + var_trail_index_[i_lit.var] = integer_trail_.size(); integer_trail_.push_back({/*bound=*/i_lit.bound, /*var=*/i_lit.var, /*prev_trail_index=*/prev_trail_index, /*reason_index=*/reason_index}); - var_lbs_[i_lit.var] = i_lit.bound; - var_trail_index_[i_lit.var] = integer_trail_.size() - 1; return true; } @@ -2113,9 +2113,10 @@ void GenericLiteralWatcher::CallOnNextPropagate(int id) { void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) { // Process any new Literal on the trail. + const int literal_limit = literal_to_watcher_.size(); while (propagation_trail_index_ < trail->Index()) { const Literal literal = (*trail)[propagation_trail_index_++]; - if (literal.Index() >= literal_to_watcher_.size()) continue; + if (literal.Index() >= literal_limit) continue; for (const auto entry : literal_to_watcher_[literal]) { if (!in_queue_[entry.id]) { in_queue_[entry.id] = true; @@ -2128,8 +2129,9 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) { } // Process the newly changed variables lower bounds. + const int var_limit = var_to_watcher_.size(); for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { - if (var.value() >= var_to_watcher_.size()) continue; + if (var.value() >= var_limit) continue; for (const auto entry : var_to_watcher_[var]) { if (!in_queue_[entry.id]) { in_queue_[entry.id] = true; diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 3923dcb89d5..8aa355bddbc 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -207,8 +207,7 @@ inline std::string IntegerTermDebugString(IntegerVariable var, } // Returns the vector of the negated variables. -std::vector NegationOf( - const std::vector& vars); +std::vector NegationOf(absl::Span vars); // The integer equivalent of a literal. // It represents an IntegerVariable and an upper/lower bound on it. diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 9d9924f8963..29ca184f70b 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -445,10 +445,11 @@ bool LinearConstraintPropagator::PropagateAtLevelZero() { IntegerValue new_ub; if (use_int128) { const IntegerValue ub = shared_->integer_trail->LevelZeroUpperBound(var); - const absl::int128 div128 = slack128 / absl::int128(coeff.value()); - if (absl::int128(lb.value()) + div128 >= absl::int128(ub.value())) { + if (absl::int128((ub - lb).value()) * absl::int128(coeff.value()) <= + slack128) { continue; } + const absl::int128 div128 = slack128 / absl::int128(coeff.value()); new_ub = lb + IntegerValue(static_cast(div128)); } else { const IntegerValue div = slack / coeff; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 67e0e4632a2..7270d4bca9b 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -61,10 +61,10 @@ IntegerLiteral AtMinValue(IntegerVariable var, IntegerTrail* integer_trail) { return IntegerLiteral::LowerOrEqual(var, lb); } -IntegerLiteral ChooseBestObjectiveValue(IntegerVariable var, Model* model) { - const auto& variables = - model->GetOrCreate()->objective_impacting_variables; - auto* integer_trail = model->GetOrCreate(); +IntegerLiteral ChooseBestObjectiveValue( + IntegerVariable var, IntegerTrail* integer_trail, + ObjectiveDefinition* objective_definition) { + const auto& variables = objective_definition->objective_impacting_variables; if (variables.contains(var)) { return AtMinValue(var, integer_trail); } else if (variables.contains(NegationOf(var))) { @@ -394,8 +394,11 @@ std::function IntegerValueSelectionHeuristic( // Objective based value. if (parameters.exploit_objective()) { - value_selection_heuristics.push_back([model](IntegerVariable var) { - return ChooseBestObjectiveValue(var, model); + auto* integer_trail = model->GetOrCreate(); + auto* objective_definition = model->GetOrCreate(); + value_selection_heuristics.push_back([integer_trail, objective_definition]( + IntegerVariable var) { + return ChooseBestObjectiveValue(var, integer_trail, objective_definition); }); } diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index 5a3d5d64b4f..bf545851dc5 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -139,7 +139,9 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding(Model* model); IntegerLiteral AtMinValue(IntegerVariable var, IntegerTrail* integer_trail); // If a variable appear in the objective, branch on its best objective value. -IntegerLiteral ChooseBestObjectiveValue(IntegerVariable var, Model* model); +IntegerLiteral ChooseBestObjectiveValue( + IntegerVariable var, IntegerTrail* integer_trail, + ObjectiveDefinition* objective_definition); // Returns decision corresponding to var >= lb + max(1, (ub - lb) / 2). It also // CHECKs that the variable is not fixed. diff --git a/ortools/sat/integer_test.cc b/ortools/sat/integer_test.cc new file mode 100644 index 00000000000..48fe3902f10 --- /dev/null +++ b/ortools/sat/integer_test.cc @@ -0,0 +1,1333 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/integer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/base/types.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/sorted_interval_list.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +TEST(AffineExpressionTest, Inequalities) { + const IntegerVariable var(1); + EXPECT_EQ( + AffineExpression(var, IntegerValue(3)).LowerOrEqual(IntegerValue(8)), + IntegerLiteral::LowerOrEqual(var, IntegerValue(2))); + EXPECT_EQ( + AffineExpression(var, IntegerValue(-3)).LowerOrEqual(IntegerValue(-1)), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(1))); + EXPECT_EQ( + AffineExpression(var, IntegerValue(2)).GreaterOrEqual(IntegerValue(3)), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2))); +} + +TEST(AffineExpressionTest, ValueAt) { + const IntegerVariable var(1); + EXPECT_EQ(AffineExpression(var, IntegerValue(3)).ValueAt(IntegerValue(8)), + IntegerValue(3 * 8)); + EXPECT_EQ(AffineExpression(var, IntegerValue(3), IntegerValue(-2)) + .ValueAt(IntegerValue(5)), + IntegerValue(3 * 5 - 2)); +} + +TEST(AffineExpressionTest, NegatedConstant) { + const AffineExpression negated = AffineExpression(IntegerValue(3)).Negated(); + EXPECT_EQ(negated.var, kNoIntegerVariable); + EXPECT_EQ(negated.coeff, 0); + EXPECT_EQ(negated.constant, -3); +} + +TEST(AffineExpressionTest, ApiWithoutVar) { + const AffineExpression three(IntegerValue(3)); + EXPECT_TRUE(three.GreaterOrEqual(IntegerValue(2)).IsAlwaysTrue()); + EXPECT_TRUE(three.LowerOrEqual(IntegerValue(2)).IsAlwaysFalse()); +} + +TEST(ToDoubleTest, Infinities) { + EXPECT_EQ(ToDouble(IntegerValue(100)), 100.0); + + const double kInfinity = std::numeric_limits::infinity(); + EXPECT_EQ(ToDouble(kMaxIntegerValue), kInfinity); + EXPECT_EQ(ToDouble(kMinIntegerValue), -kInfinity); + + EXPECT_LT(ToDouble(kMaxIntegerValue - IntegerValue(1)), kInfinity); + EXPECT_GT(ToDouble(kMinIntegerValue + IntegerValue(1)), -kInfinity); +} + +TEST(FloorRatioTest, AllSmallCases) { + // Dividend can take any value. + for (IntegerValue dividend(-100); dividend < 100; ++dividend) { + // Divisor must be positive. + for (IntegerValue divisor(1); divisor < 100; ++divisor) { + const IntegerValue floor = FloorRatio(dividend, divisor); + EXPECT_LE(floor * divisor, dividend); + EXPECT_GT((floor + 1) * divisor, dividend); + } + } +} + +TEST(PositiveRemainderTest, AllCasesForFixedDivisor) { + IntegerValue divisor(17); + for (IntegerValue dividend(-100); dividend < 100; ++dividend) { + EXPECT_EQ(PositiveRemainder(dividend, divisor), + dividend - divisor * FloorRatio(dividend, divisor)); + } +} + +TEST(CeilRatioTest, AllSmallCases) { + // Dividend can take any value. + for (IntegerValue dividend(-100); dividend < 100; ++dividend) { + // Divisor must be positive. + for (IntegerValue divisor(1); divisor < 100; ++divisor) { + const IntegerValue ceil = CeilRatio(dividend, divisor); + EXPECT_GE(ceil * divisor, dividend); + EXPECT_LT((ceil - 1) * divisor, dividend); + } + } +} + +TEST(NegationOfTest, IsIdempotent) { + for (int i = 0; i < 100; ++i) { + const IntegerVariable var(i); + EXPECT_EQ(NegationOf(NegationOf(var)), var); + } +} + +TEST(NegationOfTest, VectorArgument) { + std::vector vars{IntegerVariable(1), IntegerVariable(2)}; + std::vector negated_vars = NegationOf(vars); + EXPECT_EQ(negated_vars.size(), vars.size()); + for (int i = 0; i < vars.size(); ++i) { + EXPECT_EQ(negated_vars[i], NegationOf(vars[i])); + } +} + +TEST(IntegerValue, NegatedCannotOverflow) { + EXPECT_GT(kMinIntegerValue - 1, std::numeric_limits::min()); +} + +TEST(IntegerLiteral, OverflowValueAreCapped) { + const IntegerVariable var(0); + EXPECT_EQ(IntegerLiteral::GreaterOrEqual(var, kMaxIntegerValue + 1), + IntegerLiteral::GreaterOrEqual( + var, IntegerValue(std::numeric_limits::max()))); + EXPECT_EQ(IntegerLiteral::LowerOrEqual(var, kMinIntegerValue - 1), + IntegerLiteral::LowerOrEqual( + var, IntegerValue(std::numeric_limits::min()))); +} + +TEST(IntegerLiteral, NegatedIsIdempotent) { + for (const IntegerValue value : + {kMinIntegerValue, kMaxIntegerValue, kMaxIntegerValue + 1, + IntegerValue(0), IntegerValue(1), IntegerValue(2)}) { + const IntegerLiteral literal = + IntegerLiteral::GreaterOrEqual(IntegerVariable(0), value); + CHECK_EQ(literal, literal.Negated().Negated()); + } +} + +// A bound difference of exactly kint64max is ok. +TEST(IntegerTrailDeathTest, LargeVariableDomain) { + Model model; + model.Add(NewIntegerVariable(-3, std::numeric_limits::max() - 3)); + + if (DEBUG_MODE) { + // But one of kint64max + 1 cause a check fail in debug. + EXPECT_DEATH(model.Add(NewIntegerVariable( + -3, std::numeric_limits::max() - 2)), + ""); + } +} + +TEST(IntegerTrailTest, ConstantIntegerVariableSharing) { + Model model; + const IntegerVariable a = model.Add(ConstantIntegerVariable(0)); + const IntegerVariable b = model.Add(ConstantIntegerVariable(7)); + const IntegerVariable c = model.Add(ConstantIntegerVariable(-7)); + const IntegerVariable d = model.Add(ConstantIntegerVariable(0)); + const IntegerVariable e = model.Add(ConstantIntegerVariable(3)); + EXPECT_EQ(a, d); + EXPECT_EQ(b, NegationOf(c)); + EXPECT_NE(a, e); + EXPECT_EQ(0, model.Get(Value(a))); + EXPECT_EQ(7, model.Get(Value(b))); + EXPECT_EQ(-7, model.Get(Value(c))); + EXPECT_EQ(0, model.Get(Value(d))); + EXPECT_EQ(3, model.Get(Value(e))); +} + +TEST(IntegerTrailTest, VariableCreationAndBoundGetter) { + Model model; + IntegerTrail* p = model.GetOrCreate(); + IntegerVariable a = model.Add(NewIntegerVariable(0, 10)); + IntegerVariable b = model.Add(NewIntegerVariable(-10, 10)); + IntegerVariable c = model.Add(NewIntegerVariable(20, 30)); + + // Index are dense and contiguous, but two indices are created each time. + // They start at zero. + EXPECT_EQ(0, a.value()); + EXPECT_EQ(1, NegationOf(a).value()); + EXPECT_EQ(2, b.value()); + EXPECT_EQ(3, NegationOf(b).value()); + EXPECT_EQ(4, c.value()); + EXPECT_EQ(5, NegationOf(c).value()); + + // Bounds matches the one we passed at creation. + EXPECT_EQ(0, p->LowerBound(a)); + EXPECT_EQ(10, p->UpperBound(a)); + EXPECT_EQ(-10, p->LowerBound(b)); + EXPECT_EQ(10, p->UpperBound(b)); + EXPECT_EQ(20, p->LowerBound(c)); + EXPECT_EQ(30, p->UpperBound(c)); + + // Test level-zero enqueue. + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::LowerOrEqual(a, IntegerValue(20)), {}, {})); + EXPECT_EQ(10, p->UpperBound(a)); + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::LowerOrEqual(a, IntegerValue(7)), {}, {})); + EXPECT_EQ(7, p->UpperBound(a)); + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(5)), {}, {})); + EXPECT_EQ(5, p->LowerBound(a)); +} + +TEST(IntegerTrailTest, Untrail) { + Model model; + IntegerTrail* p = model.GetOrCreate(); + IntegerVariable a = p->AddIntegerVariable(IntegerValue(1), IntegerValue(10)); + IntegerVariable b = p->AddIntegerVariable(IntegerValue(2), IntegerValue(10)); + + Trail* trail = model.GetOrCreate(); + trail->Resize(10); + + // We need a reason for the Enqueue(): + const Literal r(model.Add(NewBooleanVariable()), true); + trail->EnqueueWithUnitReason(r.Negated()); + + // Enqueue. + trail->SetDecisionLevel(1); + EXPECT_TRUE(p->Propagate(trail)); + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(5)), {r}, {})); + EXPECT_EQ(5, p->LowerBound(a)); + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::GreaterOrEqual(b, IntegerValue(7)), {r}, {})); + EXPECT_EQ(7, p->LowerBound(b)); + + trail->SetDecisionLevel(2); + EXPECT_TRUE(p->Propagate(trail)); + EXPECT_TRUE( + p->Enqueue(IntegerLiteral::GreaterOrEqual(b, IntegerValue(9)), {r}, {})); + EXPECT_EQ(9, p->LowerBound(b)); + + // Untrail. + trail->SetDecisionLevel(1); + p->Untrail(*trail, 0); + EXPECT_EQ(7, p->LowerBound(b)); + + trail->SetDecisionLevel(0); + p->Untrail(*trail, 0); + EXPECT_EQ(1, p->LowerBound(a)); + EXPECT_EQ(2, p->LowerBound(b)); +} + +TEST(IntegerTrailTest, BasicReason) { + Model model; + IntegerTrail* p = model.GetOrCreate(); + IntegerVariable a = p->AddIntegerVariable(IntegerValue(1), IntegerValue(10)); + + Trail* trail = model.GetOrCreate(); + trail->Resize(10); + trail->EnqueueWithUnitReason(Literal(-1)); + trail->EnqueueWithUnitReason(Literal(-2)); + trail->EnqueueWithUnitReason(Literal(+3)); + trail->EnqueueWithUnitReason(Literal(+4)); + trail->SetDecisionLevel(1); + EXPECT_TRUE(p->Propagate(trail)); + + // Enqueue. + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), + Literals({+1}), {})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(3)), + Literals({+2}), {})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(5)), + Literals({-3}), {})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(6)), + Literals({-4}), {})); + + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(6))), + ElementsAre(Literal(-4))); + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(5))), + ElementsAre(Literal(-3))); + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(4))), + ElementsAre(Literal(-3))); + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(3))), + ElementsAre(Literal(+2))); + EXPECT_TRUE( + p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(0))).empty()); + EXPECT_TRUE(p->ReasonFor(IntegerLiteral::GreaterOrEqual(a, IntegerValue(-10))) + .empty()); +} + +struct LazyReasonForTest : public LazyReasonInterface { + bool called = false; + + void Explain(int /*id*/, IntegerValue /*propagation_slack*/, + IntegerVariable /*variable_to_explain*/, int /*trail_index*/, + std::vector* /*literals_reason*/, + std::vector* /*trail_indices_reason*/) final { + called = true; + } +}; + +TEST(IntegerTrailTest, LazyReason) { + Model model; + IntegerTrail* p = model.GetOrCreate(); + IntegerVariable a = p->AddIntegerVariable(IntegerValue(1), IntegerValue(10)); + + Trail* trail = model.GetOrCreate(); + trail->Resize(10); + trail->SetDecisionLevel(1); + EXPECT_TRUE(p->Propagate(trail)); + + LazyReasonForTest mock; + + // Enqueue. + EXPECT_TRUE(p->EnqueueWithLazyReason( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), 0, 0, &mock)); + EXPECT_TRUE(p->Propagate(trail)); + EXPECT_FALSE(mock.called); + + // Called if needed for the conflict. + EXPECT_FALSE( + p->Enqueue(IntegerLiteral::LowerOrEqual(a, IntegerValue(1)), {}, {})); + EXPECT_TRUE(mock.called); +} + +TEST(IntegerTrailTest, LiteralAndBoundReason) { + Model model; + IntegerTrail* p = model.GetOrCreate(); + IntegerVariable a = model.Add(NewIntegerVariable(0, 10)); + IntegerVariable b = model.Add(NewIntegerVariable(0, 10)); + IntegerVariable c = model.Add(NewIntegerVariable(0, 10)); + + Trail* trail = model.GetOrCreate(); + trail->Resize(10); + trail->EnqueueWithUnitReason(Literal(-1)); + trail->EnqueueWithUnitReason(Literal(-2)); + trail->EnqueueWithUnitReason(Literal(-3)); + trail->EnqueueWithUnitReason(Literal(-4)); + trail->SetDecisionLevel(1); + EXPECT_TRUE(p->Propagate(trail)); + + // Enqueue. + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(1)), + Literals({+1}), {})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), + Literals({+2}), {})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(b, IntegerValue(3)), + Literals({+3}), + {IntegerLiteral::GreaterOrEqual(a, IntegerValue(1))})); + EXPECT_TRUE(p->Enqueue(IntegerLiteral::GreaterOrEqual(c, IntegerValue(5)), + Literals({+4, +3}), + {IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), + IntegerLiteral::GreaterOrEqual(b, IntegerValue(3))})); + + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(b, IntegerValue(2))), + UnorderedElementsAre(Literal(+1), Literal(+3))); + EXPECT_THAT(p->ReasonFor(IntegerLiteral::GreaterOrEqual(c, IntegerValue(3))), + UnorderedElementsAre(Literal(+2), Literal(+3), Literal(+4))); +} + +TEST(IntegerTrailTest, LevelZeroBounds) { + Model model; + auto* integer_trail = model.GetOrCreate(); + IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + + Trail* trail = model.GetOrCreate(); + trail->Resize(10); + trail->SetDecisionLevel(1); + trail->EnqueueWithUnitReason(Literal(-1)); + trail->EnqueueWithUnitReason(Literal(-2)); + EXPECT_TRUE(integer_trail->Propagate(trail)); + + // Enqueue. + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(x, IntegerValue(1)), Literals({+1}), {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(x, IntegerValue(2)), Literals({+2}), {})); + + // TEST. + EXPECT_EQ(integer_trail->LowerBound(x), IntegerValue(1)); + EXPECT_EQ(integer_trail->UpperBound(x), IntegerValue(2)); + EXPECT_EQ(integer_trail->LevelZeroLowerBound(x), IntegerValue(0)); + EXPECT_EQ(integer_trail->LevelZeroUpperBound(x), IntegerValue(10)); +} + +TEST(IntegerTrailTest, RelaxLinearReason) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + const IntegerVariable a = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable b = model.Add(NewIntegerVariable(0, 10)); + + Trail* trail = model.GetOrCreate(); + trail->SetDecisionLevel(1); + EXPECT_TRUE(integer_trail->Propagate(trail)); + + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(1)), {}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), {}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(b, IntegerValue(1)), {}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(3)), {}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(b, IntegerValue(3)), {}, {})); + + std::vector coeffs(2, IntegerValue(1)); + std::vector reasons{ + IntegerLiteral::GreaterOrEqual(a, IntegerValue(3)), + IntegerLiteral::GreaterOrEqual(b, IntegerValue(3))}; + + // No slack, nothing happens. + integer_trail->RelaxLinearReason(IntegerValue(0), coeffs, &reasons); + EXPECT_THAT(reasons, + ElementsAre(IntegerLiteral::GreaterOrEqual(a, IntegerValue(3)), + IntegerLiteral::GreaterOrEqual(b, IntegerValue(3)))); + + // Some slack, we find the "lowest" possible reason in term of trail index. + integer_trail->RelaxLinearReason(IntegerValue(3), coeffs, &reasons); + EXPECT_THAT(reasons, + ElementsAre(IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)), + IntegerLiteral::GreaterOrEqual(b, IntegerValue(1)))); +} + +TEST(IntegerTrailTest, LiteralIsTrueOrFalse) { + Model model; + const IntegerVariable a = model.Add(NewIntegerVariable(1, 9)); + + auto* integer_trail = model.GetOrCreate(); + EXPECT_TRUE(integer_trail->IntegerLiteralIsTrue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(0)))); + EXPECT_TRUE(integer_trail->IntegerLiteralIsTrue( + IntegerLiteral::LowerOrEqual(a, IntegerValue(10)))); + + EXPECT_TRUE(integer_trail->IntegerLiteralIsTrue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(1)))); + EXPECT_FALSE(integer_trail->IntegerLiteralIsFalse( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(1)))); + + EXPECT_FALSE(integer_trail->IntegerLiteralIsTrue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)))); + EXPECT_FALSE(integer_trail->IntegerLiteralIsFalse( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(2)))); + + EXPECT_FALSE(integer_trail->IntegerLiteralIsTrue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(10)))); + EXPECT_TRUE(integer_trail->IntegerLiteralIsFalse( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(10)))); +} + +TEST(IntegerTrailTest, VariableWithHole) { + Model model; + IntegerVariable a = + model.Add(NewIntegerVariable(Domain::FromIntervals({{1, 3}, {6, 7}}))); + model.Add(GreaterOrEqual(a, 4)); + EXPECT_EQ(model.Get(LowerBound(a)), 6); +} + +TEST(GenericLiteralWatcherTest, LevelZeroModifiedVariablesCallbackTest) { + Model model; + auto* integer_trail = model.GetOrCreate(); + auto* watcher = model.GetOrCreate(); + IntegerVariable a = model.Add(NewIntegerVariable(0, 10)); + IntegerVariable b = model.Add(NewIntegerVariable(-10, 10)); + IntegerVariable c = model.Add(NewIntegerVariable(20, 30)); + + std::vector collector; + watcher->RegisterLevelZeroModifiedVariablesCallback( + [&collector](const std::vector& modified_vars) { + collector = modified_vars; + }); + + // No propagation. + auto* sat_solver = model.GetOrCreate(); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_EQ(0, collector.size()); + + // Modify 1 variable. + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(c, IntegerValue(27)), {}, {})); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_EQ(1, collector.size()); + EXPECT_EQ(NegationOf(c), collector[0]); + + // Modify 2 variables. + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(a, IntegerValue(10)), {}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(b, IntegerValue(7)), {}, {})); + EXPECT_TRUE(sat_solver->Propagate()); + ASSERT_EQ(2, collector.size()); + EXPECT_EQ(a, collector[0]); + EXPECT_EQ(NegationOf(b), collector[1]); + + // Modify 1 variable at level 1. + model.GetOrCreate()->SetDecisionLevel(1); + EXPECT_TRUE(sat_solver->Propagate()); + collector.clear(); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(b, IntegerValue(6)), {}, {})); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_TRUE(collector.empty()); +} + +TEST(GenericLiteralWatcherTest, RevIsInDiveUpdate) { + Model model; + bool is_in_dive = false; + auto* sat_solver = model.GetOrCreate(); + auto* watcher = model.GetOrCreate(); + const Literal a(sat_solver->NewBooleanVariable(), true); + const Literal b(sat_solver->NewBooleanVariable(), true); + + // First decision. + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(a)); + EXPECT_FALSE(is_in_dive); + watcher->SetUntilNextBacktrack(&is_in_dive); + + // Second decision. + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(b)); + EXPECT_TRUE(is_in_dive); + watcher->SetUntilNextBacktrack(&is_in_dive); + + // If we backtrack, it should be set to false. + EXPECT_TRUE(sat_solver->ResetToLevelZero()); + EXPECT_FALSE(is_in_dive); + + // We can redo the same. + EXPECT_FALSE(is_in_dive); + watcher->SetUntilNextBacktrack(&is_in_dive); + + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(a)); + EXPECT_TRUE(is_in_dive); +} + +TEST(IntegerEncoderTest, BasicInequalityEncoding) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); + const Literal l3 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(3))); + const Literal l7 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7))); + const Literal l5 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(5))); + + // Test SearchForLiteralAtOrBefore(). + for (IntegerValue v(0); v < 10; ++v) { + IntegerValue unused; + const LiteralIndex lb_index = encoder->SearchForLiteralAtOrBefore( + IntegerLiteral::GreaterOrEqual(var, v), &unused); + const LiteralIndex ub_index = encoder->SearchForLiteralAtOrBefore( + IntegerLiteral::LowerOrEqual(var, v), &unused); + if (v < 3) { + EXPECT_EQ(lb_index, kNoLiteralIndex); + EXPECT_EQ(ub_index, l3.NegatedIndex()); + } else if (v < 5) { + EXPECT_EQ(lb_index, l3.Index()); + EXPECT_EQ(ub_index, l5.NegatedIndex()); + } else if (v < 7) { + EXPECT_EQ(lb_index, l5.Index()); + EXPECT_EQ(ub_index, l7.NegatedIndex()); + } else { + EXPECT_EQ(lb_index, l7.Index()); + EXPECT_EQ(ub_index, kNoLiteralIndex); + } + } + + // Test the propagation from the literal to the bounds. + // By default the polarity of the literal are false. + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_FALSE(model.Get(Value(l3))); + EXPECT_FALSE(model.Get(Value(l5))); + EXPECT_FALSE(model.Get(Value(l7))); + EXPECT_EQ(0, model.Get(LowerBound(var))); + EXPECT_EQ(2, model.Get(UpperBound(var))); + + // Test the other way around. + model.GetOrCreate()->Backtrack(0); + model.Add(GreaterOrEqual(var, 4)); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_TRUE(model.Get(Value(l3))); + EXPECT_FALSE(model.Get(Value(l5))); + EXPECT_FALSE(model.Get(Value(l7))); + EXPECT_EQ(4, model.Get(LowerBound(var))); + EXPECT_EQ(4, model.Get(UpperBound(var))); +} + +TEST(IntegerEncoderTest, GetOrCreateTrivialAssociatedLiteral) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); + EXPECT_EQ(encoder->GetTrueLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(0)))); + EXPECT_EQ(encoder->GetTrueLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(-1)))); + EXPECT_EQ(encoder->GetTrueLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(10)))); + EXPECT_EQ(encoder->GetFalseLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(11)))); + EXPECT_EQ(encoder->GetFalseLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(12)))); + EXPECT_EQ(encoder->GetFalseLiteral(), + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(-1)))); +} + +TEST(IntegerEncoderTest, ShiftedBinary) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(1, 2)); + + encoder->FullyEncodeVariable(var); + EXPECT_EQ(encoder->FullDomainEncoding(var).size(), 2); + const std::vector var_encoding = + encoder->FullDomainEncoding(var); + + const Literal g2 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2))); + const Literal l1 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(1))); + + EXPECT_EQ(g2, var_encoding[1].literal); + EXPECT_EQ(l1, var_encoding[0].literal); + EXPECT_EQ(g2, l1.Negated()); +} + +TEST(IntegerEncoderTest, SizeTwoDomains) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({1, 3}))); + + encoder->FullyEncodeVariable(var); + EXPECT_EQ(encoder->FullDomainEncoding(var).size(), 2); + const std::vector var_encoding = + encoder->FullDomainEncoding(var); + + const Literal g2 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2))); + const Literal g3 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(3))); + const Literal l1 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(1))); + const Literal l2 = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(2))); + + EXPECT_EQ(g3, var_encoding[1].literal); + EXPECT_EQ(l1, var_encoding[0].literal); + EXPECT_EQ(g3, l1.Negated()); + EXPECT_EQ(g2, g3); + EXPECT_EQ(l1, l2); +} + +TEST(IntegerEncoderDeathTest, NegatedIsNotCreatedTwice) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); + const IntegerLiteral l = IntegerLiteral::GreaterOrEqual(var, IntegerValue(3)); + const Literal associated = encoder->GetOrCreateAssociatedLiteral(l); + EXPECT_EQ(associated.Negated(), + encoder->GetOrCreateAssociatedLiteral(l.Negated())); +} + +TEST(IntegerEncoderTest, AutomaticallyDetectFullEncoding) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({3, -4, 0}))); + + // Adding <= min should automatically also add == min. + encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, IntegerValue(-4))); + + // We still miss one value. + EXPECT_FALSE(encoder->VariableIsFullyEncoded(var)); + EXPECT_FALSE(encoder->VariableIsFullyEncoded(NegationOf(var))); + + // This is enough to fully encode, because not(<=0) is >=3 which is ==3, and + // we do have all values. + encoder->GetOrCreateLiteralAssociatedToEquality(var, IntegerValue(0)); + EXPECT_TRUE(encoder->VariableIsFullyEncoded(var)); + EXPECT_TRUE(encoder->VariableIsFullyEncoded(NegationOf(var))); + + std::vector values; + for (const auto pair : encoder->FullDomainEncoding(var)) { + values.push_back(pair.value.value()); + } + EXPECT_THAT(values, ElementsAre(-4, 0, 3)); +} + +TEST(IntegerEncoderTest, BasicFullEqualityEncoding) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({3, -4, 0}))); + encoder->FullyEncodeVariable(var); + + // Normal var. + { + const auto& result = encoder->FullDomainEncoding(var); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], ValueLiteralPair({IntegerValue(-4), + Literal(BooleanVariable(0), true)})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(0), + Literal(BooleanVariable(1), true)})); + EXPECT_EQ(result[2], + ValueLiteralPair( + {IntegerValue(3), Literal(BooleanVariable(2), false)})); + } + + // Its negation. + { + const auto& result = encoder->FullDomainEncoding(NegationOf(var)); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], + ValueLiteralPair( + {IntegerValue(-3), Literal(BooleanVariable(2), false)})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(0), + Literal(BooleanVariable(1), true)})); + EXPECT_EQ(result[2], ValueLiteralPair({IntegerValue(4), + Literal(BooleanVariable(0), true)})); + } +} + +TEST(IntegerEncoderTest, PartialEncodingOfBinaryVarIsFull) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({0, 5}))); + const Literal lit(model.Add(NewBooleanVariable()), true); + + // Initially empty. + EXPECT_TRUE(encoder->PartialDomainEncoding(var).empty()); + + // Normal var. + encoder->AssociateToIntegerEqualValue(lit, var, IntegerValue(0)); + { + const auto& result = encoder->PartialDomainEncoding(var); + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], ValueLiteralPair({IntegerValue(0), lit})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(5), lit.Negated()})); + } + + // Its negation. + { + const auto& result = encoder->PartialDomainEncoding(NegationOf(var)); + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], ValueLiteralPair({IntegerValue(-5), lit.Negated()})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(0), lit})); + } +} + +TEST(IntegerEncoderTest, PartialEncodingOfLargeVar) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1e12)); + for (const int value : {50, 1000, 1}) { + const Literal lit(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerEqualValue(lit, var, IntegerValue(value)); + } + const auto& result = encoder->PartialDomainEncoding(var); + EXPECT_EQ(result.size(), 4); + // Zero is created because encoding (== 1) requires (>= 1 and <= 1), but the + // negation of (>= 1) is also (== 0). + EXPECT_EQ(result[0].value, IntegerValue(0)); + EXPECT_EQ(result[1].value, IntegerValue(1)); + EXPECT_EQ(result[2].value, IntegerValue(50)); + EXPECT_EQ(result[3].value, IntegerValue(1000)); +} + +TEST(IntegerEncoderTest, UpdateInitialDomain) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({3, -4, 0}))); + encoder->FullyEncodeVariable(var); + EXPECT_TRUE(model.GetOrCreate()->UpdateInitialDomain( + var, Domain::FromIntervals({{-4, -4}, {0, 0}, {5, 5}}))); + + // Note that we return the filtered encoding. + { + const auto& result = encoder->FullDomainEncoding(var); + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], ValueLiteralPair({IntegerValue(-4), + Literal(BooleanVariable(0), true)})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(0), + Literal(BooleanVariable(1), true)})); + } +} + +TEST(IntegerEncoderTest, Canonicalize) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromIntervals({{1, 4}, {7, 9}}))); + + EXPECT_EQ(encoder->Canonicalize( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2))), + std::make_pair(IntegerLiteral::GreaterOrEqual(var, IntegerValue(2)), + IntegerLiteral::LowerOrEqual(var, IntegerValue(1)))); + EXPECT_EQ(encoder->Canonicalize( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(4))), + std::make_pair(IntegerLiteral::GreaterOrEqual(var, IntegerValue(4)), + IntegerLiteral::LowerOrEqual(var, IntegerValue(3)))); + EXPECT_EQ( + encoder->Canonicalize(IntegerLiteral::LowerOrEqual(var, IntegerValue(4))), + std::make_pair(IntegerLiteral::LowerOrEqual(var, IntegerValue(4)), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)))); + EXPECT_EQ( + encoder->Canonicalize(IntegerLiteral::LowerOrEqual(var, IntegerValue(6))), + std::make_pair(IntegerLiteral::LowerOrEqual(var, IntegerValue(4)), + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)))); +} + +TEST(IntegerEncoderDeathTest, CanonicalizeDoNotAcceptTrivialLiterals) { + if (!DEBUG_MODE) GTEST_SKIP() << "Moot in opt mode"; + + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromIntervals({{1, 4}, {7, 9}}))); + + EXPECT_DEATH(encoder->Canonicalize( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(1))), + ""); + EXPECT_DEATH(encoder->Canonicalize( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(0))), + ""); + EXPECT_DEATH( + encoder->Canonicalize(IntegerLiteral::LowerOrEqual(var, IntegerValue(0))), + ""); + EXPECT_DEATH(encoder->Canonicalize( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(0))), + ""); + + EXPECT_DEATH( + encoder->Canonicalize(IntegerLiteral::LowerOrEqual(var, IntegerValue(9))), + ""); + EXPECT_DEATH(encoder->Canonicalize( + IntegerLiteral::LowerOrEqual(var, IntegerValue(15))), + ""); +} + +TEST(IntegerEncoderTest, TrivialAssociation) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromIntervals({{1, 1}, {5, 5}}))); + + { + const Literal l(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerLiteral( + l, IntegerLiteral::GreaterOrEqual(var, IntegerValue(1))); + EXPECT_EQ(model.Get(Value(l)), true); + } + { + const Literal l(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerLiteral( + l, IntegerLiteral::GreaterOrEqual(var, IntegerValue(6))); + EXPECT_EQ(model.Get(Value(l)), false); + } + { + const Literal l(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerEqualValue(l, var, IntegerValue(4)); + EXPECT_EQ(model.Get(Value(l)), false); + } +} + +TEST(IntegerEncoderTest, TrivialAssociationWithFixedVariable) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(Domain(1))); + { + const Literal l(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerEqualValue(l, var, IntegerValue(1)); + EXPECT_EQ(model.Get(Value(l)), true); + } +} + +TEST(IntegerEncoderTest, FullEqualityEncodingForTwoValuesWithDuplicates) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = + model.Add(NewIntegerVariable(Domain::FromValues({3, 5, 3}))); + encoder->FullyEncodeVariable(var); + + // Normal var. + { + const auto& result = encoder->FullDomainEncoding(var); + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], ValueLiteralPair({IntegerValue(3), + Literal(BooleanVariable(0), true)})); + EXPECT_EQ(result[1], + ValueLiteralPair( + {IntegerValue(5), Literal(BooleanVariable(0), false)})); + } + + // Its negation. + { + const auto& result = encoder->FullDomainEncoding(NegationOf(var)); + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], + ValueLiteralPair( + {IntegerValue(-5), Literal(BooleanVariable(0), false)})); + EXPECT_EQ(result[1], ValueLiteralPair({IntegerValue(-3), + Literal(BooleanVariable(0), true)})); + } +} + +#define EXPECT_BOUNDS_EQ(var, lb, ub) \ + EXPECT_EQ(model.Get(LowerBound(var)), lb); \ + EXPECT_EQ(model.Get(UpperBound(var)), ub) + +TEST(IntegerEncoderTest, IntegerTrailToEncodingPropagation) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + IntegerEncoder* encoder = model.GetOrCreate(); + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + + const IntegerVariable var = model.Add( + NewIntegerVariable(Domain::FromIntervals({{3, 4}, {7, 7}, {9, 9}}))); + model.Add(FullyEncodeVariable(var)); + + // We copy this because Enqueue() might change it. + const auto encoding = encoder->FullDomainEncoding(var); + + // Initial propagation is correct. + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 3, 9); + + // Note that the bounds snap to the possible values. + const VariablesAssignment& assignment = trail->Assignment(); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::LowerOrEqual(var, IntegerValue(8)), {}, {})); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_TRUE(assignment.LiteralIsFalse(encoding[3].literal)); + EXPECT_FALSE(assignment.VariableIsAssigned(encoding[0].literal.Variable())); + EXPECT_FALSE(assignment.VariableIsAssigned(encoding[1].literal.Variable())); + EXPECT_FALSE(assignment.VariableIsAssigned(encoding[2].literal.Variable())); + EXPECT_BOUNDS_EQ(var, 3, 7); + + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(5)), {}, {})); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_TRUE(assignment.LiteralIsFalse(encoding[0].literal)); + EXPECT_TRUE(assignment.LiteralIsFalse(encoding[1].literal)); + EXPECT_TRUE(assignment.LiteralIsTrue(encoding[2].literal)); + EXPECT_BOUNDS_EQ(var, 7, 7); + + // Encoding[2] will become true on the sat solver propagation. + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_TRUE(assignment.LiteralIsTrue(encoding[2].literal)); +} + +TEST(IntegerEncoderTest, EncodingToIntegerTrailPropagation) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + IntegerEncoder* encoder = model.GetOrCreate(); + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + const IntegerVariable var = model.Add( + NewIntegerVariable(Domain::FromIntervals({{3, 4}, {7, 7}, {9, 9}}))); + model.Add(FullyEncodeVariable(var)); + const auto& encoding = encoder->FullDomainEncoding(var); + + // Initial propagation is correct. + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 3, 9); + + // We remove the value 4, nothing happen. + trail->SetDecisionLevel(1); + trail->EnqueueSearchDecision(encoding[1].literal.Negated()); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 3, 9); + + // When we remove 3, the lower bound change though. + trail->SetDecisionLevel(2); + trail->EnqueueSearchDecision(encoding[0].literal.Negated()); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 7, 9); + + // The reason for the lower bounds is that both encoding[0] and encoding[1] + // are false. But it is captured by the literal associated to x >= 7. + { + const IntegerLiteral l = integer_trail->LowerBoundAsLiteral(var); + EXPECT_EQ(integer_trail->ReasonFor(l), + std::vector{ + Literal(encoder->GetAssociatedLiteral(l)).Negated()}); + } + + // Test the other direction. + trail->SetDecisionLevel(3); + trail->EnqueueSearchDecision(encoding[3].literal.Negated()); + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 7, 7); + { + const IntegerLiteral l = integer_trail->UpperBoundAsLiteral(var); + EXPECT_EQ(integer_trail->ReasonFor(l), + std::vector{ + Literal(encoder->GetAssociatedLiteral(l)).Negated()}); + } +} + +TEST(IntegerEncoderTest, IsFixedOrHasAssociatedLiteral) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add( + NewIntegerVariable(Domain::FromIntervals({{3, 4}, {7, 7}, {9, 9}}))); + + // Initial propagation is correct. + EXPECT_TRUE(sat_solver->Propagate()); + EXPECT_BOUNDS_EQ(var, 3, 9); + + // These are trivially true/false. + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 2))); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 3))); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 10))); + + // Not other encoding currently. + EXPECT_FALSE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 4))); + EXPECT_FALSE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 9))); + + // Add one encoding and test. + encoder->GetOrCreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual(var, 7)); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 5))); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 7))); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, 6))); + EXPECT_TRUE(encoder->IsFixedOrHasAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, 4))); +} + +TEST(IntegerEncoderTest, EncodingOfConstantVariableHasSizeOne) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(7, 7)); + model.Add(FullyEncodeVariable(var)); + const auto& encoding = encoder->FullDomainEncoding(var); + EXPECT_EQ(encoding.size(), 1); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsTrue( + encoding[0].literal)); +} + +TEST(IntegerEncoderTest, IntegerVariableOfAssignedLiteralIsFixed) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + + { + Literal literal_false = Literal(sat_solver->NewBooleanVariable(), true); + CHECK(sat_solver->AddUnitClause(literal_false.Negated())); + const IntegerVariable zero = + model.Add(NewIntegerVariableFromLiteral(literal_false)); + EXPECT_EQ(model.Get(UpperBound(zero)), 0); + } + + { + Literal literal_true = Literal(sat_solver->NewBooleanVariable(), true); + CHECK(sat_solver->AddUnitClause(literal_true)); + const IntegerVariable one = + model.Add(NewIntegerVariableFromLiteral(literal_true)); + EXPECT_EQ(model.Get(LowerBound(one)), 1); + } +} + +TEST(IntegerEncoderTest, LiteralView1) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const Literal literal(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerEqualValue(literal, var, IntegerValue(1)); + EXPECT_EQ(var, encoder->GetLiteralView(literal)); + EXPECT_EQ(kNoIntegerVariable, encoder->GetLiteralView(literal.Negated())); +} + +TEST(IntegerEncoderTest, LiteralView2) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const Literal literal(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerEqualValue(literal, var, IntegerValue(0)); + EXPECT_EQ(kNoIntegerVariable, encoder->GetLiteralView(literal)); + EXPECT_EQ(var, encoder->GetLiteralView(literal.Negated())); +} + +TEST(IntegerEncoderTest, LiteralView3) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const Literal literal(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerLiteral( + literal, IntegerLiteral::GreaterOrEqual(var, IntegerValue(1))); + EXPECT_EQ(var, encoder->GetLiteralView(literal)); + EXPECT_EQ(kNoIntegerVariable, encoder->GetLiteralView(literal.Negated())); +} + +TEST(IntegerEncoderTest, LiteralView4) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + const IntegerVariable var = model.Add(NewIntegerVariable(0, 1)); + const Literal literal(model.Add(NewBooleanVariable()), true); + encoder->AssociateToIntegerLiteral( + literal, IntegerLiteral::LowerOrEqual(var, IntegerValue(0))); + EXPECT_EQ(kNoIntegerVariable, encoder->GetLiteralView(literal)); + EXPECT_EQ(var, encoder->GetLiteralView(literal.Negated())); +} + +TEST(IntegerEncoderTest, IssueWhenNotFullyingPropagatingAtLoading) { + Model model; + auto* integer_trail = model.GetOrCreate(); + auto* integer_encoder = model.GetOrCreate(); + const IntegerVariable var = + integer_trail->AddIntegerVariable(Domain::FromValues({0, 3, 7, 9})); + const Literal false_literal = integer_encoder->GetFalseLiteral(); + integer_encoder->DisableImplicationBetweenLiteral(); + + // This currently doesn't propagate the domain. + integer_encoder->AssociateToIntegerLiteral( + false_literal, IntegerLiteral::GreaterOrEqual(var, IntegerValue(5))); + EXPECT_EQ(integer_trail->LowerBound(var), 0); + EXPECT_EQ(integer_trail->UpperBound(var), 9); + + // And that used to fail because it does some domain propagation when it + // detect that some value cannot be there and update the domains of var while + // iterating over it. + integer_encoder->FullyEncodeVariable(var); +} + +#undef EXPECT_BOUNDS_EQ + +TEST(SolveIntegerProblemWithLazyEncodingTest, Sat) { + static const int kNumVariables = 10; + Model model; + std::vector integer_vars; + for (int i = 0; i < kNumVariables; ++i) { + integer_vars.push_back(model.Add(NewIntegerVariable(0, 10))); + } + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic(integer_vars, &model); + ConfigureSearchHeuristics(&model); + ASSERT_EQ(model.GetOrCreate()->SolveIntegerProblem(), + SatSolver::Status::FEASIBLE); + for (const IntegerVariable var : integer_vars) { + EXPECT_EQ(model.Get(LowerBound(var)), model.Get(UpperBound(var))); + } +} + +TEST(SolveIntegerProblemWithLazyEncodingTest, Unsat) { + Model model; + const IntegerVariable var = model.Add(NewIntegerVariable(-100, 100)); + model.Add(LowerOrEqual(var, -10)); + model.Add(GreaterOrEqual(var, 10)); + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic({var}, &model); + ConfigureSearchHeuristics(&model); + EXPECT_EQ(model.GetOrCreate()->SolveIntegerProblem(), + SatSolver::Status::INFEASIBLE); +} + +TEST(IntegerTrailTest, InitialVariableDomainIsUpdated) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + const IntegerVariable var = + integer_trail->AddIntegerVariable(IntegerValue(0), IntegerValue(1000)); + EXPECT_EQ(integer_trail->InitialVariableDomain(var), Domain(0, 1000)); + EXPECT_EQ(integer_trail->InitialVariableDomain(NegationOf(var)), + Domain(-1000, 0)); + + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(7)), {}, {})); + EXPECT_EQ(integer_trail->InitialVariableDomain(var), Domain(7, 1000)); + EXPECT_EQ(integer_trail->InitialVariableDomain(NegationOf(var)), + Domain(-1000, -7)); +} + +TEST(IntegerTrailTest, AppendNewBounds) { + Model model; + const Literal l(model.Add(NewBooleanVariable()), true); + const IntegerVariable var(model.Add(NewIntegerVariable(0, 100))); + + // So that there is a decision. + EXPECT_TRUE( + model.GetOrCreate()->EnqueueDecisionIfNotConflicting(l)); + + // Enqueue a bunch of fact. + IntegerTrail* integer_trail = model.GetOrCreate(); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(2)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(4)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(8)), {l.Negated()}, {})); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(9)), {l.Negated()}, {})); + + // Only the last bound should be present. + std::vector bounds; + integer_trail->AppendNewBounds(&bounds); + EXPECT_THAT(bounds, ElementsAre(IntegerLiteral::GreaterOrEqual( + var, IntegerValue(9)))); +} + +TEST(FastDivisionTest, AllPossibleValues) { + for (int i = 1; i <= std::numeric_limits::max(); ++i) { + const QuickSmallDivision div(i); + for (int j = 0; j <= std::numeric_limits::max(); ++j) { + const uint16_t result = div.DivideByDivisor(j); + const uint16_t j_rounded_to_lowest_multiple = result * i; + CHECK_LE(j_rounded_to_lowest_multiple, j); + CHECK_GT(j_rounded_to_lowest_multiple + i, j); + } + } +} + +static void BM_FloorRatio(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + benchmark::DoNotOptimize(test += FloorRatio(dividend, divisor)); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +static void BM_PositiveRemainder(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + benchmark::DoNotOptimize(test += PositiveRemainder(dividend, divisor)); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +static void BM_PositiveRemainderAlternative(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + benchmark::DoNotOptimize(test += dividend - + divisor * FloorRatio(dividend, divisor)); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +// What we use in the code. This is safe of integer overflow. The compiler +// should also do a single integer division to get the quotient and remainder. +static void BM_DivisionAndRemainder(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + benchmark::DoNotOptimize(test += FloorRatio(dividend, divisor)); + benchmark::DoNotOptimize(test += PositiveRemainder(dividend, divisor)); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +// An alternative version, note however that divisor * f might overflow! +static void BM_DivisionAndRemainderAlternative(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + const IntegerValue f = FloorRatio(dividend, divisor); + benchmark::DoNotOptimize(test += f); + benchmark::DoNotOptimize(test += dividend - divisor * f); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +// The best we can hope for ? +static void BM_DivisionAndRemainderBaseline(benchmark::State& state) { + IntegerValue divisor(654676436498); + IntegerValue dividend(45454655155444); + IntegerValue test(0); + for (auto _ : state) { + dividend++; + divisor++; + benchmark::DoNotOptimize(test += dividend / divisor); + benchmark::DoNotOptimize(test += dividend % divisor); + } + state.SetBytesProcessed(static_cast(state.iterations())); +} + +BENCHMARK(BM_FloorRatio); +BENCHMARK(BM_PositiveRemainder); +BENCHMARK(BM_PositiveRemainderAlternative); +BENCHMARK(BM_DivisionAndRemainder); +BENCHMARK(BM_DivisionAndRemainderAlternative); +BENCHMARK(BM_DivisionAndRemainderBaseline); + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/intervals_test.cc b/ortools/sat/intervals_test.cc new file mode 100644 index 00000000000..ab2ead8f90e --- /dev/null +++ b/ortools/sat/intervals_test.cc @@ -0,0 +1,278 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/intervals.h" + +#include + +#include +#include + +#include "gtest/gtest.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(IntervalsRepositoryTest, Precedences) { + Model model; + const AffineExpression start1(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size1(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end1(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression start2(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size2(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end2(model.Add(NewIntegerVariable(0, 10))); + + auto* repo = model.GetOrCreate(); + const IntervalVariable a = repo->CreateInterval(start1, end1, size1); + const IntervalVariable b = repo->CreateInterval(start2, end2, size2); + + // Ok to call many times. + repo->CreateDisjunctivePrecedenceLiteral(a, b); + repo->CreateDisjunctivePrecedenceLiteral(a, b); + + EXPECT_NE(kNoLiteralIndex, repo->GetPrecedenceLiteral(a, b)); + EXPECT_EQ(Literal(repo->GetPrecedenceLiteral(a, b)), + Literal(repo->GetPrecedenceLiteral(b, a)).Negated()); +} + +TEST(SchedulingConstraintHelperTest, PushConstantBoundWithOptionalIntervals) { + Model model; + auto* repo = model.GetOrCreate(); + + const AffineExpression start(IntegerValue(0)); + const AffineExpression size(IntegerValue(10)); + const AffineExpression end(IntegerValue(10)); + + Literal presence2 = Literal(model.Add(NewBooleanVariable()), true); + IntervalVariable inter1 = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + IntervalVariable inter2 = + repo->CreateInterval(start, end, size, presence2.Index(), false); + + SchedulingConstraintHelper helper({inter1, inter2}, &model); + + EXPECT_TRUE(helper.IncreaseStartMin(1, IntegerValue(20))); + EXPECT_FALSE(model.Get(Value(presence2))); +} + +TEST(SchedulingDemandHelperTest, EnergyInWindow) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(8)); + + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 8, 2)); + EXPECT_EQ(8, demands_helper.EnergyMinInWindow(0, 0, 10)); + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 2, 10)); + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 0, 8)); + EXPECT_EQ(4, demands_helper.EnergyMinInWindow(0, 0, 9)); +} + +TEST(SchedulingDemandHelperTest, EnergyInWindowTakeIntoAccountWindowSize) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 4))); + const AffineExpression size(model.Add(NewIntegerVariable(6, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(6, 10))); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(8), IntegerValue(6)}, + {alt2, IntegerValue(6), IntegerValue(8)}}}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(48)); + + EXPECT_EQ(6, demands_helper.EnergyMinInWindow(0, 5, 6)); +} + +TEST(SchedulingDemandHelperTest, LinearizedDemandWithAffineExpression) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand( + AffineExpression(model.Add(NewIntegerVariable(2, 10)), 2, 5)); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + + LinearConstraintBuilder builder(&model); + ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "2*X3 + 5"); +} + +TEST(SchedulingDemandHelperTest, LinearizedDemandWithDecomposedEnergy) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt2, var2, IntegerValue(1)); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}}); + demands_helper.CacheAllEnergyValues(); + LinearConstraintBuilder builder(&model); + ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "4*X4 2*X5"); +} + +TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergy) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + IntegerEncoder* encoder = model.GetOrCreate(); + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const std::vector no_energy; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); + const std::vector energy = { + {alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}; + demands_helper.OverrideDecomposedEnergies({energy}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), energy); + + EXPECT_EQ(sat_solver->EnqueueDecisionAndBackjumpOnConflict(alt1.Negated()), + 0); + const std::vector filtered_energy = { + {alt2, IntegerValue(4), IntegerValue(2)}}; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); + EXPECT_EQ(demands_helper.DecomposedEnergies()[0], energy); +} + +TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergyWithFalseLiteral) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + const IntervalVariable inter = + model.GetOrCreate()->CreateInterval( + start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper helper({inter}, &model); + SchedulingDemandHelper demands_helper({demand}, &helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const std::vector no_energy; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); + + const Literal alt1 = encoder->GetFalseLiteral(); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); + const std::vector energy = { + {alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}; + demands_helper.OverrideDecomposedEnergies({energy}); + demands_helper.CacheAllEnergyValues(); + const std::vector filtered_energy = { + {alt2, IntegerValue(4), IntegerValue(2)}}; + EXPECT_EQ(demands_helper.DecomposedEnergies()[0], filtered_energy); + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); + EXPECT_EQ(0, model.GetOrCreate()->CurrentDecisionLevel()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/linear_constraint.cc b/ortools/sat/linear_constraint.cc index f3f0c891fdc..046f27f9681 100644 --- a/ortools/sat/linear_constraint.cc +++ b/ortools/sat/linear_constraint.cc @@ -172,26 +172,27 @@ double ComputeActivity( double a1 = 0.0; double a2 = 0.0; double a3 = 0.0; + const double* view = values.data(); for (; i < shifted_size; i += 4) { a0 += static_cast(constraint.coeffs[i].value()) * - values[constraint.vars[i]]; + view[constraint.vars[i].value()]; a1 += static_cast(constraint.coeffs[i + 1].value()) * - values[constraint.vars[i + 1]]; + view[constraint.vars[i + 1].value()]; a2 += static_cast(constraint.coeffs[i + 2].value()) * - values[constraint.vars[i + 2]]; + view[constraint.vars[i + 2].value()]; a3 += static_cast(constraint.coeffs[i + 3].value()) * - values[constraint.vars[i + 3]]; + view[constraint.vars[i + 3].value()]; } double activity = a0 + a1 + a2 + a3; if (i < size) { activity += static_cast(constraint.coeffs[i].value()) * - values[constraint.vars[i]]; + view[constraint.vars[i].value()]; if (i + 1 < size) { activity += static_cast(constraint.coeffs[i + 1].value()) * - values[constraint.vars[i + 1]]; + view[constraint.vars[i + 1].value()]; if (i + 2 < size) { activity += static_cast(constraint.coeffs[i + 2].value()) * - values[constraint.vars[i + 2]]; + view[constraint.vars[i + 2].value()]; } } } diff --git a/ortools/sat/linear_constraint_manager.h b/ortools/sat/linear_constraint_manager.h index eb39b52c3b5..2085fa63d43 100644 --- a/ortools/sat/linear_constraint_manager.h +++ b/ortools/sat/linear_constraint_manager.h @@ -48,6 +48,12 @@ struct ModelLpValues ModelLpValues() = default; }; +// Same as ModelLpValues for reduced costs. +struct ModelReducedCosts + : public util_intops::StrongVector { + ModelReducedCosts() = default; +}; + // This class holds a list of globally valid linear constraints and has some // logic to decide which one should be part of the LP relaxation. We want more // for a better relaxation, but for efficiency we do not want to have too much @@ -98,6 +104,7 @@ class LinearConstraintManager { integer_trail_(*model->GetOrCreate()), time_limit_(model->GetOrCreate()), expanded_lp_solution_(*model->GetOrCreate()), + expanded_reduced_costs_(*model->GetOrCreate()), model_(model), logger_(model->GetOrCreate()) {} ~LinearConstraintManager(); @@ -161,6 +168,9 @@ class LinearConstraintManager { const util_intops::StrongVector& LpValues() { return expanded_lp_solution_; } + const util_intops::StrongVector& ReducedCosts() { + return expanded_reduced_costs_; + } // Stats. int64_t num_constraints() const { return constraint_infos_.size(); } @@ -267,6 +277,7 @@ class LinearConstraintManager { TimeLimit* time_limit_; ModelLpValues& expanded_lp_solution_; + ModelReducedCosts& expanded_reduced_costs_; Model* model_; SolverLogger* logger_; diff --git a/ortools/sat/linear_constraint_manager_test.cc b/ortools/sat/linear_constraint_manager_test.cc new file mode 100644 index 00000000000..e842599ca59 --- /dev/null +++ b/ortools/sat/linear_constraint_manager_test.cc @@ -0,0 +1,421 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/linear_constraint_manager.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/glop/variables_info.h" +#include "ortools/lp_data/lp_types.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::EndsWith; +using ::testing::StartsWith; +using ::testing::UnorderedElementsAre; +using ConstraintIndex = LinearConstraintManager::ConstraintIndex; + +TEST(LinearConstraintManagerTest, DuplicateDetection) { + Model model; + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(-10, 10)); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(10)); + ct_one.AddTerm(x, IntegerValue(2)); + manager.Add(ct_one.Build()); + + LinearConstraintBuilder ct_two(IntegerValue(-4), IntegerValue(6)); + ct_two.AddTerm(NegationOf(x), IntegerValue(-2)); + manager.Add(ct_two.Build()); + + EXPECT_EQ(manager.AllConstraints().size(), 1); + EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), + "0 <= 1*X0 <= 3"); +} + +void SetLpValue(IntegerVariable v, double value, Model* model) { + auto& values = *model->GetOrCreate(); + const int needed_size = 1 + std::max(v.value(), NegationOf(v).value()); + if (needed_size > values.size()) values.resize(needed_size, 0.0); + values[v] = value; + values[NegationOf(v)] = -value; +} + +TEST(LinearConstraintManagerTest, DuplicateDetectionCuts) { + Model model; + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(-10, 10)); + SetLpValue(x, -4.0, &model); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(10)); + ct_one.AddTerm(x, IntegerValue(2)); + manager.AddCut(ct_one.Build(), "Cut"); + + LinearConstraintBuilder ct_two(IntegerValue(-4), IntegerValue(6)); + ct_two.AddTerm(NegationOf(x), IntegerValue(-2)); + manager.AddCut(ct_two.Build(), "Cut"); + + // The second cut is more restrictive so it counts. + EXPECT_EQ(manager.num_cuts(), 2); + + EXPECT_EQ(manager.AllConstraints().size(), 1); + EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), + "0 <= 1*X0 <= 3"); +} + +TEST(LinearConstraintManagerTest, DuplicateDetectionCauseLpChange) { + Model model; + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(-10, 10)); + SetLpValue(x, 0.0, &model); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(10)); + ct_one.AddTerm(x, IntegerValue(2)); + manager.Add(ct_one.Build()); + + manager.AddAllConstraintsToLp(); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0))); + glop::BasisState state; + state.statuses.resize(glop::ColIndex(1)); + EXPECT_FALSE(manager.ChangeLp(&state)); + + // Adding the second constraint will cause a bound change, so ChangeLp() will + // returns true even if the constraint is satisfied. + LinearConstraintBuilder ct_two(IntegerValue(-4), IntegerValue(6)); + ct_two.AddTerm(x, IntegerValue(2)); + manager.Add(ct_two.Build()); + EXPECT_TRUE(manager.ChangeLp(&state)); + + EXPECT_EQ(manager.AllConstraints().size(), 1); + EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), + "0 <= 1*X0 <= 3"); +} + +TEST(LinearConstraintManagerTest, OnlyAddInfeasibleConstraints) { + Model model; + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(-10, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(-10, 10)); + SetLpValue(x, 0.0, &model); + SetLpValue(y, 0.0, &model); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(10)); + ct_one.AddTerm(x, IntegerValue(2)); + ct_one.AddTerm(y, IntegerValue(3)); + manager.Add(ct_one.Build()); + + LinearConstraintBuilder ct_two(IntegerValue(-4), IntegerValue(6)); + ct_two.AddTerm(x, IntegerValue(3)); + ct_one.AddTerm(y, IntegerValue(2)); + manager.Add(ct_two.Build()); + + EXPECT_TRUE(manager.LpConstraints().empty()); + EXPECT_EQ(manager.AllConstraints().size(), 2); + + // All constraints satisfy this, so no change. + glop::BasisState state; + state.statuses.resize(glop::ColIndex(2)); // Content is not relevant. + EXPECT_FALSE(manager.ChangeLp(&state)); + EXPECT_FALSE(manager.ChangeLp(&state)); + + SetLpValue(x, -1.0, &model); + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0))); + EXPECT_EQ(state.statuses.size(), glop::ColIndex(3)); // State was resized. + EXPECT_EQ(state.statuses[glop::ColIndex(2)], glop::VariableStatus::BASIC); + + // Note that we keep the first constraint even if the value of 4.0 make it + // satisfied. + SetLpValue(x, 4.0, &model); + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0), ConstraintIndex(1))); + EXPECT_EQ(state.statuses.size(), glop::ColIndex(4)); // State was resized. + EXPECT_EQ(state.statuses[glop::ColIndex(3)], glop::VariableStatus::BASIC); +} + +TEST(LinearConstraintManagerTest, OnlyAddOrthogonalConstraints) { + Model model; + model.GetOrCreate()->set_min_orthogonality_for_lp_constraints( + 0.8); + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + SetLpValue(x, 1.0, &model); + SetLpValue(y, 1.0, &model); + SetLpValue(z, 1.0, &model); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(11)); + ct_one.AddTerm(x, IntegerValue(3)); + ct_one.AddTerm(y, IntegerValue(-4)); + manager.Add(ct_one.Build()); + + LinearConstraintBuilder ct_two(IntegerValue(-4), IntegerValue(2)); + ct_two.AddTerm(z, IntegerValue(-5)); + manager.Add(ct_two.Build()); + + LinearConstraintBuilder ct_three(IntegerValue(0), IntegerValue(14)); + ct_three.AddTerm(x, IntegerValue(5)); + ct_three.AddTerm(y, IntegerValue(5)); + ct_three.AddTerm(z, IntegerValue(5)); + manager.Add(ct_three.Build()); + + EXPECT_TRUE(manager.LpConstraints().empty()); + EXPECT_EQ(manager.AllConstraints().size(), 3); + + // First Call. Last constraint does not satisfy the orthogonality criteria. + glop::BasisState state; + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0), ConstraintIndex(1))); + + // Second Call. Only the last constraint is considered. The other two + // constraints are already added. + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0), ConstraintIndex(1), + ConstraintIndex(2))); +} + +TEST(LinearConstraintManagerTest, RemoveIneffectiveCuts) { + Model model; + model.GetOrCreate()->set_max_consecutive_inactive_count(0); + + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + SetLpValue(x, 1.0, &model); + SetLpValue(y, 1.0, &model); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(11)); + ct_one.AddTerm(x, IntegerValue(3)); + ct_one.AddTerm(y, IntegerValue(-4)); + manager.AddCut(ct_one.Build(), "Cut"); + + EXPECT_TRUE(manager.LpConstraints().empty()); + EXPECT_EQ(manager.AllConstraints().size(), 1); + + // First Call. The constraint is added to LP. + glop::BasisState state; + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_THAT(manager.LpConstraints(), + UnorderedElementsAre(ConstraintIndex(0))); + + // Second Call. Constraint is inactive and hence removed. + state.statuses.resize(glop::ColIndex(2 + manager.LpConstraints().size())); + state.statuses[glop::ColIndex(2)] = glop::VariableStatus::BASIC; + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_TRUE(manager.LpConstraints().empty()); + EXPECT_EQ(state.statuses.size(), glop::ColIndex(2)); +} + +TEST(LinearConstraintManagerTest, ObjectiveParallelism) { + Model model; + LinearConstraintManager manager(&model); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + SetLpValue(x, 1.0, &model); + SetLpValue(y, 1.0, &model); + SetLpValue(z, 1.0, &model); + + manager.SetObjectiveCoefficient(x, IntegerValue(1)); + manager.SetObjectiveCoefficient(y, IntegerValue(1)); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(0)); + ct_one.AddTerm(z, IntegerValue(-1)); + manager.Add(ct_one.Build()); + + LinearConstraintBuilder ct_two(IntegerValue(0), IntegerValue(2)); + ct_two.AddTerm(x, IntegerValue(1)); + ct_two.AddTerm(y, IntegerValue(1)); + ct_two.AddTerm(z, IntegerValue(1)); + manager.Add(ct_two.Build()); + + EXPECT_TRUE(manager.LpConstraints().empty()); + EXPECT_EQ(manager.AllConstraints().size(), 2); + + // Last constraint is more parallel to the objective. + glop::BasisState state; + EXPECT_TRUE(manager.ChangeLp(&state)); + // scores: efficacy, orthogonality, obj_para, total + // ct_one: 1, 1, 0, 2 + // ct_two: 0.5774, 1, 0.8165, 2.394 + + EXPECT_THAT(manager.LpConstraints(), + ElementsAre(ConstraintIndex(1), ConstraintIndex(0))); +} + +TEST(LinearConstraintManagerTest, SimplificationRemoveFixedVariable) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + SetLpValue(x, 0.0, &model); + SetLpValue(y, 0.0, &model); + SetLpValue(z, 0.0, &model); + + LinearConstraintManager manager(&model); + + { + LinearConstraintBuilder ct(IntegerValue(0), IntegerValue(11)); + ct.AddTerm(x, IntegerValue(3)); + ct.AddTerm(y, IntegerValue(-4)); + ct.AddTerm(z, IntegerValue(7)); + manager.Add(ct.Build()); + } + + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ("0 <= 3*X0 -4*X1 7*X2 <= 11", + manager.AllConstraints()[index].constraint.DebugString()); + + // ChangeLp will trigger the simplification. + EXPECT_TRUE(model.GetOrCreate()->Enqueue( + IntegerLiteral::GreaterOrEqual(y, IntegerValue(5)), {}, {})); + glop::BasisState state; + EXPECT_TRUE(manager.ChangeLp(&state)); + EXPECT_EQ(1, manager.num_shortened_constraints()); + EXPECT_EQ("20 <= 3*X0 7*X2 <= 31", + manager.AllConstraints()[index].constraint.DebugString()); + + // We also test that the constraint equivalence work with the change. + // Adding a constraint equiv to the new one is detected. + { + LinearConstraintBuilder ct(IntegerValue(0), IntegerValue(21)); + ct.AddTerm(x, IntegerValue(3)); + ct.AddTerm(z, IntegerValue(7)); + manager.Add(ct.Build()); + } + EXPECT_EQ(manager.AllConstraints().size(), 1); + EXPECT_EQ("20 <= 3*X0 7*X2 <= 21", + manager.AllConstraints()[index].constraint.DebugString()); +} + +TEST(LinearConstraintManagerTest, SimplificationStrenghtenUb) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + LinearConstraintManager manager(&model); + + LinearConstraintBuilder ct(IntegerValue(-100), IntegerValue(30 + 70 - 5)); + ct.AddTerm(x, IntegerValue(3)); + ct.AddTerm(y, IntegerValue(-8)); + ct.AddTerm(z, IntegerValue(7)); + manager.Add(ct.Build()); + + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ(2, manager.num_coeff_strenghtening()); + EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), + EndsWith("3*X0 -5*X1 5*X2 <= 75")); +} + +TEST(LinearConstraintManagerTest, SimplificationStrenghtenLb) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + LinearConstraintManager manager(&model); + + LinearConstraintBuilder ct(IntegerValue(-75), IntegerValue(1000)); + ct.AddTerm(x, IntegerValue(3)); + ct.AddTerm(y, IntegerValue(-8)); + ct.AddTerm(z, IntegerValue(7)); + manager.Add(ct.Build()); + + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ(2, manager.num_coeff_strenghtening()); + EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), + StartsWith("-45 <= 3*X0 -5*X1 5*X2")); +} + +TEST(LinearConstraintManagerTest, AdvancedStrenghtening1) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + LinearConstraintManager manager(&model); + + LinearConstraintBuilder ct(IntegerValue(16), IntegerValue(1000)); + ct.AddTerm(x, IntegerValue(15)); + ct.AddTerm(y, IntegerValue(9)); + ct.AddTerm(z, IntegerValue(14)); + manager.Add(ct.Build()); + + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ(3, manager.num_coeff_strenghtening()); + EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), + StartsWith("2 <= 1*X0 1*X1 1*X2")); +} + +TEST(LinearConstraintManagerTest, AdvancedStrenghtening2) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + LinearConstraintManager manager(&model); + + LinearConstraintBuilder ct(IntegerValue(16), IntegerValue(1000)); + ct.AddTerm(x, IntegerValue(15)); + ct.AddTerm(y, IntegerValue(7)); + ct.AddTerm(z, IntegerValue(14)); + manager.Add(ct.Build()); + + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ(2, manager.num_coeff_strenghtening()); + EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), + StartsWith("16 <= 9*X0 7*X1 9*X2")); +} + +TEST(LinearConstraintManagerTest, AdvancedStrenghtening3) { + Model model; + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + LinearConstraintManager manager(&model); + + LinearConstraintBuilder ct(IntegerValue(5), IntegerValue(1000)); + ct.AddTerm(x, IntegerValue(5)); + ct.AddTerm(y, IntegerValue(5)); + ct.AddTerm(z, IntegerValue(4)); + manager.Add(ct.Build()); + + // TODO(user): Technically, because the 5 are "enforcement" the inner + // constraint is 4*X2 >= 5 which can be rewriten and X2 >= 2, and we could + // instead have 2X0 + 2X1 + X2 >= 2 which should be tighter. + const LinearConstraintManager::ConstraintIndex index(0); + EXPECT_EQ(1, manager.num_coeff_strenghtening()); + EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), + StartsWith("5 <= 5*X0 5*X1 3*X2")); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/linear_constraint_test.cc b/ortools/sat/linear_constraint_test.cc new file mode 100644 index 00000000000..cad9ad4d9e9 --- /dev/null +++ b/ortools/sat/linear_constraint_test.cc @@ -0,0 +1,481 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/linear_constraint.h" + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +TEST(ComputeActivityTest, BasicBehavior) { + // The bounds are not useful for this test. + LinearConstraintBuilder ct(IntegerValue(0), IntegerValue(0)); + + ct.AddTerm(IntegerVariable(0), IntegerValue(1)); + ct.AddTerm(IntegerVariable(2), IntegerValue(-2)); + ct.AddTerm(IntegerVariable(4), IntegerValue(3)); + + util_intops::StrongVector values = {0.5, 0.0, 1.4, + 0.0, -2.1, 0.0}; + EXPECT_NEAR(ComputeActivity(ct.Build(), values), 1 * 0.5 - 2 * 1.4 - 3 * 2.1, + 1e-6); +} + +TEST(ComputeActivityTest, EmptyConstraint) { + // The bounds are not useful for this test. + LinearConstraintBuilder ct(IntegerValue(-10), IntegerValue(10)); + util_intops::StrongVector values; + EXPECT_EQ(ComputeActivity(ct.Build(), values), 0.0); +} + +TEST(ComputeInfinityNormTest, BasicTest) { + IntegerVariable x(0); + IntegerVariable y(2); + IntegerVariable z(4); + { + LinearConstraint constraint; + EXPECT_EQ(IntegerValue(0), ComputeInfinityNorm(constraint)); + } + { + LinearConstraintBuilder constraint; + constraint.AddTerm(x, IntegerValue(3)); + constraint.AddTerm(y, IntegerValue(-4)); + constraint.AddTerm(z, IntegerValue(1)); + EXPECT_EQ(IntegerValue(4), ComputeInfinityNorm(constraint.Build())); + } + { + LinearConstraintBuilder constraint; + constraint.AddTerm(y, IntegerValue(std::numeric_limits::max())); + EXPECT_EQ(IntegerValue(std::numeric_limits::max()), + ComputeInfinityNorm(constraint.Build())); + } +} + +TEST(ComputeL2NormTest, BasicTest) { + IntegerVariable x(0); + IntegerVariable y(2); + IntegerVariable z(4); + { + LinearConstraint constraint; + EXPECT_EQ(0.0, ComputeL2Norm(constraint)); + } + { + LinearConstraintBuilder constraint; + constraint.AddTerm(x, IntegerValue(3)); + constraint.AddTerm(y, IntegerValue(-4)); + constraint.AddTerm(z, IntegerValue(12)); + EXPECT_EQ(13.0, ComputeL2Norm(constraint.Build())); + } + { + LinearConstraintBuilder constraint; + constraint.AddTerm(x, kMaxIntegerValue); + constraint.AddTerm(y, kMaxIntegerValue); + EXPECT_EQ(std::numeric_limits::infinity(), + ComputeL2Norm(constraint.Build())); + } + { + LinearConstraintBuilder constraint; + constraint.AddTerm(x, IntegerValue(1LL << 60)); + constraint.AddTerm(y, IntegerValue(1LL << 60)); + EXPECT_NEAR(1.6304772e+18, ComputeL2Norm(constraint.Build()), 1e+16); + } +} + +TEST(ScalarProductTest, BasicTest) { + IntegerVariable x(0); + IntegerVariable y(2); + IntegerVariable z(4); + + LinearConstraintBuilder ct_one(IntegerValue(0), IntegerValue(11)); + ct_one.AddTerm(x, IntegerValue(3)); + ct_one.AddTerm(y, IntegerValue(-4)); + + LinearConstraintBuilder ct_two(IntegerValue(1), IntegerValue(2)); + ct_two.AddTerm(z, IntegerValue(-1)); + + LinearConstraintBuilder ct_three(IntegerValue(0), IntegerValue(2)); + ct_three.AddTerm(x, IntegerValue(1)); + ct_three.AddTerm(y, IntegerValue(1)); + ct_three.AddTerm(z, IntegerValue(1)); + + EXPECT_EQ(0.0, ScalarProduct(ct_one.Build(), ct_two.Build())); + EXPECT_EQ(-1.0, ScalarProduct(ct_one.Build(), ct_three.Build())); + EXPECT_EQ(-1.0, ScalarProduct(ct_two.Build(), ct_three.Build())); +} + +namespace { + +// Creates an upper bounded LinearConstraintBuilder from a dense representation. +LinearConstraint CreateUbConstraintForTest( + absl::Span dense_coeffs, int64_t upper_bound) { + LinearConstraint result; + result.resize(dense_coeffs.size()); + int new_size = 0; + for (int i = 0; i < dense_coeffs.size(); ++i) { + if (dense_coeffs[i] != 0) { + result.vars[new_size] = IntegerVariable(i); + result.coeffs[new_size] = dense_coeffs[i]; + ++new_size; + } + } + result.resize(new_size); + result.lb = kMinIntegerValue; + result.ub = upper_bound; + return result; +} + +} // namespace + +TEST(DivideByGCDTest, BasicBehaviorWithoughLowerBound) { + LinearConstraint ct = CreateUbConstraintForTest({2, 4, -8}, 11); + DivideByGCD(&ct); + const LinearConstraint expected = CreateUbConstraintForTest({1, 2, -4}, 5); + EXPECT_EQ(ct, expected); +} + +TEST(DivideByGCDTest, BasicBehaviorWithLowerBound) { + LinearConstraint ct = CreateUbConstraintForTest({2, 4, -8}, 11); + ct.lb = IntegerValue(-3); + DivideByGCD(&ct); + LinearConstraint expected = CreateUbConstraintForTest({1, 2, -4}, 5); + expected.lb = IntegerValue(-1); + EXPECT_EQ(ct, expected); +} + +TEST(RemoveZeroTermsTest, BasicBehavior) { + LinearConstraint ct = CreateUbConstraintForTest({2, 4, -8}, 11); + ct.coeffs[1] = IntegerValue(0); + RemoveZeroTerms(&ct); + EXPECT_EQ(ct, CreateUbConstraintForTest({2, 0, -8}, 11)); +} + +TEST(MakeAllCoefficientsPositiveTest, BasicBehavior) { + // Note that this relies on the fact that the negation of an IntegerVariable + // var is is the one with IntegerVariable(var.value() ^ 1); + LinearConstraint ct = CreateUbConstraintForTest({-2, 0, -7, 0}, 10); + MakeAllCoefficientsPositive(&ct); + EXPECT_EQ(ct, CreateUbConstraintForTest({0, 2, 0, 7}, 10)); +} + +TEST(LinearConstraintBuilderTest, DuplicateCoefficient) { + Model model; + model.GetOrCreate(); + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(10)); + + // Note that internally, positive variable have an even index, so we only + // use those so that we don't remap a negated variable. + builder.AddTerm(IntegerVariable(0), IntegerValue(100)); + builder.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder.AddTerm(IntegerVariable(4), IntegerValue(7)); + builder.AddTerm(IntegerVariable(0), IntegerValue(-10)); + builder.AddTerm(IntegerVariable(2), IntegerValue(1)); + builder.AddTerm(IntegerVariable(4), IntegerValue(-7)); + builder.AddTerm(IntegerVariable(2), IntegerValue(3)); + + EXPECT_EQ(builder.Build(), CreateUbConstraintForTest({90, 0, 14}, 10)); +} + +TEST(LinearConstraintBuilderTest, AffineExpression) { + Model model; + model.GetOrCreate(); + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(10)); + + // Note that internally, positive variable have an even index, so we only + // use those so that we don't remap a negated variable. + const IntegerVariable var(0); + builder.AddTerm(AffineExpression(var, IntegerValue(3), IntegerValue(2)), + IntegerValue(100)); + builder.AddTerm(AffineExpression(var, IntegerValue(-2), IntegerValue(1)), + IntegerValue(70)); + + // Coeff is 3*100 - 2 * 70, ub is 10 - 2*100 - 1*70 + EXPECT_EQ(builder.Build(), CreateUbConstraintForTest({160}, -260)) + << builder.Build().DebugString(); +} + +TEST(LinearConstraintBuilderTest, AddLiterals) { + Model model; + model.GetOrCreate(); + const BooleanVariable b = model.Add(NewBooleanVariable()); + const BooleanVariable c = model.Add(NewBooleanVariable()); + const BooleanVariable d = model.Add(NewBooleanVariable()); + + // Create integer views. + model.Add(NewIntegerVariableFromLiteral(Literal(b, true))); // X0 + model.Add(NewIntegerVariableFromLiteral(Literal(b, false))); // X1 + model.Add(NewIntegerVariableFromLiteral(Literal(c, false))); // X2 + model.Add(NewIntegerVariableFromLiteral(Literal(d, false))); // X3 + model.Add(NewIntegerVariableFromLiteral(Literal(d, true))); // X4 + + // When we have both view, we use the lowest IntegerVariable. + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(b, true), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "1*X0 <= 1"); + } + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(b, false), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "-1*X0 <= 0"); + } + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(d, true), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "-1*X3 <= 0"); + } + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(d, false), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "1*X3 <= 1"); + } + + // When we have just one view, we use the one we have. + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(c, true), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "-1*X2 <= 0"); + } + { + LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); + EXPECT_TRUE(builder.AddLiteralTerm(Literal(c, false), IntegerValue(1))); + EXPECT_EQ(builder.Build().DebugString(), "1*X2 <= 1"); + } +} + +TEST(LinearConstraintBuilderTest, AddConstant) { + Model model; + model.GetOrCreate(); + LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); + builder1.AddTerm(IntegerVariable(0), IntegerValue(5)); + builder1.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder1.AddConstant(IntegerValue(3)); + EXPECT_EQ(builder1.Build().DebugString(), "5*X0 10*X1 <= 7"); + + LinearConstraintBuilder builder2(&model, IntegerValue(4), kMaxIntegerValue); + builder2.AddTerm(IntegerVariable(0), IntegerValue(5)); + builder2.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder2.AddConstant(IntegerValue(-3)); + EXPECT_EQ(builder2.Build().DebugString(), "7 <= 5*X0 10*X1"); + + LinearConstraintBuilder builder3(&model, kMinIntegerValue, IntegerValue(10)); + builder3.AddTerm(IntegerVariable(0), IntegerValue(5)); + builder3.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder3.AddConstant(IntegerValue(-3)); + EXPECT_EQ(builder3.Build().DebugString(), "5*X0 10*X1 <= 13"); + + LinearConstraintBuilder builder4(&model, IntegerValue(4), kMaxIntegerValue); + builder4.AddTerm(IntegerVariable(0), IntegerValue(5)); + builder4.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder4.AddConstant(IntegerValue(3)); + EXPECT_EQ(builder4.Build().DebugString(), "1 <= 5*X0 10*X1"); + + LinearConstraintBuilder builder5(&model, IntegerValue(4), IntegerValue(10)); + builder5.AddTerm(IntegerVariable(0), IntegerValue(5)); + builder5.AddTerm(IntegerVariable(2), IntegerValue(10)); + builder5.AddConstant(IntegerValue(3)); + EXPECT_EQ(builder5.Build().DebugString(), "1 <= 5*X0 10*X1 <= 7"); +} + +TEST(CleanTermsAndFillConstraintTest, VarAndItsNegation) { + std::vector> terms; + terms.push_back({IntegerVariable(4), IntegerValue(7)}); + terms.push_back({IntegerVariable(5), IntegerValue(4)}); + LinearConstraint constraint; + CleanTermsAndFillConstraint(&terms, &constraint); + EXPECT_EQ(constraint.DebugString(), "0 <= 3*X2 <= 0"); +} + +TEST(LinearConstraintBuilderTest, AddQuadraticLowerBound) { + Model model; + model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x0 = model.Add(NewIntegerVariable(2, 5)); + IntegerVariable x1 = model.Add(NewIntegerVariable(3, 6)); + LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); + AffineExpression a0(x0, IntegerValue(3), IntegerValue(2)); // 3 * x0 + 2. + builder1.AddQuadraticLowerBound(a0, x1, integer_trail); + EXPECT_EQ(builder1.Build().DebugString(), "9*X0 8*X1 <= 28"); +} + +TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsVar) { + Model model; + model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x0 = model.Add(NewIntegerVariable(2, 5)); + IntegerVariable x1 = model.Add(NewIntegerVariable(3, 6)); + LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); + builder1.AddQuadraticLowerBound(x0, x1, integer_trail); + EXPECT_EQ(builder1.Build().DebugString(), "3*X0 2*X1 <= 16"); +} + +TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsConstant) { + Model model; + model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + IntegerVariable x0 = model.Add(NewIntegerVariable(2, 5)); + LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); + builder1.AddQuadraticLowerBound(IntegerValue(4), x0, integer_trail); + EXPECT_EQ(builder1.Build().DebugString(), "4*X0 <= 10"); +} + +TEST(LinExprTest, Bounds) { + Model model; + std::vector vars{model.Add(NewIntegerVariable(1, 2)), + model.Add(NewIntegerVariable(0, 3)), + model.Add(NewIntegerVariable(-2, 4))}; + IntegerTrail* integer_trail = model.GetOrCreate(); + LinearExpression expr1; // 2x0 + 3x1 - 5 + expr1.vars = {vars[0], vars[1]}; + expr1.coeffs = {IntegerValue(2), IntegerValue(3)}; + expr1.offset = IntegerValue(-5); + expr1 = CanonicalizeExpr(expr1); + EXPECT_EQ(IntegerValue(-3), expr1.Min(*integer_trail)); + EXPECT_EQ(IntegerValue(8), expr1.Max(*integer_trail)); + + LinearExpression expr2; // 2x1 - 5x2 + 6 + expr2.vars = {vars[1], vars[2]}; + expr2.coeffs = {IntegerValue(2), IntegerValue(-5)}; + expr2.offset = IntegerValue(6); + expr2 = CanonicalizeExpr(expr2); + EXPECT_EQ(IntegerValue(-14), expr2.Min(*integer_trail)); + EXPECT_EQ(IntegerValue(22), expr2.Max(*integer_trail)); + + LinearExpression expr3; // 2x0 + 3x2 + expr3.vars = {vars[0], vars[2]}; + expr3.coeffs = {IntegerValue(2), IntegerValue(3)}; + expr3 = CanonicalizeExpr(expr3); + EXPECT_EQ(IntegerValue(-4), expr3.Min(*integer_trail)); + EXPECT_EQ(IntegerValue(16), expr3.Max(*integer_trail)); +} + +TEST(LinExprTest, Canonicalization) { + Model model; + std::vector vars{model.Add(NewIntegerVariable(1, 2)), + model.Add(NewIntegerVariable(0, 3))}; + LinearExpression expr; // 2x0 - 3x1 - 5 + expr.vars = vars; + expr.coeffs = {IntegerValue(2), IntegerValue(-3)}; + expr.offset = IntegerValue(-5); + + LinearExpression canonical_expr = CanonicalizeExpr(expr); + EXPECT_THAT(canonical_expr.vars, ElementsAre(vars[0], NegationOf(vars[1]))); + EXPECT_THAT(canonical_expr.coeffs, + ElementsAre(IntegerValue(2), IntegerValue(3))); + EXPECT_EQ(canonical_expr.offset, IntegerValue(-5)); +} + +TEST(NoDuplicateVariable, BasicBehavior) { + LinearConstraint ct; + ct.lb = kMinIntegerValue; + ct.ub = IntegerValue(10); + + ct.resize(3); + ct.num_terms = 1; + ct.vars[0] = IntegerVariable(4); + ct.coeffs[0] = IntegerValue(1); + EXPECT_TRUE(NoDuplicateVariable(ct)); + + ct.num_terms = 2; + ct.vars[1] = IntegerVariable(2); + ct.coeffs[1] = IntegerValue(5); + EXPECT_TRUE(NoDuplicateVariable(ct)); + + ct.num_terms = 3; + ct.vars[2] = IntegerVariable(4); + ct.coeffs[2] = IntegerValue(1); + EXPECT_FALSE(NoDuplicateVariable(ct)); +} + +TEST(NoDuplicateVariable, BasicBehaviorNegativeVar) { + LinearConstraint ct; + + ct.lb = kMinIntegerValue; + ct.ub = IntegerValue(10); + + ct.resize(3); + ct.num_terms = 1; + ct.vars[0] = IntegerVariable(4); + ct.coeffs[0] = IntegerValue(1); + EXPECT_TRUE(NoDuplicateVariable(ct)); + + ct.num_terms = 2; + ct.vars[1] = IntegerVariable(2); + ct.coeffs[1] = IntegerValue(5); + EXPECT_TRUE(NoDuplicateVariable(ct)); + + ct.num_terms = 3; + ct.vars[2] = IntegerVariable(5); + ct.coeffs[2] = IntegerValue(1); + EXPECT_FALSE(NoDuplicateVariable(ct)); +} + +TEST(PositiveVarExpr, BasicBehaviorNegativeVar) { + LinearExpression ct; + ct.offset = IntegerValue(10); + ct.vars.push_back(IntegerVariable(4)); + ct.coeffs.push_back(IntegerValue(1)); + + ct.vars.push_back(IntegerVariable(1)); + ct.coeffs.push_back(IntegerValue(5)); + + LinearExpression positive_var_expr = PositiveVarExpr(ct); + EXPECT_THAT(positive_var_expr.vars, + ElementsAre(ct.vars[0], NegationOf(ct.vars[1]))); + EXPECT_THAT(positive_var_expr.coeffs, + ElementsAre(ct.coeffs[0], -ct.coeffs[1])); + EXPECT_EQ(positive_var_expr.offset, ct.offset); +} + +TEST(GetCoefficient, BasicBehavior) { + LinearExpression ct; + ct.offset = IntegerValue(10); + ct.vars.push_back(IntegerVariable(4)); + ct.coeffs.push_back(IntegerValue(2)); + + EXPECT_EQ(IntegerValue(2), GetCoefficient(IntegerVariable(4), ct)); + EXPECT_EQ(IntegerValue(-2), GetCoefficient(IntegerVariable(5), ct)); + EXPECT_EQ(IntegerValue(0), GetCoefficient(IntegerVariable(2), ct)); +} + +TEST(GetCoefficientOfPositiveVar, BasicBehavior) { + LinearExpression ct; + ct.offset = IntegerValue(10); + ct.vars.push_back(IntegerVariable(4)); + ct.coeffs.push_back(IntegerValue(2)); + + EXPECT_EQ(IntegerValue(2), + GetCoefficientOfPositiveVar(IntegerVariable(4), ct)); + EXPECT_EQ(IntegerValue(0), + GetCoefficientOfPositiveVar(IntegerVariable(2), ct)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index aa1a79cef31..63237ee2f70 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -106,21 +106,31 @@ bool ScatteredIntegerVector::Add(glop::ColIndex col, IntegerValue value) { template bool ScatteredIntegerVector::AddLinearExpressionMultiple( const IntegerValue multiplier, absl::Span cols, - absl::Span coeffs) { + absl::Span coeffs, IntegerValue max_coeff_magnitude) { + // Since we have the norm, this avoid checking each products below. + if (check_overflow) { + const IntegerValue prod = CapProdI(max_coeff_magnitude, multiplier); + if (AtMinOrMaxInt64(prod.value())) return false; + } + + IntegerValue* data = dense_vector_.data(); const double threshold = 0.1 * static_cast(dense_vector_.size()); const int num_terms = cols.size(); if (is_sparse_ && static_cast(num_terms) < threshold) { for (int i = 0; i < num_terms; ++i) { - if (is_zeros_[cols[i]]) { - is_zeros_[cols[i]] = false; - non_zeros_.push_back(cols[i]); + const glop::ColIndex col = cols[i]; + if (is_zeros_[col]) { + is_zeros_[col] = false; + non_zeros_.push_back(col); } + const IntegerValue product = multiplier * coeffs[i]; if (check_overflow) { - if (!AddProductTo(multiplier, coeffs[i], &dense_vector_[cols[i]])) { + if (AddIntoOverflow(product.value(), + data[col.value()].mutable_value())) { return false; } } else { - dense_vector_[cols[i]] += multiplier * coeffs[i]; + data[col.value()] += product; } } if (static_cast(non_zeros_.size()) > threshold) { @@ -129,12 +139,15 @@ bool ScatteredIntegerVector::AddLinearExpressionMultiple( } else { is_sparse_ = false; for (int i = 0; i < num_terms; ++i) { + const glop::ColIndex col = cols[i]; + const IntegerValue product = multiplier * coeffs[i]; if (check_overflow) { - if (!AddProductTo(multiplier, coeffs[i], &dense_vector_[cols[i]])) { + if (AddIntoOverflow(product.value(), + data[col.value()].mutable_value())) { return false; } } else { - dense_vector_[cols[i]] += multiplier * coeffs[i]; + data[col.value()] += product; } } } @@ -206,10 +219,11 @@ void ScatteredIntegerVector::ConvertToCutData( CutData* result) { result->terms.clear(); result->rhs = rhs; + absl::Span dense_vector = dense_vector_; if (is_sparse_) { std::sort(non_zeros_.begin(), non_zeros_.end()); for (const glop::ColIndex col : non_zeros_) { - const IntegerValue coeff = dense_vector_[col]; + const IntegerValue coeff = dense_vector[col.value()]; if (coeff == 0) continue; const IntegerVariable var = integer_variables[col.value()]; CHECK(result->AppendOneTerm(var, coeff, lp_solution[col.value()], @@ -217,12 +231,11 @@ void ScatteredIntegerVector::ConvertToCutData( integer_trail->LevelZeroUpperBound(var))); } } else { - const int size = dense_vector_.size(); - for (glop::ColIndex col(0); col < size; ++col) { - const IntegerValue coeff = dense_vector_[col]; + for (int col(0); col < dense_vector.size(); ++col) { + const IntegerValue coeff = dense_vector[col]; if (coeff == 0) continue; - const IntegerVariable var = integer_variables[col.value()]; - CHECK(result->AppendOneTerm(var, coeff, lp_solution[col.value()], + const IntegerVariable var = integer_variables[col]; + CHECK(result->AppendOneTerm(var, coeff, lp_solution[col], integer_trail->LevelZeroLowerBound(var), integer_trail->LevelZeroUpperBound(var))); } @@ -269,7 +282,8 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( implied_bounds_processor_({}, integer_trail_, model->GetOrCreate()), dispatcher_(model->GetOrCreate()), - expanded_lp_solution_(*model->GetOrCreate()) { + expanded_lp_solution_(*model->GetOrCreate()), + expanded_reduced_costs_(*model->GetOrCreate()) { // Tweak the default parameters to make the solve incremental. simplex_params_.set_use_dual_simplex(true); simplex_params_.set_cost_scaling(glop::GlopParameters::MEAN_COST_SCALING); @@ -314,6 +328,9 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( if (max_index >= expanded_lp_solution_.size()) { expanded_lp_solution_.assign(max_index + 1, 0.0); } + if (max_index >= expanded_reduced_costs_.size()) { + expanded_reduced_costs_.assign(max_index + 1, 0.0); + } } } @@ -718,33 +735,30 @@ bool LinearProgrammingConstraint::SolveLp() { } lp_at_optimal_ = simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL; - if (simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL) { + // If stop_after_root_propagation() is true, we still copy whatever we have as + // these values will be used for the local-branching lns heuristic. + if (simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL || + parameters_.stop_after_root_propagation()) { lp_solution_is_set_ = true; lp_solution_level_ = trail_->CurrentDecisionLevel(); const int num_vars = integer_variables_.size(); + const auto reduced_costs = simplex_.GetReducedCosts().const_view(); for (int i = 0; i < num_vars; i++) { - const glop::Fractional value = - GetVariableValueAtCpScale(glop::ColIndex(i)); + const glop::ColIndex col(i); + const glop::Fractional value = GetVariableValueAtCpScale(col); lp_solution_[i] = value; expanded_lp_solution_[integer_variables_[i]] = value; expanded_lp_solution_[NegationOf(integer_variables_[i])] = -value; + + const glop::Fractional rc = + scaler_.UnscaleReducedCost(col, reduced_costs[col]); + expanded_reduced_costs_[integer_variables_[i]] = rc; + expanded_reduced_costs_[NegationOf(integer_variables_[i])] = -rc; } if (lp_solution_level_ == 0) { level_zero_lp_solution_ = lp_solution_; } - } else { - // If this parameter is true, we still copy whatever we have as these - // values will be used for the local-branching lns heuristic. - if (parameters_.stop_after_root_propagation()) { - const int num_vars = integer_variables_.size(); - for (int i = 0; i < num_vars; i++) { - const glop::Fractional value = - GetVariableValueAtCpScale(glop::ColIndex(i)); - expanded_lp_solution_[integer_variables_[i]] = value; - expanded_lp_solution_[NegationOf(integer_variables_[i])] = -value; - } - } } return true; @@ -1232,7 +1246,8 @@ bool LinearProgrammingConstraint::PostprocessAndAddCut( const int slack_index = (var.value() - first_slack.value()) / 2; const glop::RowIndex row = tmp_slack_rows_[slack_index]; if (!tmp_scattered_vector_.AddLinearExpressionMultiple( - coeff, IntegerLpRowCols(row), IntegerLpRowCoeffs(row))) { + coeff, IntegerLpRowCols(row), IntegerLpRowCoeffs(row), + infinity_norms_[row])) { VLOG(2) << "Overflow in slack removal"; ++num_cut_overflows_; return false; @@ -1452,8 +1467,6 @@ void LinearProgrammingConstraint::AddMirCuts() { const int num_rows = lp_data_.num_constraints().value(); std::vector> base_rows; util_intops::StrongVector row_weights(num_rows, 0.0); - util_intops::StrongVector at_ub(num_rows, false); - util_intops::StrongVector at_lb(num_rows, false); for (RowIndex row(0); row < num_rows; ++row) { // We only consider tight rows. // We use both the status and activity to have as much options as possible. @@ -1466,13 +1479,11 @@ void LinearProgrammingConstraint::AddMirCuts() { if (activity > lp_data_.constraint_upper_bounds()[row] - 1e-4 || status == glop::ConstraintStatus::AT_UPPER_BOUND || status == glop::ConstraintStatus::FIXED_VALUE) { - at_ub[row] = true; base_rows.push_back({row, IntegerValue(1)}); } if (activity < lp_data_.constraint_lower_bounds()[row] + 1e-4 || status == glop::ConstraintStatus::AT_LOWER_BOUND || status == glop::ConstraintStatus::FIXED_VALUE) { - at_lb[row] = true; base_rows.push_back({row, IntegerValue(-1)}); } @@ -1590,16 +1601,20 @@ void LinearProgrammingConstraint::AddMirCuts() { if (used_rows[row]) continue; used_rows[row] = true; - // We only consider "tight" rows, as defined above. + // Note that we consider all rows here, not only tight one. This makes a + // big difference on problem like blp-ic98.pb.gz. We can also use the + // integrality of the slack when adding a non-tight row to derive good + // cuts. Also, non-tight row will have a low weight, so they should + // still be chosen after the tight-one in most situation. bool add_row = false; - if (at_ub[row]) { + if (!integer_lp_[row].ub_is_trivial) { if (entry.coefficient() > 0.0) { if (dense_cut[var_to_eliminate] < 0) add_row = true; } else { if (dense_cut[var_to_eliminate] > 0) add_row = true; } } - if (at_lb[row]) { + if (!integer_lp_[row].lb_is_trivial) { if (entry.coefficient() > 0.0) { if (dense_cut[var_to_eliminate] > 0) add_row = true; } else { @@ -1919,7 +1934,7 @@ bool LinearProgrammingConstraint::ScalingCanOverflow( std::vector> LinearProgrammingConstraint::ScaleLpMultiplier( bool take_objective_into_account, bool ignore_trivial_constraints, - const std::vector>& lp_multipliers, + absl::Span> lp_multipliers, IntegerValue* scaling, int64_t overflow_cap) const { *scaling = 0; @@ -2014,11 +2029,12 @@ bool LinearProgrammingConstraint::ComputeNewLinearConstraint( for (const std::pair& term : integer_multipliers) { const RowIndex row = term.first; const IntegerValue multiplier = term.second; - CHECK_LT(row, integer_lp_.size()); + DCHECK_LT(row, integer_lp_.size()); // Update the constraint. if (!scattered_vector->AddLinearExpressionMultiple( - multiplier, IntegerLpRowCols(row), IntegerLpRowCoeffs(row))) { + multiplier, IntegerLpRowCols(row), IntegerLpRowCoeffs(row), + infinity_norms_[row])) { return false; } @@ -2184,13 +2200,11 @@ void LinearProgrammingConstraint::AdjustNewLinearConstraint( if (to_add != 0) { term.second += to_add; *upper_bound += to_add * row_bound; - - // TODO(user): we could avoid checking overflow here, but this is likely - // not in the hot loop. adjusted = true; CHECK(scattered_vector ->AddLinearExpressionMultiple( - to_add, IntegerLpRowCols(row), IntegerLpRowCoeffs(row))); + to_add, IntegerLpRowCols(row), IntegerLpRowCoeffs(row), + infinity_norms_[row])); } } if (adjusted) ++num_adjusts_; @@ -2310,7 +2324,7 @@ bool LinearProgrammingConstraint::PropagateExactLpReason() { } CHECK(tmp_scattered_vector_ .AddLinearExpressionMultiple( - obj_scale, tmp_cols_, tmp_coeffs_)); + obj_scale, tmp_cols_, tmp_coeffs_, objective_infinity_norm_)); CHECK(AddProductTo(-obj_scale, integer_objective_offset_, &rc_ub)); extra_term = {objective_cp_, -obj_scale}; diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index cc6e5c2cc5f..ae988777085 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -73,7 +73,8 @@ class ScatteredIntegerVector { template bool AddLinearExpressionMultiple(IntegerValue multiplier, absl::Span cols, - absl::Span coeffs); + absl::Span coeffs, + IntegerValue max_coeff_magnitude); // This is not const only because non_zeros is sorted. Note that sorting the // non-zeros make the result deterministic whether or not we were in sparse @@ -329,7 +330,7 @@ class LinearProgrammingConstraint : public PropagatorInterface, // will still be exact as it will work for any set of multiplier. std::vector> ScaleLpMultiplier( bool take_objective_into_account, bool ignore_trivial_constraints, - const std::vector>& lp_multipliers, + absl::Span> lp_multipliers, IntegerValue* scaling, int64_t overflow_cap = std::numeric_limits::max()) const; @@ -568,6 +569,7 @@ class LinearProgrammingConstraint : public PropagatorInterface, // Same as lp_solution_ but this vector is indexed by IntegerVariable. ModelLpValues& expanded_lp_solution_; + ModelReducedCosts& expanded_reduced_costs_; // Linear constraints cannot be created or modified after this is registered. bool lp_constraint_is_registered_ = false; diff --git a/ortools/sat/linear_propagation_test.cc b/ortools/sat/linear_propagation_test.cc new file mode 100644 index 00000000000..aeda248a414 --- /dev/null +++ b/ortools/sat/linear_propagation_test.cc @@ -0,0 +1,321 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/linear_propagation.h" + +#include + +#include +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +TEST(EnforcementPropagatorTest, BasicTest) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + const EnforcementId id1 = propag->Register(Literals({+1})); + const EnforcementId id2 = propag->Register(Literals({+1, +2})); + const EnforcementId id3 = propag->Register(Literals({-2})); + + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_FALSE); + + CHECK(sat_solver->ResetToLevelZero()); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, UntrailWork) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + const EnforcementId id1 = propag->Register(Literals({+1})); + const EnforcementId id2 = propag->Register(Literals({+2})); + const EnforcementId id3 = propag->Register(Literals({+3})); + + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + const int level = sat_solver->CurrentDecisionLevel(); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+3)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_ENFORCED); + + sat_solver->Backtrack(level); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, AddingAtPositiveLevelTrue) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + EXPECT_TRUE(propag->Propagate(trail)); + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + + const EnforcementId id = propag->Register(std::vector{+1}); + EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_ENFORCED); + + sat_solver->Backtrack(0); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, AddingAtPositiveLevelFalse) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + EXPECT_TRUE(propag->Propagate(trail)); + sat_solver->EnqueueDecisionIfNotConflicting(Literal(-1)); + EXPECT_TRUE(propag->Propagate(trail)); + + const EnforcementId id = propag->Register(std::vector{+1}); + EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_FALSE); + + sat_solver->Backtrack(0); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); +} + +// TEST copied from integer_expr test with little modif to use the new propag. +IntegerVariable AddWeightedSum(const absl::Span vars, + const absl::Span coeffs, + Model* model) { + IntegerVariable sum = model->Add(NewIntegerVariable(-10000, 10000)); + std::vector c; + std::vector v; + for (int i = 0; i < coeffs.size(); ++i) { + c.push_back(IntegerValue(coeffs[i])); + v.push_back(vars[i]); + } + c.push_back(IntegerValue(-1)); + v.push_back(sum); + + // <= sum + auto* propag = model->GetOrCreate(); + propag->AddConstraint({}, v, c, IntegerValue(0)); + + // >= sum + for (IntegerValue& ref : c) ref = -ref; + propag->AddConstraint({}, v, c, IntegerValue(0)); + + return sum; +} + +void AddWeightedSumLowerOrEqual(const absl::Span vars, + const absl::Span coeffs, int64_t rhs, + Model* model) { + std::vector c; + std::vector v; + for (int i = 0; i < coeffs.size(); ++i) { + c.push_back(IntegerValue(coeffs[i])); + v.push_back(vars[i]); + } + auto* propag = model->GetOrCreate(); + propag->AddConstraint({}, v, c, IntegerValue(rhs)); +} + +void AddWeightedSumLowerOrEqualReified( + Literal equiv, const absl::Span vars, + const absl::Span coeffs, int64_t rhs, Model* model) { + std::vector c; + std::vector v; + for (int i = 0; i < coeffs.size(); ++i) { + c.push_back(IntegerValue(coeffs[i])); + v.push_back(vars[i]); + } + auto* propag = model->GetOrCreate(); + propag->AddConstraint({equiv}, v, c, IntegerValue(rhs)); + + for (IntegerValue& ref : c) ref = -ref; + propag->AddConstraint({equiv.Negated()}, v, c, IntegerValue(-rhs) - 1); +} + +// A simple macro to make the code more readable. +#define EXPECT_BOUNDS_EQ(var, lb, ub) \ + EXPECT_EQ(model.Get(LowerBound(var)), lb); \ + EXPECT_EQ(model.Get(UpperBound(var)), ub) + +TEST(WeightedSumTest, LevelZeroPropagation) { + Model model; + std::vector vars{model.Add(NewIntegerVariable(4, 9)), + model.Add(NewIntegerVariable(-7, -2)), + model.Add(NewIntegerVariable(3, 8))}; + + const IntegerVariable sum = AddWeightedSum(vars, {1, -2, 3}, &model); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_EQ(model.Get(LowerBound(sum)), 4 + 2 * 2 + 3 * 3); + EXPECT_EQ(model.Get(UpperBound(sum)), 9 + 2 * 7 + 3 * 8); + + // Setting this leave only a slack of 2. + model.Add(LowerOrEqual(sum, 19)); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_BOUNDS_EQ(vars[0], 4, 6); // coeff = 1, slack = 2 + EXPECT_BOUNDS_EQ(vars[1], -3, -2); // coeff = 2, slack = 1 + EXPECT_BOUNDS_EQ(vars[2], 3, 3); // coeff = 3, slack = 0 +} + +// This one used to fail before CL 139204507. +TEST(WeightedSumTest, LevelZeroPropagationWithNegativeNumbers) { + Model model; + std::vector vars{model.Add(NewIntegerVariable(-5, 0)), + model.Add(NewIntegerVariable(-6, 0)), + model.Add(NewIntegerVariable(-4, 0))}; + + const IntegerVariable sum = AddWeightedSum(vars, {3, 3, 3}, &model); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_EQ(model.Get(LowerBound(sum)), -15 * 3); + EXPECT_EQ(model.Get(UpperBound(sum)), 0); + + // Setting this leave only a slack of 5 which is not an exact multiple of 3. + model.Add(LowerOrEqual(sum, -40)); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_BOUNDS_EQ(vars[0], -5, -4); + EXPECT_BOUNDS_EQ(vars[1], -6, -5); + EXPECT_BOUNDS_EQ(vars[2], -4, -3); +} + +TEST(WeightedSumLowerOrEqualTest, UnaryRounding) { + Model model; + IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); + const std::vector coeffs = {-100}; + AddWeightedSumLowerOrEqual({var}, coeffs, -320, &model); + EXPECT_EQ(SatSolver::FEASIBLE, model.GetOrCreate()->Solve()); + EXPECT_EQ(model.Get(LowerBound(var)), 4); +} + +TEST(ReifiedWeightedSumLeTest, ReifToBoundPropagation) { + Model model; + const Literal r = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(4, 9)); + AddWeightedSumLowerOrEqualReified(r, {var}, {1}, 6, &model); + EXPECT_EQ( + SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions({r})); + EXPECT_BOUNDS_EQ(var, 4, 6); + EXPECT_EQ(SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions( + {r.Negated()})); + EXPECT_BOUNDS_EQ(var, 7, 9); // The associated literal (x <= 6) is false. +} + +TEST(ReifiedWeightedSumLeTest, ReifToBoundPropagationWithNegatedCoeff) { + Model model; + const Literal r = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(-9, 9)); + AddWeightedSumLowerOrEqualReified(r, {var}, {-3}, 7, &model); + EXPECT_EQ( + SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions({r})); + EXPECT_BOUNDS_EQ(var, -2, 9); + EXPECT_EQ(SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions( + {r.Negated()})); + EXPECT_BOUNDS_EQ(var, -9, -3); // The associated literal (x >= -2) is false. +} + +TEST(ReifiedWeightedSumGeTest, ReifToBoundPropagation) { + Model model; + const Literal r = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(4, 9)); + AddWeightedSumLowerOrEqualReified(r, {var}, {-1}, -6, &model); + EXPECT_EQ( + SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions({r})); + EXPECT_BOUNDS_EQ(var, 6, 9); + EXPECT_EQ(SatSolver::FEASIBLE, + model.GetOrCreate()->ResetAndSolveWithGivenAssumptions( + {r.Negated()})); + EXPECT_BOUNDS_EQ(var, 4, 5); +} + +TEST(ReifiedWeightedSumTest, BoundToReifTrueLe) { + Model model; + const Literal r = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(4, 9)); + AddWeightedSumLowerOrEqualReified(r, {var}, {1}, 9, &model); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.Get(Value(r))); +} + +TEST(ReifiedWeightedSumTest, BoundToReifFalseLe) { + Model model; + const Literal r = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(4, 9)); + AddWeightedSumLowerOrEqualReified(r, {var}, {1}, 3, &model); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.Get(Value(r))); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/model_test.cc b/ortools/sat/model_test.cc new file mode 100644 index 00000000000..b3862ff5d48 --- /dev/null +++ b/ortools/sat/model_test.cc @@ -0,0 +1,92 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/model.h" + +#include + +#include "gtest/gtest.h" + +namespace operations_research { +namespace sat { +namespace { + +struct A { + A() = default; + explicit A(Model* model) {} + std::string name; +}; + +class B { + public: + explicit B(A* a) : a_(a) {} + explicit B(Model* model) : a_(model->GetOrCreate()) {} + + std::string name() const { return a_->name; } + + private: + A* a_; +}; + +TEST(ModelTest, RecursiveCreationTest) { + Model model; + B* b = model.GetOrCreate(); + model.GetOrCreate()->name = "test"; + EXPECT_EQ("test", b->name()); +} + +struct C1 { + C1() = default; +}; +struct C2 { + explicit C2(Model* model) {} +}; +struct C3 { + C3() : name("no_arg") {} + explicit C3(Model*) : name("model") {} + std::string name; +}; + +TEST(ModelTest, DefaultConstructorFallback) { + Model model; + model.GetOrCreate(); + model.GetOrCreate(); + EXPECT_EQ(model.GetOrCreate()->name, "model"); +} + +TEST(ModelTest, Register) { + Model model; + C3 c3; + c3.name = "Shared struct"; + model.Register(&c3); + EXPECT_EQ(model.GetOrCreate()->name, c3.name); +} + +TEST(ModelTest, RegisterDeathTest) { + Model model; + C3 c3; + model.Register(&c3); + C3 c3_2; + EXPECT_DEATH(model.Register(&c3_2), ""); +} + +TEST(ModelTest, RegisterDeathTest2) { + Model model; + model.GetOrCreate(); + C3 c3; + EXPECT_DEATH(model.Register(&c3), ""); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/optimization_test.cc b/ortools/sat/optimization_test.cc new file mode 100644 index 00000000000..0401dc74fa3 --- /dev/null +++ b/ortools/sat/optimization_test.cc @@ -0,0 +1,172 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/optimization.h" + +#include + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/strings/str_format.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/boolean_problem.h" +#include "ortools/sat/boolean_problem.pb.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/pb_constraint.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +// Test the lazy encoding logic on a trivial problem. +TEST(MinimizeIntegerVariableWithLinearScanAndLazyEncodingTest, BasicProblem) { + Model model; + IntegerVariable var = model.Add(NewIntegerVariable(-5, 10)); + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic({var}, &model); + ConfigureSearchHeuristics(&model); + int num_feasible_solution = 0; + SatSolver::Status status = + MinimizeIntegerVariableWithLinearScanAndLazyEncoding( + var, + /*feasible_solution_observer=*/ + [var, &num_feasible_solution, &model]() { + ++num_feasible_solution; + EXPECT_EQ(model.Get(Value(var)), -5); + }, + + &model); + EXPECT_EQ(num_feasible_solution, 1); + EXPECT_EQ(status, SatSolver::Status::INFEASIBLE); // Search done. +} + +TEST(MinimizeIntegerVariableWithLinearScanAndLazyEncodingTest, + BasicProblemWithSolutionLimit) { + Model model; + SatParameters* parameters = model.GetOrCreate(); + parameters->set_stop_after_first_solution(true); + IntegerVariable var = model.Add(NewIntegerVariable(-5, 10)); + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic({var}, &model); + ConfigureSearchHeuristics(&model); + + SatSolver::Status status = + MinimizeIntegerVariableWithLinearScanAndLazyEncoding( + var, + /*feasible_solution_observer=*/ + [var, &model]() { EXPECT_EQ(model.Get(Value(var)), -5); }, &model); + EXPECT_EQ(status, SatSolver::Status::LIMIT_REACHED); +} + +TEST(MinimizeIntegerVariableWithLinearScanAndLazyEncodingTest, + BasicProblemWithBadHeuristic) { + Model model; + IntegerVariable var = model.Add(NewIntegerVariable(-5, 10)); + int expected_value = 10; + int num_feasible_solution = 0; + + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic({NegationOf(var)}, &model); + ConfigureSearchHeuristics(&model); + + SatSolver::Status status = + MinimizeIntegerVariableWithLinearScanAndLazyEncoding( + var, + /*feasible_solution_observer=*/ + [&]() { + ++num_feasible_solution; + EXPECT_EQ(model.Get(Value(var)), expected_value--); + }, + &model); + EXPECT_EQ(num_feasible_solution, 16); + EXPECT_EQ(status, SatSolver::Status::INFEASIBLE); // Search done. +} + +// TODO(user): The core find the best solution right away here, so it doesn't +// really exercise the solution limit... +TEST(MinimizeWithCoreAndLazyEncodingTest, BasicProblemWithSolutionLimit) { + Model model; + SatParameters* parameters = model.GetOrCreate(); + parameters->set_stop_after_first_solution(true); + IntegerVariable var = model.Add(NewIntegerVariable(-5, 10)); + std::vector vars = {var}; + std::vector coeffs = {IntegerValue(1)}; + + model.GetOrCreate()->fixed_search = + FirstUnassignedVarAtItsMinHeuristic({var}, &model); + ConfigureSearchHeuristics(&model); + + int num_solutions = 0; + CoreBasedOptimizer core( + var, vars, coeffs, + /*feasible_solution_observer=*/ + [var, &model, &num_solutions]() { + ++num_solutions; + EXPECT_EQ(model.Get(Value(var)), -5); + }, + &model); + SatSolver::Status status = core.Optimize(); + EXPECT_EQ(status, SatSolver::Status::INFEASIBLE); // i.e. optimal. + EXPECT_EQ(1, num_solutions); +} + +TEST(PresolveBooleanLinearExpressionTest, NegateCoeff) { + Coefficient offset(0); + std::vector literals = Literals({+1}); + std::vector coefficients = {Coefficient(-3)}; + PresolveBooleanLinearExpression(&literals, &coefficients, &offset); + EXPECT_THAT(literals, ElementsAre(Literal(-1))); + EXPECT_THAT(coefficients, ElementsAre(Coefficient(3))); + EXPECT_EQ(offset, -3); +} + +TEST(PresolveBooleanLinearExpressionTest, Duplicate) { + Coefficient offset(0); + std::vector literals = Literals({+1, -4, +1}); + std::vector coefficients = {Coefficient(-3), Coefficient(7), + Coefficient(5)}; + PresolveBooleanLinearExpression(&literals, &coefficients, &offset); + EXPECT_THAT(literals, ElementsAre(Literal(+1), Literal(-4))); + EXPECT_THAT(coefficients, ElementsAre(Coefficient(2), Coefficient(7))); + EXPECT_EQ(offset, 0); +} + +TEST(PresolveBooleanLinearExpressionTest, NegatedLiterals) { + Coefficient offset(0); + std::vector literals = Literals({+1, -4, -1}); + std::vector coefficients = {Coefficient(-3), Coefficient(7), + Coefficient(-5)}; + PresolveBooleanLinearExpression(&literals, &coefficients, &offset); + EXPECT_THAT(literals, ElementsAre(Literal(+1), Literal(-4))); + EXPECT_THAT(coefficients, ElementsAre(Coefficient(2), Coefficient(7))); + EXPECT_EQ(offset, -5); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/parameters_validation_test.cc b/ortools/sat/parameters_validation_test.cc new file mode 100644 index 00000000000..3c039f3e8ca --- /dev/null +++ b/ortools/sat/parameters_validation_test.cc @@ -0,0 +1,125 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/parameters_validation.h" + +#include +#include + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::HasSubstr; +using ::testing::IsEmpty; + +TEST(ValidateParameters, MaxTimeInSeconds) { + SatParameters params; + params.set_max_time_in_seconds(-1); + EXPECT_THAT(ValidateParameters(params), HasSubstr("non-negative")); +} + +TEST(ValidateParameters, ParametersInRange) { + SatParameters params; + params.set_mip_max_bound(-1); + EXPECT_THAT(ValidateParameters(params), + HasSubstr("'mip_max_bound' should be in")); +} + +TEST(ValidateParameters, NumWorkers) { + SatParameters params; + params.set_num_workers(-1); + EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [0,10000]")); +} + +TEST(ValidateParameters, NumSearchWorkers) { + SatParameters params; + params.set_num_search_workers(-1); + EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [0,10000]")); +} + +TEST(ValidateParameters, LinearizationLevel) { + SatParameters params; + params.set_linearization_level(-1); + EXPECT_THAT(ValidateParameters(params), HasSubstr("non-negative")); +} + +TEST(ValidateParameters, NumSharedTreeSearchWorkers) { + SatParameters params; + params.set_shared_tree_num_workers(-1); + EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [0,10000]")); +} + +TEST(ValidateParameters, SharedTreeSearchMaxNodesPerWorker) { + SatParameters params; + params.set_shared_tree_max_nodes_per_worker(0); + EXPECT_THAT(ValidateParameters(params), HasSubstr("positive")); +} + +TEST(ValidateParameters, SharedTreeSearchOpenLeavesPerWorker) { + SatParameters params; + params.set_shared_tree_open_leaves_per_worker(0.0); + EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [1,10000]")); +} + +TEST(ValidateParameters, UseSharedTreeSearch) { + SatParameters params; + params.set_use_shared_tree_search(true); + EXPECT_THAT(ValidateParameters(params), HasSubstr("only be set on workers")); +} + +TEST(ValidateParameters, NaNs) { + const google::protobuf::Descriptor& descriptor = *SatParameters::descriptor(); + const google::protobuf::Reflection& reflection = + *SatParameters::GetReflection(); + for (int i = 0; i < descriptor.field_count(); ++i) { + const google::protobuf::FieldDescriptor* const field = descriptor.field(i); + SCOPED_TRACE(field->name()); + + SatParameters params; + switch (field->type()) { + case google::protobuf::FieldDescriptor::TYPE_DOUBLE: + reflection.SetDouble(¶ms, field, + std::numeric_limits::quiet_NaN()); + break; + case google::protobuf::FieldDescriptor::TYPE_FLOAT: + reflection.SetFloat(¶ms, field, + std::numeric_limits::quiet_NaN()); + break; + default: + continue; + } + + EXPECT_THAT(ValidateParameters(params), + AllOf(HasSubstr(field->name()), HasSubstr("NaN"))); + } +} + +TEST(ValidateParameters, ValidateSubsolvers) { + SatParameters params; + params.add_extra_subsolvers("not_defined"); + EXPECT_THAT(ValidateParameters(params), HasSubstr("is not valid")); + + params.add_subsolver_params()->set_name("not_defined"); + EXPECT_THAT(ValidateParameters(params), IsEmpty()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/pb_constraint_test.cc b/ortools/sat/pb_constraint_test.cc new file mode 100644 index 00000000000..5b3d7d91bee --- /dev/null +++ b/ortools/sat/pb_constraint_test.cc @@ -0,0 +1,673 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/pb_constraint.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ContainerEq; + +template +auto LiteralsAre(Args... literals) { + return ::testing::ElementsAre(Literal(literals)...); +} + +std::vector MakePb( + absl::Span> input) { + std::vector result; + result.reserve(input.size()); + for (const auto p : input) { + result.push_back({Literal(p.first), p.second}); + } + return result; +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, RemoveZeroCoefficient) { + Coefficient bound_shift, max_value; + auto cst = MakePb({{+1, 4}, {+2, 0}, {+3, 4}, {+5, 0}}); + const auto result = MakePb({{+1, 4}, {+3, 4}}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, 0); + EXPECT_EQ(max_value, 8); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, MakeAllCoefficientPositive) { + Coefficient bound_shift, max_value; + auto cst = MakePb({{+1, 4}, {+2, -3}, {+3, 4}, {+5, -1}}); + const auto result = MakePb({{-5, 1}, {-2, 3}, {+1, 4}, {+3, 4}}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, 4); + EXPECT_EQ(max_value, 12); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, MergeSameVariableCase1) { + Coefficient bound_shift, max_value; + // 4x -3(1-x) +4(1-x) -x is the same as to 2x + 1 + auto cst = MakePb({{+1, 4}, {-1, -3}, {-1, 4}, {+1, -1}}); + const auto result = MakePb({{+1, 2}}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, -1); + EXPECT_EQ(max_value, 2); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, MergeSameVariableCase2) { + Coefficient bound_shift, max_value; + // 4x -3(1-x) +4(1-x) -5x is the same as to -2x + 1 + // which is expressed as 2(1-x) -2 +1 + auto cst = MakePb({{+1, 4}, {-1, -3}, {-1, 4}, {+1, -5}}); + const auto result = MakePb({{-1, 2}}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, 1); + EXPECT_EQ(max_value, 2); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, MergeSameVariableCase3) { + Coefficient bound_shift, max_value; + // Here the last variable will disappear completely + auto cst = MakePb({{+1, 4}, {+2, -3}, {+2, 4}, {+2, -1}}); + const auto result = MakePb({{+1, 4}}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, 0); + EXPECT_EQ(max_value, 4); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, Overflow) { + Coefficient bound_shift, max_value; + auto cst = MakePb({{+1, -kCoefficientMax}, {+2, -kCoefficientMax}}); + EXPECT_FALSE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); +} + +TEST(ComputeBooleanLinearExpressionCanonicalForm, BigIntCase) { + Coefficient bound_shift, max_value; + auto cst = MakePb({{+1, -kCoefficientMax}, {-1, -kCoefficientMax}}); + const auto result = MakePb({}); + EXPECT_TRUE(ComputeBooleanLinearExpressionCanonicalForm(&cst, &bound_shift, + &max_value)); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, kCoefficientMax); + EXPECT_EQ(max_value, 0); +} + +TEST(ApplyLiteralMappingTest, BasicTest) { + Coefficient bound_shift, max_value; + + // This is needed to initizalize the ITIVector below. + std::vector temp{ + kTrueLiteralIndex, kFalseLiteralIndex, // var1 fixed to true. + Literal(-1).Index(), Literal(+1).Index(), // var2 mapped to not(var1) + Literal(+2).Index(), Literal(-2).Index(), // var3 mapped to var2 + kFalseLiteralIndex, kTrueLiteralIndex, // var4 fixed to false + Literal(+2).Index(), Literal(-2).Index()}; // var5 mapped to var2 + util_intops::StrongVector mapping(temp.begin(), + temp.end()); + + auto cst = MakePb({{+1, 4}, {+3, -3}, {+2, 4}, {+4, 7}, {+5, 5}}); + EXPECT_TRUE(ApplyLiteralMapping(mapping, &cst, &bound_shift, &max_value)); + const auto result = MakePb({{+2, 2}, {-1, 4}}); + EXPECT_THAT(cst, ContainerEq(result)); + EXPECT_EQ(bound_shift, -4); + EXPECT_EQ(max_value, 6); +} + +TEST(SimplifyCanonicalBooleanLinearConstraint, CoefficientsLargerThanRhs) { + auto cst = MakePb({{+1, 4}, {+2, 5}, {+3, 6}, {-4, 7}}); + Coefficient rhs(10); + SimplifyCanonicalBooleanLinearConstraint(&cst, &rhs); + EXPECT_THAT(cst, ContainerEq(cst)); + rhs = Coefficient(5); + SimplifyCanonicalBooleanLinearConstraint(&cst, &rhs); + const auto result = MakePb({{+1, 4}, {+2, 5}, {+3, 6}, {-4, 6}}); + EXPECT_THAT(cst, ContainerEq(result)); +} + +TEST(CanonicalBooleanLinearProblem, BasicTest) { + auto cst = MakePb({{+1, 4}, {+2, -5}, {+3, 6}, {-4, 7}}); + CanonicalBooleanLinearProblem problem; + problem.AddLinearConstraint(true, Coefficient(-5), true, Coefficient(5), + &cst); + + // We have just one constraint because the >= -5 is always true. + EXPECT_EQ(1, problem.NumConstraints()); + const auto result0 = MakePb({{+1, 4}, {-2, 5}, {+3, 6}, {-4, 7}}); + EXPECT_EQ(problem.Rhs(0), 10); + EXPECT_THAT(problem.Constraint(0), ContainerEq(result0)); + + // So lets restrict it and only use the lower bound + // Note that the API destroy the input so we have to reconstruct it. + cst = MakePb({{+1, 4}, {+2, -5}, {+3, 6}, {-4, 7}}); + problem.AddLinearConstraint(true, Coefficient(-4), false, + /*unused*/ Coefficient(-10), &cst); + + // Now we have another constraint corresponding to the >= -4 constraint. + EXPECT_EQ(2, problem.NumConstraints()); + const auto result1 = MakePb({{-1, 4}, {+2, 5}, {-3, 6}, {+4, 7}}); + EXPECT_EQ(problem.Rhs(1), 21); + EXPECT_THAT(problem.Constraint(1), ContainerEq(result1)); +} + +TEST(CanonicalBooleanLinearProblem, BasicTest2) { + auto cst = MakePb({{+1, 1}, {+2, 2}}); + CanonicalBooleanLinearProblem problem; + problem.AddLinearConstraint(true, Coefficient(2), false, + /*unused*/ Coefficient(0), &cst); + + EXPECT_EQ(1, problem.NumConstraints()); + const auto result = MakePb({{-1, 1}, {-2, 2}}); + EXPECT_EQ(problem.Rhs(0), 1); + EXPECT_THAT(problem.Constraint(0), ContainerEq(result)); +} + +TEST(CanonicalBooleanLinearProblem, OverflowCases) { + auto cst = MakePb({}); + CanonicalBooleanLinearProblem problem; + for (int i = 0; i < 2; ++i) { + std::vector reference; + if (i == 0) { + // This is a constraint with a "bound shift" of 10. + reference = MakePb({{+1, -10}, {+2, 10}}); + } else { + // This is a constraint with a "bound shift" of -10 since its domain value + // is actually [10, 10]. + reference = MakePb({{+1, 10}, {-1, 10}}); + } + + // All These constraint are trivially satisfiables, so no new constraints + // should be added. + cst = reference; + EXPECT_TRUE(problem.AddLinearConstraint(true, -kCoefficientMax, true, + kCoefficientMax, &cst)); + cst = reference; + EXPECT_TRUE(problem.AddLinearConstraint(true, -kCoefficientMax - 1, true, + kCoefficientMax, &cst)); + cst = reference; + EXPECT_TRUE(problem.AddLinearConstraint(true, Coefficient(-10), true, + Coefficient(10), &cst)); + + // These are trivially unsat, and all AddLinearConstraint() should return + // false. + cst = reference; + EXPECT_FALSE(problem.AddLinearConstraint(true, kCoefficientMax, true, + kCoefficientMax, &cst)); + cst = reference; + EXPECT_FALSE(problem.AddLinearConstraint(true, -kCoefficientMax, true, + -kCoefficientMax, &cst)); + cst = reference; + EXPECT_FALSE(problem.AddLinearConstraint( + true, -kCoefficientMax, true, -kCoefficientMax - Coefficient(1), &cst)); + } + + // No constraints were actually added. + EXPECT_EQ(problem.NumConstraints(), 0); +} + +// Constructs a vector from the current trail, so we can use LiteralsAre(). +std::vector TrailToVector(const Trail& trail) { + std::vector output; + for (int i = 0; i < trail.Index(); ++i) output.push_back(trail[i]); + return output; +} + +TEST(UpperBoundedLinearConstraintTest, ConstructionAndBasicPropagation) { + Coefficient threshold; + PbConstraintsEnqueueHelper helper; + helper.reasons.resize(10); + Trail trail; + trail.Resize(10); + + UpperBoundedLinearConstraint cst( + MakePb({{+1, 4}, {+2, 4}, {-3, 5}, {+4, 10}})); + cst.InitializeRhs(Coefficient(7), 0, &threshold, &trail, &helper); + EXPECT_EQ(threshold, 2); + EXPECT_THAT(TrailToVector(trail), LiteralsAre(-4)); + + trail.Enqueue(Literal(-3), AssignmentType::kSearchDecision); + threshold -= 5; // The coeff of -3 in cst. + EXPECT_TRUE(cst.Propagate(trail.Info(Literal(-3).Variable()).trail_index, + &threshold, &trail, &helper)); + EXPECT_EQ(threshold, 2); + EXPECT_THAT(TrailToVector(trail), LiteralsAre(-4, -3, -1, -2)); + + // Untrail. + trail.Untrail(0); + threshold += 5; + cst.Untrail(&threshold, 0); + EXPECT_EQ(threshold, 2); +} + +TEST(UpperBoundedLinearConstraintTest, Conflict) { + Coefficient threshold; + Trail trail; + trail.Resize(10); + PbConstraintsEnqueueHelper helper; + helper.reasons.resize(10); + + // At most one constraint. + UpperBoundedLinearConstraint cst( + MakePb({{+1, 1}, {+2, 1}, {+3, 1}, {+4, 1}})); + cst.InitializeRhs(Coefficient(1), 0, &threshold, &trail, &helper); + EXPECT_EQ(threshold, 0); + + // Two assignment from other part of the solver. + trail.SetDecisionLevel(1); + trail.Enqueue(Literal(+1), AssignmentType::kSearchDecision); + trail.SetDecisionLevel(2); + trail.Enqueue(Literal(+2), AssignmentType::kSearchDecision); + + // We propagate only +1. + threshold -= 1; + EXPECT_FALSE(cst.Propagate(trail.Info(Literal(+1).Variable()).trail_index, + &threshold, &trail, &helper)); + EXPECT_THAT(helper.conflict, LiteralsAre(-1, -2)); +} + +TEST(UpperBoundedLinearConstraintTest, CompactReason) { + Coefficient threshold; + Trail trail; + trail.Resize(10); + PbConstraintsEnqueueHelper helper; + helper.reasons.resize(10); + + // At most one constraint. + UpperBoundedLinearConstraint cst( + MakePb({{+1, 1}, {+2, 2}, {+3, 3}, {+4, 4}})); + cst.InitializeRhs(Coefficient(7), 0, &threshold, &trail, &helper); + EXPECT_EQ(threshold, 3); + + // Two assignment from other part of the solver. + trail.SetDecisionLevel(1); + trail.Enqueue(Literal(+1), AssignmentType::kSearchDecision); + trail.SetDecisionLevel(2); + trail.Enqueue(Literal(+2), AssignmentType::kSearchDecision); + trail.SetDecisionLevel(3); + trail.Enqueue(Literal(+3), AssignmentType::kSearchDecision); + + // We propagate when +3 is processed. + threshold = -3; + const int source_trail_index = trail.Info(Literal(+3).Variable()).trail_index; + EXPECT_TRUE(cst.Propagate(source_trail_index, &threshold, &trail, &helper)); + EXPECT_EQ(trail.Index(), 4); + EXPECT_EQ(trail[3], Literal(-4)); + + // -1 do not need to be in the reason since {-3, -2} propagates exactly + // the same way. + cst.FillReason(trail, source_trail_index, Literal(-4).Variable(), + &helper.conflict); + EXPECT_THAT(helper.conflict, LiteralsAre(-3, -2)); +} + +TEST(PbConstraintsTest, Duplicates) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + csts.Resize(10); + + CHECK_EQ(csts.NumberOfConstraints(), 0); + csts.AddConstraint(MakePb({{-1, 7}, {-2, 7}, {+3, 7}}), Coefficient(20), + &trail); + csts.AddConstraint(MakePb({{-1, 1}, {-2, 3}, {+3, 7}}), Coefficient(20), + &trail); + CHECK_EQ(csts.NumberOfConstraints(), 2); + + // Adding the same constraints will do nothing. + csts.AddConstraint(MakePb({{-1, 7}, {-2, 7}, {+3, 7}}), Coefficient(20), + &trail); + CHECK_EQ(csts.NumberOfConstraints(), 2); + CHECK_EQ(trail.Index(), 0); + + // Over constraining it will fix the 3 literals. + csts.AddConstraint(MakePb({{-1, 7}, {-2, 7}, {+3, 7}}), Coefficient(6), + &trail); + CHECK_EQ(csts.NumberOfConstraints(), 2); + EXPECT_THAT(TrailToVector(trail), LiteralsAre(+1, +2, -3)); +} + +TEST(PbConstraintsTest, BasicPropagation) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + trail.SetDecisionLevel(1); + trail.Enqueue(Literal(-1), AssignmentType::kSearchDecision); + + csts.Resize(10); + csts.AddConstraint(MakePb({{-1, 1}, {+2, 1}}), Coefficient(1), &trail); + csts.AddConstraint(MakePb({{-1, 7}, {-2, 7}, {+3, 7}}), Coefficient(20), + &trail); + csts.AddConstraint(MakePb({{-1, 1}, {-2, 1}, {-3, 1}, {+4, 1}}), + Coefficient(3), &trail); + + EXPECT_THAT(TrailToVector(trail), LiteralsAre(-1, -2)); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_THAT(TrailToVector(trail), LiteralsAre(-1, -2, -3, -4)); + + // Test the reason for each assignment. + EXPECT_THAT(trail.Reason(Literal(-2).Variable()), LiteralsAre(+1)); + EXPECT_THAT(trail.Reason(Literal(-3).Variable()), LiteralsAre(+2, +1)); + EXPECT_THAT(trail.Reason(Literal(-4).Variable()), LiteralsAre(+3, +2, +1)); + + // Untrail, and repropagate everything. + csts.Untrail(trail, 0); + trail.Untrail(0); + trail.Enqueue(Literal(-1), AssignmentType::kSearchDecision); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_THAT(TrailToVector(trail), LiteralsAre(-1, -2, -3, -4)); +} + +TEST(PbConstraintsTest, BasicDeletion) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + PbConstraintsEnqueueHelper helper; + helper.reasons.resize(10); + trail.Resize(10); + trail.SetDecisionLevel(0); + csts.Resize(10); + csts.AddConstraint(MakePb({{-1, 1}, {+2, 1}}), Coefficient(1), &trail); + csts.AddConstraint(MakePb({{-1, 7}, {-2, 7}, {+3, 7}}), Coefficient(20), + &trail); + csts.AddConstraint(MakePb({{-1, 1}, {-2, 1}, {-3, 1}, {+4, 1}}), + Coefficient(3), &trail); + + // Delete the first constraint. + EXPECT_EQ(3, csts.NumberOfConstraints()); + csts.DeleteConstraint(0); + EXPECT_EQ(2, csts.NumberOfConstraints()); + + // The constraint 1 is deleted, so enqueuing -1 shouldn't propagate. + trail.Enqueue(Literal(-1), AssignmentType::kSearchDecision); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_EQ("-1", trail.DebugString()); + + // But also enqueing -2 should. + trail.Enqueue(Literal(-2), AssignmentType::kSearchDecision); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_EQ("-1 -2 -3 -4", trail.DebugString()); + + // Let's bactrack. + trail.Untrail(1); + csts.Untrail(trail, 1); + + // Let's delete one more constraint. + csts.DeleteConstraint(0); + EXPECT_EQ(1, csts.NumberOfConstraints()); + + // Now, if we enqueue -2 again, nothing is propagated. + trail.Enqueue(Literal(-2), AssignmentType::kSearchDecision); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_EQ("-1 -2", trail.DebugString()); + + // We need to also enqueue -3 for -4 to be propagated. + trail.Enqueue(Literal(-3), AssignmentType::kSearchDecision); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + EXPECT_EQ("-1 -2 -3 -4", trail.DebugString()); + + // Deleting everything doesn't crash. + csts.DeleteConstraint(0); + EXPECT_EQ(0, csts.NumberOfConstraints()); +} + +TEST(PbConstraintsTest, UnsatAtConstruction) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + trail.SetDecisionLevel(1); + trail.Enqueue(Literal(+1), AssignmentType::kUnitReason); + trail.Enqueue(Literal(+2), AssignmentType::kUnitReason); + trail.Enqueue(Literal(+3), AssignmentType::kUnitReason); + + csts.Resize(10); + + EXPECT_TRUE( + csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}}), Coefficient(2), &trail)); + while (!csts.PropagationIsDone(trail)) EXPECT_TRUE(csts.Propagate(&trail)); + + // We need to propagate before adding this constraint for the AddConstraint() + // to notice that it is unsat. Otherwise, it will be noticed at propagation + // time. + EXPECT_FALSE(csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}, {+3, 1}}), + Coefficient(2), &trail)); + EXPECT_TRUE(csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}, {+4, 1}}), + Coefficient(2), &trail)); +} + +TEST(PbConstraintsTest, AddConstraintWithLevel0Propagation) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + trail.SetDecisionLevel(0); + csts.Resize(10); + + EXPECT_TRUE(csts.AddConstraint(MakePb({{+1, 1}, {+2, 3}, {+3, 7}}), + Coefficient(2), &trail)); + EXPECT_EQ(trail.Index(), 2); + EXPECT_EQ(trail[0], Literal(-2)); + EXPECT_EQ(trail[1], Literal(-3)); +} + +TEST(PbConstraintsTest, AddConstraintUMR) { + const auto cst = MakePb({{+3, 7}}); + UpperBoundedLinearConstraint c(cst); + // Calling hashing on c generates an UMR that is triggered during the hash_map + // lookup below. + const uint64_t ct_hash = c.hash(); + absl::flat_hash_map> store; + std::vector& vec = store[ct_hash]; + EXPECT_EQ(vec.size(), 0); +} + +TEST(PbConstraintsDeathTest, AddConstraintWithLevel0PropagationInSearch) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + trail.SetDecisionLevel(10); + csts.Resize(10); + + // If the decision level is not 0, this will fail. + ASSERT_DEATH(csts.AddConstraint(MakePb({{+1, 1}, {+2, 3}, {+3, 7}}), + Coefficient(2), &trail), + "var should have been propagated at an earlier level."); +} + +TEST(PbConstraintsDeathTest, AddConstraintPrecondition) { + Model model; + PbConstraints& csts = *(model.GetOrCreate()); + Trail& trail = *(model.GetOrCreate()); + + trail.Resize(10); + trail.SetDecisionLevel(1); + trail.Enqueue(Literal(+1), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(+2), AssignmentType::kUnitReason); + trail.SetDecisionLevel(2); + trail.Enqueue(Literal(+3), AssignmentType::kSearchDecision); + csts.Resize(10); + + // We can't add this constraint since it is conflicting under the current + // assignment. + EXPECT_FALSE(csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}, {+3, 1}}), + Coefficient(2), &trail)); + + trail.Untrail(trail.Index() - 1); // Remove the +3. + EXPECT_EQ(trail.Index(), 2); + csts.Untrail(trail, 2); + + // Adding this one at a decision level of 2 will also fail because it will + // propagate 3 from decision level 1. + ASSERT_DEATH(csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}, {+3, 2}}), + Coefficient(3), &trail), + "var should have been propagated at an earlier level."); + + // However, adding the same constraint while the decision level is 1 is ok. + // It will propagate -3 at the correct decision level. + trail.SetDecisionLevel(1); + EXPECT_TRUE(csts.AddConstraint(MakePb({{+1, 1}, {+2, 1}, {+3, 2}}), + Coefficient(3), &trail)); + EXPECT_EQ(trail.Index(), 3); + EXPECT_EQ(trail[2], Literal(-3)); +} + +TEST(MutableUpperBoundedLinearConstraintTest, LinearAddition) { + MutableUpperBoundedLinearConstraint cst_a; + cst_a.ClearAndResize(5); + cst_a.AddTerm(Literal(+1), Coefficient(3)); + cst_a.AddTerm(Literal(+2), Coefficient(4)); + cst_a.AddTerm(Literal(+3), Coefficient(5)); + cst_a.AddTerm(Literal(+4), Coefficient(1)); + cst_a.AddTerm(Literal(+5), Coefficient(1)); + cst_a.AddToRhs(Coefficient(10)); + + // The result of cst_a + cst_b is describes in the comments. + MutableUpperBoundedLinearConstraint cst_b; + cst_b.ClearAndResize(5); + cst_b.AddTerm(Literal(+1), Coefficient(3)); // 3x + 3x = 6x + cst_b.AddTerm(Literal(-2), Coefficient(3)); // 4x + 3(1-x) = x + 3 + cst_b.AddTerm(Literal(+3), Coefficient(3)); // 5x + 3x = 8x + cst_b.AddTerm(Literal(-4), Coefficient(6)); // x + 6(1-x) = 5(1-x) + 1 + cst_b.AddTerm(Literal(+5), Coefficient(5)); // x + 5x = 6x + cst_b.AddToRhs(Coefficient(10)); + + for (BooleanVariable var : cst_b.PossibleNonZeros()) { + cst_a.AddTerm(cst_b.GetLiteral(var), cst_b.GetCoefficient(var)); + } + cst_a.AddToRhs(cst_b.Rhs()); + + EXPECT_EQ(cst_a.DebugString(), "6[+1] + 1[+2] + 8[+3] + 5[-4] + 6[+5] <= 16"); +} + +TEST(MutableUpperBoundedLinearConstraintTest, ReduceCoefficients) { + MutableUpperBoundedLinearConstraint cst; + cst.ClearAndResize(100); + Coefficient max_value(0); + for (int i = 1; i <= 10; ++i) { + max_value += Coefficient(i); + cst.AddTerm(Literal(BooleanVariable(i), true), Coefficient(i)); + } + cst.AddToRhs(max_value - 3); + + // The constraint is equivalent to sum i * Literal(i, false) >= 3, + // So we can reduce any coeff > 3 to 3 and change the rhs accordingly. + cst.ReduceCoefficients(); + for (BooleanVariable var : cst.PossibleNonZeros()) { + EXPECT_LE(cst.GetCoefficient(var), 3); + } + EXPECT_EQ(cst.Rhs(), 1 + 2 + 3 * 8 - 3); +} + +TEST(MutableUpperBoundedLinearConstraintTest, ComputeSlackForTrailPrefix) { + MutableUpperBoundedLinearConstraint cst; + cst.ClearAndResize(100); + cst.AddTerm(Literal(+1), Coefficient(3)); + cst.AddTerm(Literal(+2), Coefficient(4)); + cst.AddTerm(Literal(+3), Coefficient(5)); + cst.AddTerm(Literal(+4), Coefficient(6)); + cst.AddTerm(Literal(+5), Coefficient(7)); + cst.AddToRhs(Coefficient(10)); + + Trail trail; + trail.Resize(10); + trail.Enqueue(Literal(+1), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(-2), AssignmentType::kUnitReason); + trail.Enqueue(Literal(+3), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(-5), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(+4), AssignmentType::kSearchDecision); + + EXPECT_EQ(Coefficient(10), cst.ComputeSlackForTrailPrefix(trail, 0)); + EXPECT_EQ(Coefficient(10 - 3), cst.ComputeSlackForTrailPrefix(trail, 1)); + EXPECT_EQ(Coefficient(10 - 3), cst.ComputeSlackForTrailPrefix(trail, 2)); + EXPECT_EQ(Coefficient(10 - 3 - 5), cst.ComputeSlackForTrailPrefix(trail, 3)); + EXPECT_EQ(Coefficient(10 - 3 - 5), cst.ComputeSlackForTrailPrefix(trail, 4)); + EXPECT_EQ(Coefficient(10 - 14), cst.ComputeSlackForTrailPrefix(trail, 5)); + EXPECT_EQ(Coefficient(10 - 14), cst.ComputeSlackForTrailPrefix(trail, 50)); +} + +TEST(MutableUpperBoundedLinearConstraintTest, ReduceSlackToZero) { + MutableUpperBoundedLinearConstraint cst; + cst.ClearAndResize(100); + cst.AddTerm(Literal(+1), Coefficient(3)); + cst.AddTerm(Literal(+2), Coefficient(1)); + cst.AddTerm(Literal(+3), Coefficient(5)); + cst.AddTerm(Literal(+4), Coefficient(6)); + cst.AddTerm(Literal(+5), Coefficient(7)); + cst.AddToRhs(Coefficient(10)); + + Trail trail; + trail.Resize(10); + trail.Enqueue(Literal(+1), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(-2), AssignmentType::kUnitReason); + trail.Enqueue(Literal(+3), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(+5), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(+4), AssignmentType::kSearchDecision); + + // +1, -2 and +3 gives a slack of 2. + EXPECT_EQ(Coefficient(2), cst.ComputeSlackForTrailPrefix(trail, 3)); + + // It also propagate -4 and -5, to have the same propagation but with a slack + // of zero, we can call ReduceSlackToZero(). + cst.ReduceSlackTo(trail, 3, Coefficient(2), Coefficient(0)); + + // +1 and +3 have the same coeff. + EXPECT_EQ(cst.GetCoefficient(BooleanVariable(0)), Coefficient(3)); + EXPECT_EQ(cst.GetCoefficient(BooleanVariable(2)), Coefficient(5)); + + // the variable 1 disappeared. + EXPECT_EQ(cst.GetCoefficient(BooleanVariable(1)), Coefficient(0)); + + // The propagated variable coeff has been reduced by the slack. + EXPECT_EQ(cst.GetCoefficient(BooleanVariable(3)), Coefficient(6 - 2)); + EXPECT_EQ(cst.GetCoefficient(BooleanVariable(4)), Coefficient(7 - 2)); + + // The rhs has been reduced by slack, and the slack is now 0. + EXPECT_EQ(cst.Rhs(), Coefficient(10 - 2)); + EXPECT_EQ(Coefficient(0), cst.ComputeSlackForTrailPrefix(trail, 3)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/precedences_test.cc b/ortools/sat/precedences_test.cc new file mode 100644 index 00000000000..a18cfe26988 --- /dev/null +++ b/ortools/sat/precedences_test.cc @@ -0,0 +1,592 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/precedences.h" + +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/sorted_interval_list.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +// A simple macro to make the code more readable. +// TODO(user): move that in a common place. test_utils? +#define EXPECT_BOUNDS_EQ(var, lb, ub) \ + EXPECT_EQ(integer_trail->LowerBound(var), lb); \ + EXPECT_EQ(integer_trail->UpperBound(var), ub) + +// All the tests here uses 10 integer variables initially in [0, 100]. +std::vector AddVariables(IntegerTrail* integer_trail) { + std::vector vars; + const int num_variables = 10; + const IntegerValue lower_bound(0); + const IntegerValue upper_bound(100); + for (int i = 0; i < num_variables; ++i) { + vars.push_back(integer_trail->AddIntegerVariable(lower_bound, upper_bound)); + } + return vars; +} + +TEST(PrecedenceRelationsTest, BasicAPI) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + const std::vector vars = AddVariables(integer_trail); + + // Note that odd indices are for the negation. + IntegerVariable a(0), b(2), c(4), d(6); + + PrecedenceRelations precedences(&model); + precedences.Add(a, b, 10); + precedences.Add(d, c, 7); + precedences.Add(b, d, 5); + + precedences.Build(); + EXPECT_EQ(precedences.GetOffset(a, b), 10); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), NegationOf(a)), 10); + EXPECT_EQ(precedences.GetOffset(a, c), 22); + EXPECT_EQ(precedences.GetOffset(NegationOf(c), NegationOf(a)), 22); + EXPECT_EQ(precedences.GetOffset(a, d), 15); + EXPECT_EQ(precedences.GetOffset(NegationOf(d), NegationOf(a)), 15); + EXPECT_EQ(precedences.GetOffset(d, a), kMinIntegerValue); + + // Once built, we can update the offsets. + // Note however that this would not propagate through the precedence graphs. + precedences.Add(a, b, 15); + EXPECT_EQ(precedences.GetOffset(a, b), 15); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), NegationOf(a)), 15); +} + +TEST(PrecedenceRelationsTest, CornerCase1) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + const std::vector vars = AddVariables(integer_trail); + + // Note that odd indices are for the negation. + IntegerVariable a(0), b(2), c(4), d(6); + + PrecedenceRelations precedences(&model); + precedences.Add(a, b, 10); + precedences.Add(b, c, 7); + precedences.Add(b, d, 5); + precedences.Add(NegationOf(b), a, 5); + + precedences.Build(); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), a), 5); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), b), 15); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), c), 22); + EXPECT_EQ(precedences.GetOffset(NegationOf(b), d), 20); +} + +TEST(PrecedenceRelationsTest, CornerCase2) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + const std::vector vars = AddVariables(integer_trail); + + // Note that odd indices are for the negation. + IntegerVariable a(0), b(2), c(4), d(6); + + PrecedenceRelations precedences(&model); + precedences.Add(NegationOf(a), a, 10); + precedences.Add(a, b, 7); + precedences.Add(a, c, 5); + precedences.Add(a, d, 2); + + precedences.Build(); + EXPECT_EQ(precedences.GetOffset(NegationOf(a), a), 10); + EXPECT_EQ(precedences.GetOffset(NegationOf(a), b), 17); + EXPECT_EQ(precedences.GetOffset(NegationOf(a), c), 15); + EXPECT_EQ(precedences.GetOffset(NegationOf(a), d), 12); +} + +TEST(PrecedenceRelationsTest, ConditionalRelations) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* integer_trail = model.GetOrCreate(); + const std::vector vars = AddVariables(integer_trail); + + const Literal l(model.Add(NewBooleanVariable()), true); + EXPECT_TRUE(sat_solver->EnqueueDecisionIfNotConflicting(l)); + + // Note that odd indices are for the negation. + IntegerVariable a(0), b(2); + PrecedenceRelations precedences(&model); + precedences.PushConditionalRelation({l}, a, b, 15); + precedences.PushConditionalRelation({l}, a, b, 20); + + // We only keep the best one. + EXPECT_EQ(precedences.GetConditionalOffset(a, NegationOf(b)), -15); + EXPECT_THAT(precedences.GetConditionalEnforcements(a, NegationOf(b)), + ElementsAre(l)); + + // Backtrack works. + EXPECT_TRUE(sat_solver->ResetToLevelZero()); + EXPECT_EQ(precedences.GetConditionalOffset(a, NegationOf(b)), + kMinIntegerValue); + EXPECT_THAT(precedences.GetConditionalEnforcements(a, NegationOf(b)), + ElementsAre()); +} + +TEST(PrecedencesPropagatorTest, Empty) { + Model model; + Trail* trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_TRUE(propagator->Propagate(trail)); + propagator->Untrail(*trail, 0); +} + +TEST(PrecedencesPropagatorTest, BasicPropagationTest) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithOffset(vars[0], vars[1], IntegerValue(4)); + propagator->AddPrecedenceWithOffset(vars[0], vars[2], IntegerValue(8)); + propagator->AddPrecedenceWithOffset(vars[1], vars[2], IntegerValue(10)); + + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_BOUNDS_EQ(vars[0], 0, 86); + EXPECT_BOUNDS_EQ(vars[1], 4, 90); + EXPECT_BOUNDS_EQ(vars[2], 14, 100); + + // Lets now move vars[1] lower bound. + std::vector lr; + std::vector ir; + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(vars[1], IntegerValue(20)), lr, ir)); + + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_BOUNDS_EQ(vars[1], 20, 90); + EXPECT_BOUNDS_EQ(vars[2], 30, 100); +} + +TEST(PrecedencesPropagatorTest, PropagationTestWithVariableOffset) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithVariableOffset(vars[0], vars[1], vars[2]); + + // Make var[2] >= 10 and propagate + std::vector lr; + std::vector ir; + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(vars[2], IntegerValue(10)), lr, ir)); + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_BOUNDS_EQ(vars[0], 0, 90); + EXPECT_BOUNDS_EQ(vars[1], 10, 100); + + // Change the lower bound to 40 and propagate again. + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(vars[2], IntegerValue(40)), lr, ir)); + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_BOUNDS_EQ(vars[0], 0, 60); + EXPECT_BOUNDS_EQ(vars[1], 40, 100); +} + +TEST(PrecedencesPropagatorTest, BasicPropagation) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + trail->Resize(10); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithOffset(vars[0], vars[1], IntegerValue(4)); + propagator->AddPrecedenceWithOffset(vars[1], vars[2], IntegerValue(8)); + propagator->AddPrecedenceWithOffset(vars[0], vars[3], IntegerValue(90)); + + // These arcs are not possible, because the upper bound of vars[0] is 10. + propagator->AddConditionalPrecedenceWithOffset(vars[1], vars[0], + IntegerValue(7), Literal(+1)); + propagator->AddConditionalPrecedenceWithOffset(vars[2], vars[0], + IntegerValue(-1), Literal(+2)); + + // These are is ok. + propagator->AddConditionalPrecedenceWithOffset(vars[1], vars[0], + IntegerValue(6), Literal(+3)); + propagator->AddConditionalPrecedenceWithOffset(vars[2], vars[0], + IntegerValue(-2), Literal(+4)); + + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+1))); + EXPECT_TRUE(trail->Assignment().LiteralIsFalse(Literal(+2))); + EXPECT_FALSE(trail->Assignment().VariableIsAssigned(Literal(+3).Variable())); + EXPECT_FALSE(trail->Assignment().VariableIsAssigned(Literal(+4).Variable())); +} + +TEST(PrecedencesPropagatorTest, PropagateOnVariableOffset) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + trail->Resize(10); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithVariableOffset(vars[0], vars[1], vars[2]); + propagator->AddPrecedenceWithOffset(vars[1], vars[3], IntegerValue(50)); + + EXPECT_TRUE(propagator->Propagate(trail)); + EXPECT_BOUNDS_EQ(vars[0], 0, 50); + EXPECT_BOUNDS_EQ(vars[1], 0, 50); + EXPECT_BOUNDS_EQ(vars[2], 0, 50); +} + +TEST(PrecedencesPropagatorTest, Cycles) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + trail->Resize(10); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithOffset(vars[0], vars[1], IntegerValue(4)); + propagator->AddPrecedenceWithOffset(vars[1], vars[2], IntegerValue(8)); + propagator->AddConditionalPrecedenceWithOffset( + vars[2], vars[3], IntegerValue(-10), Literal(+1)); + propagator->AddConditionalPrecedenceWithOffset(vars[3], vars[0], + IntegerValue(-2), Literal(+2)); + propagator->AddConditionalPrecedence(vars[3], vars[0], Literal(+3)); + + // This one will force the upper bound of vars[0] to be 50, so we can + // check that the cycle is detected before the lower bound of var[0] crosses + // this bound. + propagator->AddConditionalPrecedenceWithOffset(vars[0], vars[4], + IntegerValue(50), Literal(+4)); + + // If we add this one, the cycle will be detected using the integer bound and + // not the graph cycle. TODO(user): Maybe this is a bad thing? but it seems + // difficult to avoid it without extra computations. + propagator->AddConditionalPrecedenceWithOffset(vars[0], vars[4], + IntegerValue(99), Literal(+5)); + + EXPECT_TRUE(propagator->Propagate(trail)); + + // Cycle of weight zero is fine. + trail->SetDecisionLevel(1); + EXPECT_TRUE(integer_trail->Propagate(trail)); + trail->Enqueue(Literal(+1), AssignmentType::kUnitReason); + trail->Enqueue(Literal(+2), AssignmentType::kUnitReason); + trail->Enqueue(Literal(+4), AssignmentType::kUnitReason); + EXPECT_TRUE(propagator->Propagate(trail)); + + // But a cycle of positive length is not! + trail->Enqueue(Literal(+3), AssignmentType::kUnitReason); + EXPECT_FALSE(propagator->Propagate(trail)); + EXPECT_THAT(trail->FailingClause(), + UnorderedElementsAre(Literal(-1), Literal(-3))); + + // Test the untrail. + trail->SetDecisionLevel(0); + integer_trail->Untrail(*trail, 0); + propagator->Untrail(*trail, 0); + trail->Untrail(0); + EXPECT_TRUE(propagator->Propagate(trail)); + + // Still fine here. + trail->SetDecisionLevel(1); + EXPECT_TRUE(integer_trail->Propagate(trail)); + trail->Enqueue(Literal(+5), AssignmentType::kUnitReason); + EXPECT_TRUE(propagator->Propagate(trail)); + + // But fail there with a different and longer reason. + trail->Enqueue(Literal(+1), AssignmentType::kUnitReason); + trail->Enqueue(Literal(+3), AssignmentType::kUnitReason); + EXPECT_FALSE(propagator->Propagate(trail)); + EXPECT_THAT(trail->FailingClause(), + UnorderedElementsAre(Literal(-1), Literal(-3), Literal(-5))); +} + +// This test a tricky situation: +// +// vars[0] + (offset = vars[2]) <= var[1] +// vars[1] <= vars[2] !! +TEST(PrecedencesPropagatorTest, TrickyCycle) { + Model model; + Trail* trail = model.GetOrCreate(); + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* propagator = + model.GetOrCreate(); + trail->Resize(10); + + std::vector vars = AddVariables(integer_trail); + propagator->AddPrecedenceWithVariableOffset(vars[0], vars[1], vars[2]); + propagator->AddPrecedence(vars[1], vars[2]); + + // This will cause an infinite cycle. + propagator->AddConditionalPrecedenceWithOffset(vars[3], vars[0], + IntegerValue(1), Literal(+1)); + + // So far so good. + EXPECT_TRUE(propagator->Propagate(trail)); + trail->SetDecisionLevel(1); + EXPECT_TRUE(integer_trail->Propagate(trail)); + + // Conflict. + trail->Enqueue(Literal(+1), AssignmentType::kUnitReason); + EXPECT_FALSE(propagator->Propagate(trail)); + EXPECT_THAT(trail->FailingClause(), ElementsAre(Literal(-1))); + + // Test that the code dectected properly a positive cycle in the dependency + // graph instead of just pushing the bounds until the upper bound is reached. + EXPECT_LT(integer_trail->num_enqueues(), 10); +} + +TEST(PrecedencesPropagatorTest, ZeroWeightCycleOnDiscreteDomain) { + Model model; + IntegerVariable a = model.Add( + NewIntegerVariable(Domain::FromValues({2, 5, 7, 15, 16, 17, 20, 32}))); + IntegerVariable b = model.Add( + NewIntegerVariable(Domain::FromValues({3, 6, 9, 14, 16, 18, 20, 35}))); + + // Add the fact that a == b with two inequalities. + model.Add(LowerOrEqual(a, b)); + model.Add(LowerOrEqual(b, a)); + + // After propagation, we should detect that the only common values fall in + // [16, 20]. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + + // The integer_trail is only used in the macros below. + IntegerTrail* integer_trail = model.GetOrCreate(); + EXPECT_BOUNDS_EQ(a, 16, 20); + EXPECT_BOUNDS_EQ(b, 16, 20); +} + +// This was failing before CL 135903015. +TEST(PrecedencesPropagatorTest, ConditionalPrecedencesOnFixedLiteral) { + Model model; + + // To trigger the old bug, we need to add some precedences. + IntegerVariable x = model.Add(NewIntegerVariable(0, 100)); + IntegerVariable y = model.Add(NewIntegerVariable(50, 100)); + model.Add(LowerOrEqual(x, y)); + + // We then add a Boolean variable and fix it. + // This will trigger a propagation. + BooleanVariable b = model.Add(NewBooleanVariable()); + model.Add(ClauseConstraint({Literal(b, true)})); // Fix b To true. + + // We now add a conditional precedences using the fixed variable. + // This used to not be taken into account. + model.Add(ConditionalLowerOrEqualWithOffset(y, x, 0, Literal(b, true))); + + EXPECT_EQ(SatSolver::FEASIBLE, SolveIntegerProblemWithLazyEncoding(&model)); + EXPECT_EQ(model.Get(Value(x)), model.Get(Value(y))); +} + +#undef EXPECT_BOUNDS_EQ + +TEST(PrecedenceRelationsTest, CollectPrecedences) { + Model model; + auto* integer_trail = model.GetOrCreate(); + auto* relations = model.GetOrCreate(); + + std::vector vars = AddVariables(integer_trail); + relations->Add(vars[0], vars[2], IntegerValue(1)); + relations->Add(vars[0], vars[5], IntegerValue(1)); + relations->Add(vars[1], vars[2], IntegerValue(1)); + relations->Add(vars[2], vars[4], IntegerValue(1)); + relations->Add(vars[3], vars[4], IntegerValue(1)); + relations->Add(vars[4], vars[5], IntegerValue(1)); + + std::vector p; + relations->CollectPrecedences({vars[0], vars[2], vars[3]}, &p); + + // Note that we do not return precedences with just one variable. + std::vector indices; + std::vector variables; + for (const auto precedence : p) { + indices.push_back(precedence.index); + variables.push_back(precedence.var); + } + EXPECT_EQ(indices, (std::vector{1, 2})); + EXPECT_EQ(variables, (std::vector{vars[4], vars[4]})); + + // Same with NegationOf() and also test that p is cleared. + relations->CollectPrecedences({NegationOf(vars[0]), NegationOf(vars[4])}, &p); + EXPECT_TRUE(p.empty()); +} + +TEST(GreaterThanAtLeastOneOfDetectorTest, AddGreaterThanAtLeastOneOf) { + Model model; + const IntegerVariable a = model.Add(NewIntegerVariable(2, 10)); + const IntegerVariable b = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable c = model.Add(NewIntegerVariable(3, 10)); + const IntegerVariable d = model.Add(NewIntegerVariable(0, 10)); + const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); + const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); + const Literal lit_c = Literal(model.Add(NewBooleanVariable()), true); + model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); + + auto* detector = model.GetOrCreate(); + detector->Add(lit_a, {a, -1}, {d, 1}, 2, 1000); // d >= a + 2 + detector->Add(lit_b, {b, -1}, {d, 1}, -1, 1000); // d >= b -1 + detector->Add(lit_c, {c, -1}, {d, 1}, 0, 1000); // d >= c + + auto* solver = model.GetOrCreate(); + EXPECT_TRUE(solver->Propagate()); + EXPECT_EQ(model.Get(LowerBound(d)), 0); + + EXPECT_EQ(1, detector->AddGreaterThanAtLeastOneOfConstraints(&model)); + EXPECT_TRUE(solver->Propagate()); + EXPECT_EQ(model.Get(LowerBound(d)), std::min({2 + 2, 5 - 1, 3 + 0})); +} + +TEST(GreaterThanAtLeastOneOfDetectorTest, + AddGreaterThanAtLeastOneOfWithAutoDetect) { + Model model; + const IntegerVariable a = model.Add(NewIntegerVariable(2, 10)); + const IntegerVariable b = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable c = model.Add(NewIntegerVariable(3, 10)); + const IntegerVariable d = model.Add(NewIntegerVariable(0, 10)); + const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); + const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); + const Literal lit_c = Literal(model.Add(NewBooleanVariable()), true); + model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); + + auto* detector = model.GetOrCreate(); + detector->Add(lit_a, {a, -1}, {d, 1}, 2, 1000); // d >= a + 2 + detector->Add(lit_b, {b, -1}, {d, 1}, -1, 1000); // d >= b -1 + detector->Add(lit_c, {c, -1}, {d, 1}, 0, 1000); // d >= c + + auto* solver = model.GetOrCreate(); + EXPECT_TRUE(solver->Propagate()); + EXPECT_EQ(model.Get(LowerBound(d)), 0); + + EXPECT_EQ(1, detector->AddGreaterThanAtLeastOneOfConstraints( + &model, /*auto_detect_clauses=*/true)); + EXPECT_TRUE(solver->Propagate()); + EXPECT_EQ(model.Get(LowerBound(d)), std::min({2 + 2, 5 - 1, 3 + 0})); +} + +TEST(PrecedencesPropagatorTest, ComputeFullPrecedencesIfCycle) { + Model model; + std::vector vars(10); + for (int i = 0; i < vars.size(); ++i) { + vars[i] = model.Add(NewIntegerVariable(0, 10)); + } + + // Even if the weight are compatible, we will fail here. + model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); + model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); + model.Add(LowerOrEqualWithOffset(vars[2], vars[1], -10)); + model.Add(LowerOrEqualWithOffset(vars[0], vars[2], 5)); + + std::vector precedences; + model.GetOrCreate()->ComputeFullPrecedences( + {vars[0], vars[1]}, &precedences); + EXPECT_TRUE(precedences.empty()); +} + +TEST(PrecedencesPropagatorTest, BasicFiltering) { + Model model; + std::vector vars(10); + for (int i = 0; i < vars.size(); ++i) { + vars[i] = model.Add(NewIntegerVariable(0, 10)); + } + + // 1 + // / \ + // 0 2 -- 4 + // \ / + // 3 + model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); + model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); + model.Add(LowerOrEqualWithOffset(vars[0], vars[3], 1)); + model.Add(LowerOrEqualWithOffset(vars[3], vars[2], 2)); + model.Add(LowerOrEqualWithOffset(vars[2], vars[4], 2)); + + std::vector precedences; + model.GetOrCreate()->ComputeFullPrecedences( + {vars[0], vars[1], vars[3]}, &precedences); + + // We only output size at least 2, and "relevant" precedences. + // So here only vars[2]. + ASSERT_EQ(precedences.size(), 1); + EXPECT_EQ(precedences[0].var, vars[2]); + EXPECT_THAT(precedences[0].offsets, ElementsAre(4, 2, 2)); + EXPECT_THAT(precedences[0].indices, ElementsAre(0, 1, 2)); +} + +TEST(PrecedencesPropagatorTest, BasicFiltering2) { + Model model; + std::vector vars(10); + for (int i = 0; i < vars.size(); ++i) { + vars[i] = model.Add(NewIntegerVariable(0, 10)); + } + + // 1 + // / \ + // 0 2 -- 4 + // \ / / + // 3 5 + model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); + model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); + model.Add(LowerOrEqualWithOffset(vars[0], vars[3], 1)); + model.Add(LowerOrEqualWithOffset(vars[3], vars[2], 2)); + model.Add(LowerOrEqualWithOffset(vars[2], vars[4], 2)); + model.Add(LowerOrEqualWithOffset(vars[5], vars[4], 7)); + + std::vector precedences; + model.GetOrCreate()->ComputeFullPrecedences( + {vars[0], vars[1], vars[3]}, &precedences); + + // Same as before here. + ASSERT_EQ(precedences.size(), 1); + EXPECT_EQ(precedences[0].var, vars[2]); + EXPECT_THAT(precedences[0].offsets, ElementsAre(4, 2, 2)); + EXPECT_THAT(precedences[0].indices, ElementsAre(0, 1, 2)); + + // But if we ask for 5, we will get two results. + precedences.clear(); + model.GetOrCreate()->ComputeFullPrecedences( + {vars[0], vars[1], vars[3], vars[5]}, &precedences); + ASSERT_EQ(precedences.size(), 2); + EXPECT_EQ(precedences[0].var, vars[2]); + EXPECT_THAT(precedences[0].offsets, ElementsAre(4, 2, 2)); + EXPECT_THAT(precedences[0].indices, ElementsAre(0, 1, 2)); + EXPECT_EQ(precedences[1].var, vars[4]); + EXPECT_THAT(precedences[1].offsets, ElementsAre(6, 4, 4, 7)); + EXPECT_THAT(precedences[1].indices, ElementsAre(0, 1, 2, 3)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 7377477eee1..c4789f9bff0 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -33,11 +33,11 @@ #include "absl/numeric/int128.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_loader.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_utils.h" @@ -98,10 +98,13 @@ int PresolveContext::NewIntVarWithDefinition( return new_var; } -int PresolveContext::NewBoolVar() { return NewIntVar(Domain(0, 1)); } +int PresolveContext::NewBoolVar(absl::string_view source) { + UpdateRuleStats(absl::StrCat("new_bool: ", source)); + return NewIntVar(Domain(0, 1)); +} int PresolveContext::NewBoolVarWithClause(absl::Span clause) { - const int new_var = NewBoolVar(); + const int new_var = NewBoolVar("with clause"); if (hint_is_loaded_) { bool all_have_hint = true; for (const int literal : clause) { @@ -125,7 +128,7 @@ int PresolveContext::NewBoolVarWithClause(absl::Span clause) { } } - // If there all literal where hinted and at zero, we set the hint of + // If all literals where hinted and at zero, we set the hint of // new_var to zero, otherwise we leave it unassigned. if (all_have_hint && !hint_has_value_[new_var]) { hint_has_value_[new_var] = true; @@ -599,7 +602,15 @@ void PresolveContext::UpdateRuleStats(const std::string& name, int num_times) { if (!is_todo) num_presolve_operations += num_times; if (logger_->LoggingIsEnabled()) { - VLOG(is_todo ? 3 : 2) << num_presolve_operations << " : " << name; + if (VLOG_IS_ON(1)) { + int level = is_todo ? 3 : 2; + if (std::abs(num_presolve_operations - + params_.debug_max_num_presolve_operations()) <= 100) { + level = 1; + } + VLOG(level) << num_presolve_operations << " : " << name; + } + stats_by_rule_name_[name] += num_times; } } @@ -715,6 +726,7 @@ void PresolveContext::UpdateConstraintVariableUsage(int c) { } bool PresolveContext::ConstraintVariableGraphIsUpToDate() const { + if (is_unsat_) return true; // We do not care in this case. return constraint_to_vars_.size() == working_model->constraints_size(); } @@ -1006,6 +1018,12 @@ bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff, return true; } +void PresolveContext::PermuteHintValues(const SparsePermutation& perm) { + CHECK(hint_is_loaded_); + perm.ApplyToDenseCollection(hint_); + perm.ApplyToDenseCollection(hint_has_value_); +} + bool PresolveContext::StoreAffineRelation(int ref_x, int ref_y, int64_t coeff, int64_t offset, bool debug_no_recursion) { @@ -1361,8 +1379,9 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { max_literal = max_it->second.Get(this); if (min_literal != NegatedRef(max_literal)) { UpdateRuleStats("variables with 2 values: merge encoding literals"); - StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal)); - if (is_unsat_) return; + if (!StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal))) { + return; + } } min_literal = GetLiteralRepresentative(min_literal); max_literal = GetLiteralRepresentative(max_literal); @@ -1379,7 +1398,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { var_map[var_min] = SavedLiteral(min_literal); } else { UpdateRuleStats("variables with 2 values: create encoding literal"); - max_literal = NewBoolVar(); + max_literal = NewBoolVar("var with 2 values"); min_literal = NegatedRef(max_literal); var_map[var_min] = SavedLiteral(min_literal); var_map[var_max] = SavedLiteral(max_literal); @@ -1409,7 +1428,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { } } -void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, +bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var, int64_t value, bool add_constraints) { DCHECK(RefIsPositive(var)); @@ -1436,10 +1455,12 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, if (literal != previous_literal) { UpdateRuleStats( "variables: merge equivalent var value encoding literals"); - StoreBooleanEqualityRelation(literal, previous_literal); + if (!StoreBooleanEqualityRelation(literal, previous_literal)) { + return false; + } } } - return; + return true; } if (DomainOf(var).Size() == 2) { @@ -1451,6 +1472,9 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, AddImplyInDomain(literal, var, Domain(value)); AddImplyInDomain(NegatedRef(literal), var, Domain(value).Complement()); } + + // The canonicalization might have proven UNSAT. + return !ModelIsUnsat(); } bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var, @@ -1474,8 +1498,10 @@ bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var, if (other_set.contains({NegatedRef(literal), var, value})) { UpdateRuleStats("variables: detect fully reified value encoding"); const int imply_eq_literal = imply_eq ? literal : NegatedRef(literal); - InsertVarValueEncodingInternal(imply_eq_literal, var, value, - /*add_constraints=*/false); + if (!InsertVarValueEncodingInternal(imply_eq_literal, var, value, + /*add_constraints=*/false)) { + return false; + } } return true; @@ -1495,7 +1521,10 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int var, return SetLiteralToFalse(literal); } literal = GetLiteralRepresentative(literal); - InsertVarValueEncodingInternal(literal, var, value, /*add_constraints=*/true); + if (!InsertVarValueEncodingInternal(literal, var, value, + /*add_constraints=*/true)) { + return false; + } eq_half_encoding_.insert({literal, var, value}); neq_half_encoding_.insert({NegatedRef(literal), var, value}); @@ -1621,14 +1650,14 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { var_map[0] = SavedLiteral(NegatedRef(representative)); return value == 1 ? representative : NegatedRef(representative); } else { - const int literal = NewBoolVar(); + const int literal = NewBoolVar("integer encoding"); InsertVarValueEncoding(literal, var, var_max); const int representative = GetLiteralRepresentative(literal); return value == var_max ? representative : NegatedRef(representative); } } - const int literal = NewBoolVar(); + const int literal = NewBoolVar("integer encoding"); InsertVarValueEncoding(literal, var, value); return GetLiteralRepresentative(literal); } @@ -2155,7 +2184,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( const auto& it = reified_precedences_cache_.find(key); if (it != reified_precedences_cache_.end()) return it->second; - const int result = NewBoolVar(); + const int result = NewBoolVar("reified precedence"); reified_precedences_cache_[key] = result; // result => (time_i <= time_j) && active_i && active_j. diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index c2e581289df..1dfac184eff 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -28,6 +28,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" @@ -97,7 +98,7 @@ class PresolveContext { // TODO(user): We should control more how this is called so we can update // a solution hint accordingly. int NewIntVar(const Domain& domain); - int NewBoolVar(); + int NewBoolVar(absl::string_view source); // This should replace NewIntVar() eventually in order to be able to crush // primal solution or just update the hint. @@ -574,9 +575,21 @@ class PresolveContext { // the hint, in order to maintain it as best as possible during presolve. void LoadSolutionHint(); + void PermuteHintValues(const SparsePermutation& perm); + // Solution hint accessor. bool VarHasSolutionHint(int var) const { return hint_has_value_[var]; } int64_t SolutionHint(int var) const { return hint_[var]; } + bool HintIsLoaded() const { return hint_is_loaded_; } + absl::Span SolutionHint() const { return hint_; } + + // Allows to set the hint of a newly created variable. + void SetNewVariableHint(int var, int64_t value) { + CHECK(hint_is_loaded_); + CHECK(!hint_has_value_[var]); + hint_has_value_[var] = true; + hint_[var] = value; + } SolverLogger* logger() const { return logger_; } const SatParameters& params() const { return params_; } @@ -654,7 +667,8 @@ class PresolveContext { bool imply_eq); // Insert fully reified var-value encoding. - void InsertVarValueEncodingInternal(int literal, int var, int64_t value, + // Returns false if this make the problem infeasible. + bool InsertVarValueEncodingInternal(int literal, int var, int64_t value, bool add_constraints); SolverLogger* logger_; diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index e11c15b1f70..941c78c8309 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -149,7 +149,8 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { IntegerValue ub_min = kMaxIntegerValue; new_integer_bounds_.push_back(IntegerLiteral()); // Sentinel. - for (int i = 0; i < new_integer_bounds_.size(); ++i) { + const int limit = new_integer_bounds_.size(); + for (int i = 0; i < limit; ++i) { const IntegerVariable var = new_integer_bounds_[i].var; // Hole detection. diff --git a/ortools/sat/probing_test.cc b/ortools/sat/probing_test.cc new file mode 100644 index 00000000000..d57e4744477 --- /dev/null +++ b/ortools/sat/probing_test.cc @@ -0,0 +1,80 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/probing.h" + +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(ProbeBooleanVariablesTest, IntegerBoundInference) { + Model model; + const BooleanVariable a = model.Add(NewBooleanVariable()); + const IntegerVariable b = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable c = model.Add(NewIntegerVariable(0, 10)); + + // Bound restriction. + model.Add(Implication({Literal(a, true)}, + IntegerLiteral::GreaterOrEqual(b, IntegerValue(2)))); + model.Add(Implication({Literal(a, false)}, + IntegerLiteral::GreaterOrEqual(b, IntegerValue(3)))); + model.Add(Implication({Literal(a, true)}, + IntegerLiteral::LowerOrEqual(b, IntegerValue(7)))); + model.Add(Implication({Literal(a, false)}, + IntegerLiteral::LowerOrEqual(b, IntegerValue(9)))); + + // Hole. + model.Add(Implication({Literal(a, true)}, + IntegerLiteral::GreaterOrEqual(c, IntegerValue(7)))); + model.Add(Implication({Literal(a, false)}, + IntegerLiteral::LowerOrEqual(c, IntegerValue(4)))); + + Prober* prober = model.GetOrCreate(); + prober->ProbeBooleanVariables(/*deterministic_time_limit=*/1.0); + auto* integer_trail = model.GetOrCreate(); + EXPECT_EQ("[2,9]", integer_trail->InitialVariableDomain(b).ToString()); + EXPECT_EQ("[0,4][7,10]", integer_trail->InitialVariableDomain(c).ToString()); +} + +TEST(FailedLiteralProbingRoundTest, TrivialExample) { + Model model; + const Literal a(model.Add(NewBooleanVariable()), true); + const Literal b(model.Add(NewBooleanVariable()), true); + const Literal c(model.Add(NewBooleanVariable()), true); + + // Setting a to false will result in a constradiction, so a must be true. + model.Add(ClauseConstraint({a, b, c})); + model.Add(Implication(a.Negated(), b.Negated())); + model.Add(Implication(c, a)); + + auto* sat_soler = model.GetOrCreate(); + EXPECT_TRUE(sat_soler->Propagate()); + EXPECT_FALSE(sat_soler->Assignment().LiteralIsAssigned(a)); + + EXPECT_TRUE(FailedLiteralProbingRound(ProbingOptions(), &model)); + EXPECT_TRUE(sat_soler->Assignment().LiteralIsTrue(a)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/pseudo_costs.cc b/ortools/sat/pseudo_costs.cc index 2afe38e7905..75e424d3a5a 100644 --- a/ortools/sat/pseudo_costs.cc +++ b/ortools/sat/pseudo_costs.cc @@ -100,7 +100,36 @@ bool PseudoCosts::SaveLpInfo() { void PseudoCosts::SaveBoundChanges(Literal decision, absl::Span lp_values) { - bound_changes_ = GetBoundChanges(decision, lp_values); + bound_changes_.clear(); + for (const IntegerLiteral l : encoder_->GetIntegerLiterals(decision)) { + PseudoCosts::VariableBoundChange entry; + entry.var = l.var; + entry.lower_bound_change = l.bound - integer_trail_->LowerBound(l.var); + if (l.var < lp_values.size()) { + entry.lp_increase = + std::max(0.0, ToDouble(l.bound) - lp_values[l.var.value()]); + } + bound_changes_.push_back(entry); + } + + // NOTE: We ignore literal associated to var != value. + for (const auto [var, value] : encoder_->GetEqualityLiterals(decision)) { + { + PseudoCosts::VariableBoundChange entry; + entry.var = var; + entry.lower_bound_change = value - integer_trail_->LowerBound(var); + bound_changes_.push_back(entry); + } + + // Also do the negation. + { + PseudoCosts::VariableBoundChange entry; + entry.var = NegationOf(var); + entry.lower_bound_change = + (-value) - integer_trail_->LowerBound(NegationOf(var)); + bound_changes_.push_back(entry); + } + } } void PseudoCosts::BeforeTakingDecision(Literal decision) { @@ -281,42 +310,5 @@ IntegerVariable PseudoCosts::GetBestDecisionVar() { return chosen_var; } -std::vector PseudoCosts::GetBoundChanges( - Literal decision, absl::Span lp_values) { - std::vector bound_changes; - - for (const IntegerLiteral l : encoder_->GetIntegerLiterals(decision)) { - PseudoCosts::VariableBoundChange entry; - entry.var = l.var; - entry.lower_bound_change = l.bound - integer_trail_->LowerBound(l.var); - if (l.var < lp_values.size()) { - entry.lp_increase = - std::max(0.0, ToDouble(l.bound) - lp_values[l.var.value()]); - } - bound_changes.push_back(entry); - } - - // NOTE: We ignore literal associated to var != value. - for (const auto [var, value] : encoder_->GetEqualityLiterals(decision)) { - { - PseudoCosts::VariableBoundChange entry; - entry.var = var; - entry.lower_bound_change = value - integer_trail_->LowerBound(var); - bound_changes.push_back(entry); - } - - // Also do the negation. - { - PseudoCosts::VariableBoundChange entry; - entry.var = NegationOf(var); - entry.lower_bound_change = - (-value) - integer_trail_->LowerBound(NegationOf(var)); - bound_changes.push_back(entry); - } - } - - return bound_changes; -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/pseudo_costs.h b/ortools/sat/pseudo_costs.h index 3ab96f30e0c..84486f19af2 100644 --- a/ortools/sat/pseudo_costs.h +++ b/ortools/sat/pseudo_costs.h @@ -100,8 +100,9 @@ class PseudoCosts { IntegerValue lower_bound_change = IntegerValue(0); double lp_increase = 0.0; }; - std::vector GetBoundChanges( - Literal decision, absl::Span lp_values); + const std::vector& BoundChanges() { + return bound_changes_; + } private: // Returns the current objective info. diff --git a/ortools/sat/pseudo_costs_test.cc b/ortools/sat/pseudo_costs_test.cc new file mode 100644 index 00000000000..fc0b98ad9d9 --- /dev/null +++ b/ortools/sat/pseudo_costs_test.cc @@ -0,0 +1,263 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/pseudo_costs.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(GetBoundChangeTest, LowerBoundChange) { + Model model; + auto* encoder = model.GetOrCreate(); + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const Literal decision = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(x, IntegerValue(3))); + + PseudoCosts pseudo_costs(&model); + pseudo_costs.SaveBoundChanges(decision, {}); + auto& bound_changes = pseudo_costs.BoundChanges(); + EXPECT_EQ(1, bound_changes.size()); + PseudoCosts::VariableBoundChange bound_change = bound_changes[0]; + EXPECT_EQ(bound_change.var, x); + EXPECT_EQ(bound_change.lower_bound_change, IntegerValue(3)); +} + +TEST(GetBoundChangeTest, UpperBoundChange) { + Model model; + auto* encoder = model.GetOrCreate(); + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const Literal decision = encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(x, IntegerValue(7))); + + PseudoCosts pseudo_costs(&model); + pseudo_costs.SaveBoundChanges(decision, {}); + auto& bound_changes = pseudo_costs.BoundChanges(); + EXPECT_EQ(1, bound_changes.size()); + PseudoCosts::VariableBoundChange bound_change = bound_changes[0]; + EXPECT_EQ(bound_change.var, NegationOf(x)); + EXPECT_EQ(bound_change.lower_bound_change, IntegerValue(3)); +} + +TEST(GetBoundChangeTest, EqualityDecision) { + Model model; + auto* encoder = model.GetOrCreate(); + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + Literal decision(model.GetOrCreate()->NewBooleanVariable(), true); + encoder->AssociateToIntegerEqualValue(decision, x, IntegerValue(6)); + + PseudoCosts pseudo_costs(&model); + pseudo_costs.SaveBoundChanges(decision, {}); + auto& bound_changes = pseudo_costs.BoundChanges(); + EXPECT_EQ(2, bound_changes.size()); + PseudoCosts::VariableBoundChange lower_bound_change = bound_changes[0]; + EXPECT_EQ(lower_bound_change.var, x); + EXPECT_EQ(lower_bound_change.lower_bound_change, IntegerValue(6)); + PseudoCosts::VariableBoundChange upper_bound_change = bound_changes[1]; + EXPECT_EQ(upper_bound_change.var, NegationOf(x)); + EXPECT_EQ(upper_bound_change.lower_bound_change, IntegerValue(4)); +} + +TEST(PseudoCosts, Initialize) { + Model model; + SatParameters* parameters = model.GetOrCreate(); + parameters->set_pseudo_cost_reliability_threshold(1); + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + + PseudoCosts pseudo_costs(&model); + + EXPECT_EQ(0.0, pseudo_costs.GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs.GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs.GetCost(y)); + EXPECT_EQ(0.0, pseudo_costs.GetCost(NegationOf(y))); + EXPECT_EQ(0, pseudo_costs.GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs.GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs.GetNumRecords(y)); + EXPECT_EQ(0, pseudo_costs.GetNumRecords(NegationOf(y))); +} + +namespace { +void SimulateDecision(Literal decision, IntegerValue obj_delta, Model* model) { + const IntegerVariable objective_var = + model->GetOrCreate()->objective_var; + auto* integer_trail = model->GetOrCreate(); + auto* pseudo_costs = model->GetOrCreate(); + + pseudo_costs->BeforeTakingDecision(decision); + const IntegerValue lb = integer_trail->LowerBound(objective_var); + EXPECT_TRUE(integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(objective_var, lb + obj_delta), {}, {})); + pseudo_costs->AfterTakingDecision(); +} +} // namespace + +TEST(PseudoCosts, UpdateCostOfNewVar) { + Model model; + auto* encoder = model.GetOrCreate(); + SatParameters* parameters = model.GetOrCreate(); + parameters->set_pseudo_cost_reliability_threshold(1); + + const IntegerVariable objective_var = model.Add(NewIntegerVariable(0, 100)); + model.GetOrCreate()->objective_var = objective_var; + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + auto* pseudo_costs = model.GetOrCreate(); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(x, IntegerValue(3))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(y, IntegerValue(8))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(3.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(NegationOf(y))); +} + +TEST(PseudoCosts, BasicCostUpdate) { + Model model; + auto* encoder = model.GetOrCreate(); + SatParameters* parameters = model.GetOrCreate(); + parameters->set_pseudo_cost_reliability_threshold(1); + + const IntegerVariable objective_var = model.Add(NewIntegerVariable(0, 100)); + model.GetOrCreate()->objective_var = objective_var; + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + auto* pseudo_costs = model.GetOrCreate(); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(x, IntegerValue(3))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(z)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(z))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(y))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(z)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(z))); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(y, IntegerValue(8))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(3.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(z)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(z))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(NegationOf(y))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(z)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(z))); +} + +TEST(PseudoCosts, PseudoCostReliabilityTest) { + Model model; + auto* encoder = model.GetOrCreate(); + SatParameters* parameters = model.GetOrCreate(); + parameters->set_pseudo_cost_reliability_threshold(2); + + const IntegerVariable objective_var = model.Add(NewIntegerVariable(0, 100)); + model.GetOrCreate()->objective_var = objective_var; + + const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + auto* pseudo_costs = model.GetOrCreate(); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(x, IntegerValue(3))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(y))); + EXPECT_EQ(kNoIntegerVariable, pseudo_costs->GetBestDecisionVar()); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(y, IntegerValue(8))), + IntegerValue(14), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(0.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(7.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(NegationOf(y))); + EXPECT_EQ(kNoIntegerVariable, pseudo_costs->GetBestDecisionVar()); + + SimulateDecision(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(x, IntegerValue(8))), + IntegerValue(6), &model); + + EXPECT_EQ(2.0, pseudo_costs->GetCost(x)); + EXPECT_EQ(3.0, pseudo_costs->GetCost(NegationOf(x))); + EXPECT_EQ(0.0, pseudo_costs->GetCost(y)); + EXPECT_EQ(7.0, pseudo_costs->GetCost(NegationOf(y))); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(x)); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(NegationOf(x))); + EXPECT_EQ(0, pseudo_costs->GetNumRecords(y)); + EXPECT_EQ(1, pseudo_costs->GetNumRecords(NegationOf(y))); + EXPECT_EQ(NegationOf(x), pseudo_costs->GetBestDecisionVar()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index d3fb0878a7a..8bd2aae00bb 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ortools.sat.python.cp_model.""" +import itertools +import time from absl.testing import absltest import pandas as pd @@ -95,30 +96,57 @@ def bool_var_values(self): return self.__bool_var_values +class TimeRecorder(cp_model.CpSolverSolutionCallback): + + def __init__(self) -> None: + super().__init__() + self.__last_time: float = 0.0 + + def on_solution_callback(self) -> None: + self.__last_time = time.time() + + @property + def last_time(self) -> float: + return self.__last_time + + class LogToString: """Record log in a string.""" - def __init__(self): + def __init__(self) -> None: self.__log = "" - def new_message(self, message: str): + def new_message(self, message: str) -> None: self.__log += message self.__log += "\n" @property - def log(self): + def log(self) -> str: return self.__log class BestBoundCallback: - def __init__(self): + def __init__(self) -> None: self.best_bound: float = 0.0 - def new_best_bound(self, bb: float): + def new_best_bound(self, bb: float) -> None: self.best_bound = bb +class BestBoundTimeCallback: + + def __init__(self) -> None: + self.__last_time: float = 0.0 + + def new_best_bound(self, unused_bb: float): + self.__last_time = time.time() + + @property + def last_time(self) -> float: + return self.__last_time + + class CpModelTest(absltest.TestCase): def testCreateIntegerVariable(self): @@ -1649,6 +1677,222 @@ def testIntervalVarSeries(self): ) self.assertLen(model.proto.constraints, 13) + def testIssue4376SatModel(self): + print("testIssue4376SatModel") + letters: str = "BCFLMRT" + + def symbols_from_string(text: str) -> list[int]: + return [letters.index(char) for char in text] + + def rotate_symbols(symbols: list[int], turns: int) -> list[int]: + return symbols[turns:] + symbols[:turns] + + data = """FMRC +FTLB +MCBR +FRTM +FBTM +BRFM +BTRM +BCRM +RTCF +TFRC +CTRM +CBTM +TFBM +TCBM +CFTM +BLTR +RLFM +CFLM +CRML +FCLR +FBTR +TBRF +RBCF +RBCT +BCTF +TFCR +CBRT +FCBT +FRTB +RBCM +MTFC +MFTC +MBFC +RTBM +RBFM +TRFM""" + + tiles = [symbols_from_string(line) for line in data.splitlines()] + + model = cp_model.CpModel() + + # choices[i, x, y, r] is true iff we put tile i in cell (x,y) with + # rotation r. + choices = {} + for i in range(len(tiles)): + for x in range(6): + for y in range(6): + for r in range(4): + choices[(i, x, y, r)] = model.new_bool_var( + f"tile_{i}_{x}_{y}_{r}" + ) + + # corners[x, y, s] is true iff the corner at (x,y) contains symbol s. + corners = {} + for x in range(7): + for y in range(7): + for s in range(7): + corners[(x, y, s)] = model.new_bool_var(f"corner_{x}_{y}_{s}") + + # Placing a tile puts a symbol in each corner. + for (i, x, y, r), choice in choices.items(): + symbols = rotate_symbols(tiles[i], r) + model.add_implication(choice, corners[x, y, symbols[0]]) + model.add_implication(choice, corners[x, y + 1, symbols[1]]) + model.add_implication(choice, corners[x + 1, y + 1, symbols[2]]) + model.add_implication(choice, corners[x + 1, y, symbols[3]]) + + # We must make exactly one choice for each tile. + for i in range(len(tiles)): + tmp_literals = [] + for x in range(6): + for y in range(6): + for r in range(4): + tmp_literals.append(choices[(i, x, y, r)]) + model.add_exactly_one(tmp_literals) + + # We must make exactly one choice for each square. + for x, y in itertools.product(range(6), range(6)): + tmp_literals = [] + for i in range(len(tiles)): + for r in range(4): + tmp_literals.append(choices[(i, x, y, r)]) + model.add_exactly_one(tmp_literals) + + # Each corner contains exactly one symbol. + for x, y in itertools.product(range(7), range(7)): + model.add_exactly_one(corners[x, y, s] for s in range(7)) + + # Solve. + solver = cp_model.CpSolver() + solver.parameters.num_workers = 8 + solver.parameters.max_time_in_seconds = 20 + solver.parameters.log_search_progress = True + solver.parameters.cp_model_presolve = False + solver.parameters.symmetry_level = 0 + + solution_callback = TimeRecorder() + status = solver.Solve(model, solution_callback) + if status == cp_model.OPTIMAL: + self.assertLess(time.time(), solution_callback.last_time + 5.0) + + def testIssue4376MinimizeModel(self): + print("testIssue4376MinimizeModel") + + model = cp_model.CpModel() + + jobs = [ + [3, 3], # [duration, width] + [2, 5], + [1, 3], + [3, 7], + [7, 3], + [2, 2], + [2, 2], + [5, 5], + [10, 2], + [4, 3], + [2, 6], + [1, 2], + [6, 8], + [4, 5], + [3, 7], + ] + + max_width = 10 + + horizon = sum(t[0] for t in jobs) + num_jobs = len(jobs) + all_jobs = range(num_jobs) + + intervals = [] + intervals0 = [] + intervals1 = [] + performed = [] + starts = [] + ends = [] + demands = [] + + for i in all_jobs: + # Create main interval. + start = model.new_int_var(0, horizon, f"start_{i}") + duration = jobs[i][0] + end = model.new_int_var(0, horizon, f"end_{i}") + interval = model.new_interval_var(start, duration, end, f"interval_{i}") + starts.append(start) + intervals.append(interval) + ends.append(end) + demands.append(jobs[i][1]) + + # Create an optional copy of interval to be executed on machine 0. + performed_on_m0 = model.new_bool_var(f"perform_{i}_on_m0") + performed.append(performed_on_m0) + start0 = model.new_int_var(0, horizon, f"start_{i}_on_m0") + end0 = model.new_int_var(0, horizon, f"end_{i}_on_m0") + interval0 = model.new_optional_interval_var( + start0, duration, end0, performed_on_m0, f"interval_{i}_on_m0" + ) + intervals0.append(interval0) + + # Create an optional copy of interval to be executed on machine 1. + start1 = model.new_int_var(0, horizon, f"start_{i}_on_m1") + end1 = model.new_int_var(0, horizon, f"end_{i}_on_m1") + interval1 = model.new_optional_interval_var( + start1, + duration, + end1, + ~performed_on_m0, + f"interval_{i}_on_m1", + ) + intervals1.append(interval1) + + # We only propagate the constraint if the tasks is performed on the + # machine. + model.add(start0 == start).only_enforce_if(performed_on_m0) + model.add(start1 == start).only_enforce_if(~performed_on_m0) + + # Width constraint (modeled as a cumulative) + model.add_cumulative(intervals, demands, max_width) + + # Choose which machine to perform the jobs on. + model.add_no_overlap(intervals0) + model.add_no_overlap(intervals1) + + # Objective variable. + makespan = model.new_int_var(0, horizon, "makespan") + model.add_max_equality(makespan, ends) + model.minimize(makespan) + + # Symmetry breaking. + model.add(performed[0] == 0) + + # Solve. + solver = cp_model.CpSolver() + solver.parameters.num_workers = 8 + solver.parameters.max_time_in_seconds = 50 + solver.parameters.log_search_progress = True + solution_callback = TimeRecorder() + best_bound_callback = BestBoundTimeCallback() + solver.best_bound_callback = best_bound_callback.new_best_bound + status = solver.Solve(model, solution_callback) + if status == cp_model.OPTIMAL: + self.assertLess( + time.time(), + max(best_bound_callback.last_time, solution_callback.last_time) + 5.0, + ) + if __name__ == "__main__": absltest.main() diff --git a/ortools/sat/restart_test.cc b/ortools/sat/restart_test.cc new file mode 100644 index 00000000000..fd319ed33d7 --- /dev/null +++ b/ortools/sat/restart_test.cc @@ -0,0 +1,86 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/restart.h" + +#include + +#include "absl/base/macros.h" +#include "gtest/gtest.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(SUnivTest, Luby) { + const int kSUniv[] = {1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, 1}; + for (int i = 0; i < ABSL_ARRAYSIZE(kSUniv); ++i) { + EXPECT_EQ(kSUniv[i], SUniv(i + 1)); + } +} + +TEST(RestartPolicyTest, BasicRunningAverageTest) { + Model model; + RestartPolicy* restart = model.GetOrCreate(); + SatParameters* params = model.GetOrCreate(); + + // The parameters for this test. + params->clear_restart_algorithms(); + params->add_restart_algorithms(SatParameters::DL_MOVING_AVERAGE_RESTART); + params->set_use_blocking_restart(false); + params->set_restart_dl_average_ratio(1.0); + params->set_restart_running_window_size(10); + restart->Reset(); + + EXPECT_FALSE(restart->ShouldRestart()); + int i = 0; + for (; i < 100; ++i) { + const int unused = 0; + const int decision_level = i; + if (restart->ShouldRestart()) break; + restart->OnConflict(unused, decision_level, unused); + } + + // Increasing decision levels, so as soon as we have 11 conflicts and 10 in + // the window, the window average is > global average. + EXPECT_EQ(i, 11); + + // Now the window is reset, but not the global average. So as soon as we have + // 10 conflicts, we restart. + i = 0; + for (; i < 100; ++i) { + const int unused = 0; + const int decision_level = 1000 - i; + if (restart->ShouldRestart()) break; + restart->OnConflict(unused, decision_level, unused); + } + EXPECT_EQ(i, 10); + + // If we call Reset() the global average is reaset, so if we have conflicts at + // a decreasing decision level, we never restart. + restart->Reset(); + i = 0; + for (; i < 1000; ++i) { + const int unused = 0; + const int decision_level = 1000 - i; + if (restart->ShouldRestart()) break; + restart->OnConflict(unused, decision_level, unused); + } + EXPECT_EQ(i, 1000); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/routing_cuts_test.cc b/ortools/sat/routing_cuts_test.cc new file mode 100644 index 00000000000..e2adb9da48f --- /dev/null +++ b/ortools/sat/routing_cuts_test.cc @@ -0,0 +1,422 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/routing_cuts.h" + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/graph/max_flow.h" +#include "ortools/sat/cuts.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/linear_constraint_manager.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +// Test on a simple tree: +// 3 +// / \ \ +// 1 0 5 +// / \ +// 2 4 +TEST(ExtractAllSubsetsFromForestTest, Basic) { + std::vector parents = {3, 3, 1, 3, 1, 3}; + + std::vector buffer; + std::vector> subsets; + ExtractAllSubsetsFromForest(parents, &buffer, &subsets); + + // Post order but we explore high number first. + // Alternatively, we could use unordered here, but the order is stable. + EXPECT_THAT(buffer, ElementsAre(5, 4, 2, 1, 0, 3)); + EXPECT_THAT(subsets, + ElementsAre(ElementsAre(5), ElementsAre(4), ElementsAre(2), + ElementsAre(4, 2, 1), ElementsAre(0), + ElementsAre(5, 4, 2, 1, 0, 3))); +} + +// +// 0 3 4 +// / \ | +// 1 2 5 +TEST(ExtractAllSubsetsFromForestTest, BasicForest) { + std::vector parents = {0, 0, 0, 3, 4, 4}; + + std::vector buffer; + std::vector> subsets; + ExtractAllSubsetsFromForest(parents, &buffer, &subsets); + + // Post order but we explore high number first. + // Alternatively, we could use unordered here, but the order is stable. + EXPECT_THAT(buffer, ElementsAre(2, 1, 0, 3, 5, 4)); + EXPECT_THAT(subsets, + ElementsAre(ElementsAre(2), ElementsAre(1), ElementsAre(2, 1, 0), + ElementsAre(3), ElementsAre(5), ElementsAre(5, 4))); +} + +TEST(ExtractAllSubsetsFromForestTest, Random) { + const int num_nodes = 20; + absl::BitGen random; + + // Create a random tree rooted at zero. + std::vector parents(num_nodes, 0); + for (int i = 2; i < num_nodes; ++i) { + parents[i] = absl::Uniform(random, 0, i); // in [0, i - 1]. + } + + std::vector buffer; + std::vector> subsets; + ExtractAllSubsetsFromForest(parents, &buffer, &subsets); + + // We don't test that we are exhaustive, but we check basic property. + std::vector in_subset(num_nodes, false); + for (const auto subset : subsets) { + for (const int n : subset) in_subset[n] = true; + + // There should be at most one out edge. + int root = -1; + for (const int n : subset) { + if (in_subset[parents[n]]) continue; + if (root != -1) EXPECT_EQ(parents[n], root); + root = parents[n]; + } + + // No node outside should point inside. + for (int n = 0; n < num_nodes; ++n) { + if (in_subset[n]) continue; + EXPECT_TRUE(!in_subset[parents[n]]); + } + + for (const int n : subset) in_subset[n] = false; + } +} + +TEST(SymmetrizeArcsTest, BasicTest) { + std::vector arcs{{.tail = 0, .head = 1, .lp_value = 0.5}, + {.tail = 2, .head = 0, .lp_value = 0.5}, + {.tail = 1, .head = 0, .lp_value = 0.5}}; + SymmetrizeArcs(&arcs); + EXPECT_THAT( + arcs, ElementsAre(ArcWithLpValue{.tail = 0, .head = 1, .lp_value = 1.0}, + ArcWithLpValue{.tail = 0, .head = 2, .lp_value = 0.5})); +} + +TEST(ComputeGomoryHuTreeTest, Random) { + absl::BitGen random; + + // Lets generate a random graph on a small number of nodes. + const int num_nodes = 10; + const int num_arcs = 100; + std::vector arcs; + for (int i = 0; i < num_arcs; ++i) { + const int tail = absl::Uniform(random, 0, num_nodes); + const int head = absl::Uniform(random, 0, num_nodes); + if (tail == head) continue; + const double lp_value = absl::Uniform(random, 0, 1); + arcs.push_back({tail, head, lp_value}); + } + + // Get all cut from Gomory-Hu tree. + const std::vector parents = ComputeGomoryHuTree(num_nodes, arcs); + std::vector buffer; + std::vector> subsets; + ExtractAllSubsetsFromForest(parents, &buffer, &subsets); + + // Compute the cost of entering (resp. leaving) each subset. + // TODO(user): We need the same scaling as in ComputeGomoryHu(), not super + // clean. We might want an integer input to the function, but ok for now. + std::vector in_subset(num_nodes, false); + std::vector out_costs(subsets.size(), 0); + std::vector in_costs(subsets.size(), 0); + for (int i = 0; i < subsets.size(); ++i) { + for (const int n : subsets[i]) in_subset[n] = true; + for (const auto& arc : arcs) { + if (in_subset[arc.tail] && !in_subset[arc.head]) { + out_costs[i] += std::round(1.0e6 * arc.lp_value); + } + if (!in_subset[arc.tail] && in_subset[arc.head]) { + in_costs[i] += std::round(1.0e6 * arc.lp_value); + } + } + for (const int n : subsets[i]) in_subset[n] = false; + } + + // We will test with an exhaustive comparison. We are in n ^ 3 ! + // For all (s,t) pair, get the actual max-flow on the scaled graph. + // Check than one of the cuts separate s and t, with this exact weight. + SimpleMaxFlow max_flow; + for (const auto& [tail, head, lp_value] : arcs) { + // TODO(user): the algo only seems to work on an undirected graph, or + // equivalently when we always have a reverse arc with the same weight. + // Note that you can see below that we compute "min" cut for the sum of + // outgoing + incoming arcs this way. + max_flow.AddArcWithCapacity(tail, head, std::round(1.0e6 * lp_value)); + max_flow.AddArcWithCapacity(head, tail, std::round(1.0e6 * lp_value)); + } + for (int s = 0; s < num_nodes; ++s) { + for (int t = s + 1; t < num_nodes; ++t) { + ASSERT_EQ(max_flow.Solve(s, t), SimpleMaxFlow::OPTIMAL); + const int64_t flow = max_flow.OptimalFlow(); + bool found = false; + for (int i = 0; i < subsets.size(); ++i) { + bool s_out = true; + bool t_out = true; + for (const int n : subsets[i]) { + if (n == s) s_out = false; + if (n == t) t_out = false; + } + if (!s_out && t_out && out_costs[i] + in_costs[i] == flow) found = true; + if (s_out && !t_out && in_costs[i] + out_costs[i] == flow) found = true; + if (found) break; + } + + // Debug. + if (!found) { + LOG(INFO) << s << " -> " << t << " flow= " << flow; + for (int i = 0; i < subsets.size(); ++i) { + bool s_out = true; + bool t_out = true; + for (const int n : subsets[i]) { + if (n == s) s_out = false; + if (n == t) t_out = false; + } + if (!s_out && t_out) { + LOG(INFO) << i << " out= " << out_costs[i] + in_costs[i]; + } + if (s_out && !t_out) { + LOG(INFO) << i << " in= " << in_costs[i] + out_costs[i]; + } + } + } + ASSERT_TRUE(found); + } + } +} + +TEST(CreateStronglyConnectedGraphCutGeneratorTest, BasicExample) { + Model model; + + // Lets create a simple square graph with arcs in both directions: + // + // 0 ---- 1 + // | | + // | | + // 2 ---- 3 + const int num_nodes = 4; + std::vector tails{0, 1, 1, 3, 3, 2, 2, 0}; + std::vector heads{1, 0, 3, 1, 2, 3, 0, 2}; + std::vector literals; + std::vector vars; + for (int i = 0; i < 2 * num_nodes; ++i) { + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + vars.push_back(model.Add(NewIntegerVariableFromLiteral(literals.back()))); + } + + CutGenerator generator = CreateStronglyConnectedGraphCutGenerator( + num_nodes, tails, heads, literals, &model); + + // Suppose only 0-1 and 2-3 are in the lp solution (values do not matter). + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(16, 0.0); + lp_values[vars[0]] = 0.5; + lp_values[vars[1]] = 0.5; + lp_values[vars[4]] = 1.0; + lp_values[vars[5]] = 1.0; + LinearConstraintManager manager(&model); + generator.generate_cuts(&manager); + + // We should get two cuts. + EXPECT_EQ(manager.num_cuts(), 2); + EXPECT_THAT(manager.AllConstraints().front().constraint.VarsAsSpan(), + ElementsAre(vars[3], vars[6])); + EXPECT_THAT(manager.AllConstraints().back().constraint.VarsAsSpan(), + ElementsAre(vars[2], vars[7])); +} + +TEST(CreateStronglyConnectedGraphCutGeneratorTest, AnotherExample) { + // This time, the graph is fully connected, but we still detect that {1, 2, 3} + // do not have enough outgoing flow: + // + // 0.5 + // 0 <--> 1 + // ^ | 0.5 + // 0.5 | | 1 and 2 ----> 1 + // v v + // 2 <--- 3 + // 1 + const int num_nodes = 4; + std::vector tails{0, 1, 0, 2, 1, 3, 2}; + std::vector heads{1, 0, 2, 0, 3, 2, 1}; + std::vector values{0.5, 0.0, 0.5, 0.0, 1.0, 1.0, 0.5}; + + Model model; + std::vector literals; + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(16, 0.0); + for (int i = 0; i < values.size(); ++i) { + literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); + lp_values[model.Add(NewIntegerVariableFromLiteral(literals.back()))] = + values[i]; + } + + CutGenerator generator = CreateStronglyConnectedGraphCutGenerator( + num_nodes, tails, heads, literals, &model); + + LinearConstraintManager manager(&model); + generator.generate_cuts(&manager); + + // The sets {2, 3} and {1, 2, 3} will generate cuts. + // However as an heuristic, we will wait another round to generate {1, 2 ,3}. + EXPECT_EQ(manager.num_cuts(), 1); + EXPECT_THAT(manager.AllConstraints().back().constraint.DebugString(), + ::testing::StartsWith("1 <= 1*X3 1*X6")); +} + +TEST(GenerateInterestingSubsetsTest, BasicExample) { + const int num_nodes = 6; + const std::vector> arcs = {{0, 5}, {2, 3}, {3, 4}}; + + // Note that the order is not important, but is currently fixed. + // This document the actual order. + std::vector subset_data; + std::vector> subsets; + GenerateInterestingSubsets(num_nodes, arcs, + /*stop_at_num_components=*/2, &subset_data, + &subsets); + EXPECT_THAT( + subsets, + ElementsAre(ElementsAre(1), ElementsAre(5), ElementsAre(0), + ElementsAre(5, 0), ElementsAre(3), ElementsAre(2), + ElementsAre(3, 2), ElementsAre(4), ElementsAre(3, 2, 4))); + + // We can call it more than once. + GenerateInterestingSubsets(num_nodes, arcs, + /*stop_at_num_components=*/2, &subset_data, + &subsets); + EXPECT_THAT( + subsets, + ElementsAre(ElementsAre(1), ElementsAre(5), ElementsAre(0), + ElementsAre(5, 0), ElementsAre(3), ElementsAre(2), + ElementsAre(3, 2), ElementsAre(4), ElementsAre(3, 2, 4))); +} + +TEST(CreateFlowCutGeneratorTest, BasicExample) { + // + // /---> 2 + // 0 ---> 1 ^ + // \---> 3 + // + // With a flow of 2 leaving 0 and a flow of 1 requested at 2 and 3. + // On each arc the flow <= max_flow * arc_indicator where max_flow = 2. + const int num_nodes = 4; + std::vector tails{0, 1, 1, 3}; + std::vector heads{1, 2, 3, 2}; + std::vector values{1.0, 0.5, 0.5, 0.0}; + + Model model; + std::vector capacities; + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(16, 0.0); + for (int i = 0; i < values.size(); ++i) { + AffineExpression expr; + expr.var = model.Add(NewIntegerVariable(0, 1)); + expr.coeff = 2; + expr.constant = 0; + capacities.emplace_back(expr); + lp_values[capacities.back().var] = values[i]; + } + + const auto get_flows = [](const std::vector& in_subset, + IntegerValue* min_incoming_flow, + IntegerValue* min_outgoing_flow) { + IntegerValue demand(0); + if (in_subset[0]) demand -= 2; + if (in_subset[2]) demand += 1; + if (in_subset[3]) demand += 1; + *min_incoming_flow = std::max(IntegerValue(0), demand); + *min_outgoing_flow = std::max(IntegerValue(0), -demand); + }; + const CutGenerator generator = CreateFlowCutGenerator( + num_nodes, tails, heads, capacities, get_flows, &model); + + LinearConstraintManager manager(&model); + generator.generate_cuts(&manager); + + // The sets {2} and {3} will generate incoming flow cuts. + EXPECT_EQ(manager.num_cuts(), 2); + EXPECT_THAT(manager.AllConstraints().front().constraint.DebugString(), + ::testing::StartsWith("1 <= 1*X2")); + EXPECT_THAT(manager.AllConstraints().back().constraint.DebugString(), + ::testing::StartsWith("1 <= 1*X1 1*X3")); +} + +TEST(CreateFlowCutGeneratorTest, WithMinusOneArcs) { + // 0 ---> 1 --> + // | + // \ --> + const int num_nodes = 2; + std::vector tails{0, 1, 1}; + std::vector heads{1, -1, -1}; + std::vector values{1.0, 0.5, 0.0}; + + Model model; + std::vector capacities; + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(16, 0.0); + for (int i = 0; i < values.size(); ++i) { + AffineExpression expr; + expr.var = model.Add(NewIntegerVariable(0, 1)); + expr.coeff = 2; + expr.constant = 0; + capacities.emplace_back(expr); + lp_values[capacities.back().var] = values[i]; + } + + const auto get_flows = [](const std::vector& in_subset, + IntegerValue* min_incoming_flow, + IntegerValue* min_outgoing_flow) { + IntegerValue demand(0); + if (in_subset[0]) demand -= 2; + *min_incoming_flow = std::max(IntegerValue(0), demand); + *min_outgoing_flow = std::max(IntegerValue(0), -demand); + }; + const CutGenerator generator = CreateFlowCutGenerator( + num_nodes, tails, heads, capacities, get_flows, &model); + + LinearConstraintManager manager(&model); + generator.generate_cuts(&manager); + + // We artificially put bad LP values so that {1} generate outgoing flow cut. + EXPECT_EQ(manager.num_cuts(), 1); + EXPECT_THAT(manager.AllConstraints().front().constraint.DebugString(), + ::testing::StartsWith("1 <= 1*X1 1*X2")); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/samples/BUILD.bazel b/ortools/sat/samples/BUILD.bazel index 47f1ef0c753..95de48b8098 100644 --- a/ortools/sat/samples/BUILD.bazel +++ b/ortools/sat/samples/BUILD.bazel @@ -11,7 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load(":code_samples.bzl", "code_sample_cc_py", "code_sample_java", "code_sample_py") +load( + ":code_samples.bzl", + "code_sample_cc_go_py", + "code_sample_cc_py", + "code_sample_go", + "code_sample_java", + "code_sample_py", +) code_sample_py(name = "all_different_except_zero_sample_sat") @@ -23,19 +30,21 @@ code_sample_cc_py(name = "assignment_task_sizes_sat") code_sample_cc_py(name = "assignment_teams_sat") -code_sample_cc_py(name = "assumptions_sample_sat") +code_sample_cc_go_py(name = "assumptions_sample_sat") -code_sample_cc_py(name = "binpacking_problem_sat") +code_sample_cc_go_py(name = "binpacking_problem_sat") code_sample_py(name = "bin_packing_sat") code_sample_py(name = "bool_and_int_var_product_sample_sat") -code_sample_cc_py(name = "bool_or_sample_sat") +code_sample_cc_go_py(name = "bool_or_sample_sat") + +code_sample_go(name = "boolean_product_sample_sat") code_sample_py(name = "boolean_product_sample_sat") -code_sample_cc_py(name = "channeling_sample_sat") +code_sample_cc_go_py(name = "channeling_sample_sat") code_sample_cc_py(name = "clone_model_sample_sat") @@ -45,55 +54,55 @@ code_sample_cc_py(name = "cp_sat_example") code_sample_py(name = "cumulative_variable_profile_sample_sat") -code_sample_cc_py(name = "earliness_tardiness_cost_sample_sat") +code_sample_cc_go_py(name = "earliness_tardiness_cost_sample_sat") code_sample_py(name = "index_first_boolvar_true_sample_sat") code_sample_py(name = "interval_relations_sample_sat") -code_sample_cc_py(name = "interval_sample_sat") +code_sample_cc_go_py(name = "interval_sample_sat") code_sample_cc_py(name = "minimal_jobshop_sat") -code_sample_cc_py(name = "literal_sample_sat") +code_sample_cc_go_py(name = "literal_sample_sat") code_sample_cc_py(name = "multiple_knapsack_sat") code_sample_cc_py(name = "non_linear_sat") -code_sample_cc_py(name = "nqueens_sat") +code_sample_cc_go_py(name = "no_overlap_sample_sat") -code_sample_cc_py(name = "nurses_sat") +code_sample_cc_go_py(name = "nqueens_sat") -code_sample_cc_py(name = "optional_interval_sample_sat") +code_sample_cc_go_py(name = "nurses_sat") -code_sample_cc_py(name = "no_overlap_sample_sat") +code_sample_cc_go_py(name = "optional_interval_sample_sat") code_sample_py(name = "overlapping_intervals_sample_sat") -code_sample_cc_py(name = "rabbits_and_pheasants_sat") +code_sample_cc_go_py(name = "rabbits_and_pheasants_sat") code_sample_py(name = "ranking_circuit_sample_sat") -code_sample_cc_py(name = "ranking_sample_sat") +code_sample_cc_go_py(name = "ranking_sample_sat") -code_sample_cc_py(name = "reified_sample_sat") +code_sample_cc_go_py(name = "reified_sample_sat") code_sample_cc_py(name = "schedule_requests_sat") code_sample_py(name = "scheduling_with_calendar_sample_sat") -code_sample_cc_py(name = "simple_sat_program") +code_sample_cc_go_py(name = "search_for_all_solutions_sample_sat") -code_sample_cc_py(name = "search_for_all_solutions_sample_sat") +code_sample_cc_go_py(name = "simple_sat_program") -code_sample_cc_py(name = "solution_hinting_sample_sat") +code_sample_cc_go_py(name = "solution_hinting_sample_sat") -code_sample_cc_py(name = "solve_and_print_intermediate_solutions_sample_sat") +code_sample_cc_go_py(name = "solve_and_print_intermediate_solutions_sample_sat") -code_sample_cc_py(name = "step_function_sample_sat") +code_sample_cc_go_py(name = "solve_with_time_limit_sample_sat") -code_sample_cc_py(name = "solve_with_time_limit_sample_sat") +code_sample_cc_go_py(name = "step_function_sample_sat") code_sample_cc_py(name = "stop_after_n_solutions_sample_sat") diff --git a/ortools/sat/samples/assumptions_sample_sat.go b/ortools/sat/samples/assumptions_sample_sat.go index 56f700c5d99..c8b291b6883 100644 --- a/ortools/sat/samples/assumptions_sample_sat.go +++ b/ortools/sat/samples/assumptions_sample_sat.go @@ -17,9 +17,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func assumptionsSampleSat() error { @@ -62,6 +62,6 @@ func assumptionsSampleSat() error { func main() { if err := assumptionsSampleSat(); err != nil { - glog.Exitf("assumptionsSampleSat returned with error: %v", err) + log.Exitf("assumptionsSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/binpacking_problem_sat.go b/ortools/sat/samples/binpacking_problem_sat.go index 5ba42154794..972d7a6d6cd 100644 --- a/ortools/sat/samples/binpacking_problem_sat.go +++ b/ortools/sat/samples/binpacking_problem_sat.go @@ -18,8 +18,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const ( @@ -30,7 +30,7 @@ const ( ) type item struct { - Cost, Copies int64_t + Cost, Copies int64 } func binpackingProblemSat() error { @@ -116,6 +116,6 @@ func binpackingProblemSat() error { func main() { if err := binpackingProblemSat(); err != nil { - glog.Exitf("binpackingProblemSat returned with error: %v", err) + log.Exitf("binpackingProblemSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/bool_or_sample_sat.go b/ortools/sat/samples/bool_or_sample_sat.go index 0af18fde136..14ef894e981 100644 --- a/ortools/sat/samples/bool_or_sample_sat.go +++ b/ortools/sat/samples/bool_or_sample_sat.go @@ -15,7 +15,7 @@ package main import ( - "ortools/sat/go/cpmodel" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func boolOrSampleSat() { diff --git a/ortools/sat/samples/boolean_product_sample_sat.go b/ortools/sat/samples/boolean_product_sample_sat.go index 874294ebd8c..d79ae52d8b7 100644 --- a/ortools/sat/samples/boolean_product_sample_sat.go +++ b/ortools/sat/samples/boolean_product_sample_sat.go @@ -17,10 +17,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func booleanProductSample() error { @@ -44,11 +44,11 @@ func booleanProductSample() error { } // Set `fill_additional_solutions_in_response` and `enumerate_all_solutions` to true so // the solver returns all solutions found. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(4), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -68,6 +68,6 @@ func booleanProductSample() error { func main() { err := booleanProductSample() if err != nil { - glog.Exitf("booleanProductSample returned with error: %v", err) + log.Exitf("booleanProductSample returned with error: %v", err) } } diff --git a/ortools/sat/samples/channeling_sample_sat.go b/ortools/sat/samples/channeling_sample_sat.go index 88db3279f9f..33d3d52df26 100644 --- a/ortools/sat/samples/channeling_sample_sat.go +++ b/ortools/sat/samples/channeling_sample_sat.go @@ -17,11 +17,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func channelingSampleSat() error { @@ -53,12 +53,12 @@ func channelingSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(11), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -76,6 +76,6 @@ func channelingSampleSat() error { func main() { if err := channelingSampleSat(); err != nil { - glog.Exitf("channelingSampleSat returned with error: %v", err) + log.Exitf("channelingSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/code_samples.bzl b/ortools/sat/samples/code_samples.bzl index 92f9ebde9d5..48764a31c45 100644 --- a/ortools/sat/samples/code_samples.bzl +++ b/ortools/sat/samples/code_samples.bzl @@ -13,11 +13,14 @@ """Helper macro to compile and test code samples.""" +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("@pip_deps//:requirements.bzl", "requirement") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_test") +load("@rules_java//java:defs.bzl", "java_test") load("@rules_python//python:defs.bzl", "py_binary", "py_test") def code_sample_cc(name): - native.cc_binary( + cc_binary( name = name + "_cc", srcs = [name + ".cc"], deps = [ @@ -28,7 +31,7 @@ def code_sample_cc(name): ], ) - native.cc_test( + cc_test( name = name + "_cc_test", size = "small", srcs = [name + ".cc"], @@ -41,6 +44,20 @@ def code_sample_cc(name): ], ) +def code_sample_go(name): + go_test( + name = name + "_go_test", + size = "small", + srcs = [name + ".go"], + deps = [ + "//ortools/sat:cp_model_go_proto", + "//ortools/sat:sat_parameters_go_proto", + "//ortools/sat/go/cpmodel", + "@com_github_golang_glog//:glog", + "@org_golang_google_protobuf//proto", + ], + ) + def code_sample_py(name): py_binary( name = name + "_py3", @@ -74,12 +91,17 @@ def code_sample_py(name): srcs_version = "PY3", ) +def code_sample_cc_go_py(name): + code_sample_cc(name = name) + code_sample_go(name = name) + code_sample_py(name = name) + def code_sample_cc_py(name): code_sample_cc(name = name) code_sample_py(name = name) def code_sample_java(name): - native.java_test( + java_test( name = name + "_java_test", size = "small", srcs = [name + ".java"], diff --git a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go index 43f7c2f1812..c8bdab2b25c 100644 --- a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go +++ b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go @@ -18,11 +18,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) const ( @@ -62,12 +62,12 @@ func earlinessTardinessCostSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -85,6 +85,6 @@ func earlinessTardinessCostSampleSat() error { func main() { if err := earlinessTardinessCostSampleSat(); err != nil { - glog.Exitf("earlinessTardinessCostSampleSat returned with error: %v", err) + log.Exitf("earlinessTardinessCostSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/interval_sample_sat.go b/ortools/sat/samples/interval_sample_sat.go index e0e2776631b..24d89f677bb 100644 --- a/ortools/sat/samples/interval_sample_sat.go +++ b/ortools/sat/samples/interval_sample_sat.go @@ -17,8 +17,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const horizon = 100 @@ -53,6 +53,6 @@ func intervalSampleSat() error { func main() { if err := intervalSampleSat(); err != nil { - glog.Exitf("intervalSampleSat returned with error: %v", err) + log.Exitf("intervalSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/literal_sample_sat.go b/ortools/sat/samples/literal_sample_sat.go index be171a9161c..6018b1911ee 100644 --- a/ortools/sat/samples/literal_sample_sat.go +++ b/ortools/sat/samples/literal_sample_sat.go @@ -15,8 +15,8 @@ package main import ( - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func literalSampleSat() { @@ -25,7 +25,7 @@ func literalSampleSat() { x := model.NewBoolVar().WithName("x") notX := x.Not() - glog.Infof("x = %d, x.Not() = %d", x.Index(), notX.Index()) + log.Infof("x = %d, x.Not() = %d", x.Index(), notX.Index()) } func main() { diff --git a/ortools/sat/samples/no_overlap_sample_sat.go b/ortools/sat/samples/no_overlap_sample_sat.go index ce5fccab1ee..cbe175f404d 100644 --- a/ortools/sat/samples/no_overlap_sample_sat.go +++ b/ortools/sat/samples/no_overlap_sample_sat.go @@ -17,9 +17,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const horizon = 21 // 3 weeks @@ -85,6 +85,6 @@ func noOverlapSampleSat() error { func main() { if err := noOverlapSampleSat(); err != nil { - glog.Exitf("noOverlapSampleSat returned with error: %v", err) + log.Exitf("noOverlapSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/nqueens_sat.go b/ortools/sat/samples/nqueens_sat.go index b2a7b8323c0..5694bfd7eaf 100644 --- a/ortools/sat/samples/nqueens_sat.go +++ b/ortools/sat/samples/nqueens_sat.go @@ -17,8 +17,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const boardSize = 8 @@ -30,7 +30,7 @@ func nQueensSat() error { // of the board. The value of each variable is the row that the queen is in. var queenRows []cpmodel.LinearArgument for i := 0; i < boardSize; i++ { - queenRows = append(queenRows, model.NewIntVar(0, int64_t(boardSize-1))) + queenRows = append(queenRows, model.NewIntVar(0, int64(boardSize-1))) } // The following sets the constraint that all queens are in different rows. @@ -40,8 +40,8 @@ func nQueensSat() error { var diag1 []cpmodel.LinearArgument var diag2 []cpmodel.LinearArgument for i := 0; i < boardSize; i++ { - diag1 = append(diag1, cpmodel.NewConstant(int64_t(i)).Add(queenRows[i])) - diag2 = append(diag2, cpmodel.NewConstant(int64_t(-i)).Add(queenRows[i])) + diag1 = append(diag1, cpmodel.NewConstant(int64(i)).Add(queenRows[i])) + diag2 = append(diag2, cpmodel.NewConstant(int64(-i)).Add(queenRows[i])) } model.AddAllDifferent(diag1...) model.AddAllDifferent(diag2...) @@ -59,7 +59,7 @@ func nQueensSat() error { fmt.Printf("Objective: %v\n", response.GetObjectiveValue()) fmt.Printf("Solution:\n") - for i := int64_t(0); i < boardSize; i++ { + for i := int64(0); i < boardSize; i++ { for j := 0; j < boardSize; j++ { if cpmodel.SolutionIntegerValue(response, queenRows[j]) == i { fmt.Print("Q") @@ -76,6 +76,6 @@ func nQueensSat() error { func main() { err := nQueensSat() if err != nil { - glog.Exitf("nQueensSat returned with error: %v", err) + log.Exitf("nQueensSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/nurses_sat.go b/ortools/sat/samples/nurses_sat.go index 59ffa83e8a7..8b1df4aa335 100644 --- a/ortools/sat/samples/nurses_sat.go +++ b/ortools/sat/samples/nurses_sat.go @@ -17,8 +17,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const ( @@ -121,6 +121,6 @@ func nursesSat() error { func main() { if err := nursesSat(); err != nil { - glog.Exitf("nursesSat returned with error: %v", err) + log.Exitf("nursesSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/optional_interval_sample_sat.go b/ortools/sat/samples/optional_interval_sample_sat.go index b586dce62dd..1dffb2d88ea 100644 --- a/ortools/sat/samples/optional_interval_sample_sat.go +++ b/ortools/sat/samples/optional_interval_sample_sat.go @@ -18,8 +18,8 @@ package main import ( "fmt" - "github.com/golang/glog" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) const horizon = 100 @@ -51,6 +51,6 @@ func optionalIntervalSampleSat() error { func main() { if err := optionalIntervalSampleSat(); err != nil { - glog.Exitf("optionalIntervalSampleSat returned with error: %v", err) + log.Exitf("optionalIntervalSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/rabbits_and_pheasants_sat.go b/ortools/sat/samples/rabbits_and_pheasants_sat.go index f00a76adb65..a49d688e10f 100644 --- a/ortools/sat/samples/rabbits_and_pheasants_sat.go +++ b/ortools/sat/samples/rabbits_and_pheasants_sat.go @@ -18,9 +18,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const numAnimals = 20 @@ -58,6 +58,6 @@ func rabbitsAndPheasants() error { func main() { if err := rabbitsAndPheasants(); err != nil { - glog.Exitf("rabbitsAndPheasants returned with error: %v", err) + log.Exitf("rabbitsAndPheasants returned with error: %v", err) } } diff --git a/ortools/sat/samples/ranking_sample_sat.go b/ortools/sat/samples/ranking_sample_sat.go index cb1a3989840..a4bf692340e 100644 --- a/ortools/sat/samples/ranking_sample_sat.go +++ b/ortools/sat/samples/ranking_sample_sat.go @@ -17,9 +17,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) const ( @@ -89,7 +89,7 @@ func rankingSampleSat() error { for t := 0; t < numTasks; t++ { start := model.NewIntVarFromDomain(horizon) - duration := cpmodel.NewConstant(int64_t(t + 1)) + duration := cpmodel.NewConstant(int64(t + 1)) end := model.NewIntVarFromDomain(horizon) var presence cpmodel.BoolVar if t < numTasks/2 { @@ -160,6 +160,6 @@ func rankingSampleSat() error { func main() { if err := rankingSampleSat(); err != nil { - glog.Exitf("rankingSampleSat returned with error: %v", err) + log.Exitf("rankingSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/reified_sample_sat.go b/ortools/sat/samples/reified_sample_sat.go index 63a77dba72e..00ecb4313e5 100644 --- a/ortools/sat/samples/reified_sample_sat.go +++ b/ortools/sat/samples/reified_sample_sat.go @@ -15,7 +15,7 @@ package main import ( - "ortools/sat/go/cpmodel" + "github.com/google/or-tools/ortools/sat/go/cpmodel" ) func reifiedSampleSat() { diff --git a/ortools/sat/samples/search_for_all_solutions_sample_sat.go b/ortools/sat/samples/search_for_all_solutions_sample_sat.go index 17c4e3f84d4..6e43abbe659 100644 --- a/ortools/sat/samples/search_for_all_solutions_sample_sat.go +++ b/ortools/sat/samples/search_for_all_solutions_sample_sat.go @@ -18,10 +18,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func searchForAllSolutionsSampleSat() error { @@ -41,11 +41,11 @@ func searchForAllSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so each feasible solution cannot // be printed while solving. However, the CP Solver can return all of the enumerated solutions // in the response by setting the following parameters. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ EnumerateAllSolutions: proto.Bool(true), FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(27), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -63,6 +63,6 @@ func searchForAllSolutionsSampleSat() error { func main() { if err := searchForAllSolutionsSampleSat(); err != nil { - glog.Exitf("searchForAllSolutionsSampleSat returned with error: %v", err) + log.Exitf("searchForAllSolutionsSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/simple_sat_program.go b/ortools/sat/samples/simple_sat_program.go index 0430d34fb49..8cddcb7efeb 100644 --- a/ortools/sat/samples/simple_sat_program.go +++ b/ortools/sat/samples/simple_sat_program.go @@ -17,9 +17,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func simpleSatProgram() error { @@ -55,6 +55,6 @@ func simpleSatProgram() error { func main() { if err := simpleSatProgram(); err != nil { - glog.Exitf("simpleSatProgram returned with error: %v", err) + log.Exitf("simpleSatProgram returned with error: %v", err) } } diff --git a/ortools/sat/samples/solution_hinting_sample_sat.go b/ortools/sat/samples/solution_hinting_sample_sat.go index 59b2766c2be..d0d4379fe80 100644 --- a/ortools/sat/samples/solution_hinting_sample_sat.go +++ b/ortools/sat/samples/solution_hinting_sample_sat.go @@ -17,9 +17,9 @@ package main import ( "fmt" - "github.com/golang/glog" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func solutionHintingSampleSat() error { @@ -60,6 +60,6 @@ func solutionHintingSampleSat() error { func main() { if err := solutionHintingSampleSat(); err != nil { - glog.Exitf("solutionHintingSampleSat returned with error: %v", err) + log.Exitf("solutionHintingSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go index f1a0d86507d..5ed02d38bc9 100644 --- a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go +++ b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go @@ -17,10 +17,10 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func solveAndPrintIntermediateSolutionsSampleSat() error { @@ -44,10 +44,10 @@ func solveAndPrintIntermediateSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so intermediate solutions // cannot be printed while solving. However, the CP-SAT solver does allow for returning // the intermediate solutions found while solving in the response. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(10), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -65,6 +65,6 @@ func solveAndPrintIntermediateSolutionsSampleSat() error { func main() { if err := solveAndPrintIntermediateSolutionsSampleSat(); err != nil { - glog.Exitf("solveAndPrintIntermediateSolutionsSampleSat returned with error: %v", err) + log.Exitf("solveAndPrintIntermediateSolutionsSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/solve_with_time_limit_sample_sat.go b/ortools/sat/samples/solve_with_time_limit_sample_sat.go index 4fe04c95a8d..5a5bd6cb1b3 100644 --- a/ortools/sat/samples/solve_with_time_limit_sample_sat.go +++ b/ortools/sat/samples/solve_with_time_limit_sample_sat.go @@ -17,11 +17,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func solveWithTimeLimitSampleSat() error { @@ -40,9 +40,9 @@ func solveWithTimeLimitSampleSat() error { } // Sets a time limit of 10 seconds. - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ MaxTimeInSeconds: proto.Float64(10.0), - }.Build() + } // Solve. response, err := cpmodel.SolveCpModelWithParameters(m, params) @@ -63,6 +63,6 @@ func solveWithTimeLimitSampleSat() error { func main() { if err := solveWithTimeLimitSampleSat(); err != nil { - glog.Exitf("solveWithTimeLimitSampleSat returned with error: %v", err) + log.Exitf("solveWithTimeLimitSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/samples/step_function_sample_sat.go b/ortools/sat/samples/step_function_sample_sat.go index 5fb4e66f9f3..8c927f1e751 100644 --- a/ortools/sat/samples/step_function_sample_sat.go +++ b/ortools/sat/samples/step_function_sample_sat.go @@ -17,11 +17,11 @@ package main import ( "fmt" - "github.com/golang/glog" - "golang/protobuf/v2/proto/proto" - cmpb "ortools/sat/cp_model_go_proto" - "ortools/sat/go/cpmodel" - sppb "ortools/sat/sat_parameters_go_proto" + log "github.com/golang/glog" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" + "google.golang.org/protobuf/proto" ) func stepFunctionSampleSat() error { @@ -71,12 +71,12 @@ func stepFunctionSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := sppb.SatParameters_builder{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), SearchBranching: sppb.SatParameters_FIXED_SEARCH.Enum(), - }.Build() + } response, err := cpmodel.SolveCpModelWithParameters(m, params) if err != nil { return fmt.Errorf("failed to solve the model: %w", err) @@ -94,6 +94,6 @@ func stepFunctionSampleSat() error { func main() { if err := stepFunctionSampleSat(); err != nil { - glog.Exitf("stepFunctionSampleSat returned with error: %v", err) + log.Exitf("stepFunctionSampleSat returned with error: %v", err) } } diff --git a/ortools/sat/sat_base_test.cc b/ortools/sat/sat_base_test.cc new file mode 100644 index 00000000000..391efb1d092 --- /dev/null +++ b/ortools/sat/sat_base_test.cc @@ -0,0 +1,74 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/sat_base.h" + +#include + +#include "gtest/gtest.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(BooleanVariableTest, Api) { + BooleanVariable var1(1); + BooleanVariable var2(2); + BooleanVariable var3(2); + EXPECT_NE(var1, var2); + EXPECT_EQ(var2, var3); +} + +TEST(LiteralTest, Api) { + BooleanVariable var1(1); + BooleanVariable var2(2); + Literal l1(var1, true); + Literal l2(var2, false); + Literal l3 = l2.Negated(); + EXPECT_EQ(l1.Variable(), var1); + EXPECT_EQ(l2.Variable(), var2); + EXPECT_EQ(l3.Variable(), var2); + EXPECT_TRUE(l1.IsPositive()); + EXPECT_TRUE(l2.IsNegative()); + EXPECT_TRUE(l3.IsPositive()); +} + +TEST(VariablesAssignmentTest, Api) { + BooleanVariable var0(0); + BooleanVariable var1(1); + BooleanVariable var2(2); + + VariablesAssignment assignment; + assignment.Resize(3); + assignment.AssignFromTrueLiteral(Literal(var0, true)); + assignment.AssignFromTrueLiteral(Literal(var1, false)); + + EXPECT_TRUE(assignment.LiteralIsTrue(Literal(var0, true))); + EXPECT_TRUE(assignment.LiteralIsFalse(Literal(var0, false))); + EXPECT_TRUE(assignment.LiteralIsTrue(Literal(var1, false))); + EXPECT_FALSE(assignment.VariableIsAssigned(var2)); + + assignment.UnassignLiteral(Literal(var0, true)); + EXPECT_FALSE(assignment.VariableIsAssigned(var0)); + + assignment.AssignFromTrueLiteral(Literal(var2, false)); + EXPECT_TRUE(assignment.LiteralIsTrue(Literal(var2, false))); + EXPECT_FALSE(assignment.LiteralIsTrue(Literal(var2, true))); + EXPECT_TRUE(assignment.LiteralIsFalse(Literal(var2, true))); + EXPECT_FALSE(assignment.LiteralIsFalse(Literal(var2, false))); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/sat_inprocessing_test.cc b/ortools/sat/sat_inprocessing_test.cc new file mode 100644 index 00000000000..291ca0aff97 --- /dev/null +++ b/ortools/sat/sat_inprocessing_test.cc @@ -0,0 +1,287 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/sat_inprocessing.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/clause.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(InprocessingTest, ClauseCleanupWithFixedVariables) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* inprocessing = model.GetOrCreate(); + + // Lets add some clauses. + sat_solver->SetNumVariables(100); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +2, +3, +4}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, -2, -3, +5}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, -2, -3, +1, +1}))); + + // Nothing fixed, we don't even look at the clause. + const bool log_info = true; + EXPECT_TRUE(inprocessing->DetectEquivalencesAndStamp(false, log_info)); + EXPECT_TRUE(inprocessing->RemoveFixedAndEquivalentVariables(log_info)); + { + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 3); + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({+2, -2, -3, +1, +1})); + } + + // Lets fix 3. + CHECK(sat_solver->AddUnitClause(Literal(+3))); + EXPECT_TRUE(sat_solver->FinishPropagation()); + EXPECT_TRUE(inprocessing->DetectEquivalencesAndStamp(false, log_info)); + EXPECT_TRUE(inprocessing->RemoveFixedAndEquivalentVariables(log_info)); + { + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 3); + EXPECT_EQ(all_clauses[0]->AsSpan(), Literals({})); // +3 true. + EXPECT_EQ(all_clauses[1]->AsSpan(), Literals({+1, -2, +5})); + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({})); // trivially true. + } +} + +TEST(InprocessingTest, ClauseCleanupWithEquivalence) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* implication_graph = model.GetOrCreate(); + auto* inprocessing = model.GetOrCreate(); + + // Lets add some clauses. + sat_solver->SetNumVariables(100); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +2, +5, +4}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, -2, -3, +5}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +6, -3, +1, +1}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +6, -3, +1, -5}))); + + // Lets make 3 and 5 equivalent. + implication_graph->AddBinaryClause(Literal(-3), Literal(+5)); + implication_graph->AddBinaryClause(Literal(+3), Literal(-5)); + + const bool log_info = true; + EXPECT_TRUE(inprocessing->DetectEquivalencesAndStamp(false, log_info)); + EXPECT_TRUE(inprocessing->RemoveFixedAndEquivalentVariables(log_info)); + { + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 4); + EXPECT_EQ(all_clauses[0]->AsSpan(), Literals({+1, +2, +3, +4})); + EXPECT_EQ(all_clauses[1]->AsSpan(), Literals({})); + EXPECT_EQ(all_clauses[3]->AsSpan(), Literals({+2, +6, -3, +1})); + + // Note that the +1 +1 is not simplified because this clause do not + // need to be rewritten otherwise and we assume initial simplification. + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({+2, +6, -3, +1, +1})); + } +} + +TEST(InprocessingTest, ClauseSubsumptionAndStrengthening) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* inprocessing = model.GetOrCreate(); + + // Lets add some clauses. + // Note that the order currently matter for what is left. + // + // Note that currently the binary clauses are not reprocessed. + // TODO(user): Maybe we should so that we always end up with a reduced set. + sat_solver->SetNumVariables(100); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +3, +2}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +2, -3}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +3, +2}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, -2, -3}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +6, -3, +1, +1}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({-3, +6, +2, +1, -5}))); + + const bool log_info = true; + EXPECT_TRUE(inprocessing->DetectEquivalencesAndStamp(false, log_info)); + EXPECT_TRUE(inprocessing->SubsumeAndStrenghtenRound(log_info)); + { + // This function remove empty clauses. + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + + // Depending on the order in which clauses are processed (which can + // change as we rely on std::sort()), we have a few cases. + if (all_clauses.size() == 1) { + EXPECT_EQ(all_clauses[0]->AsSpan(), Literals({+1, +2, -3})); + + // We added {+1, +2} and {+1, -3} there. + // TODO(user): make sure we don't add twice the implications. + auto* implication_graph = model.GetOrCreate(); + EXPECT_EQ(implication_graph->num_implications(), 6); + EXPECT_EQ(implication_graph->Implications(Literal(-1)).size(), 3); + EXPECT_THAT(implication_graph->Implications(Literal(-1)), + ::testing::UnorderedElementsAre(Literal(+2), Literal(+2), + Literal(-3))); + } else { + EXPECT_GE(all_clauses.size(), 3); + EXPECT_LE(all_clauses.size(), 4); + EXPECT_EQ(all_clauses[0]->AsSpan(), Literals({+1, +3, +2})); + EXPECT_EQ(all_clauses[1]->AsSpan(), Literals({+1, -2, -3})); + + // Depending on the implication added, we don't get the same clauses. + auto* implication_graph = model.GetOrCreate(); + EXPECT_EQ(implication_graph->num_implications(), 2); + EXPECT_EQ(implication_graph->Implications(Literal(-1)).size(), 1); + if (implication_graph->Implications(Literal(-1))[0] == Literal(+2)) { + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({+2, +6, +1, +1})); + if (all_clauses.size() == 4) { + EXPECT_EQ(all_clauses[3]->AsSpan(), Literals({+6, +2, +1, -5})); + } + } else { + EXPECT_EQ(implication_graph->Implications(Literal(-1))[0], Literal(-3)); + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({+6, -3, +1, +1})); + if (all_clauses.size() == 4) { + EXPECT_EQ(all_clauses[3]->AsSpan(), Literals({-3, +6, +1, -5})); + } + } + } + } +} + +TEST(StampingSimplifierTest, StampConstruction) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* implication_graph = model.GetOrCreate(); + auto* simplifier = model.GetOrCreate(); + + // Lets add some clauses. + // Note that the order currently matter for what is left. + sat_solver->SetNumVariables(100); + implication_graph->AddImplication(Literal(+1), Literal(+2)); + implication_graph->AddImplication(Literal(+1), Literal(+3)); + implication_graph->AddImplication(Literal(+1), Literal(+4)); + implication_graph->AddImplication(Literal(+2), Literal(+5)); + implication_graph->AddImplication(Literal(+2), Literal(+6)); + implication_graph->AddImplication(Literal(+3), Literal(+7)); + implication_graph->AddImplication(Literal(+4), Literal(+6)); + + EXPECT_TRUE(implication_graph->DetectEquivalences(true)); + + // Lets test some implications. + simplifier->SampleTreeAndFillParent(); + simplifier->ComputeStamps(); + EXPECT_TRUE(simplifier->ImplicationIsInTree(Literal(+1), Literal(+2))); + EXPECT_TRUE(simplifier->ImplicationIsInTree(Literal(+1), Literal(+5))); + EXPECT_TRUE(simplifier->ImplicationIsInTree(Literal(+1), Literal(+6))); + EXPECT_TRUE(simplifier->ImplicationIsInTree(Literal(+1), Literal(+7))); + EXPECT_TRUE(simplifier->ImplicationIsInTree(Literal(-7), Literal(-3))); +} + +TEST(StampingSimplifierTest, BasicSimplification) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* implication_graph = model.GetOrCreate(); + auto* simplifier = model.GetOrCreate(); + + // Lets add some clauses. + // Note that the order currently matter for what is left. + sat_solver->SetNumVariables(100); + implication_graph->AddImplication(Literal(+1), Literal(+2)); + implication_graph->AddImplication(Literal(+1), Literal(+3)); + implication_graph->AddImplication(Literal(+1), Literal(+4)); + implication_graph->AddImplication(Literal(+2), Literal(+5)); + implication_graph->AddImplication(Literal(+2), Literal(+6)); + implication_graph->AddImplication(Literal(+3), Literal(+7)); + implication_graph->AddImplication(Literal(+4), Literal(+6)); + + EXPECT_TRUE(implication_graph->DetectEquivalences(true)); + + // Lets add some clause that should be simplifiable + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +7, +8, +9}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, -6, +8, +9}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({-3, -7, +8, +9}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({-3, +7, +8, +9}))); + + // Lets test some implications. + EXPECT_TRUE(simplifier->DoOneRound(/*log_info=*/true)); + + // Results. I cover all 4 possibilities, 2 strenghtening for clause 0 and 2, + // one subsumption for clause 3 and nothing for clause 1. + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 4); + EXPECT_EQ(all_clauses[0]->AsSpan(), Literals({+7, +8, +9})); + EXPECT_EQ(all_clauses[1]->AsSpan(), Literals({+1, -6, +8, +9})); + EXPECT_EQ(all_clauses[2]->AsSpan(), Literals({-3, +8, +9})); + EXPECT_EQ(all_clauses[3]->AsSpan(), Literals({})); +} + +TEST(BlockedClauseSimplifierTest, BasicSimplification) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* implication_graph = model.GetOrCreate(); + auto* simplifier = model.GetOrCreate(); + + // Lets add some clauses. + // Note that the order currently matter for what is left. + sat_solver->SetNumVariables(100); + implication_graph->AddImplication(Literal(+1), Literal(-7)); + implication_graph->AddImplication(Literal(+1), Literal(-8)); + implication_graph->AddImplication(Literal(+1), Literal(-9)); + + // Lets add some clause that should be blocked + EXPECT_TRUE(clause_manager->AddClause(Literals({-1, +7, -8, +9}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +7, +8, +9}))); + + simplifier->DoOneRound(/*log_info=*/true); + + clause_manager->DeleteRemovedClauses(); + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 0); +} + +TEST(BoundedVariableEliminationTest, BasicSimplification) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* clause_manager = model.GetOrCreate(); + auto* simplifier = model.GetOrCreate(); + + // Lets add some clauses. + sat_solver->SetNumVariables(100); + EXPECT_TRUE(clause_manager->AddClause(Literals({+1, +2, +3, +7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+3, +4, +5, +7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({-1, +4, +5, -7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+3, -2, +5, -7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +4, -3, -7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +4, -5, -7}))); + EXPECT_TRUE(clause_manager->AddClause(Literals({+2, +3, -4, -7}))); + + simplifier->DoOneRound(/*log_info=*/true); + + // The problem is so simple that everyting should be simplified. + clause_manager->DeleteRemovedClauses(); + const auto& all_clauses = clause_manager->AllClausesInCreationOrder(); + EXPECT_EQ(all_clauses.size(), 0); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index c98489d3a19..7527fa6481b 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -15,15 +15,15 @@ syntax = "proto2"; package operations_research.sat; +option csharp_namespace = "Google.OrTools.Sat"; +option go_package = "github.com/google/or-tools/ortools/sat/proto/satparameters"; option java_package = "com.google.ortools.sat"; option java_multiple_files = true; -option csharp_namespace = "Google.OrTools.Sat"; - // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 296 +// NEXT TAG: 300 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -580,6 +580,10 @@ message SatParameters { // max-sat. We also minimize problem clauses and not just the learned clause // that we keep forever like in the paper. optional double inprocessing_minimization_dtime = 275 [default = 1.0]; + optional bool inprocessing_minimization_use_conflict_analysis = 297 + [default = true]; + optional bool inprocessing_minimization_use_all_orderings = 298 + [default = false]; // ========================================================================== // Multithread @@ -871,6 +875,8 @@ message SatParameters { optional bool use_area_energetic_reasoning_in_no_overlap_2d = 271 [default = false]; + optional bool use_try_edge_reasoning_in_no_overlap_2d = 299 [default = false]; + // If the number of pairs to look is below this threshold, do an extra step of // propagation in the no_overlap_2d constraint by looking at all pairs of // intervals. @@ -1213,7 +1219,10 @@ message SatParameters { // Turns on neighborhood generator based on local branching LP. Based on Huang // et al., "Local Branching Relaxation Heuristics for Integer Linear // Programs", 2023. - optional bool use_lb_relax_lns = 255 [default = false]; + optional bool use_lb_relax_lns = 255 [default = true]; + + // Only use lb-relax if we have at least that many workers. + optional int32 lb_relax_num_workers_threshold = 296 [default = 16]; // Rounding method to use for feasibility pump. enum FPRoundingMethod { diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 4babf78c7cd..c29bdd25e0e 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -1129,7 +1129,7 @@ SatSolver::Status SatSolver::Solve() { return SolveInternal(time_limit_, parameters_->max_number_of_conflicts()); } -void SatSolver::KeepAllClauseUsedToInfer(BooleanVariable variable) { +void SatSolver::KeepAllClausesUsedToInfer(BooleanVariable variable) { CHECK(Assignment().VariableIsAssigned(variable)); if (trail_->Info(variable).level == 0) return; int trail_index = trail_->Info(variable).trail_index; @@ -1161,7 +1161,8 @@ void SatSolver::KeepAllClauseUsedToInfer(BooleanVariable variable) { } } -bool SatSolver::SubsumptionIsInteresting(BooleanVariable variable) { +bool SatSolver::SubsumptionIsInteresting(BooleanVariable variable, + int max_size) { // TODO(user): other id should probably be safe as long as we do not delete // the propagators. Note that symmetry is tricky since we would need to keep // the symmetric clause around in KeepAllClauseUsedToInfer(). @@ -1186,7 +1187,10 @@ bool SatSolver::SubsumptionIsInteresting(BooleanVariable variable) { if (type != binary_id && type != clause_id) return false; SatClause* clause = ReasonClauseOrNull(var); if (clause != nullptr && clauses_propagator_->IsRemovable(clause)) { - ++num_clause_to_mark_as_non_deletable; + if (clause->size() > max_size) { + return false; + } + if (++num_clause_to_mark_as_non_deletable > 1) return false; } for (const Literal l : trail_->Reason(var)) { const AssignmentInfo& info = trail_->Info(l.Variable()); @@ -1201,54 +1205,63 @@ bool SatSolver::SubsumptionIsInteresting(BooleanVariable variable) { } // TODO(user): this is really an in-processing stuff and should be moved out -// of here. I think the name for that (or similar) technique is called vivify. -// Ideally this should be scheduled after other faster in-processing technique. +// of here. Ideally this should be scheduled after other faster in-processing +// techniques. This implements "vivification" as described in +// https://doi.org/10.1016/j.artint.2019.103197, with one significant tweak: +// we sort each clause by current trail index before trying to minimize it so +// that we can reuse the trail from previous calls in case there are overlaps. void SatSolver::TryToMinimizeClause(SatClause* clause) { CHECK(clause != nullptr); ++counters_.minimization_num_clauses; - absl::btree_set moved_last; - std::vector candidate(clause->begin(), clause->end()); + std::vector candidate; + candidate.reserve(clause->size()); - // Note that CP-SAT presolve detect the clauses that share n-1 literals and - // transform them into (n-1 enforcement) => (1 literal per clause). We + // Note that CP-SAT presolve detects clauses that share n-1 literals and + // transforms them into (n-1 enforcement) => (1 literal per clause). We // currently do not support that internally, but these clauses will still - // likely be loaded one after the other, so there is an high chance that if we + // likely be loaded one after the other, so there is a high chance that if we // call TryToMinimizeClause() on consecutive clauses, there will be a long - // prefix in common ! + // prefix in common! // // TODO(user): Exploit this more by choosing a good minimization order? int longest_valid_prefix = 0; if (CurrentDecisionLevel() > 0) { - // Quick linear scan to see if first literal is there. - const Literal first_decision = decisions_[0].literal; + candidate.resize(clause->size()); + // Insert any compatible decisions into their correct place in candidate + for (Literal lit : *clause) { + if (!Assignment().LiteralIsFalse(lit)) continue; + const AssignmentInfo& info = trail_->Info(lit.Variable()); + if (info.level <= 0 || info.level > clause->size()) continue; + if (decisions_[info.level - 1].literal == lit.Negated()) { + candidate[info.level - 1] = lit; + } + } + // Then compute the matching prefix and discard the rest for (int i = 0; i < candidate.size(); ++i) { - if (candidate[i].Negated() == first_decision) { - std::swap(candidate[0], candidate[i]); - longest_valid_prefix = 1; + if (candidate[i] != Literal()) { + ++longest_valid_prefix; + } else { break; } } - // Lets compute the full maximum prefix if we have already one match. - if (longest_valid_prefix == 1 && CurrentDecisionLevel() > 1) { - // Lets do the full algo. - absl::flat_hash_map indexing; - for (int i = 0; i < candidate.size(); ++i) { - indexing[candidate[i].NegatedIndex()] = i; - } - for (; longest_valid_prefix < CurrentDecisionLevel(); - ++longest_valid_prefix) { - const auto it = - indexing.find(decisions_[longest_valid_prefix].literal.Index()); - if (it == indexing.end()) break; - std::swap(candidate[longest_valid_prefix], candidate[it->second]); - indexing[candidate[it->second].NegatedIndex()] = it->second; - } - counters_.minimization_num_reused += longest_valid_prefix; + counters_.minimization_num_reused += longest_valid_prefix; + candidate.resize(longest_valid_prefix); + } + // Then do a second pass to add the remaining literals in order. + for (Literal lit : *clause) { + const AssignmentInfo& info = trail_->Info(lit.Variable()); + // Skip if this literal is already in the prefix. + if (info.level >= 1 && info.level <= longest_valid_prefix && + candidate[info.level - 1] == lit) { + continue; } + candidate.push_back(lit); } - Backtrack(longest_valid_prefix); + CHECK_EQ(candidate.size(), clause->size()); + Backtrack(longest_valid_prefix); + absl::btree_set moved_last; while (!model_is_unsat_) { // We want each literal in candidate to appear last once in our propagation // order. We want to do that while maximizing the reutilization of the @@ -1258,12 +1271,15 @@ void SatSolver::TryToMinimizeClause(SatClause* clause) { moved_last, CurrentDecisionLevel(), &candidate); if (target_level == -1) break; Backtrack(target_level); + while (CurrentDecisionLevel() < candidate.size()) { if (time_limit_->LimitReached()) return; const int level = CurrentDecisionLevel(); const Literal literal = candidate[level]; + // Remove false literals if (Assignment().LiteralIsFalse(literal)) { - candidate.erase(candidate.begin() + level); + candidate[level] = candidate.back(); + candidate.pop_back(); continue; } else if (Assignment().LiteralIsTrue(literal)) { const int variable_level = @@ -1277,27 +1293,35 @@ void SatSolver::TryToMinimizeClause(SatClause* clause) { return; } - // If literal (at true) wasn't propagated by this clause, then we - // know that this clause is subsumed by other clauses in the database, - // so we can remove it. Note however that we need to make sure we will - // never remove the clauses that subsumes it later. + if (parameters_->inprocessing_minimization_use_conflict_analysis()) { + // Replace the clause with the reason for the literal being true, plus + // the literal itself. + candidate.clear(); + for (Literal lit : + GetDecisionsFixing(trail_->Reason(literal.Variable()))) { + candidate.push_back(lit.Negated()); + } + } else { + candidate.resize(variable_level); + } + candidate.push_back(literal); + + // If a (true) literal wasn't propagated by this clause, then we know + // that this clause is subsumed by other clauses in the database, so we + // can remove it so long as the subsumption is due to non-removable + // clauses. If we can subsume this clause by making only 1 additional + // clause permanent and that clause is no longer than this one, we will + // do so. if (ReasonClauseOrNull(literal.Variable()) != clause && - SubsumptionIsInteresting(literal.Variable())) { + SubsumptionIsInteresting(literal.Variable(), candidate.size())) { counters_.minimization_num_subsumed++; counters_.minimization_num_removed_literals += clause->size(); - KeepAllClauseUsedToInfer(literal.Variable()); + KeepAllClausesUsedToInfer(literal.Variable()); Backtrack(0); clauses_propagator_->Detach(clause); return; - } else { - // Simplify. Note(user): we could only keep in clause the literals - // responsible for the propagation, but because of the subsumption - // above, this is not needed. - if (variable_level + 1 < candidate.size()) { - candidate.resize(variable_level); - candidate.push_back(literal); - } } + break; } else { ++counters_.minimization_num_decisions; @@ -1307,19 +1331,31 @@ void SatSolver::TryToMinimizeClause(SatClause* clause) { return; } if (model_is_unsat_) return; + if (CurrentDecisionLevel() < level) { + // There was a conflict, consider the conflicting literal next so we + // should be able to exploit the conflict in the next iteration. + // TODO(user): I *think* this is sufficient to ensure pushing + // the same literal to the new trail fails, immediately on the next + // iteration, if not we may be able to analyse the last failure and + // skip some propagation steps? + std::swap(candidate[level], candidate[CurrentDecisionLevel()]); + } } } if (candidate.empty()) { model_is_unsat_ = true; return; } + if (!parameters_->inprocessing_minimization_use_all_orderings()) break; moved_last.insert(candidate.back().Index()); } + if (candidate.empty()) { + model_is_unsat_ = true; + return; + } + // Returns if we don't have any minimization. - // - // Note that we don't backtrack right away so maybe if the next clause as - // similar literal, we can reuse the trail prefix! if (candidate.size() == clause->size()) return; Backtrack(0); @@ -1504,27 +1540,40 @@ bool SatSolver::MinimizeByPropagation(double dtime) { } std::vector SatSolver::GetLastIncompatibleDecisions() { + std::vector* clause = trail_->MutableConflict(); + int num_true = 0; + for (int i = 0; i < clause->size(); ++i) { + const Literal literal = (*clause)[i]; + if (Assignment().LiteralIsTrue(literal)) { + // literal at true in the conflict must be the last decision/assumption + // that could not be taken. Put it at the front to add to the result + // later. + std::swap((*clause)[i], (*clause)[num_true++]); + } + } + CHECK_LE(num_true, 1); + std::vector result = + GetDecisionsFixing(absl::MakeConstSpan(*clause).subspan(num_true)); + for (int i = 0; i < num_true; ++i) { + result.push_back((*clause)[i].Negated()); + } + return result; +} + +std::vector SatSolver::GetDecisionsFixing( + absl::Span literals) { SCOPED_TIME_STAT(&stats_); std::vector unsat_assumptions; is_marked_.ClearAndResize(num_variables_); int trail_index = 0; - int num_true = 0; - for (const Literal lit : trail_->FailingClause()) { + for (const Literal lit : literals) { CHECK(Assignment().LiteralIsAssigned(lit)); - if (Assignment().LiteralIsTrue(lit)) { - // literal at true in the conflict must be decision/assumptions that could - // not be taken. - ++num_true; - unsat_assumptions.push_back(lit.Negated()); - continue; - } trail_index = std::max(trail_index, trail_->Info(lit.Variable()).trail_index); is_marked_.Set(lit.Variable()); } - CHECK_LE(num_true, 1); // We just expand the conflict until we only have decisions. const int limit = diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index b8d2deaaff5..d1bc181aa53 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -234,6 +234,10 @@ class SatSolver { // the problem UNSAT. std::vector GetLastIncompatibleDecisions(); + // Returns a subset of decisions that are sufficient to ensure all literals in + // `literals` are fixed to their current value. + std::vector GetDecisionsFixing(absl::Span literals); + // Advanced usage. The next 3 functions allow to drive the search from outside // the solver. @@ -717,15 +721,16 @@ class SatSolver { std::string StatusString(Status status) const; std::string RunningStatisticsString() const; - // Marks as "non-deletable" all clauses that were used to infer the given - // variable. The variable must be currently assigned. - void KeepAllClauseUsedToInfer(BooleanVariable variable); - bool SubsumptionIsInteresting(BooleanVariable variable); + // Returns true if variable is fixed in the current assignment due to + // non-removable clauses, plus at most one removable clause with size <= + // max_size. + bool SubsumptionIsInteresting(BooleanVariable variable, int max_size); + void KeepAllClausesUsedToInfer(BooleanVariable variable); // Use propagation to try to minimize the given clause. This is really similar - // to MinimizeCoreWithPropagation(). It must be called when the current - // decision level is zero. Note that because this do a small tree search, it - // will impact the variable/clauses activities and may add new conflicts. + // to MinimizeCoreWithPropagation(). Note that because this does a small tree + // search, it will impact the variable/clause activities and may add new + // conflicts. void TryToMinimizeClause(SatClause* clause); // This is used by the old non-model constructor. diff --git a/ortools/sat/scheduling_cuts_test.cc b/ortools/sat/scheduling_cuts_test.cc new file mode 100644 index 00000000000..23a8b92bac1 --- /dev/null +++ b/ortools/sat/scheduling_cuts_test.cc @@ -0,0 +1,576 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/scheduling_cuts.h" + +#include + +#include +#include +#include +#include + +#include "absl/base/log_severity.h" +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/cp_model.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/cuts.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/linear_constraint_manager.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::EndsWith; +using ::testing::StartsWith; + +TEST(CumulativeEnergyCutGenerator, TestCutTimeTableGenerator) { + Model model; + + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerVariable end1 = model.Add(NewIntegerVariable(7, 10)); + const IntegerVariable size1 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i1 = model.Add(NewInterval(start1, end1, size1)); + + const BooleanVariable b = model.Add(NewBooleanVariable()); + const IntegerVariable b_view = model.Add(NewIntegerVariable(0, 1)); + auto* integer_encoder = model.GetOrCreate(); + integer_encoder->AssociateToIntegerEqualValue(Literal(b, true), b_view, + IntegerValue(1)); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(3, 6)); + const IntegerVariable end2 = model.Add(NewIntegerVariable(10, 13)); + const IntegerVariable size2 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i2 = + model.Add(NewOptionalInterval(start2, end2, size2, Literal(b, true))); + + const IntegerVariable demand1 = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable demand2 = model.Add(NewIntegerVariable(3, 10)); + const IntegerVariable capacity = model.Add(NewIntegerVariable(10, 10)); + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2}); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({demand1, demand2}, helper, &model); + model.TakeOwnership(demands_helper); + CutGenerator cumulative = CreateCumulativeTimeTableCutGenerator( + helper, demands_helper, capacity, &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 3.0; // x0 + lp_values[end1] = 10.0; // x1 + lp_values[size1] = 7.0; // x2 + lp_values[b_view] = 1.0; // x3 + lp_values[start2] = 6.0; // x4 + lp_values[end2] = 13.0; // x5 + lp_values[size2] = 7.0; // x6 + lp_values[demand1] = 8.0; // x7 + lp_values[demand2] = 7.0; // x8 + lp_values[capacity] = 10.0; // x9 + + cumulative.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + + // 3*X3 1*X7 -1*X9 <= 0 -> Normalized to 3*X3 1*X7 <= 10 + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("3*X3 1*X7 <= 10")); +} + +TEST(CumulativeEnergyCutGenerator, SameDemand) { + Model model; + + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerVariable end1 = model.Add(NewIntegerVariable(7, 10)); + const IntegerVariable size1 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i1 = model.Add(NewInterval(start1, end1, size1)); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(3, 6)); + const IntegerVariable end2 = model.Add(NewIntegerVariable(10, 13)); + const IntegerVariable size2 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i2 = model.Add(NewInterval(start2, end2, size2)); + + const IntegerVariable start3 = model.Add(NewIntegerVariable(4, 8)); + const IntegerVariable end3 = model.Add(NewIntegerVariable(11, 15)); + const IntegerVariable size3 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i3 = model.Add(NewInterval(start3, end3, size3)); + + const IntegerVariable demand = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable demand2 = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable capacity = model.Add(NewIntegerVariable(10, 10)); + + LinearExpression e1; + e1.vars.push_back(demand); + e1.coeffs.push_back(IntegerValue(7)); + LinearExpression e2; + e2.vars.push_back(demand2); + e2.coeffs.push_back(IntegerValue(7)); + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2, i3}); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({demand, demand, demand2}, helper, &model); + model.TakeOwnership(demands_helper); + + CutGenerator cumulative = CreateCumulativeEnergyCutGenerator( + helper, demands_helper, capacity, std::optional(), + &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 3.0; // x0 + lp_values[end1] = 10.0; // x1 + lp_values[size1] = 7.0; // x2 + lp_values[start2] = 6.0; // x3 + lp_values[end2] = 13.0; // x4 + lp_values[size2] = 7.0; // x5 + lp_values[start3] = 6.0; // x6 + lp_values[end3] = 13.0; // x7 + lp_values[size3] = 7.0; // x8 + lp_values[demand] = 8.0; // x9 + lp_values[demand2] = 8.0; // x10 + lp_values[capacity] = 10.0; // x11 + + cumulative.generate_cuts(manager); + ASSERT_EQ(5, manager->num_cuts()); + + // CumulativeEnergy cut. + EXPECT_THAT( + manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(0)] + .constraint.DebugString(), + EndsWith("1*X9 <= 5")); + EXPECT_THAT( + manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(1)] + .constraint.DebugString(), + EndsWith("1*X9 1*X10 <= 10")); + EXPECT_THAT( + manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(2)] + .constraint.DebugString(), + EndsWith("3*X9 2*X10 <= 30")); + EXPECT_THAT( + manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(3)] + .constraint.DebugString(), + EndsWith("5*X9 2*X10 <= 40")); + EXPECT_THAT( + manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(4)] + .constraint.DebugString(), + EndsWith("2*X9 3*X10 <= 30")); +} + +TEST(CumulativeEnergyCutGenerator, SameDemandTimeTableGenerator) { + Model model; + + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerVariable end1 = model.Add(NewIntegerVariable(7, 10)); + const IntegerVariable size1 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i1 = model.Add(NewInterval(start1, end1, size1)); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(3, 6)); + const IntegerVariable end2 = model.Add(NewIntegerVariable(10, 13)); + const IntegerVariable size2 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i2 = model.Add(NewInterval(start2, end2, size2)); + + const IntegerVariable start3 = model.Add(NewIntegerVariable(4, 8)); + const IntegerVariable end3 = model.Add(NewIntegerVariable(11, 15)); + const IntegerVariable size3 = model.Add(NewIntegerVariable(7, 7)); + const IntervalVariable i3 = model.Add(NewInterval(start3, end3, size3)); + + const IntegerVariable demand = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable demand2 = model.Add(NewIntegerVariable(5, 10)); + const IntegerVariable capacity = model.Add(NewIntegerVariable(10, 10)); + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2, i3}); + SchedulingDemandHelper* demands_helper = + new SchedulingDemandHelper({demand, demand, demand2}, helper, &model); + model.TakeOwnership(demands_helper); + CutGenerator cumulative = CreateCumulativeTimeTableCutGenerator( + helper, demands_helper, capacity, &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 3.0; // x0 + lp_values[end1] = 10.0; // x1 + lp_values[size1] = 7.0; // x2 + lp_values[start2] = 6.0; // x3 + lp_values[end2] = 13.0; // x4 + lp_values[size2] = 7.0; // x5 + lp_values[start3] = 6.0; // x6 + lp_values[end3] = 13.0; // x7 + lp_values[size3] = 7.0; // x8 + lp_values[demand] = 8.0; // x9 + lp_values[demand2] = 8.0; // x10 + lp_values[capacity] = 10.0; // x11 + + cumulative.generate_cuts(manager); + ASSERT_EQ(2, manager->num_cuts()); + + // 1*X9 1*X9 <= X11 -> Normalized to 1*X9 <= 5 + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("1*X9 <= 5")); + // 1*X9 1*X10 <= X11 -> Normalized to 1*X9 1*X10 <= 10 + EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), + EndsWith("1*X9 1*X10 <= 10")); +} + +TEST(CumulativeEnergyCutGenerator, DetectedPrecedence) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + const IntegerValue one(1); + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), AffineExpression(size1), + kNoLiteralIndex, /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(1, 5)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), AffineExpression(size2), + kNoLiteralIndex, /*add_linear_relation=*/false); + CutGenerator disjunctive = CreateNoOverlapPrecedenceCutGenerator( + intervals_repository->GetOrCreateHelper({ + i1, + i2, + }), + &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 0.0; + lp_values[NegationOf(start1)] = 0.0; + lp_values[start2] = 2.0; + lp_values[NegationOf(start2)] = -2.0; + + disjunctive.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("1*X0 -1*X1 <= -3")); +} + +TEST(CumulativeEnergyCutGenerator, DetectedPrecedenceRev) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + const IntegerValue one(1); + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), AffineExpression(size1), + kNoLiteralIndex, /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(1, 5)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), AffineExpression(size2), + kNoLiteralIndex, /*add_linear_relation=*/false); + + CutGenerator disjunctive = CreateNoOverlapPrecedenceCutGenerator( + intervals_repository->GetOrCreateHelper({ + i2, + i1, + }), + &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 0.0; + lp_values[NegationOf(start1)] = 0.0; + lp_values[start2] = 2.0; + lp_values[NegationOf(start2)] = -2.0; + + disjunctive.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + EndsWith("1*X0 -1*X1 <= -3")); +} + +TEST(CumulativeEnergyCutGenerator, DisjunctionOnStart) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + const IntegerValue one(1); + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 5)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), AffineExpression(size1), + kNoLiteralIndex, /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(1, 5)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), AffineExpression(size2), + kNoLiteralIndex, /*add_linear_relation=*/false); + + CutGenerator disjunctive = CreateNoOverlapPrecedenceCutGenerator( + intervals_repository->GetOrCreateHelper({ + i2, + i1, + }), + &model); + LinearConstraintManager* const manager = + model.GetOrCreate(); + const IntegerVariable num_vars = + model.GetOrCreate()->NumIntegerVariables(); + + auto& lp_values = *model.GetOrCreate(); + lp_values.resize(num_vars.value() * 2, 0.0); + lp_values[start1] = 0.0; + lp_values[NegationOf(start1)] = 0.0; + lp_values[start2] = 2.0; + lp_values[NegationOf(start2)] = -2.0; + + disjunctive.generate_cuts(manager); + ASSERT_EQ(1, manager->num_cuts()); + + EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), + StartsWith("15 <= 2*X0 5*X1")); +} + +TEST(ComputeMinSumOfEndMinsTest, CombinationOf3) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + IntegerValue one(1); + IntegerValue two(2); + + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 10)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), size1, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(0, 10)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), size2, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start3 = model.Add(NewIntegerVariable(0, 10)); + const IntegerValue size3(5); + const IntervalVariable i3 = intervals_repository->CreateInterval( + start3, AffineExpression(start3, one, size3), size3, kNoLiteralIndex, + /*add_linear_relation=*/false); + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2, i3}); + CtEvent e1(0, helper); + e1.y_size_min = two; + CtEvent e2(1, helper); + e2.y_size_min = one; + CtEvent e3(2, helper); + e3.y_size_min = one; + std::vector events = {{0, e1}, {1, e2}, {1, e3}}; + + IntegerValue min_sum_of_end_mins(0); + IntegerValue min_sum_of_weighted_end_mins(0); + ASSERT_TRUE(ComputeMinSumOfWeightedEndMins( + events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins, + kMinIntegerValue, kMinIntegerValue)); + EXPECT_EQ(min_sum_of_end_mins, 17); + EXPECT_EQ(min_sum_of_weighted_end_mins, 21); +} + +TEST(ComputeMinSumOfEndMinsTest, CombinationOf3ConstraintStart) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + IntegerValue one(1); + IntegerValue two(2); + + const IntegerVariable start1 = model.Add(NewIntegerVariable(0, 3)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), size1, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(0, 10)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), size2, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start3 = model.Add(NewIntegerVariable(0, 10)); + const IntegerValue size3(5); + const IntervalVariable i3 = intervals_repository->CreateInterval( + start3, AffineExpression(start3, one, size3), size3, kNoLiteralIndex, + /*add_linear_relation=*/false); + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2, i3}); + CtEvent e1(0, helper); + e1.y_size_min = two; + CtEvent e2(1, helper); + e2.y_size_min = one; + CtEvent e3(2, helper); + e3.y_size_min = one; + std::vector events = {{0, e1}, {1, e2}, {2, e3}}; + + IntegerValue min_sum_of_end_mins(0); + IntegerValue min_sum_of_weighted_end_mins(0); + ASSERT_TRUE(ComputeMinSumOfWeightedEndMins( + events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins, + kMinIntegerValue, kMinIntegerValue)); + EXPECT_EQ(min_sum_of_end_mins, 18); + EXPECT_EQ(min_sum_of_weighted_end_mins, 21); +} + +TEST(ComputeMinSumOfEndMinsTest, Infeasible) { + Model model; + auto* intervals_repository = model.GetOrCreate(); + + IntegerValue one(1); + IntegerValue two(2); + + const IntegerVariable start1 = model.Add(NewIntegerVariable(1, 3)); + const IntegerValue size1(3); + const IntervalVariable i1 = intervals_repository->CreateInterval( + start1, AffineExpression(start1, one, size1), size1, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start2 = model.Add(NewIntegerVariable(0, 3)); + const IntegerValue size2(4); + const IntervalVariable i2 = intervals_repository->CreateInterval( + start2, AffineExpression(start2, one, size2), size2, kNoLiteralIndex, + /*add_linear_relation=*/false); + + const IntegerVariable start3 = model.Add(NewIntegerVariable(0, 3)); + const IntegerValue size3(5); + const IntervalVariable i3 = intervals_repository->CreateInterval( + start3, AffineExpression(start3, one, size3), size3, kNoLiteralIndex, + /*add_linear_relation=*/false); + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper({i1, i2, i3}); + CtEvent e1(0, helper); + e1.y_size_min = two; + CtEvent e2(1, helper); + e2.y_size_min = one; + CtEvent e3(2, helper); + e3.y_size_min = one; + std::vector events = {{0, e1}, {1, e2}, {2, e3}}; + + IntegerValue min_sum_of_end_mins(0); + IntegerValue min_sum_of_weighted_end_mins(0); + ASSERT_FALSE(ComputeMinSumOfWeightedEndMins( + events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins, + kMinIntegerValue, kMinIntegerValue)); +} + +int64_t ExactMakespan(const std::vector& sizes, std::vector& demands, + int capacity) { + const int64_t kHorizon = 1000; + CpModelBuilder builder; + LinearExpr obj; + CumulativeConstraint cumul = builder.AddCumulative(capacity); + for (int i = 0; i < sizes.size(); ++i) { + IntVar s = builder.NewIntVar({0, kHorizon}); + IntervalVar v = builder.NewFixedSizeIntervalVar(s, sizes[i]); + obj += s + sizes[i]; + cumul.AddDemand(v, demands[i]); + } + builder.Minimize(obj); + const CpSolverResponse response = + SolveWithParameters(builder.Build(), "num_search_workers:8"); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + return static_cast(response.objective_value()); +} + +int64_t ExactMakespanBruteForce(absl::Span sizes, + std::vector& demands, int capacity) { + const int64_t kHorizon = 1000; + Model model; + auto* intervals_repository = model.GetOrCreate(); + IntegerValue one(1); + + std::vector intervals; + for (int i = 0; i < sizes.size(); ++i) { + const IntegerVariable start = model.Add(NewIntegerVariable(0, kHorizon)); + const IntegerValue size(sizes[i]); + const IntervalVariable interval = intervals_repository->CreateInterval( + start, AffineExpression(start, one, size), size, kNoLiteralIndex, + /*add_linear_relation=*/false); + intervals.push_back(interval); + } + + SchedulingConstraintHelper* helper = + model.GetOrCreate()->GetOrCreateHelper(intervals); + std::vector events; + for (int i = 0; i < demands.size(); ++i) { + CtEvent e(i, helper); + e.y_size_min = demands[i]; + events.emplace_back(i, e); + } + + IntegerValue min_sum_of_end_mins(0); + IntegerValue min_sum_of_weighted_end_mins(0); + EXPECT_TRUE(ComputeMinSumOfWeightedEndMins( + events, IntegerValue(capacity), min_sum_of_end_mins, + min_sum_of_weighted_end_mins, kMinIntegerValue, kMinIntegerValue)); + return min_sum_of_end_mins.value(); +} + +TEST(ComputeMinSumOfEndMinsTest, RandomCases) { + absl::BitGen random; + const int kNumTests = DEBUG_MODE ? 100 : 1000; + const int kNumTasks = 7; + for (int loop = 0; loop < kNumTests; ++loop) { + const int capacity = absl::Uniform(random, 10, 30); + std::vector sizes; + std::vector demands; + for (int t = 0; t < kNumTasks; ++t) { + sizes.push_back(absl::Uniform(random, 2, 15)); + demands.push_back(absl::Uniform(random, 1, capacity)); + } + + EXPECT_EQ(ExactMakespan(sizes, demands, capacity), + ExactMakespanBruteForce(sizes, demands, capacity)); + } +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/subsolver.cc b/ortools/sat/subsolver.cc index 885913187b1..6641c51add1 100644 --- a/ortools/sat/subsolver.cc +++ b/ortools/sat/subsolver.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -43,18 +44,23 @@ namespace { // only SubSolvers for which TaskIsAvailable() is true are considered. Return -1 // if no SubSolver can generate a new task. // -// For now we use a really basic logic: call the least frequently called. +// For now we use a really basic logic that tries to equilibrate the walltime or +// deterministic time spent in each subsolver. int NextSubsolverToSchedule(std::vector>& subsolvers, - absl::Span num_generated_tasks) { + bool deterministic = true) { int best = -1; + double best_score = std::numeric_limits::infinity(); for (int i = 0; i < subsolvers.size(); ++i) { if (subsolvers[i] == nullptr) continue; if (subsolvers[i]->TaskIsAvailable()) { - if (best == -1 || num_generated_tasks[i] < num_generated_tasks[best]) { + const double score = subsolvers[i]->GetSelectionScore(deterministic); + if (best == -1 || score < best_score) { + best_score = score; best = i; } } } + if (best != -1) VLOG(1) << "Scheduling " << subsolvers[best]->name(); return best; } @@ -85,14 +91,13 @@ void SynchronizeAll(const std::vector>& subsolvers) { void SequentialLoop(std::vector>& subsolvers) { int64_t task_id = 0; - std::vector num_generated_tasks(subsolvers.size(), 0); std::vector num_in_flight_per_subsolvers(subsolvers.size(), 0); while (true) { SynchronizeAll(subsolvers); ClearSubsolversThatAreDone(num_in_flight_per_subsolvers, subsolvers); - const int best = NextSubsolverToSchedule(subsolvers, num_generated_tasks); + const int best = NextSubsolverToSchedule(subsolvers); if (best == -1) break; - num_generated_tasks[best]++; + subsolvers[best]->NotifySelection(); WallTimer timer; timer.Start(); @@ -126,7 +131,6 @@ void DeterministicLoop(std::vector>& subsolvers, } int64_t task_id = 0; - std::vector num_generated_tasks(subsolvers.size(), 0); std::vector num_in_flight_per_subsolvers(subsolvers.size(), 0); std::vector> to_run; std::vector indices; @@ -149,10 +153,10 @@ void DeterministicLoop(std::vector>& subsolvers, to_run.clear(); indices.clear(); for (int t = 0; t < batch_size; ++t) { - const int best = NextSubsolverToSchedule(subsolvers, num_generated_tasks); + const int best = NextSubsolverToSchedule(subsolvers); if (best == -1) break; num_in_flight_per_subsolvers[best]++; - num_generated_tasks[best]++; + subsolvers[best]->NotifySelection(); to_run.push_back(subsolvers[best]->GenerateTask(task_id++)); indices.push_back(best); } @@ -210,7 +214,6 @@ void NonDeterministicLoop(std::vector>& subsolvers, // to create millions of them, so we use the blocking nature of // pool.Schedule() when the queue capacity is set. int64_t task_id = 0; - std::vector num_generated_tasks(subsolvers.size(), 0); while (true) { // Set to true if no task is pending right now. bool all_done = false; @@ -238,13 +241,14 @@ void NonDeterministicLoop(std::vector>& subsolvers, } SynchronizeAll(subsolvers); + int best = -1; { // We need to do that while holding the lock since substask below might // be currently updating the time via AddTaskDuration(). const absl::MutexLock mutex_lock(&mutex); ClearSubsolversThatAreDone(num_in_flight_per_subsolvers, subsolvers); + best = NextSubsolverToSchedule(subsolvers, /*deterministic=*/false); } - const int best = NextSubsolverToSchedule(subsolvers, num_generated_tasks); if (best == -1) { if (all_done) break; @@ -257,7 +261,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, } // Schedule next task. - num_generated_tasks[best]++; + subsolvers[best]->NotifySelection(); { absl::MutexLock mutex_lock(&mutex); num_in_flight++; diff --git a/ortools/sat/subsolver.h b/ortools/sat/subsolver.h index 3dea6ceb4d9..5f93d8a1dcb 100644 --- a/ortools/sat/subsolver.h +++ b/ortools/sat/subsolver.h @@ -101,9 +101,15 @@ class SubSolver { // Note that this is protected by the global execution mutex and so it is // called sequentially. Subclasses do not need to call this. void AddTaskDuration(double duration_in_seconds) { + ++num_finished_tasks_; + wall_time_ += duration_in_seconds; timing_.AddTimeInSec(duration_in_seconds); } + // Note that this is protected by the global execution mutex and so it is + // called sequentially. Subclasses do not need to call this. + void NotifySelection() { ++num_scheduled_tasks_; } + // This one need to be called by the Subclasses. Usually from Synchronize(), // or from the task itself it we execute a single task at the same time. void AddTaskDeterministicDuration(double deterministic_duration) { @@ -127,11 +133,42 @@ class SubSolver { return data; } + // Returns a score used to compare which tasks to schedule next. + // We will schedule the LOWER score. + // + // Tricky: Note that this will only be called sequentially. The deterministic + // time should only be used with the DeterministicLoop() because otherwise it + // can be updated at the same time as this is called. + double GetSelectionScore(bool deterministic) const { + const double time = deterministic ? deterministic_time_ : wall_time_; + const double divisor = num_scheduled_tasks_ > 0 + ? static_cast(num_scheduled_tasks_) + : 1.0; + + // If we have little data, we strongly limit the number of task in flight. + // This is needed if some LNS are stuck for a long time to not just only + // schedule this type at the beginning. + const int64_t in_flight = num_scheduled_tasks_ - num_finished_tasks_; + const double confidence_factor = + num_finished_tasks_ > 10 ? 1.0 : std::exp(in_flight); + + // We assume a "minimum time per task" which will be our base etimation for + // the average running time of this task. + return num_scheduled_tasks_ * std::max(0.1, time / divisor) * + confidence_factor; + } + private: const std::string name_; const SubsolverType type_; + int64_t num_scheduled_tasks_ = 0; + int64_t num_finished_tasks_ = 0; + + // Sum of wall_time / deterministic_time. + double wall_time_ = 0.0; double deterministic_time_ = 0.0; + TimeDistribution timing_ = TimeDistribution("task time"); TimeDistribution dtiming_ = TimeDistribution("task dtime"); }; diff --git a/ortools/sat/subsolver_test.cc b/ortools/sat/subsolver_test.cc new file mode 100644 index 00000000000..06f549cfe7a --- /dev/null +++ b/ortools/sat/subsolver_test.cc @@ -0,0 +1,105 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/subsolver.h" + +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "gtest/gtest.h" + +namespace operations_research { +namespace sat { +namespace { + +// Just a trivial example showing how to use the DeterministicLoop() and +// NonDeterministicLoop() functions. +template +void TestLoopFunction() { + struct GlobalState { + int num_task = 0; + const int limit = 100; + + absl::Mutex mutex; + std::vector updates; + + // This one will be always the same after each batch of task. + int64_t max_update_value = 0; + }; + + class TestSubSolver : public SubSolver { + public: + explicit TestSubSolver(GlobalState* state) + : SubSolver("test", FULL_PROBLEM), state_(state) {} + + bool TaskIsAvailable() override { + // Note that the lock is only needed for the non-deterministic test. + absl::MutexLock mutex_lock(&state_->mutex); + return state_->num_task < state_->limit; + } + + std::function GenerateTask(int64_t id) override { + { + // Note that the lock is only needed for the non-deterministic test. + absl::MutexLock mutex_lock(&state_->mutex); + state_->num_task++; + } + return [this, id] { + absl::MutexLock mutex_lock(&state_->mutex); + state_->updates.push_back(id); + }; + } + + void Synchronize() override { + // Note that the lock is only needed for the non-deterministic test. + absl::MutexLock mutex_lock(&state_->mutex); + for (const int64_t i : state_->updates) { + state_->max_update_value = std::max(state_->max_update_value, i); + } + state_->updates.clear(); + } + + private: + GlobalState* state_; + }; + + GlobalState state; + + // The number of subsolver can be independent of the number of threads. Here + // there is actually no need to have 3 of them except for testing the feature. + std::vector> subsolvers; + for (int i = 0; i < 3; ++i) { + subsolvers.push_back(std::make_unique(&state)); + } + + const int num_threads = 4; + if (deterministic) { + const int batch_size = 20; + DeterministicLoop(subsolvers, num_threads, batch_size); + } else { + NonDeterministicLoop(subsolvers, num_threads); + } + EXPECT_EQ(state.max_update_value, state.limit - 1); +} + +TEST(DeterministicLoop, BasicTest) { TestLoopFunction(); } + +TEST(NonDeterministicLoop, BasicTest) { TestLoopFunction(); } + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index b03de25b90c..0d9b045c4e6 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -15,7 +15,6 @@ #include -#include #include #include @@ -27,9 +26,9 @@ #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" -#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -90,18 +89,15 @@ bool SolutionCallback::SolutionBooleanValue(int index) { } void SolutionCallback::StopSearch() { - if (stopped_ptr_ != nullptr) { - (*stopped_ptr_) = true; - } + if (wrapper_ != nullptr) wrapper_->StopSearch(); } operations_research::sat::CpSolverResponse SolutionCallback::Response() const { return response_; } -void SolutionCallback::SetAtomicBooleanToStopTheSearch( - std::atomic* stopped_ptr) const { - stopped_ptr_ = stopped_ptr; +void SolutionCallback::SetWrapperClass(SolveWrapper* wrapper) const { + wrapper_ = wrapper; } bool SolutionCallback::HasResponse() const { return has_response_; } @@ -116,15 +112,13 @@ void SolveWrapper::SetStringParameters(const std::string& string_parameters) { } void SolveWrapper::AddSolutionCallback(const SolutionCallback& callback) { - // Overwrite the atomic bool. - callback.SetAtomicBooleanToStopTheSearch(&stopped_); + callback.SetWrapperClass(this); model_.Add(NewFeasibleSolutionObserver( [&callback](const CpSolverResponse& r) { return callback.Run(r); })); } void SolveWrapper::ClearSolutionCallback(const SolutionCallback& callback) { - // cleanup the atomic bool. - callback.SetAtomicBooleanToStopTheSearch(nullptr); + callback.SetWrapperClass(nullptr); // Detach the wrapper class. } void SolveWrapper::AddLogCallback( @@ -157,11 +151,13 @@ void SolveWrapper::AddBestBoundCallbackFromClass(BestBoundCallback* callback) { operations_research::sat::CpSolverResponse SolveWrapper::Solve( const operations_research::sat::CpModelProto& model_proto) { FixFlagsAndEnvironmentForSwig(); - model_.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped_); return operations_research::sat::SolveCpModel(model_proto, &model_); } -void SolveWrapper::StopSearch() { stopped_ = true; } +void SolveWrapper::StopSearch() { + model_.GetOrCreate()->Stop(); +} + std::string CpSatHelper::ModelStats( const operations_research::sat::CpModelProto& model_proto) { return CpModelStats(model_proto); diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index e9821b620d1..3a9cfeec691 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -14,24 +14,20 @@ #ifndef OR_TOOLS_SAT_SWIG_HELPER_H_ #define OR_TOOLS_SAT_SWIG_HELPER_H_ -#include #include #include #include #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_checker.h" -#include "ortools/sat/cp_model_solver.h" -#include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" -#include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" -#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { +class SolveWrapper; + // Base class for SWIG director based on solution callbacks. // See http://www.swig.org/Doc4.0/SWIGDocumentation.html#CSharp_directors. class SolutionCallback { @@ -72,14 +68,14 @@ class SolutionCallback { operations_research::sat::CpSolverResponse Response() const; // We use mutable and non const methods to overcome SWIG difficulties. - void SetAtomicBooleanToStopTheSearch(std::atomic* stopped_ptr) const; + void SetWrapperClass(SolveWrapper* wrapper) const; bool HasResponse() const; private: mutable CpSolverResponse response_; mutable bool has_response_ = false; - mutable std::atomic* stopped_ptr_; + mutable SolveWrapper* wrapper_ = nullptr; }; // Simple director class for C#. @@ -126,7 +122,6 @@ class SolveWrapper { private: Model model_; - std::atomic stopped_ = false; }; // Static methods are stored in a module which name can vary. diff --git a/ortools/sat/symmetry_test.cc b/ortools/sat/symmetry_test.cc new file mode 100644 index 00000000000..7bfe774df06 --- /dev/null +++ b/ortools/sat/symmetry_test.cc @@ -0,0 +1,151 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/symmetry.h" + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/algorithms/sparse_permutation.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/sat_base.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; + +TEST(SymmetryPropagatorTest, Permute) { + const int num_variables = 6; + const int num_literals = 2 * num_variables; + std::unique_ptr perm(new SparsePermutation(num_literals)); + perm->AddToCurrentCycle(Literal(+3).Index().value()); + perm->AddToCurrentCycle(Literal(+2).Index().value()); + perm->AddToCurrentCycle(Literal(-4).Index().value()); + perm->CloseCurrentCycle(); + // Note that the permutation 'p' must be compatible with the negation. + // That is negation(p(l)) = p(negation(l)). This is actually not required + // for this test though. + perm->AddToCurrentCycle(Literal(-3).Index().value()); + perm->AddToCurrentCycle(Literal(-2).Index().value()); + perm->AddToCurrentCycle(Literal(+4).Index().value()); + perm->CloseCurrentCycle(); + + Trail trail; + SymmetryPropagator propagator; + propagator.AddSymmetry(std::move(perm)); + trail.RegisterPropagator(&propagator); + + std::vector literals = Literals({+1, +2, -2, +3}); + std::vector output; + propagator.Permute(0, literals, &output); + EXPECT_THAT(output, + ElementsAre(Literal(+1), Literal(-4), Literal(+4), Literal(+2))); +} + +TEST(SymmetryPropagatorTest, BasicTest) { + const int num_variables = 6; + const int num_literals = 2 * num_variables; + std::unique_ptr perm(new SparsePermutation(num_literals)); + perm->AddToCurrentCycle(Literal(+3).Index().value()); + perm->AddToCurrentCycle(Literal(+2).Index().value()); + perm->AddToCurrentCycle(Literal(-4).Index().value()); + perm->CloseCurrentCycle(); + // Note that the permutation 'p' must be compatible with the negation. + // That is negation(p(l)) = p(negation(l)). + perm->AddToCurrentCycle(Literal(-3).Index().value()); + perm->AddToCurrentCycle(Literal(-2).Index().value()); + perm->AddToCurrentCycle(Literal(+4).Index().value()); + perm->CloseCurrentCycle(); + perm->AddToCurrentCycle(Literal(-5).Index().value()); + perm->AddToCurrentCycle(Literal(+5).Index().value()); + perm->CloseCurrentCycle(); + + Trail trail; + trail.Resize(num_variables); + SymmetryPropagator propagator; + propagator.AddSymmetry(std::move(perm)); + trail.RegisterPropagator(&propagator); + + // We need a mock propagator to inject a reason. + struct MockPropagator : SatPropagator { + MockPropagator() : SatPropagator("MockPropagator") {} + bool Propagate(Trail* trail) final { return true; } + absl::Span Reason(const Trail& /*trail*/, + int /*trail_index*/, + int64_t /*conflict_id*/) const final { + return reason; + } + std::vector reason; + }; + MockPropagator mock_propagator; + trail.RegisterPropagator(&mock_propagator); + + // With such a trail, nothing should propagate because the first non-symmetric + // literal +3 is a decision. + trail.Enqueue(Literal(+3), AssignmentType::kSearchDecision); + trail.Enqueue(Literal(-5), mock_propagator.PropagatorId()); + while (!propagator.PropagationIsDone(trail)) { + EXPECT_TRUE(propagator.Propagate(&trail)); + } + EXPECT_EQ(trail.Index(), 2); + + // Now we take the decision +2 (which is the image of +3). + trail.Enqueue(Literal(+2), AssignmentType::kUnitReason); + + // We need to initialize the reason for -5, because it will be needed during + // the conflict creation that the Propagate() below will trigger. + mock_propagator.reason = Literals({-3}); + + // Because -5 is now the first non-symmetric literal, a conflict is detected + // since +5 can then be propagated. + EXPECT_FALSE(propagator.PropagationIsDone(trail)); + EXPECT_FALSE(propagator.Propagate(&trail)); + + // Let assume that the reason for -5 is the assignment +3 (which make sense + // since it was propagated). The expected conflict is as stated below because + // if -5 and +2 are true, by summetry since we had +3 => -5 we know that +2 => + // 5. + // + // Note: by convention all the literals of a reason or a conflict are false. + EXPECT_THAT(trail.FailingClause(), ElementsAre(Literal(-2), Literal(+5))); + + // Let backtrack to the trail to +3. + trail.Untrail(trail.Index() - 2); + propagator.Untrail(trail, trail.Index()); + + // Let now assume that +3 => +2, by symmetry we can also propagate -4! + while (!propagator.PropagationIsDone(trail)) { + EXPECT_TRUE(propagator.Propagate(&trail)); + } + EXPECT_EQ(trail.Index(), 1); + trail.Enqueue(Literal(+2), mock_propagator.PropagatorId()); + EXPECT_FALSE(propagator.PropagationIsDone(trail)); + EXPECT_TRUE(propagator.Propagate(&trail)); + EXPECT_EQ(trail.Index(), 3); + EXPECT_EQ(trail[2], Literal(-4)); + + // Once again, if the reason for +2 was the assignment +3, we can compute + // the reason for the assignment -4 (it is just the symmetric of the other). + EXPECT_THAT(trail.Reason(Literal(-4).Variable()), ElementsAre(Literal(-2))); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/symmetry_util.cc b/ortools/sat/symmetry_util.cc index 78edf8fe87f..c1d96e0a38e 100644 --- a/ortools/sat/symmetry_util.cc +++ b/ortools/sat/symmetry_util.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/algorithms/dynamic_partition.h" @@ -194,5 +195,42 @@ std::vector GetOrbitopeOrbits( return orbits; } +void GetSchreierVectorAndOrbit( + int point, absl::Span> generators, + std::vector* schrier_vector, std::vector* orbit) { + schrier_vector->clear(); + *orbit = {point}; + if (generators.empty()) return; + schrier_vector->resize(generators[0]->Size(), -1); + absl::flat_hash_set orbit_set = {point}; + for (int i = 0; i < orbit->size(); ++i) { + const int orbit_element = (*orbit)[i]; + for (int i = 0; i < generators.size(); ++i) { + DCHECK_EQ(schrier_vector->size(), generators[i]->Size()); + const int image = generators[i]->Image(orbit_element); + if (image == orbit_element) continue; + const auto [it, inserted] = orbit_set.insert(image); + if (inserted) { + (*schrier_vector)[image] = i; + orbit->push_back(image); + } + } + } +} + +std::vector TracePoint( + int point, absl::Span schrier_vector, + absl::Span> generators) { + std::vector result; + while (schrier_vector[point] != -1) { + const SparsePermutation& perm = *generators[schrier_vector[point]]; + result.push_back(schrier_vector[point]); + const int next = perm.InverseImage(point); + DCHECK_NE(next, point); + point = next; + } + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/symmetry_util.h b/ortools/sat/symmetry_util.h index f045be430e3..5e5e813d6e0 100644 --- a/ortools/sat/symmetry_util.h +++ b/ortools/sat/symmetry_util.h @@ -62,6 +62,19 @@ std::vector GetOrbits( std::vector GetOrbitopeOrbits(int n, absl::Span> orbitope); +// See Chapter 7 of Butler, Gregory, ed. Fundamental algorithms for permutation +// groups. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991. +void GetSchreierVectorAndOrbit( + int point, absl::Span> generators, + std::vector* schrier_vector, std::vector* orbit); + +// Given a schreier vector for a given base point and a point in the same orbit +// of the base point, returns a list of index of the `generators` to apply to +// get a permutation mapping the base point to get the given point. +std::vector TracePoint( + int point, absl::Span schrier_vector, + absl::Span> generators); + // Given the generators for a permutation group of [0, n-1], update it to // a set of generators of the group stabilizing the given element. // diff --git a/ortools/sat/symmetry_util_test.cc b/ortools/sat/symmetry_util_test.cc index 85a7d674811..9b3a5b19f59 100644 --- a/ortools/sat/symmetry_util_test.cc +++ b/ortools/sat/symmetry_util_test.cc @@ -13,9 +13,12 @@ #include "ortools/sat/symmetry_util.h" +#include #include +#include #include +#include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/gmock.h" @@ -25,24 +28,25 @@ namespace sat { namespace { using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +std::unique_ptr MakePerm( + int size, absl::Span> cycles) { + auto perm = std::make_unique(size); + for (const auto& cycle : cycles) { + for (const int x : cycle) { + perm->AddToCurrentCycle(x); + } + perm->CloseCurrentCycle(); + } + return perm; +} TEST(GetOrbitsTest, BasicExample) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->AddToCurrentCycle(2); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(7); - generators[0]->AddToCurrentCycle(8); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(3); - generators[1]->AddToCurrentCycle(2); - generators[1]->AddToCurrentCycle(7); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1, 2}, {7, 8}})); + generators.push_back(MakePerm(n, {{3, 2, 7}})); const std::vector orbits = GetOrbits(n, generators); for (const int i : std::vector{0, 1, 2, 3, 7, 8}) { EXPECT_EQ(orbits[i], 0); @@ -60,27 +64,8 @@ TEST(BasicOrbitopeExtractionTest, BasicExample) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(4); - generators[0]->AddToCurrentCycle(5); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(8); - generators[0]->AddToCurrentCycle(7); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(2); - generators[1]->AddToCurrentCycle(1); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(5); - generators[1]->AddToCurrentCycle(3); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(6); - generators[1]->AddToCurrentCycle(7); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1}, {4, 5}, {8, 7}})); + generators.push_back(MakePerm(n, {{2, 1}, {5, 3}, {6, 7}})); const std::vector> orbitope = BasicOrbitopeExtraction(generators); @@ -99,27 +84,8 @@ TEST(BasicOrbitopeExtractionTest, NotAnOrbitopeBecauseOfDuplicates) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(4); - generators[0]->AddToCurrentCycle(5); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(8); - generators[0]->AddToCurrentCycle(7); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(1); - generators[1]->AddToCurrentCycle(2); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(5); - generators[1]->AddToCurrentCycle(8); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(6); - generators[1]->AddToCurrentCycle(9); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1}, {4, 5}, {8, 7}})); + generators.push_back(MakePerm(n, {{1, 2}, {5, 8}, {6, 9}})); const std::vector> orbitope = BasicOrbitopeExtraction(generators); @@ -129,6 +95,66 @@ TEST(BasicOrbitopeExtractionTest, NotAnOrbitopeBecauseOfDuplicates) { EXPECT_THAT(orbitope[2], ElementsAre(8, 7)); } +TEST(GetSchreierVectorTest, Square) { + const int n = 4; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 1, 2, 3}})); + generators.push_back(MakePerm(n, {{1, 3}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 0, 0, 1)); +} + +TEST(GetSchreierVectorTest, ComplicatedGroup) { + // See Chapter 7 of Butler, Gregory, ed. Fundamental algorithms for + // permutation groups. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991. + const int n = 11; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 3, 4, 10, 5, 9, 2, 1}, {6, 7}})); + generators.push_back(MakePerm(n, {{0, 3, 4, 10, 5, 9, 2, 1}, {7, 8}})); + generators.push_back(MakePerm(n, {{0, 3, 1, 2}, {4, 10, 9, 5}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 2, 2, 0, 0, 0, -1, -1, -1, 2, 0)); + std::vector generators_idx = TracePoint(9, schrier_vector, generators); + std::vector points = {"0", "1", "2", "3", "4", "5", + "6", "7", "8", "9", "10"}; + for (const int i : generators_idx) { + generators[i]->ApplyToDenseCollection(points); + } + // It needs to take the base point 0 to the traced point 9. + EXPECT_THAT(points, ElementsAre("9", "10", "1", "4", "5", "2", "7", "6", "8", + "3", "0")); + GetSchreierVectorAndOrbit(6, generators, &schrier_vector, &orbit); + EXPECT_THAT(orbit, UnorderedElementsAre(6, 7, 8)); + EXPECT_THAT(schrier_vector, + ElementsAre(-1, -1, -1, -1, -1, -1, -1, 0, 1, -1, -1)); +} + +TEST(GetSchreierVectorTest, ProjectivePlaneOrderTwo) { + const int n = 7; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 1, 3, 4, 6, 2, 5}})); + generators.push_back(MakePerm(n, {{1, 3}, {2, 4}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 0, 1, 0, 0, 0, 0)); + EXPECT_THAT(orbit, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6)); + + // Now let's see the stabilizer of the point 0. + std::vector> stabilizer; + + stabilizer.push_back(MakePerm(n, {{1, 3}, {2, 4}})); + stabilizer.push_back(MakePerm(n, {{3, 4}, {5, 6}})); + stabilizer.push_back(MakePerm(n, {{3, 5}, {4, 6}})); + + GetSchreierVectorAndOrbit(1, stabilizer, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, -1, 0, 0, 1, 2, 2)); +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/theta_tree_test.cc b/ortools/sat/theta_tree_test.cc new file mode 100644 index 00000000000..3e3ebb89640 --- /dev/null +++ b/ortools/sat/theta_tree_test.cc @@ -0,0 +1,291 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/theta_tree.h" + +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/sat/integer.h" +#include "ortools/util/random_engine.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { +namespace { + +template +class ThetaTreeTest : public ::testing::Test {}; + +using IntegerTypes = ::testing::Types; +TYPED_TEST_SUITE(ThetaTreeTest, IntegerTypes); + +TYPED_TEST(ThetaTreeTest, EnvelopeOfEmptySet) { + ThetaLambdaTree tree; + tree.Reset(0); + EXPECT_EQ(IntegerTypeMinimumValue(), tree.GetEnvelope()); +} + +template +std::vector IntegerTypeVector(std::vector arg) { + return std::vector(arg.begin(), arg.end()); +} + +TYPED_TEST(ThetaTreeTest, Envelope) { + ThetaLambdaTree tree; + std::vector envelope = + IntegerTypeVector({-10, -7, -6, -4, -2}); + std::vector energy = IntegerTypeVector({2, 1, 3, 2, 2}); + tree.Reset(5); + + for (int i = 0; i < 5; i++) { + tree.AddOrUpdateEvent(i, envelope[i], energy[i], energy[i]); + } + EXPECT_EQ(1, tree.GetEnvelope()); // (-7) + (1+3+2+2) or (-6) + (3+2+2) + EXPECT_EQ(2, tree.GetMaxEventWithEnvelopeGreaterThan(TypeParam(0))); + EXPECT_EQ(4, tree.GetMaxEventWithEnvelopeGreaterThan(TypeParam(-1))); + EXPECT_EQ(0, tree.GetEnvelopeOf(0)); + EXPECT_EQ(1, tree.GetEnvelopeOf(1)); + EXPECT_EQ(1, tree.GetEnvelopeOf(2)); + EXPECT_EQ(0, tree.GetEnvelopeOf(3)); + EXPECT_EQ(0, tree.GetEnvelopeOf(4)); +} + +TYPED_TEST(ThetaTreeTest, EnvelopeOpt) { + ThetaLambdaTree tree; + std::vector envelope = + IntegerTypeVector({-10, -7, -6, -4, -2}); + std::vector energy = IntegerTypeVector({2, 1, 3, 3, 2}); + tree.Reset(5); + + int event, optional_event; + TypeParam energy_max; + + tree.AddOrUpdateEvent(0, envelope[0], energy[0], energy[0]); + tree.AddOrUpdateEvent(1, envelope[1], energy[1], energy[1]); + tree.AddOrUpdateEvent(3, envelope[3], TypeParam(0), energy[3]); + tree.AddOrUpdateEvent(4, envelope[4], energy[4], energy[4]); + EXPECT_EQ(1, tree.GetOptionalEnvelope()); + + tree.GetEventsWithOptionalEnvelopeGreaterThan(TypeParam(0), &event, + &optional_event, &energy_max); + EXPECT_EQ(3, event); + EXPECT_EQ(3, optional_event); + EXPECT_EQ(2, energy_max); + + tree.RemoveEvent(4); + tree.AddOrUpdateEvent(2, envelope[2], energy[2], energy[2]); + EXPECT_EQ(0, tree.GetOptionalEnvelope()); + tree.GetEventsWithOptionalEnvelopeGreaterThan(TypeParam(-1), &event, + &optional_event, &energy_max); + EXPECT_EQ(2, event); + EXPECT_EQ(3, optional_event); + EXPECT_EQ(2, energy_max); + EXPECT_EQ(-4, tree.GetEnvelopeOf(0)); + EXPECT_EQ(-3, tree.GetEnvelopeOf(1)); + EXPECT_EQ(-3, tree.GetEnvelopeOf(2)); +} + +TYPED_TEST(ThetaTreeTest, EnvelopeOptWithAddOptional) { + ThetaLambdaTree tree; + std::vector envelope = + IntegerTypeVector({-10, -7, -6, -4, -2}); + std::vector energy = IntegerTypeVector({2, 1, 3, 3, 2}); + tree.Reset(5); + + int event, optional_event; + TypeParam energy_max; + + tree.AddOrUpdateEvent(0, envelope[0], energy[0], energy[0]); + tree.AddOrUpdateEvent(1, envelope[1], energy[1], energy[1]); + tree.AddOrUpdateOptionalEvent(3, envelope[3], energy[3]); + tree.AddOrUpdateEvent(4, envelope[4], energy[4], energy[4]); + EXPECT_EQ(1, tree.GetOptionalEnvelope()); + + tree.GetEventsWithOptionalEnvelopeGreaterThan(TypeParam(0), &event, + &optional_event, &energy_max); + EXPECT_EQ(3, event); + EXPECT_EQ(3, optional_event); + EXPECT_EQ(2, energy_max); + + tree.RemoveEvent(4); + tree.AddOrUpdateEvent(2, envelope[2], energy[2], energy[2]); + EXPECT_EQ(0, tree.GetOptionalEnvelope()); + tree.GetEventsWithOptionalEnvelopeGreaterThan(TypeParam(-1), &event, + &optional_event, &energy_max); + EXPECT_EQ(2, event); + EXPECT_EQ(3, optional_event); + EXPECT_EQ(2, energy_max); + EXPECT_EQ(-4, tree.GetEnvelopeOf(0)); + EXPECT_EQ(-3, tree.GetEnvelopeOf(1)); + EXPECT_EQ(-3, tree.GetEnvelopeOf(2)); +} + +TYPED_TEST(ThetaTreeTest, AddingAndGettingOptionalEvents) { + ThetaLambdaTree tree; + std::vector envelope = + IntegerTypeVector({0, 3, 4, 6, 8}); + std::vector energy = IntegerTypeVector({2, 1, 3, 3, 2}); + tree.Reset(5); + + tree.AddOrUpdateEvent(0, envelope[0], energy[0], energy[0]); + tree.AddOrUpdateEvent(1, envelope[1], energy[1], energy[1]); + EXPECT_EQ(4, tree.GetEnvelope()); + + // Even with 0 energy, standard update takes task 3's envelope into account. + tree.AddOrUpdateEvent(3, envelope[3], TypeParam(0), energy[3]); + EXPECT_EQ(6, tree.GetEnvelope()); + EXPECT_EQ(9, tree.GetOptionalEnvelope()); + tree.RemoveEvent(3); + + // Changing task 3 to optional makes it disappear from GetEnvelope(). + tree.AddOrUpdateOptionalEvent(3, envelope[3], energy[3]); + EXPECT_EQ(4, tree.GetEnvelope()); // Same as before adding task 3. + EXPECT_EQ(9, tree.GetOptionalEnvelope()); + + // Changing task 3 to optional changes its optional values. + tree.AddOrUpdateEvent(3, envelope[3], TypeParam(1), TypeParam(9)); + tree.AddOrUpdateOptionalEvent(3, envelope[3], energy[3]); + EXPECT_EQ(4, tree.GetEnvelope()); + EXPECT_EQ(9, tree.GetOptionalEnvelope()); +} + +TYPED_TEST(ThetaTreeTest, RemoveAndDelayedAddOrUpdateEventTest) { + ThetaLambdaTree tree; + // The tree encoding is tricky, check that RecomputeTreeForDelayedOperations() + // works for all values from a power of two until the next. + for (int num_events = 4; num_events < 8; ++num_events) { + tree.Reset(num_events); + std::vector envelope; + std::vector energy; + // Event start envelope = event, energy min = 2, energy max = 3 + for (int event = 0; event < num_events; ++event) { + envelope.push_back(TypeParam{event}); + energy.push_back(TypeParam{2}); + } + EXPECT_EQ(tree.GetEnvelope(), IntegerTypeMinimumValue()); + EXPECT_EQ(tree.GetOptionalEnvelope(), IntegerTypeMinimumValue()); + // Envelope of events [0, i) is (0) + 2 * i. + for (int event = 0; event < num_events; ++event) { + tree.DelayedAddOrUpdateEvent(event, envelope[event], energy[event], + energy[event] + 1); + tree.RecomputeTreeForDelayedOperations(); + EXPECT_EQ(tree.GetEnvelope(), 2 * (event + 1)); + EXPECT_EQ(tree.GetOptionalEnvelope(), 2 * (event + 1) + 1); + } + // Envelope of events [i, n) is (n-1) + 2 + (n - i) + for (int event = 0; event < num_events; ++event) { + EXPECT_EQ(tree.GetEnvelope(), 2 * num_events - event); + EXPECT_EQ(tree.GetOptionalEnvelope(), 2 * num_events - event + 1); + tree.DelayedRemoveEvent(event); + tree.RecomputeTreeForDelayedOperations(); + } + EXPECT_EQ(tree.GetEnvelope(), IntegerTypeMinimumValue()); + EXPECT_EQ(tree.GetOptionalEnvelope(), IntegerTypeMinimumValue()); + } +} + +TYPED_TEST(ThetaTreeTest, DelayedAddOrUpdateOptionalEventTest) { + ThetaLambdaTree tree; + // The tree encoding is tricky, check that RecomputeTreeForDelayedOperations() + // works for all values from a power of two until the next. + for (int num_events = 4; num_events < 8; ++num_events) { + tree.Reset(num_events); + std::vector envelope; + std::vector energy; + // Event start envelope = event, event energy max = 2. + for (int event = 0; event < num_events; ++event) { + envelope.push_back(TypeParam{event}); + energy.push_back(TypeParam{2}); + } + EXPECT_EQ(tree.GetEnvelope(), IntegerTypeMinimumValue()); + EXPECT_EQ(tree.GetOptionalEnvelope(), IntegerTypeMinimumValue()); + // Optional envelope of events [0, i) is i + 2. + for (int event = 0; event < num_events; ++event) { + tree.DelayedAddOrUpdateOptionalEvent(event, envelope[event], + energy[event]); + tree.RecomputeTreeForDelayedOperations(); + EXPECT_EQ(tree.GetEnvelope(), IntegerTypeMinimumValue()); + EXPECT_EQ(tree.GetOptionalEnvelope(), event + 2); + } + } +} + +static void BM_update(benchmark::State& state) { + random_engine_t random_; + const int size = state.range(0); + const int num_updates = 4 * size; + ThetaLambdaTree tree; + std::uniform_int_distribution event_dist(0, size - 1); + std::uniform_int_distribution enveloppe_dist(-10000, 10000); + std::uniform_int_distribution energy_dist(0, 10000); + for (auto _ : state) { + tree.Reset(size); + for (int i = 0; i < num_updates; ++i) { + const int event = event_dist(random_); + const IntegerValue enveloppe(enveloppe_dist(random_)); + const IntegerValue energy1(energy_dist(random_)); + const IntegerValue energy2(energy_dist(random_)); + tree.AddOrUpdateEvent(event, enveloppe, std::min(energy1, energy2), + std::max(energy1, energy2)); + } + } + // Number of updates. + state.SetBytesProcessed(static_cast(state.iterations()) * + num_updates); +} + +// Note that we didn't pick only power of two +BENCHMARK(BM_update)->Arg(10)->Arg(20)->Arg(64)->Arg(100)->Arg(256)->Arg(800); + +static void BM_delayed_update(benchmark::State& state) { + random_engine_t random_; + const int size = state.range(0); + const int num_updates = 4 * size; + ThetaLambdaTree tree; + std::uniform_int_distribution event_dist(0, size - 1); + std::uniform_int_distribution enveloppe_dist(-10000, 10000); + std::uniform_int_distribution energy_dist(0, 10000); + for (auto _ : state) { + tree.Reset(size); + for (int i = 0; i < num_updates; ++i) { + const int event = event_dist(random_); + const IntegerValue enveloppe(enveloppe_dist(random_)); + const IntegerValue energy1(energy_dist(random_)); + const IntegerValue energy2(energy_dist(random_)); + tree.DelayedAddOrUpdateEvent(event, enveloppe, std::min(energy1, energy2), + std::max(energy1, energy2)); + } + tree.RecomputeTreeForDelayedOperations(); + } + // Number of updates. + state.SetBytesProcessed(static_cast(state.iterations()) * + num_updates); +} + +// Note that we didn't pick only power of two +BENCHMARK(BM_delayed_update) + ->Arg(10) + ->Arg(20) + ->Arg(64) + ->Arg(100) + ->Arg(256) + ->Arg(800); + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/timetable_test.cc b/ortools/sat/timetable_test.cc new file mode 100644 index 00000000000..c8999baeddf --- /dev/null +++ b/ortools/sat/timetable_test.cc @@ -0,0 +1,555 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/timetable.h" + +#include + +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/log/check.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/logging.h" +#include "ortools/sat/all_different.h" +#include "ortools/sat/cumulative.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +struct CumulativeTasks { + int min_duration; + int min_demand; + int min_start; + int max_end; +}; + +struct Task { + int min_start; + int max_end; +}; + +bool TestTimeTablingPropagation(absl::Span tasks, + absl::Span expected, int capacity) { + Model model; + IntegerTrail* integer_trail = model.GetOrCreate(); + PrecedencesPropagator* precedences = + model.GetOrCreate(); + IntervalsRepository* intervals = model.GetOrCreate(); + + const int num_tasks = tasks.size(); + std::vector interval_vars(num_tasks); + std::vector start_exprs(num_tasks); + std::vector duration_exprs(num_tasks); + std::vector end_exprs(num_tasks); + std::vector demands(num_tasks); + const AffineExpression capacity_expr = + AffineExpression(IntegerValue(capacity)); + + const int kStart(0); + const int kHorizon(10000); + + for (int t = 0; t < num_tasks; ++t) { + const CumulativeTasks& task = tasks[t]; + // Build the task variables. + interval_vars[t] = + model.Add(NewInterval(kStart, kHorizon, task.min_duration)); + start_exprs[t] = intervals->Start(interval_vars[t]); + end_exprs[t] = intervals->End(interval_vars[t]); + demands[t] = AffineExpression(IntegerValue(task.min_demand)); + + // Set task initial minimum starting time. + std::vector no_literal_reason; + std::vector no_integer_reason; + EXPECT_TRUE( + integer_trail->Enqueue(start_exprs[t].GreaterOrEqual(task.min_start), + no_literal_reason, no_integer_reason)); + // Set task initial maximum ending time. + EXPECT_TRUE(integer_trail->Enqueue(end_exprs[t].LowerOrEqual(task.max_end), + no_literal_reason, no_integer_reason)); + } + + // Propagate properly the other bounds of the intervals. + EXPECT_TRUE(precedences->Propagate()); + + SchedulingConstraintHelper* helper = model.TakeOwnership( + new SchedulingConstraintHelper(interval_vars, &model)); + SchedulingDemandHelper* demands_helper = + model.TakeOwnership(new SchedulingDemandHelper(demands, helper, &model)); + + // Propagator responsible for filtering start variables. + TimeTablingPerTask timetabling(capacity_expr, helper, demands_helper, &model); + timetabling.RegisterWith(model.GetOrCreate()); + + // Check initial satisfiability + if (!model.GetOrCreate()->Propagate()) return false; + + // Check consistency of data. + CHECK_EQ(num_tasks, expected.size()); + + for (int t = 0; t < num_tasks; ++t) { + // Check starting time. + EXPECT_EQ(expected[t].min_start, integer_trail->LowerBound(start_exprs[t])) + << "task #" << t; + // Check ending time. + EXPECT_EQ(expected[t].max_end, integer_trail->UpperBound(end_exprs[t])) + << "task #" << t; + } + return true; +} + +// This is an infeasible instance on which the edge finder finds nothing. +// Cumulative Time Table finds the contradiction. +TEST(TimeTablingPropagation, UNSAT) { + EXPECT_FALSE(TestTimeTablingPropagation({{3, 2, 0, 4}, {3, 2, 1, 5}}, {}, 3)); +} + +// This is an instance on Time Table pushes a task. +TEST(TimeTablingPropagation, TimeTablePush1) { + EXPECT_TRUE(TestTimeTablingPropagation({{1, 2, 1, 2}, {3, 2, 0, 10}}, + {{1, 2}, {2, 10}}, 3)); +} + +// This is an instance on Time Table pushes a task. +TEST(TimeTablingPropagation, TimeTablePush2) { + EXPECT_TRUE( + TestTimeTablingPropagation({{1, 2, 1, 2}, {1, 2, 3, 4}, {3, 2, 0, 10}}, + {{1, 2}, {3, 4}, {4, 10}}, 3)); +} + +// This is an instance on which Time Table pushes a task. +// Here the two first tasks have the following profile: +// usage ^ +// 2 | ** +// 1 | **--** +// 0 |**------******************> time +// 0 1 2 3 4 5 6 +// The interval [2, 3] has a profile too high to accommodate the third task. +TEST(TimeTablingPropagation, TimeTablePush3) { + EXPECT_TRUE( + TestTimeTablingPropagation({{3, 1, 0, 4}, {3, 1, 1, 5}, {3, 2, 0, 10}}, + {{0, 4}, {1, 5}, {3, 10}}, 3)); +} + +// This is an instance on which Time Table pushes a task. +// Similar to TimeTablePush3, but the two small tasks have the same profile. +TEST(TimeTablingPropagation, TimeTablePush4) { + EXPECT_TRUE( + TestTimeTablingPropagation({{4, 1, 0, 5}, {3, 1, 1, 4}, {3, 2, 0, 10}}, + {{0, 5}, {1, 4}, {4, 10}}, 3)); +} + +// Regression test: there used to be a bug when no profile delta corresponded +// to the start time of a task. +TEST(TimeTablingPropagation, RegressionTest) { + EXPECT_TRUE(TestTimeTablingPropagation({{3, 1, 0, 3}, {2, 1, 2, 5}}, + {{0, 3}, {3, 5}}, 1)); +} + +// Regression test: there used to be a bug that caused Timetabling to stop +// before reaching its fixed-point. +TEST(TimeTablingPropagation, FixedPoint) { + EXPECT_TRUE(TestTimeTablingPropagation( + {{1, 1, 0, 1}, {4, 1, 0, 8}, {2, 1, 1, 5}, {1, 1, 1, 5}}, + {{0, 1}, {3, 8}, {1, 4}, {1, 4}}, 1)); +} + +// Regression test: there used to be a bug when two back to back +// tasks were exceeding the capacity in the partial sum. +TEST(TimeTablingPropagation, PartialSumBug) { + EXPECT_TRUE(TestTimeTablingPropagation({{510, 142, 0, 510}, + {268, 130, 242, 510}, + {74, 147, 510, 584}, + {197, 204, 584, 781}, + {72, 138, 781, 853}, + {170, 231, 853, 1023}, + {181, 131, 1023, 1204}}, + {{0, 510}, + {242, 510}, + {510, 584}, + {584, 781}, + {781, 853}, + {853, 1023}, + {1023, 1204}}, + 315)); +} + +// TODO(user): build automatic FindAll tests for the cumulative constraint. +// Test that we find all the solutions. +TEST(TimeTablingSolve, FindAll) { + // Instance. + const std::vector durations = {1, 2, 3, 3, 3, 3}; + const std::vector demands = {1, 1, 1, 1, 4, 4}; + const int capacity = 4; + const int horizon = 11; + + Model model; + std::vector intervals(durations.size()); + std::vector demand_exprs(durations.size()); + const AffineExpression capacity_expr = + AffineExpression(IntegerValue(capacity)); + + for (int i = 0; i < durations.size(); ++i) { + intervals[i] = model.Add(NewInterval(0, horizon, durations[i])); + demand_exprs[i] = AffineExpression(IntegerValue(demands[i])); + } + + model.Add(Cumulative(intervals, demand_exprs, capacity_expr)); + + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::vector solution(durations.size()); + for (int i = 0; i < intervals.size(); ++i) { + solution[i] = + integer_trail->LowerBound(repository->Start(intervals[i])).value(); + } + num_solutions_found++; + LOG(INFO) << "Found solution: {" << absl::StrJoin(solution, ", ") << "}."; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // Test that we have the right number of solutions. + EXPECT_EQ(num_solutions_found, 2040); +} + +TEST(TimeTablingSolve, FindAllWithVaryingCapacity) { + // Instance. + const std::vector durations = {1, 2, 3}; + const std::vector demands = {1, 2, 3}; + const int horizon = 6; + + // Collect the number of solution for each capacity value. + int sum = 0; + for (const int capacity : {3, 4, 5}) { + Model model; + std::vector intervals(durations.size()); + std::vector demand_exprs(durations.size()); + const AffineExpression capacity_expr = + AffineExpression(IntegerValue(capacity)); + + for (int i = 0; i < durations.size(); ++i) { + intervals[i] = model.Add(NewInterval(0, horizon, durations[i])); + demand_exprs[i] = AffineExpression(IntegerValue(demands[i])); + } + + model.Add(Cumulative(intervals, demand_exprs, capacity_expr)); + + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::vector solution(durations.size()); + for (int i = 0; i < intervals.size(); ++i) { + solution[i] = + integer_trail->LowerBound(repository->Start(intervals[i])).value(); + } + num_solutions_found++; + LOG(INFO) << "Found solution: {" << absl::StrJoin(solution, ", ") << "}."; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + LOG(INFO) << "capacity: " << capacity + << " num_solutions: " << num_solutions_found; + sum += num_solutions_found; + } + + // Now solve with a varying capacity. + Model model; + std::vector intervals(durations.size()); + std::vector demand_exprs(durations.size()); + const AffineExpression capacity_expr = + AffineExpression(model.Add(NewIntegerVariable(0, 5))); + + for (int i = 0; i < durations.size(); ++i) { + intervals[i] = model.Add(NewInterval(0, horizon, durations[i])); + demand_exprs[i] = AffineExpression(IntegerValue(demands[i])); + } + + model.Add(Cumulative(intervals, demand_exprs, capacity_expr)); + + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::vector solution(durations.size()); + for (int i = 0; i < intervals.size(); ++i) { + solution[i] = + integer_trail->LowerBound(repository->Start(intervals[i])).value(); + } + num_solutions_found++; + LOG(INFO) << "Found solution: {" << absl::StrJoin(solution, ", ") << "}."; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // Test that we have the right number of solutions. + EXPECT_EQ(num_solutions_found, sum); +} + +TEST(TimeTablingSolve, FindAllWithOptionals) { + // Instance. + // Up to two tasks can be scheduled at the same time. + const std::vector durations = {3, 3, 3}; + const std::vector demands = {2, 2, 2}; + const int capacity = 5; + const int horizon = 3; + const int num_solutions = 7; + + Model model; + std::vector intervals(durations.size()); + std::vector demand_exprs(durations.size()); + std::vector is_present_literals(durations.size()); + const AffineExpression capacity_expr = + AffineExpression(IntegerValue(capacity)); + + for (int i = 0; i < durations.size(); ++i) { + is_present_literals[i] = Literal(model.Add(NewBooleanVariable()), true); + intervals[i] = model.Add( + NewOptionalInterval(0, horizon, durations[i], is_present_literals[i])); + demand_exprs[i] = AffineExpression(IntegerValue(demands[i])); + } + + model.Add(Cumulative(intervals, demand_exprs, capacity_expr)); + + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::vector solution(durations.size()); + for (int i = 0; i < intervals.size(); ++i) { + if (model.Get(Value(is_present_literals[i]))) { + solution[i] = + integer_trail->LowerBound(repository->Start(intervals[i])).value(); + } else { + solution[i] = -1; + } + } + num_solutions_found++; + LOG(INFO) << "Found solution: {" << absl::StrJoin(solution, ", ") << "}."; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // Test that we have the right number of solutions. + EXPECT_EQ(num_solutions_found, num_solutions); +} + +// This construct a reservoir corresponding to a well behaved parenthesis +// sequence. +TEST(ReservoirTest, FindAllParenthesis) { + const int n = 3; + const int size = 2 * n; + + Model model; + std::vector vars(size); + std::vector times(size); + std::vector deltas(size); + for (int i = 0; i < size; ++i) { + vars[i] = model.Add(NewIntegerVariable(0, size - 1)); + times[i] = vars[i]; + deltas[i] = IntegerValue((i % 2 == 1) ? -1 : 1); + } + const Literal true_lit = + model.GetOrCreate()->GetTrueLiteral(); + std::vector all_true(size, true_lit); + + model.Add(AllDifferentOnBounds(vars)); + AddReservoirConstraint(times, deltas, all_true, 0, size, &model); + + absl::btree_map sequence_to_count; + int num_solutions_found = 0; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::string parenthesis_sequence(size, ' '); + for (int i = 0; i < size; ++i) { + const int v = model.Get(Value(vars[i])); + parenthesis_sequence[v] = (i % 2 == 0) ? '(' : ')'; + } + sequence_to_count[parenthesis_sequence]++; + num_solutions_found++; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // To help debug the code. + for (const auto entry : sequence_to_count) { + LOG(INFO) << entry.first << " : " << entry.second; + } + LOG(INFO) << "decisions: " << model.GetOrCreate()->num_branches(); + LOG(INFO) << "conflicts: " << model.GetOrCreate()->num_failures(); + + // Test that we have the right number of solutions. + // + // The catalan number n, which is 5 for n equal five, count the number of well + // formed parathesis sequence. But we have to multiply this by the permutation + // for the open and closing parenthesis that are matched to their positions: + // n!. + EXPECT_EQ(num_solutions_found, 5 * 6 * 6); +} + +// Now some might be absent. +TEST(ReservoirTest, FindAllParenthesisWithOptionality) { + const int n = 2; + const int size = 2 * n; + + Model model; + std::vector vars(size); + std::vector times(size); + std::vector deltas(size); + std::vector present(size); + for (int i = 0; i < size; ++i) { + vars[i] = model.Add(NewIntegerVariable(0, size - 1)); + times[i] = vars[i]; + deltas[i] = IntegerValue((i % 2 == 1) ? -1 : 1); + present[i] = Literal(model.Add(NewBooleanVariable()), true); + } + + model.Add(AllDifferentOnBounds(vars)); + AddReservoirConstraint(times, deltas, present, 0, size, &model); + + absl::btree_map sequence_to_count; + int num_solutions_found = 0; + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + std::string parenthesis_sequence(size, '_'); + for (int i = 0; i < size; ++i) { + if (model.Get(Value(present[i])) == 0) continue; + const int v = model.Get(Value(vars[i])); + parenthesis_sequence[v] = (i % 2 == 0) ? '(' : ')'; + } + sequence_to_count[parenthesis_sequence]++; + num_solutions_found++; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // To help debug the code. + for (const auto entry : sequence_to_count) { + LOG(INFO) << entry.first << " : " << entry.second; + } + LOG(INFO) << "decisions: " << model.GetOrCreate()->num_branches(); + LOG(INFO) << "conflicts: " << model.GetOrCreate()->num_failures(); + + // Test that we have the right number of solutions. + EXPECT_EQ(num_solutions_found, 184); +} + +// Enumerate all fixed sequence of [-1, +1] with a partial sum >= 0 and <= 1. +TEST(ReservoirTest, VariableLevelChange) { + Model model; + const int size = 8; + std::vector times(size); + std::vector deltas(size); + for (int i = 0; i < size; ++i) { + times[i] = IntegerValue(i); + deltas[i] = model.Add(NewIntegerVariable(-1, 1)); + } + const Literal true_lit = + model.GetOrCreate()->GetTrueLiteral(); + std::vector all_true(size, true_lit); + + const int min_level = 0; + const int max_level = 1; + AddReservoirConstraint(times, deltas, all_true, min_level, max_level, &model); + + absl::btree_map sequence_to_count; + int num_solutions_found = 0; + auto* integer_trail = model.GetOrCreate(); + while (true) { + const SatSolver::Status status = + SolveIntegerProblemWithLazyEncoding(&model); + if (status != SatSolver::Status::FEASIBLE) break; + + // Add the solution. + // Test that it is a valid one. + int sum = 0; + std::vector values; + for (int i = 0; i < size; ++i) { + values.push_back(integer_trail->LowerBound(deltas[i]).value()); + sum += values.back(); + EXPECT_GE(sum, min_level); + EXPECT_LE(sum, max_level); + } + sequence_to_count[absl::StrJoin(values, ",")]++; + num_solutions_found++; + + // Loop to the next solution. + model.Add(ExcludeCurrentSolutionAndBacktrack()); + } + + // To help debug the code. + for (const auto entry : sequence_to_count) { + LOG(INFO) << entry.first << " : " << entry.second; + } + LOG(INFO) << "decisions: " << model.GetOrCreate()->num_branches(); + LOG(INFO) << "conflicts: " << model.GetOrCreate()->num_failures(); + + // Test that we have the right number of solutions. + // For each subset of non-zero position, the value are fixed, it must + // be an alternating sequence starting at 1. + EXPECT_EQ(num_solutions_found, 1 << size); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/zero_half_cuts_test.cc b/ortools/sat/zero_half_cuts_test.cc new file mode 100644 index 00000000000..6aea31cdcdc --- /dev/null +++ b/ortools/sat/zero_half_cuts_test.cc @@ -0,0 +1,114 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/zero_half_cuts.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/lp_data/lp_types.h" +#include "ortools/sat/integer.h" + +namespace operations_research { +namespace sat { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(SymmetricDifferenceTest, BasicExample) { + ZeroHalfCutHelper helper; + std::vector a = {2, 1, 4}; + std::vector b = {4, 3, 2, 7}; + helper.Reset(10); + helper.SymmetricDifference(a, &b); + EXPECT_THAT(b, ElementsAre(3, 7, 1)); +} + +TEST(SymmetricDifferenceTest, BasicExample2) { + ZeroHalfCutHelper helper; + std::vector a = {2, 1, 4}; + std::vector b = {}; + helper.Reset(10); + helper.SymmetricDifference(a, &b); + EXPECT_THAT(b, ElementsAre(2, 1, 4)); +} + +TEST(EliminateVarUsingRowTest, BasicExample) { + // We need to construct a binary matrix for this test. + ZeroHalfCutHelper helper; + helper.ProcessVariables({0.0, 0.0, 0.0, 0.0, 0.12, 0.0, 0.0, 0.0, 0.0}, + std::vector(9, IntegerValue(0)), + std::vector(9, IntegerValue(1))); + helper.AddBinaryRow({{{glop::RowIndex(1), IntegerValue(1)}}, + {0, 2, 3, 4, 7}, + /*rhs*/ 1, + /*slack*/ 0.1}); + helper.AddBinaryRow({{{glop::RowIndex(2), IntegerValue(1)}}, + {0, 2, 3, 4, 7}, + /*rhs*/ 0, + /*slack*/ 0.0}); + helper.AddBinaryRow({{{glop::RowIndex(1), IntegerValue(1)}, + {glop::RowIndex(3), IntegerValue(1)}}, + {0, 5, 4, 8}, + /*rhs*/ 1, + /*slack*/ 0.0}); + + typedef std::vector> MultiplierType; + typedef std::vector VectorType; + + // Let use row with index 2 to eliminate the variable 4. + helper.EliminateVarUsingRow(4, 2); + + // The multipliers, cols and parity behave like a xor. + EXPECT_EQ(helper.MatrixRow(0).multipliers, + MultiplierType({{glop::RowIndex(3), IntegerValue(1)}})); + EXPECT_EQ(helper.MatrixRow(0).cols, VectorType({2, 3, 7, 5, 8})); + EXPECT_EQ(helper.MatrixRow(0).rhs_parity, 0); + EXPECT_EQ(helper.MatrixRow(0).slack, 0.1); + + EXPECT_EQ(helper.MatrixRow(1).multipliers, + MultiplierType({{glop::RowIndex(1), IntegerValue(1)}, + {glop::RowIndex(2), IntegerValue(1)}, + {glop::RowIndex(3), IntegerValue(1)}})); + EXPECT_EQ(helper.MatrixRow(1).cols, VectorType({2, 3, 7, 5, 8})); + EXPECT_EQ(helper.MatrixRow(1).rhs_parity, 1); + EXPECT_EQ(helper.MatrixRow(1).slack, 0.0); + + // The column is eliminated like a singleton column and the lp value become + // the slack. + EXPECT_EQ(helper.MatrixRow(2).multipliers, + MultiplierType({{glop::RowIndex(1), IntegerValue(1)}, + {glop::RowIndex(3), IntegerValue(1)}})); + EXPECT_EQ(helper.MatrixRow(2).cols, VectorType({5, 8})); + EXPECT_EQ(helper.MatrixRow(2).rhs_parity, 1); + EXPECT_EQ(helper.MatrixRow(2).slack, 0.12); + + // The transposed information is up to date. + EXPECT_THAT(helper.MatrixCol(0), IsEmpty()); + EXPECT_THAT(helper.MatrixCol(1), IsEmpty()); + EXPECT_THAT(helper.MatrixCol(2), UnorderedElementsAre(0, 1)); + EXPECT_THAT(helper.MatrixCol(3), UnorderedElementsAre(0, 1)); + EXPECT_THAT(helper.MatrixCol(4), IsEmpty()); + EXPECT_THAT(helper.MatrixCol(5), UnorderedElementsAre(0, 1, 2)); + EXPECT_THAT(helper.MatrixCol(6), IsEmpty()); + EXPECT_THAT(helper.MatrixCol(7), UnorderedElementsAre(0, 1)); + EXPECT_THAT(helper.MatrixCol(8), UnorderedElementsAre(0, 1, 2)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/util/BUILD.bazel b/ortools/util/BUILD.bazel index 38c67dd3ad6..420738a84b7 100644 --- a/ortools/util/BUILD.bazel +++ b/ortools/util/BUILD.bazel @@ -130,6 +130,7 @@ cc_library( ":saturated_arithmetic", "//ortools/base", "//ortools/base:dump_vars", + "//ortools/base:mathutil", "//ortools/base:types", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/strings", diff --git a/ortools/util/bitset.h b/ortools/util/bitset.h index 7f2f3a4932a..9b0636a328f 100644 --- a/ortools/util/bitset.h +++ b/ortools/util/bitset.h @@ -457,10 +457,6 @@ class Bitset64 { : size_(Value(size) > 0 ? size : IndexType(0)), data_(BitLength64(Value(size_))) {} - // This type is neither copyable nor movable. - Bitset64(const Bitset64&) = delete; - Bitset64& operator=(const Bitset64&) = delete; - ConstView const_view() const { return ConstView(this); } View view() { return View(this); } @@ -548,6 +544,7 @@ class Bitset64 { void Set(IndexType i) { DCHECK_GE(Value(i), 0); DCHECK_LT(Value(i), size_); + // The c++ hardening is costly here, so we disable it. data_[BitOffset64(Value(i))] |= OneBit64(BitPos64(Value(i))); } @@ -603,6 +600,19 @@ class Bitset64 { } } + // This one assume both given bitset to be of the same size. + void SetToIntersectionOf(const Bitset64& a, + const Bitset64& b) { + DCHECK_EQ(a.size(), b.size()); + Resize(a.size()); + + // Copy buckets. + const int num_buckets = a.data_.size(); + for (int i = 0; i < num_buckets; ++i) { + data_[i] = a.data_[i] & b.data_[i]; + } + } + // Sets "this" to be the union of "this" and "other". The // bitsets do not have to be the same size. If other is smaller, all // the higher order bits are assumed to be 0. @@ -871,10 +881,14 @@ class SparseBitset { to_clear_.push_back(index); } } - void SetUnsafe(IntegerType index) { - bitset_.Set(index); + + // A bit hacky for really hot loop. + typename Bitset64::View BitsetView() { return bitset_.view(); } + void SetUnsafe(typename Bitset64::View view, IntegerType index) { + view.Set(index); to_clear_.push_back(index); } + void Clear(IntegerType index) { bitset_.Clear(index); } int NumberOfSetCallsWithDifferentArguments() const { return to_clear_.size(); diff --git a/ortools/util/fp_utils.h b/ortools/util/fp_utils.h index 34e0a0d8e70..88e962e1e1a 100644 --- a/ortools/util/fp_utils.h +++ b/ortools/util/fp_utils.h @@ -92,10 +92,16 @@ class ScopedFloatingPointEnv { fenv_.__control &= ~excepts; #elif (defined(__FreeBSD__) || defined(__OpenBSD__)) fenv_.__x87.__control &= ~excepts; +#elif defined(__NetBSD__) + fenv_.x87.control &= ~excepts; #else // Linux fenv_.__control_word &= ~excepts; #endif +#if defined(__NetBSD__) + fenv_.mxcsr &= ~(excepts << 7); +#else fenv_.__mxcsr &= ~(excepts << 7); +#endif CHECK_EQ(0, fesetenv(&fenv_)); #endif } diff --git a/ortools/util/integer_pq.h b/ortools/util/integer_pq.h index 7ae182f3724..1d50e4d7db4 100644 --- a/ortools/util/integer_pq.h +++ b/ortools/util/integer_pq.h @@ -129,8 +129,8 @@ class IntegerPriorityQueue { private: // Puts the given element at heap index i. - void Set(int i, Element element) { - heap_[i] = element; + void Set(Element* heap, int i, Element element) { + heap[i] = element; position_[element.Index()] = i; } @@ -139,44 +139,46 @@ class IntegerPriorityQueue { // this position. void SetAndDecreasePriority(int i, const Element element) { const int size = size_; + Element* heap = heap_.data(); while (true) { const int left = i * 2; const int right = left + 1; if (right > size) { if (left > size) break; - const Element left_element = heap_[left]; + const Element left_element = heap[left]; if (!less_(element, left_element)) break; - Set(i, left_element); + Set(heap, i, left_element); i = left; break; } - const Element left_element = heap_[left]; - const Element right_element = heap_[right]; + const Element left_element = heap[left]; + const Element right_element = heap[right]; if (less_(left_element, right_element)) { if (!less_(element, right_element)) break; - Set(i, right_element); + Set(heap, i, right_element); i = right; } else { if (!less_(element, left_element)) break; - Set(i, left_element); + Set(heap, i, left_element); i = left; } } - Set(i, element); + Set(heap, i, element); } // Puts the given element at heap index i and update the heap knowing that the // element has a priority >= than the priority of the element currently at // this position. void SetAndIncreasePriority(int i, const Element element) { + Element* heap = heap_.data(); while (i > 1) { const int parent = i >> 1; - const Element parent_element = heap_[parent]; + const Element parent_element = heap[parent]; if (!less_(parent_element, element)) break; - Set(i, parent_element); + Set(heap, i, parent_element); i = parent; } - Set(i, element); + Set(heap, i, element); } int size_; diff --git a/ortools/util/piecewise_linear_function.cc b/ortools/util/piecewise_linear_function.cc index e691f9ae4a8..23bc4eda239 100644 --- a/ortools/util/piecewise_linear_function.cc +++ b/ortools/util/piecewise_linear_function.cc @@ -14,15 +14,22 @@ #include "ortools/util/piecewise_linear_function.h" #include +#include #include -#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/btree_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "ortools/base/dump_vars.h" #include "ortools/base/logging.h" +#include "ortools/base/mathutil.h" +#include "ortools/base/types.h" #include "ortools/util/saturated_arithmetic.h" namespace operations_research { @@ -78,6 +85,7 @@ uint64_t UnsignedCapProd(uint64_t left, uint64_t right) { } } // namespace +// PiecewiseSegment PiecewiseSegment::PiecewiseSegment(int64_t point_x, int64_t point_y, int64_t slope, int64_t other_point_x) : slope_(slope), reference_x_(point_x), reference_y_(point_y) { @@ -269,6 +277,7 @@ std::string PiecewiseSegment::DebugString() const { return result; } +// PiecewiseLinearFunction const int PiecewiseLinearFunction::kNotFound = -1; PiecewiseLinearFunction::PiecewiseLinearFunction( @@ -800,4 +809,60 @@ bool PiecewiseLinearFunction::IsNonIncreasingInternal() const { return true; } +// FloatSlopePiecewiseLinearFunction +const int FloatSlopePiecewiseLinearFunction::kNoValue = -1; + +FloatSlopePiecewiseLinearFunction::FloatSlopePiecewiseLinearFunction( + absl::InlinedVector x_anchors, + absl::InlinedVector y_anchors) + : x_anchors_(std::move(x_anchors)), y_anchors_(std::move(y_anchors)) { + DCHECK(absl::c_is_sorted(x_anchors_)); + DCHECK_EQ(x_anchors_.size(), y_anchors_.size()); + DCHECK_NE(x_anchors_.size(), 1); +} + +std::string FloatSlopePiecewiseLinearFunction::DebugString( + absl::string_view line_prefix) const { + if (x_anchors_.size() <= 10) { + return "{ " + DUMP_VARS(x_anchors_, y_anchors_).str() + "}"; + } + return absl::StrFormat("{\n%s%s\n%s%s\n}", line_prefix, + DUMP_VARS(x_anchors_).str(), line_prefix, + DUMP_VARS(y_anchors_).str()); +} + +int64_t FloatSlopePiecewiseLinearFunction::ComputeInBoundsValue( + int64_t x) const { + const int segment_index = GetSegmentIndex(x); + if (segment_index == kNoValue) return kNoValue; + return GetValueOnSegment(x, segment_index); +} + +int64_t FloatSlopePiecewiseLinearFunction::ComputeConvexValue(int64_t x) const { + if (x_anchors_.empty()) return kNoValue; + + int segment_index = kNoValue; + if (x <= x_anchors_[0]) { + segment_index = 0; + } else if (x >= x_anchors_.back()) { + segment_index = x_anchors_.size() - 2; + } else { + segment_index = GetSegmentIndex(x); + } + + return GetValueOnSegment(x, segment_index); +} + +int64_t FloatSlopePiecewiseLinearFunction::GetValueOnSegment( + int64_t x, int segment_index) const { + DCHECK_GE(segment_index, 0); + DCHECK_LE(segment_index, x_anchors_.size() - 2); + const double slope = + static_cast(y_anchors_[segment_index + 1] - + y_anchors_[segment_index]) / + (x_anchors_[segment_index + 1] - x_anchors_[segment_index]); + return MathUtil::Round(slope * (x - x_anchors_[segment_index]) + + y_anchors_[segment_index]); +} + } // namespace operations_research diff --git a/ortools/util/piecewise_linear_function.h b/ortools/util/piecewise_linear_function.h index 26af5778bf4..b076baddd82 100644 --- a/ortools/util/piecewise_linear_function.h +++ b/ortools/util/piecewise_linear_function.h @@ -22,10 +22,16 @@ #include #include +#include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" + namespace operations_research { // This structure stores one straight line. It contains the start point, the // end point and the slope. @@ -268,5 +274,86 @@ class PiecewiseLinearFunction { bool is_non_decreasing_; bool is_non_increasing_; }; + +// The following class defines a piecewise linear formulation with potential +// double values for the slope of each linear function. +// This formulation is meant to be used with a small number of segments (see +// InlinedVector sizes below). +// These segments are determined by int64_t values for the "anchor" x and y +// values, such that (x_anchors_[i], y_anchors_[i]) and +// (x_anchors_[i+1], y_anchors_[i+1]) are respectively the start and end point +// of the i-th segment. +// TODO(user): Adjust the inlined vector sizes based on experiments. +class FloatSlopePiecewiseLinearFunction { + public: + static const int kNoValue; + + FloatSlopePiecewiseLinearFunction() = default; + FloatSlopePiecewiseLinearFunction(absl::InlinedVector x_anchors, + absl::InlinedVector y_anchors); + FloatSlopePiecewiseLinearFunction( + FloatSlopePiecewiseLinearFunction&& other) noexcept { + *this = std::move(other); + } + + FloatSlopePiecewiseLinearFunction& operator=( + FloatSlopePiecewiseLinearFunction&& other) noexcept { + x_anchors_ = std::move(other.x_anchors_); + y_anchors_ = std::move(other.y_anchors_); + return *this; + } + + std::string DebugString(absl::string_view line_prefix = {}) const; + + const absl::InlinedVector& x_anchors() const { + return x_anchors_; + } + const absl::InlinedVector& y_anchors() const { + return y_anchors_; + } + + // Computes the y value associated to 'x'. Returns kNoValue if 'x' is out of + // bounds, i.e. lower than the first x_anchor and largest than the last. + int64_t ComputeInBoundsValue(int64_t x) const; + + // Computes the y value associated to 'x'. Unlike ComputeInBoundsValue(), if + // 'x' is outside the bounds of the function, the function will still be + // defined by its outer segments. + int64_t ComputeConvexValue(int64_t x) const; + + private: + // Returns the index of the segment x belongs to, i.e. the index i such that + // x_anchors_[i] ≤ x < x_anchors_[i+1]. For x = x_anchors_.back(), also + // returns the last segment (i.e. x_anchors_.size() - 2). + // Returns kNoValue if x is out of bounds for the function. + int GetSegmentIndex(int64_t x) const { + if (x_anchors_.empty() || x < x_anchors_[0] || x > x_anchors_.back()) { + return kNoValue; + } + if (x == x_anchors_.back()) return x_anchors_.size() - 2; + + // Search for first element xi such that xi > x. + const auto upper_segment = absl::c_upper_bound(x_anchors_, x); + const int segment_index = + std::distance(x_anchors_.begin(), upper_segment) - 1; + DCHECK_GE(segment_index, 0); + DCHECK_LE(segment_index, x_anchors_.size() - 2); + return segment_index; + } + + // Returns the value of 'x' on the linear segment determined by + // x_anchors_[segment_index] and x_anchors_[segment_index + 1]. + int64_t GetValueOnSegment(int64_t x, int segment_index) const; + + // The set of *increasing* anchor cumul values for the interpolation. + absl::InlinedVector x_anchors_; + // The y values used for the interpolation: + // For any x anchor value, let i be an index such that + // x_anchors[i] ≤ x < x_anchors[i+1], then the y value for x is + // y_anchors[i] * (1-λ) + y_anchors[i+1] * λ, with + // λ = (x - x_anchors[i]) / (x_anchors[i+1] - x_anchors[i]). + absl::InlinedVector y_anchors_; +}; + } // namespace operations_research #endif // OR_TOOLS_UTIL_PIECEWISE_LINEAR_FUNCTION_H_ diff --git a/ortools/util/range_minimum_query.h b/ortools/util/range_minimum_query.h index 7b29576f974..8ea3925c0ac 100644 --- a/ortools/util/range_minimum_query.h +++ b/ortools/util/range_minimum_query.h @@ -11,24 +11,64 @@ // See the License for the specific language governing permissions and // limitations under the License. -// We use the notation min(arr, i, j) for the minimum arr[x] such that i <= x -// and x < j. -// Range Minimum Query (RMQ) is a data structure preprocessing an array arr so -// that querying min(arr, i, j) takes O(1) time. The preprocessing takes -// O(n*log(n)) time and memory. - -// Note: There exists an O(n) preprocessing algorithm, but it is considerably -// more involved and the hidden constants behind it are much higher. +// The range minimum query problem is a range query problem where queries ask +// for the minimum of all elements in ranges of the array. +// The problem is divided into two phases: +// - precomputation: the data structure is given an array A of n elements. +// - query: the data structure must answer queries min(A, begin, end), +// where min(A, begin, end) = min_{i in [begin, end)} A[i]. +// This file has an implementation of the sparse table approach to solving the +// problem, for which the precomputation takes O(n*log(n)) time and memory, +// and further queries take O(1) time. +// Reference: https://en.wikipedia.org/wiki/Range_minimum_query. // -// The algorithms are well explained in Wikipedia: -// https://en.wikipedia.org/wiki/Range_minimum_query. +// The data structure allows to have multiple arrays at the same time, and +// to reset the arrays. // +// Usage, single range: +// RangeMinimumQuery rmq({10, 100, 30, 300, 70}); +// rmq.GetMinimumFromRange(0, 5); // Returns 10. +// rmq.GetMinimumFromRange(2, 4); // Returns 30. // -// Implementation: The idea is to cache every min(arr, i, j) where j - i is a -// power of two, i.e. j = i + 2^k for some k. Provided this information, we can -// answer all queries in O(1): given a pair (i, j) find the maximum k such that -// i + 2^k < j and note that -// std::min(min(arr, i, i+2^k), min(arr, j-2^k, j)) = min(arr, i, j). +// Usage, multiple ranges: +// RangeMinimumQuery rmq({10, 100, 30, 300, 70}); +// rmq.GetMinimumFromRange(0, 5); // Returns 10. +// rmq.GetMinimumFromRange(2, 4); // Returns 30. +// +// // We add another array {-3, 10, 5, 2, 15, 3}. +// const int begin2 = rmq.TablesSize(); +// for (const int element : {-3, 10, 5, 2, 15, 3}) { +// rmq.PushBack(element); +// } +// rmq.MakeSparseTableFromNewElements(); +// rmq.GetMinimumFromRange(begin2 + 0, begin2 + 5); // Returns -3. +// rmq.GetMinimumFromRange(begin2 + 2, begin2 + 4); // Returns 2. +// rmq.GetMinimumFromRange(begin2 + 4, begin2 + 6); // Returns 3. +// // The previous array can still be queried. +// rmq.GetMinimumFromRange(1, 3); // Returns 30. +// +// // Forbidden, query ranges can only be within the same array. +// rmq.GetMinimumFromRange(3, 9); // Undefined. +// +// rmq.Clear(); +// // All arrays have been removed, so no range query can be made. +// rmq.GetMinimumFromRange(0, 5); // Undefined. +// +// // Add a new range. +// for (const int element : {0, 3, 2}) { +// rmq.PushBack(element); +// } +// rmq.MakeSparseTableFromNewElements(); +// // Queries on the new array can be made. +// +// Note: There are other space/time tradeoffs for this problem, but they are +// generally worse in terms of the constants in the O(1) query time, moreover +// their implementation is generally more involved. +// +// Implementation: The idea is to cache every min(A, i, i+2^k). +// Provided this information, we can answer all queries in O(1): given a pair +// (i, j), first find the maximum k such that i + 2^k < j, then use +// min(A, i, j) = std::min(min(A, i, i+2^k), min(A, j-2^k, j)). #ifndef OR_TOOLS_UTIL_RANGE_MINIMUM_QUERY_H_ #define OR_TOOLS_UTIL_RANGE_MINIMUM_QUERY_H_ @@ -39,12 +79,18 @@ #include #include +#include "absl/log/check.h" #include "ortools/util/bitset.h" namespace operations_research { template > class RangeMinimumQuery { public: + RangeMinimumQuery() { + // This class uses the first two rows of cache_ to know the number of new + // elements, which at any moment is cache_[1].size() - cache_[0].size(). + cache_.resize(2); + }; explicit RangeMinimumQuery(std::vector array); RangeMinimumQuery(std::vector array, Compare cmp); @@ -53,13 +99,37 @@ class RangeMinimumQuery { RangeMinimumQuery& operator=(const RangeMinimumQuery&) = delete; // Returns the minimum (w.r.t. Compare) arr[x], where x is contained in - // [from, to). - T GetMinimumFromRange(int from, int to) const; + // [begin_index, end_index). + // The range [begin_index, end_index) can only cover elements that were new + // at the same call to MakeTableFromNewElements(). + // When calling this method, there must be no pending new elements, i.e. the + // last method called apart from TableSize() must not have been PushBack(). + T RangeMinimum(int begin, int end) const; + + void PushBack(T element) { cache_[0].push_back(element); } + // Generates the sparse table for all new elements, i.e. elements that were + // added with PushBack() since the latest of these events: construction of + // this object, a previous call to MakeTableFromNewElements(), or a call to + // Clear(). + // The range of new elements [begin, end), with begin the Size() at the + // latest event, and end the current Size(). + void MakeTableFromNewElements(); + + // Returns the number of elements in sparse tables, excluding new elements. + int TableSize() const { return cache_[1].size(); } + + // Clears all tables. This invalidates all further range queries on currently + // existing tables. This does *not* release memory held by this object. + void Clear() { + for (auto& row : cache_) row.clear(); + } + + // Returns the concatenated sequence of all elements in all arrays. const std::vector& array() const; private: - // cache_[k][i] = min(arr, i, i+2^k). + // cache_[k][i] = min_{j in [i, i+2^k)} arr[j]. std::vector> cache_; Compare cmp_; }; @@ -76,9 +146,9 @@ class RangeMinimumIndexQuery { RangeMinimumIndexQuery(const RangeMinimumIndexQuery&) = delete; RangeMinimumIndexQuery& operator=(const RangeMinimumIndexQuery&) = delete; - // Returns an index idx from [from, to) such that arr[idx] is the minimum - // value of arr over the interval [from, to). - int GetMinimumIndexFromRange(int from, int to) const; + // Returns an index idx from [begin, end) such that arr[idx] is the minimum + // value of arr over the interval [begin, end). + int GetMinimumIndexFromRange(int begin, int end) const; // Returns the original array. const std::vector& array() const; @@ -99,39 +169,55 @@ template inline RangeMinimumQuery::RangeMinimumQuery(std::vector array) : RangeMinimumQuery(std::move(array), Compare()) {} -// Reminder: The task is to fill cache_ so that -// cache_[k][i] = min(arr, i, i+2^k) for every k <= Log2(n) and i <= n-2^k. -// Note that cache_[k+1][i] = min(cache_[k][i], cache_[k][i+2^k]), hence every -// row can be efficiently computed from the previous. template RangeMinimumQuery::RangeMinimumQuery(std::vector array, Compare cmp) - : cache_(MostSignificantBitPosition32(array.size()) + 1), - cmp_(std::move(cmp)) { - const int array_size = array.size(); + : cache_(2), cmp_(std::move(cmp)) { + // This class uses the first two rows of cache_ to know the number of new + // elements. cache_[0] = std::move(array); - for (int row_idx = 1; row_idx < cache_.size(); ++row_idx) { - const int row_length = array_size - (1 << row_idx) + 1; - const int window = 1 << (row_idx - 1); - cache_[row_idx].resize(row_length); - for (int col_idx = 0; col_idx < row_length; ++col_idx) { - cache_[row_idx][col_idx] = - std::min(cache_[row_idx - 1][col_idx], - cache_[row_idx - 1][col_idx + window], cmp_); - } - } + MakeTableFromNewElements(); +} + +template +inline T RangeMinimumQuery::RangeMinimum(int begin, int end) const { + DCHECK_LE(0, begin); + DCHECK_LT(begin, end); + DCHECK_LE(end, cache_[1].size()); + DCHECK_EQ(cache_[0].size(), cache_[1].size()); + const int layer = MostSignificantBitPosition32(end - begin); + DCHECK_LT(layer, cache_.size()); + const int window = 1 << layer; + const T* row = cache_[layer].data(); + DCHECK_LE(end - window, cache_[layer].size()); + return std::min(row[begin], row[end - window], cmp_); } +// Reminder: The task is to fill cache_ so that for i in [begin, end), +// cache_[k][i] = min(arr, i, i+2^k) for every k <= Log2(n) and i <= n-2^k. +// Note that cache_[k+1][i] = min(cache_[k][i], cache_[k][i+2^k]), hence every +// row can be efficiently computed from the previous. template -inline T RangeMinimumQuery::GetMinimumFromRange(int from, - int to) const { - DCHECK_LE(0, from); - DCHECK_LT(from, to); - DCHECK_LE(to, array().size()); - const int log_diff = MostSignificantBitPosition32(to - from); - const int window = 1 << log_diff; - const std::vector& row = cache_[log_diff]; - return std::min(row[from], row[to - window], cmp_); +void RangeMinimumQuery::MakeTableFromNewElements() { + const int new_size = cache_[0].size(); + const int old_size = cache_[1].size(); + if (old_size >= new_size) return; + // This is the minimum number of rows needed to store the sequence of + // new elements, there may be more rows in the cache. + const int num_rows = 1 + MostSignificantBitPosition32(new_size - old_size); + if (cache_.size() < num_rows) cache_.resize(num_rows); + // Record the new number of elements, wastes just size(T) space. + cache_[1].resize(new_size); + + for (int row = 1; row < num_rows; ++row) { + const int half_window = 1 << (row - 1); + const int last_col = new_size - 2 * half_window; + if (cache_[row].size() <= last_col) cache_[row].resize(last_col + 1); + for (int col = old_size; col <= last_col; ++col) { + cache_[row][col] = std::min(cache_[row - 1][col], + cache_[row - 1][col + half_window], cmp_); + } + } } template @@ -153,8 +239,9 @@ RangeMinimumIndexQuery::RangeMinimumIndexQuery(std::vector array, template inline int RangeMinimumIndexQuery::GetMinimumIndexFromRange( - int from, int to) const { - return rmq_.GetMinimumFromRange(from, to); + int begin, int end) const { + DCHECK_LT(begin, end); + return rmq_.RangeMinimum(begin, end); } template diff --git a/ortools/util/range_query_function.cc b/ortools/util/range_query_function.cc index 2a6266e09b1..416b40dfe2d 100644 --- a/ortools/util/range_query_function.cc +++ b/ortools/util/range_query_function.cc @@ -133,15 +133,13 @@ class CachedRangeIntToIntFunction : public RangeIntToIntFunction { DCHECK_LE(domain_start_, from); DCHECK_LT(from, to); DCHECK_LE(to, domain_start_ + static_cast(array().size())); - return rmq_min_.GetMinimumFromRange(from - domain_start_, - to - domain_start_); + return rmq_min_.RangeMinimum(from - domain_start_, to - domain_start_); } int64_t RangeMax(int64_t from, int64_t to) const override { DCHECK_LE(domain_start_, from); DCHECK_LT(from, to); DCHECK_LE(to, domain_start_ + static_cast(array().size())); - return rmq_max_.GetMinimumFromRange(from - domain_start_, - to - domain_start_); + return rmq_max_.RangeMinimum(from - domain_start_, to - domain_start_); } int64_t RangeFirstInsideInterval(int64_t range_begin, int64_t range_end, int64_t interval_begin, diff --git a/ortools/util/range_query_function.h b/ortools/util/range_query_function.h index c75fdad54b7..e7451259827 100644 --- a/ortools/util/range_query_function.h +++ b/ortools/util/range_query_function.h @@ -17,10 +17,8 @@ #ifndef OR_TOOLS_UTIL_RANGE_QUERY_FUNCTION_H_ #define OR_TOOLS_UTIL_RANGE_QUERY_FUNCTION_H_ +#include #include -#include - -#include "ortools/base/types.h" namespace operations_research { // RangeIntToIntFunction is an interface to int64_t->int64_t functions diff --git a/ortools/util/saturated_arithmetic.h b/ortools/util/saturated_arithmetic.h index 321426143d8..00bf7b92fd7 100644 --- a/ortools/util/saturated_arithmetic.h +++ b/ortools/util/saturated_arithmetic.h @@ -299,6 +299,19 @@ inline int64_t CapAdd(int64_t x, int64_t y) { #endif } +// This avoid the need to convert to int64_t min/max and is about twice as fast +// if it corresponds to your use case. +inline bool AddIntoOverflow(int64_t x, int64_t* y) { +#if defined(__clang__) + return __builtin_add_overflow(x, *y, y); +#else + const int64_t result = TwosComplementAddition(x, *y); + if (AddHadOverflow(x, *y, result)) return true; + *y = result; + return false; +#endif +} + inline void CapAddTo(int64_t x, int64_t* y) { *y = CapAdd(*y, x); } inline int64_t CapSub(int64_t x, int64_t y) { diff --git a/ortools/util/strong_integers.h b/ortools/util/strong_integers.h index 6153487a9c3..27d436158b7 100644 --- a/ortools/util/strong_integers.h +++ b/ortools/util/strong_integers.h @@ -216,6 +216,7 @@ class StrongInt64 { } constexpr int64_t value() const { return value_; } + int64_t* mutable_value() { return &value_; } template // Needed for StrongVector. constexpr ValType value() const { diff --git a/ortools/util/zvector.h b/ortools/util/zvector.h index ec21520c70b..cf0cf9322bb 100644 --- a/ortools/util/zvector.h +++ b/ortools/util/zvector.h @@ -14,7 +14,8 @@ #ifndef OR_TOOLS_UTIL_ZVECTOR_H_ #define OR_TOOLS_UTIL_ZVECTOR_H_ -#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__)) && \ +#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || \ + defined(__OpenBSD__)) && \ defined(__GNUC__) #include #elif !defined(_MSC_VER) && !defined(__MINGW32__) && !defined(__MINGW64__) diff --git a/patches/pybind11_protobuf.patch b/patches/pybind11_protobuf.patch index d16952621c6..6d2388cd8e6 100644 --- a/patches/pybind11_protobuf.patch +++ b/patches/pybind11_protobuf.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 2139dc0..1942ad0 100644 +index 2139dc0..df3f30a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -27,7 +27,7 @@ include(CTest) +@@ -27,58 +27,64 @@ include(CTest) # ============================================================================ # Find Python @@ -11,7 +11,106 @@ index 2139dc0..1942ad0 100644 # ============================================================================ # Build dependencies -@@ -87,8 +87,10 @@ pybind11_add_module( ++include(FetchContent) + +-if(USE_SYSTEM_ABSEIL) +- # Version omitted, as absl only allows EXACT version matches +- set(_absl_package_args REQUIRED) +-else() +- set(_absl_package_args 20230125) +-endif() +-if(USE_SYSTEM_PROTOBUF) +- set(_protobuf_package_args 4.23.3 REQUIRED) +-else() +- set(_protobuf_package_args 4.23.3) +-endif() +-if(USE_SYSTEM_PYBIND) +- set(_pybind11_package_args 2.11.1 REQUIRED) +-else() +- set(_pybind11_package_args 2.11.1) ++message(CHECK_START "Checking for external dependencies") ++list(APPEND CMAKE_MESSAGE_INDENT " ") ++ ++if(NOT TARGET absl::base) ++ if(USE_SYSTEM_ABSEIL) ++ # Version omitted, as absl only allows EXACT version matches ++ set(_absl_package_args REQUIRED) ++ else() ++ set(_absl_package_args 20230125) ++ endif() ++ FetchContent_Declare( ++ absl ++ GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" ++ GIT_TAG 20230125.3 ++ FIND_PACKAGE_ARGS ${_absl_package_args} NAMES absl) ++ set(ABSL_PROPAGATE_CXX_STD ON) ++ set(ABSL_ENABLE_INSTALL ON) ++ FetchContent_MakeAvailable(absl) + endif() + +-set(ABSL_PROPAGATE_CXX_STD ON) +-set(ABSL_ENABLE_INSTALL ON) ++if(NOT TARGET protobuf::libprotobuf) ++ if(USE_SYSTEM_PROTOBUF) ++ set(_protobuf_package_args 4.23.3 REQUIRED) ++ else() ++ set(_protobuf_package_args 4.23.3) ++ endif() ++ FetchContent_Declare( ++ Protobuf ++ GIT_REPOSITORY "https://github.com/protocolbuffers/protobuf.git" ++ GIT_TAG v23.3 ++ GIT_SUBMODULES "" ++ FIND_PACKAGE_ARGS ${_protobuf_package_args} NAMES protobuf) ++ set(protobuf_BUILD_TESTS OFF CACHE INTERNAL "") ++ FetchContent_MakeAvailable(Protobuf) ++endif() + +-include(FetchContent) +-FetchContent_Declare( +- absl +- GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" +- GIT_TAG 20230125.3 +- FIND_PACKAGE_ARGS ${_absl_package_args} NAMES absl) +- +-# cmake-format: off +-FetchContent_Declare( +- Protobuf +- GIT_REPOSITORY "https://github.com/protocolbuffers/protobuf.git" +- GIT_TAG v23.3 +- GIT_SUBMODULES "" +- FIND_PACKAGE_ARGS ${_protobuf_package_args} NAMES protobuf) +-set(protobuf_BUILD_TESTS OFF CACHE INTERNAL "") +-# cmake-format: on +- +-FetchContent_Declare( +- pybind11 +- GIT_REPOSITORY "https://github.com/pybind/pybind11.git" +- GIT_TAG v2.11.1 +- FIND_PACKAGE_ARGS ${_pybind11_package_args} NAMES pybind11) ++if(NOT TARGET pybind11::pybind11_headers) ++ if(USE_SYSTEM_PYBIND) ++ set(_pybind11_package_args 2.11.1 REQUIRED) ++ else() ++ set(_pybind11_package_args 2.11.1) ++ endif() ++ FetchContent_Declare( ++ pybind11 ++ GIT_REPOSITORY "https://github.com/pybind/pybind11.git" ++ GIT_TAG v2.11.1 ++ FIND_PACKAGE_ARGS ${_pybind11_package_args} NAMES pybind11) ++ FetchContent_MakeAvailable(pybind11) ++endif() + +-message(CHECK_START "Checking for external dependencies") +-list(APPEND CMAKE_MESSAGE_INDENT " ") +-FetchContent_MakeAvailable(absl Protobuf pybind11) + list(POP_BACK CMAKE_MESSAGE_INDENT) ++message(CHECK_PASS "found") + + # ============================================================================ + # pybind11_proto_utils pybind11 extension module +@@ -87,8 +93,10 @@ pybind11_add_module( pybind11_protobuf/proto_utils.h) target_link_libraries( @@ -24,7 +123,7 @@ index 2139dc0..1942ad0 100644 target_include_directories( pybind11_proto_utils PRIVATE ${PROJECT_SOURCE_DIR} ${protobuf_INCLUDE_DIRS} -@@ -116,10 +118,11 @@ target_link_libraries( +@@ -116,10 +124,11 @@ target_link_libraries( absl::optional protobuf::libprotobuf pybind11::pybind11 @@ -37,7 +136,7 @@ index 2139dc0..1942ad0 100644 PRIVATE ${PROJECT_SOURCE_DIR} ${protobuf_INCLUDE_DIRS} ${protobuf_SOURCE_DIR} ${pybind11_INCLUDE_DIRS}) -@@ -143,7 +146,7 @@ target_link_libraries( +@@ -143,7 +152,7 @@ target_link_libraries( absl::optional protobuf::libprotobuf pybind11::pybind11