Skip to content

Make it possible to place models in a different package to prevent circular package references for go-server #11623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class GoServerCodegen extends AbstractGoCodegen {
protected int serverPort = 8080;
protected String projectName = "openapi-server";
protected String sourceFolder = "go";
protected String supportFolder = "support";
protected Boolean corsFeatureEnabled = false;
protected Boolean addResponseHeaders = false;

Expand Down Expand Up @@ -230,7 +231,9 @@ public void processOpts() {
}
additionalProperties.put("routers", routers);

modelPackage = packageName;
if (!additionalProperties.containsKey(CodegenConstants.MODEL_PACKAGE)) {
modelPackage = packageName;
}
apiPackage = packageName;

/*
Expand All @@ -242,12 +245,12 @@ public void processOpts() {
supportingFiles.add(new SupportingFile("main.mustache", "", "main.go"));
supportingFiles.add(new SupportingFile("Dockerfile.mustache", "", "Dockerfile"));
supportingFiles.add(new SupportingFile("go.mod.mustache", "", "go.mod"));
supportingFiles.add(new SupportingFile("routers.mustache", sourceFolder, "routers.go"));
supportingFiles.add(new SupportingFile("logger.mustache", sourceFolder, "logger.go"));
supportingFiles.add(new SupportingFile("impl.mustache",sourceFolder, "impl.go"));
supportingFiles.add(new SupportingFile("helpers.mustache", sourceFolder, "helpers.go"));
supportingFiles.add(new SupportingFile("routers.mustache", supportFolder, "routers.go"));
supportingFiles.add(new SupportingFile("logger.mustache", supportFolder, "logger.go"));
supportingFiles.add(new SupportingFile("impl.mustache", supportFolder, "impl.go"));
supportingFiles.add(new SupportingFile("helpers.mustache", supportFolder, "helpers.go"));
supportingFiles.add(new SupportingFile("api.mustache", sourceFolder, "api.go"));
supportingFiles.add(new SupportingFile("error.mustache", sourceFolder, "error.go"));
supportingFiles.add(new SupportingFile("error.mustache", supportFolder, "error.go"));
supportingFiles.add(new SupportingFile("README.mustache", "", "README.md")
.doNotOverwrite());
}
Expand All @@ -269,6 +272,7 @@ public Map<String, Object> postProcessOperationsWithModels(Map<String, Object> o

boolean addedTimeImport = false;
boolean addedOSImport = false;
boolean addedModelImport = false;
for (CodegenOperation operation : operations) {
for (CodegenParameter param : operation.allParams) {
// import "os" if the operation uses files
Expand All @@ -284,12 +288,49 @@ public Map<String, Object> postProcessOperationsWithModels(Map<String, Object> o
addedTimeImport = true;
}
}

// import "models" directory if needed
if (!addedModelImport && param.isModel && !modelPackage.equals(apiPackage)) {
addedModelImport = true;
objs.put("hasDifferentModelDir", true);
}
}
}

return objs;
}

@Override
public Map<String, Object> postProcessSupportingFileData(Map<String, Object> objs) {
objs = super.postProcessSupportingFileData(objs);

Map<String, Object> apiInfo = (Map<String, Object>) objs.get("apiInfo");
List<HashMap<String, Object>> apiList = (List<HashMap<String, Object>>) apiInfo.get("apis");
for (HashMap<String, Object> api : apiList) {
Map<String, Object> objectMap = (Map<String, Object>) api.get("operations");
List<CodegenOperation> operations = (List<CodegenOperation>) objectMap.get("operation");

boolean addedModelImport = false;
boolean addedSupportImport = false;
for (CodegenOperation operation : operations) {
for (CodegenParameter param : operation.allParams) {
// Always Add the support import
if (!addedSupportImport) {
objs.put("hasSupportDir", true);
addedSupportImport = true;
}

// import "models" directory if needed
if (!addedModelImport && param.isModel && !modelPackage.equals(apiPackage)) {
objs.put("hasDifferentModelDir", true);
addedModelImport = true;
}
}
}
}
return objs;
}

@Override
public String apiPackage() {
return sourceFolder;
Expand Down Expand Up @@ -340,7 +381,10 @@ public String apiFileFolder() {

@Override
public String modelFileFolder() {
return outputFolder + File.separator + apiPackage().replace('.', File.separatorChar);
if (!additionalProperties.containsKey(CodegenConstants.MODEL_PACKAGE)) {
return outputFolder + File.separator + apiPackage().replace('.', File.separatorChar);
}
return outputFolder + File.separator + modelPackage().replace('.', File.separatorChar);
}

public void setSourceFolder(String sourceFolder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"net/http"{{#apiInfo}}{{#apis}}{{#imports}}
"{{import}}"{{/imports}}{{/apis}}{{/apiInfo}}
{{#hasSupportDir}}"{{gitHost}}/{{gitUserId}}/{{gitRepoId}}/support"{{/hasSupportDir}}
{{#hasDifferentModelDir}}"{{gitHost}}/{{gitUserId}}/{{gitRepoId}}/{{modelPackage}}"{{/hasDifferentModelDir}}
)


Expand All @@ -28,5 +30,5 @@ type {{classname}}Servicer interface { {{#operations}}{{#operation}}
{{#isDeprecated}}
// Deprecated
{{/isDeprecated}}
{{operationId}}(context.Context{{#allParams}}, {{dataType}}{{/allParams}}) (ImplResponse, error){{/operation}}{{/operations}}
{{operationId}}(context.Context{{#allParams}}, {{#isArray}}[]{{/isArray}}{{#isArray}}{{#items}}{{#isModel}}{{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}{{/isModel}}{{dataType}}{{/items}}{{/isArray}}{{^isArray}}{{#isModel}}{{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}{{/isModel}}{{dataType}}{{/isArray}}{{/allParams}}) (support.ImplResponse, error){{/operation}}{{/operations}}
}{{/apis}}{{/apiInfo}}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package {{packageName}}

import (
"encoding/json"
"net/http"
"strings"

Expand All @@ -14,29 +13,32 @@ import (
"github.com/go-chi/chi/v5"
{{/chi}}
{{/routers}}

support "{{gitHost}}/{{gitUserId}}/{{gitRepoId}}/support"
{{#hasDifferentModelDir}}"{{gitHost}}/{{gitUserId}}/{{gitRepoId}}/{{modelPackage}}"{{/hasDifferentModelDir}}
)

// {{classname}}Controller binds http requests to an api service and writes the service results to the http response
type {{classname}}Controller struct {
service {{classname}}Servicer
errorHandler ErrorHandler
errorHandler support.ErrorHandler
}

// {{classname}}Option for how the controller is set up.
type {{classname}}Option func(*{{classname}}Controller)

// With{{classname}}ErrorHandler inject ErrorHandler into controller
func With{{classname}}ErrorHandler(h ErrorHandler) {{classname}}Option {
func With{{classname}}ErrorHandler(h support.ErrorHandler) {{classname}}Option {
return func(c *{{classname}}Controller) {
c.errorHandler = h
}
}

// New{{classname}}Controller creates a default api controller
func New{{classname}}Controller(s {{classname}}Servicer, opts ...{{classname}}Option) Router {
func New{{classname}}Controller(s {{classname}}Servicer, opts ...{{classname}}Option) support.Router {
controller := &{{classname}}Controller{
service: s,
errorHandler: DefaultErrorHandler,
errorHandler: support.DefaultErrorHandler,
}

for _, opt := range opts {
Expand All @@ -47,8 +49,8 @@ func New{{classname}}Controller(s {{classname}}Servicer, opts ...{{classname}}Op
}

// Routes returns all the api routes for the {{classname}}Controller
func (c *{{classname}}Controller) Routes() Routes {
return Routes{ {{#operations}}{{#operation}}
func (c *{{classname}}Controller) Routes() support.Routes {
return support.Routes{ {{#operations}}{{#operation}}
{
"{{operationId}}",
strings.ToUpper("{{httpMethod}}"),
Expand All @@ -66,13 +68,13 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{#hasFormParams}}
{{#isMultipart}}
if err := r.ParseMultipartForm(32 << 20); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isMultipart}}
{{^isMultipart}}
if err := r.ParseForm(); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isMultipart}}
Expand All @@ -90,16 +92,16 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{#allParams}}
{{#isPathParam}}
{{#isLong}}
{{paramName}}Param, err := parseInt64Parameter({{#routers}}{{#mux}}params["{{baseName}}"]{{/mux}}{{#chi}}chi.URLParam(r, "{{baseName}}"){{/chi}}{{/routers}}, {{required}})
{{paramName}}Param, err := support.ParseInt64Parameter({{#routers}}{{#mux}}params["{{baseName}}"]{{/mux}}{{#chi}}chi.URLParam(r, "{{baseName}}"){{/chi}}{{/routers}}, {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isLong}}
{{#isInteger}}
{{paramName}}Param, err := parseInt32Parameter({{#routers}}{{#mux}}params["{{baseName}}"]{{/mux}}{{#chi}}chi.URLParam(r, "{{baseName}}"){{/chi}}{{/routers}}, {{required}})
{{paramName}}Param, err := support.ParseInt32Parameter({{#routers}}{{#mux}}params["{{baseName}}"]{{/mux}}{{#chi}}chi.URLParam(r, "{{baseName}}"){{/chi}}{{/routers}}, {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isInteger}}
Expand All @@ -110,38 +112,38 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{/isPathParam}}
{{#isQueryParam}}
{{#isLong}}
{{paramName}}Param, err := parseInt64Parameter(query.Get("{{baseName}}"), {{required}})
{{paramName}}Param, err := support.ParseInt64Parameter(query.Get("{{baseName}}"), {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isLong}}
{{#isInteger}}
{{paramName}}Param, err := parseInt32Parameter(query.Get("{{baseName}}"), {{required}})
{{paramName}}Param, err := support.ParseInt32Parameter(query.Get("{{baseName}}"), {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isInteger}}
{{#isBoolean}}
{{paramName}}Param, err := parseBoolParameter(query.Get("{{baseName}}"))
{{paramName}}Param, err := support.ParseBoolParameter(query.Get("{{baseName}}"))
if err != nil {
w.WriteHeader(500)
return
}
{{/isBoolean}}
{{#isArray}}
{{#items.isLong}}
{{paramName}}Param, err := parseInt64ArrayParameter(query.Get("{{baseName}}"), ",", {{required}})
{{paramName}}Param, err := support.ParseInt64ArrayParameter(query.Get("{{baseName}}"), ",", {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/items.isLong}}
{{#items.isInteger}}
{{paramName}}Param, err := parseInt32ArrayParameter(query.Get("{{baseName}}"), ",", {{required}})
{{paramName}}Param, err := support.ParseInt32ArrayParameter(query.Get("{{baseName}}"), ",", {{required}})
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/items.isInteger}}
Expand All @@ -163,27 +165,27 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{/isQueryParam}}
{{#isFormParam}}
{{#isFile}}{{#isArray}}
{{paramName}}Param, err := ReadFormFilesToTempFiles(r, "{{baseName}}"){{/isArray}}{{^isArray}}
{{paramName}}Param, err := ReadFormFileToTempFile(r, "{{baseName}}")
{{paramName}}Param, err := support.ReadFormFilesToTempFiles(r, "{{baseName}}"){{/isArray}}{{^isArray}}
{{paramName}}Param, err := support.ReadFormFileToTempFile(r, "{{baseName}}")
{{/isArray}}
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isFile}}
{{#isLong}}{{#isArray}}
{{paramName}}Param, err := parseInt64ArrayParameter(r.FormValue("{{baseName}}"), ",", {{required}}){{/isArray}}{{^isArray}}
{{paramName}}Param, err := parseInt64Parameter(r.FormValue("{{baseName}}"), {{required}}){{/isArray}}
{{paramName}}Param, err := support.ParseInt64ArrayParameter(r.FormValue("{{baseName}}"), ",", {{required}}){{/isArray}}{{^isArray}}
{{paramName}}Param, err := support.ParseInt64Parameter(r.FormValue("{{baseName}}"), {{required}}){{/isArray}}
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isLong}}
{{#isInteger}}{{#isArray}}
{{paramName}}Param, err := parseInt32ArrayParameter(r.FormValue("{{baseName}}"), ",", {{required}}){{/isArray}}{{^isArray}}
{{paramName}}Param, err := parseInt32Parameter(r.FormValue("{{baseName}}"), {{required}}){{/isArray}}
{{paramName}}Param, err := support.ParseInt32ArrayParameter(r.FormValue("{{baseName}}"), ",", {{required}}){{/isArray}}{{^isArray}}
{{paramName}}Param, err := support.ParseInt32Parameter(r.FormValue("{{baseName}}"), {{required}}){{/isArray}}
if err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{/isInteger}}
Expand All @@ -197,19 +199,19 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{paramName}}Param := r.Header.Get("{{baseName}}")
{{/isHeaderParam}}
{{#isBodyParam}}
{{paramName}}Param := {{dataType}}{}
d := json.NewDecoder(r.Body)
{{paramName}}Param := {{#isArray}}[]{{^isPrimitiveType}}{{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}{{/isPrimitiveType}}{{baseType}}{{/isArray}}{{^isArray}}{{#isModel}}{{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}{{/isModel}}{{dataType}}{{/isArray}}{}
d := support.NewJSONDecoder(r.Body)
{{^isAdditionalPropertiesTrue}}
d.DisallowUnknownFields()
{{/isAdditionalPropertiesTrue}}
if err := d.Decode(&{{paramName}}Param); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
c.errorHandler(w, r, &support.ParsingError{Err: err}, nil)
return
}
{{#isArray}}
{{#items.isModel}}
for _, el := range {{paramName}}Param {
if err := Assert{{baseType}}Required(el); err != nil {
if err := {{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}Assert{{baseType}}Required(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
Expand All @@ -218,7 +220,7 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{/isArray}}
{{^isArray}}
{{#isModel}}
if err := Assert{{baseType}}Required({{paramName}}Param); err != nil {
if err := {{#hasDifferentModelDir}}{{modelPackage}}.{{/hasDifferentModelDir}}Assert{{baseType}}Required({{paramName}}Param); err != nil {
c.errorHandler(w, r, err, nil)
return
}
Expand All @@ -233,6 +235,6 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
return
}
// If no error, encode the body and the result code
EncodeJSONResponse(result.Body, &result.Code,{{#addResponseHeaders}} result.Headers,{{/addResponseHeaders}} w)
support.EncodeJSONResponse(result.Body, &result.Code,{{#addResponseHeaders}} result.Headers,{{/addResponseHeaders}} w)

}{{/operation}}{{/operations}}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{{>partial_header}}
package {{packageName}}
package support

import (
"errors"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
{{>partial_header}}
package {{packageName}}
package support

import (
"encoding/json"
"net/http"
"io"
"reflect"
)

Expand Down Expand Up @@ -58,3 +61,33 @@ func AssertRecurseValueRequired(value reflect.Value, callback func(interface{})
}
return nil
}

// EncodeJSONResponse uses the json encoder to write an interface to the http response with an optional status code
func EncodeJSONResponse(i interface{}, status *int,{{#addResponseHeaders}} headers map[string][]string,{{/addResponseHeaders}} w http.ResponseWriter) error {
{{#addResponseHeaders}}
wHeader := w.Header()
if headers != nil {
for key, values := range headers {
for _, value := range values {
wHeader.Add(key, value)
}
}
}
wHeader.Set("Content-Type", "application/json; charset=UTF-8")
{{/addResponseHeaders}}
{{^addResponseHeaders}}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
{{/addResponseHeaders}}
if status != nil {
w.WriteHeader(*status)
} else {
w.WriteHeader(http.StatusOK)
}

return json.NewEncoder(w).Encode(i)
}

// NewJSONDecoder creates a jew json decoder to decode an http body.
func NewJSONDecoder(r io.Reader) *json.Decoder {
return json.NewDecoder(r)
}
Loading