Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Dall-e-3 #9

Merged
merged 10 commits into from
Nov 7, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ If you use this library, you must conform to Open AI's [Usage Policies](https://

## Other Language Bindings

For another great Go implementation, see [sashabaranov/go-gpt3](https://github.com/sashabaranov/go-gpt3).
For another great Go implementation, see [sashabaranov/go-openai](https://github.com/sashabaranov/go-openai).
For other languages, see [Open AI's Website](https://beta.openai.com/docs/libraries/libraries).

## Contributing
Expand Down
17 changes: 12 additions & 5 deletions examples/images/images-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ func init() {
authentication.SetAPIKey(key)
}

func create() (*images.Response, error) {
func create(model, size string) (*images.Response, error) {
const prompt = "A cute baby sea otter"

fmt.Printf("Creating from prompt: %s\n", prompt)
fmt.Printf("Creating from model=\"%s\", prompt=\"%s\"\n", model, prompt)
resp, _, err := images.MakeModeratedCreationRequest(&images.CreationRequest{
Prompt: prompt,
Size: images.SmallImage,
Size: size,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return nil, err
Expand All @@ -39,7 +40,7 @@ func variation(imagename, image string) error {
resp, err := images.MakeVariationRequest(&images.VariationRequest{
Image: image,
ImageName: imagename,
Size: images.SmallImage,
Size: images.Dalle2SmallImage,
User: "https://github.com/TannerKvarfordt/gopenai",
}, nil)
if err != nil {
Expand All @@ -51,7 +52,7 @@ func variation(imagename, image string) error {
}

func main() {
resp, err := create()
resp, err := create(images.ModelDalle2, images.Dalle2SmallImage)
if err != nil {
fmt.Println(err)
return
Expand All @@ -62,4 +63,10 @@ func main() {
fmt.Println(err)
return
}

_, err = create(images.ModelDalle3, images.Dalle3SquareImage)
if err != nil {
fmt.Println(err)
return
}
}
91 changes: 79 additions & 12 deletions images/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,85 @@ const (
)

const (
SmallImage string = "256x256"
MediumImage string = "512x512"
LargeImage string = "1024x1024"
Dalle2SmallImage = "256x256"
Dalle2MediumImage = "512x512"
Dalle2LargeImage = "1024x1024"

Dalle3SquareImage = "1024x1024"
Dalle3LandscapeImage = "1792x1024"
Dalle3PortraitImage = "1024x1792"

// Deprecated: Use Dalle2SmallImage instead.
SmallImage = Dalle2SmallImage
// Deprecated: Use Dalle2MediumImage instead.
MediumImage = Dalle2MediumImage
// Deprecated: Use Dalle2LargeImage instead.
LargeImage = Dalle2LargeImage
)

const (
ResponseFormatURL = "url"
ResponseFormatB64JSON = "b64_json"
)

const (
ModelDalle2 = "dall-e-2"
ModelDalle3 = "dall-e-3"
)

const (
QualityStandard = "standard"
QualityHD = "hd"
)

const (
StyleVivid = "vivid"
StyleNatural = "natural"
)

// Response structure for the image API endpoint.
type Response struct {
Created uint64 `json:"created"`
Data []struct {
URL string `json:"url"`
B64JSON string `json:"b64_json"`
URL string `json:"url"`
B64JSON string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}
Error *common.ResponseError `json:"error,omitempty"`
}

// Request structure for the image creation API endpoint.
type CreationRequest struct {
// A text description of the desired image(s). The maximum length is 1000 characters.
// A text description of the desired image(s).
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3.
Prompt string `json:"prompt,omitempty"`

// The model to use for image generation.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
// For dall-e-3, only n=1 is supported.
N *uint64 `json:"n,omitempty"`

// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.
Size string `json:"size,omitempty"`
// The quality of the image that will be generated.
// "hd" creates images with finer details and greater consistency across the image.
// This param is only supported for dall-e-3.
Quality string `json:"quality,omitempty"`

// The format in which the generated images are returned. Must be one of url or b64_json.
ResponseFormat string `json:"response_format,omitempty"`

// The size of the generated images.
// Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2.
// Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
Size string `json:"size,omitempty"`

// The style of the generated images. Must be one of vivid or natural.
// Vivid causes the model to lean towards generating hyper-real and dramatic images.
// Natural causes the model to produce more natural, less hyper-real looking images.
// This param is only supported for dall-e-3.
Style string `json:"style,omitempty"`

// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
User string `json:"user,omitempty"`
}
Expand Down Expand Up @@ -111,6 +156,9 @@ type EditRequest struct {
// any path information.
ImageName string `json:"-"`

// A text description of the desired image(s). The maximum length is 1000 characters.
Prompt string `json:"prompt,omitempty"`

// An additional image whose fully transparent areas (e.g. where alpha is zero)
// indicate where image should be edited. Must be a valid PNG file, less than 4MB,
// and have the same dimensions as image.
Expand All @@ -120,8 +168,8 @@ type EditRequest struct {
// path information.
MaskName string `json:"-"`

// A text description of the desired image(s). The maximum length is 1000 characters.
Prompt string `json:"prompt,omitempty"`
// The model to use for image generation. Only dall-e-2 is supported at this time.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
N *uint64 `json:"n,omitempty"`
Expand All @@ -145,14 +193,15 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e
buf := new(bytes.Buffer)
writer := multipart.NewWriter(buf)

var err error

if len(request.Prompt) > 0 {
err := common.CreateFormField("prompt", request.Prompt, writer)
err = common.CreateFormField("prompt", request.Prompt, writer)
if err != nil {
return nil, err
}
}

var err error
if request.N != nil {
err = common.CreateFormField("n", request.N, writer)
if err != nil {
Expand Down Expand Up @@ -181,6 +230,13 @@ func MakeEditRequest(request *EditRequest, organizationID *string) (*Response, e
}
}

if len(request.Model) > 0 {
err = common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}
}

if len(request.Image) > 0 {
err = common.CreateFormFile("image", request.ImageName, request.Image, writer)
if err != nil {
Expand Down Expand Up @@ -240,6 +296,9 @@ type VariationRequest struct {
// any path information.
ImageName string `json:"-"`

// The model to use for image generation. Only dall-e-2 is supported at this time.
Model string `json:"model,omitempty"`

// The number of images to generate. Must be between 1 and 10.
N *uint64 `json:"n,omitempty"`

Expand All @@ -263,6 +322,7 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R
writer := multipart.NewWriter(buf)

var err error

if request.N != nil {
err = common.CreateFormField("n", request.N, writer)
if err != nil {
Expand Down Expand Up @@ -298,6 +358,13 @@ func MakeVariationRequest(request *VariationRequest, organizationID *string) (*R
}
}

if len(request.Model) > 0 {
err = common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}
}

writer.Close()
r, err := common.MakeRequestWithForm[Response](buf, VariationEndpoint, http.MethodPost, writer.FormDataContentType(), organizationID)
if err != nil {
Expand Down
23 changes: 16 additions & 7 deletions images/images_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ func init() {
authentication.SetAPIKey(key)
}

func create() (*images.Response, error) {
func create(model, size string) (*images.Response, error) {
const prompt = "A cute baby sea otter"

fmt.Printf("Creating from prompt: %s\n", prompt)
resp, err := images.MakeCreationRequest(&images.CreationRequest{
Prompt: prompt,
Size: images.SmallImage,
Size: size,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return nil, err
Expand All @@ -38,14 +39,15 @@ func create() (*images.Response, error) {
return resp, nil
}

func variation(imagename, image string) error {
func variation(model, imagename, image string) error {

fmt.Printf("Generating a variation...")
resp, err := images.MakeVariationRequest(&images.VariationRequest{
Image: image,
ImageName: imagename,
Size: images.SmallImage,
Size: images.Dalle2SmallImage,
User: "https://github.com/TannerKvarfordt/gopenai",
Model: model,
}, nil)
if err != nil {
return err
Expand All @@ -58,13 +60,20 @@ func variation(imagename, image string) error {
return nil
}

func TestImages(t *testing.T) {
resp, err := create()
func TestImagesDalle2(t *testing.T) {
resp, err := create(images.ModelDalle2, images.Dalle2SmallImage)
if err != nil {
t.Fatal(err)
}

err = variation("Original", resp.Data[0].URL)
err = variation(images.ModelDalle2, "Original", resp.Data[0].URL)
if err != nil {
t.Fatal(err)
}
}

func TestImagesDalle3(t *testing.T) {
_, err := create(images.ModelDalle3, images.Dalle3SquareImage)
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 5 additions & 0 deletions tools/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ pushd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null
./examples-build.sh

pushd .. >/dev/null
echo "Formatting code..."
go fmt ./...
echo "Running tests..."
go test ./...
echo "Building $(basename "$(pwd)")..."
go build ./...
echo "Vetting..."
go vet ./...
echo "Done."