ClassifierPlayground.jsx
1 import { useState, useEffect } from "react"; 2 import { 3 Box, Button, Card, Chip, Grid, LinearProgress, MenuItem, styled, 4 TextField, Typography, 5 } from "@mui/material"; 6 import { Category } from "@mui/icons-material"; 7 import Breadcrumb from "app/components/Breadcrumb"; 8 import useAuth from "app/hooks/useAuth"; 9 import api from "app/utils/api"; 10 import { toast } from "react-toastify"; 11 import { Trans, useTranslation } from "react-i18next"; 12 13 const Container = styled("div")(({ theme }) => ({ 14 margin: 10, 15 [theme.breakpoints.down("sm")]: { margin: 16 }, 16 "& .breadcrumb": { marginBottom: 30, [theme.breakpoints.down("sm")]: { marginBottom: 16 } } 17 })); 18 19 const ContentBox = styled("div")(({ theme }) => ({ 20 margin: "30px", 21 [theme.breakpoints.down("sm")]: { margin: "16px" } 22 })); 23 24 const COLORS = ["#42a5f5", "#66bb6a", "#ffa726", "#ef5350", "#ab47bc", "#26c6da", "#5c6bc0", "#ec407a"]; 25 26 export default function ClassifierPlayground() { 27 const { t } = useTranslation(); 28 const auth = useAuth(); 29 const [sequence, setSequence] = useState(""); 30 const [labelsText, setLabelsText] = useState(""); 31 const [selectedModel, setSelectedModel] = useState(""); 32 const [classifiers, setClassifiers] = useState([]); 33 const [defaultModel, setDefaultModel] = useState(""); 34 const [results, setResults] = useState(null); 35 const [loading, setLoading] = useState(false); 36 37 useEffect(() => { 38 document.title = (process.env.REACT_APP_RESTAI_NAME || "RESTai") + " - " + t("classifier.title"); 39 api.get("/tools/classifiers", auth.user.token) 40 .then((data) => { 41 setClassifiers(data.classifiers || []); 42 setDefaultModel(data.default || ""); 43 setSelectedModel(data.default || ""); 44 }) 45 .catch(() => {}); 46 // eslint-disable-next-line react-hooks/exhaustive-deps 47 }, [t]); 48 49 const handleClassify = () => { 50 const labels = labelsText.split(",").map((l) => l.trim()).filter(Boolean); 51 if (!sequence.trim() || labels.length === 0) { 52 toast.warning(t("classifier.enterBoth")); 53 return; 54 } 55 56 setLoading(true); 57 setResults(null); 58 const body = { sequence: sequence.trim(), labels }; 59 if (selectedModel && selectedModel !== defaultModel) body.model = selectedModel; 60 api.post("/tools/classifier", body, auth.user.token) 61 .then((data) => setResults(data)) 62 .catch(() => {}) 63 .finally(() => setLoading(false)); 64 }; 65 66 return ( 67 <Container> 68 <Box className="breadcrumb"> 69 <Breadcrumb routeSegments={[{ name: t("classifier.breadcrumb"), path: "/classifier" }]} /> 70 </Box> 71 72 <ContentBox> 73 <Grid container spacing={3}> 74 {/* Input */} 75 <Grid item xs={12} md={6}> 76 <Card elevation={1} sx={{ p: 3 }}> 77 <Typography variant="subtitle1" fontWeight={600} sx={{ mb: 2, display: "flex", alignItems: "center", gap: 1 }}> 78 <Category fontSize="small" /> {t("classifier.playgroundTitle")} 79 </Typography> 80 81 <TextField 82 fullWidth 83 select 84 label={t("classifier.model")} 85 value={selectedModel} 86 onChange={(e) => setSelectedModel(e.target.value)} 87 sx={{ mb: 2 }} 88 helperText={t("classifier.modelHelp")} 89 > 90 {classifiers.map((c) => ( 91 <MenuItem key={c.id} value={c.id}> 92 {c.name} 93 </MenuItem> 94 ))} 95 </TextField> 96 97 <TextField 98 fullWidth 99 multiline 100 rows={4} 101 label={t("classifier.textLabel")} 102 placeholder={t("classifier.textPlaceholder")} 103 value={sequence} 104 onChange={(e) => setSequence(e.target.value)} 105 sx={{ mb: 2 }} 106 /> 107 108 <TextField 109 fullWidth 110 label={t("classifier.labels")} 111 placeholder={t("classifier.labelsPlaceholder")} 112 helperText={t("classifier.labelsHelp")} 113 value={labelsText} 114 onChange={(e) => setLabelsText(e.target.value)} 115 onKeyDown={(e) => { if (e.key === "Enter") handleClassify(); }} 116 sx={{ mb: 2 }} 117 /> 118 119 <Button 120 variant="contained" 121 onClick={handleClassify} 122 disabled={loading || !sequence.trim() || !labelsText.trim()} 123 fullWidth 124 > 125 {loading ? t("classifier.classifying") : t("classifier.classify")} 126 </Button> 127 </Card> 128 </Grid> 129 130 {/* Results */} 131 <Grid item xs={12} md={6}> 132 <Card elevation={1} sx={{ p: 3 }}> 133 <Typography variant="subtitle1" fontWeight={600} sx={{ mb: 2 }}> 134 {t("classifier.results")} 135 </Typography> 136 137 {loading && <LinearProgress sx={{ mb: 2 }} />} 138 139 {!results && !loading && ( 140 <Box sx={{ textAlign: "center", py: 6, color: "text.secondary" }}> 141 <Category sx={{ fontSize: 48, opacity: 0.2, mb: 1 }} /> 142 <Typography variant="body2">{t("classifier.inputHint")}</Typography> 143 </Box> 144 )} 145 146 {results && ( 147 <Box> 148 <Box sx={{ mb: 3, p: 2, bgcolor: "action.hover", borderRadius: 1 }}> 149 <Typography variant="caption" color="text.secondary">{t("classifier.input")}</Typography> 150 <Typography variant="body2" sx={{ fontStyle: "italic" }}> 151 {results.sequence} 152 </Typography> 153 {results.model && ( 154 <Typography variant="caption" color="text.secondary" display="block" sx={{ mt: 0.5 }}> 155 {t("classifier.modelLabel", { model: results.model })} 156 </Typography> 157 )} 158 </Box> 159 160 <Box sx={{ display: "flex", flexDirection: "column", gap: 1.5 }}> 161 {results.labels.map((label, i) => { 162 const score = results.scores[i]; 163 const pct = (score * 100).toFixed(1); 164 return ( 165 <Box key={label}> 166 <Box sx={{ display: "flex", justifyContent: "space-between", alignItems: "center", mb: 0.5 }}> 167 <Chip 168 label={label} 169 size="small" 170 sx={{ 171 bgcolor: COLORS[i % COLORS.length] + "20", 172 color: COLORS[i % COLORS.length], 173 fontWeight: i === 0 ? 700 : 400, 174 border: i === 0 ? `2px solid ${COLORS[i % COLORS.length]}` : "none", 175 }} 176 /> 177 <Typography variant="body2" fontWeight={i === 0 ? 700 : 400}> 178 {pct}% 179 </Typography> 180 </Box> 181 <LinearProgress 182 variant="determinate" 183 value={score * 100} 184 sx={{ 185 height: 8, 186 borderRadius: 4, 187 bgcolor: "action.hover", 188 "& .MuiLinearProgress-bar": { 189 bgcolor: COLORS[i % COLORS.length], 190 borderRadius: 4, 191 }, 192 }} 193 /> 194 </Box> 195 ); 196 })} 197 </Box> 198 199 {results.labels.length > 0 && ( 200 <Box sx={{ mt: 3, p: 2, bgcolor: COLORS[0] + "10", borderRadius: 1, border: `1px solid ${COLORS[0]}30` }}> 201 <Typography variant="body2" color="text.secondary"> 202 <Trans 203 i18nKey="classifier.bestMatch" 204 values={{ label: results.labels[0], pct: (results.scores[0] * 100).toFixed(1) }} 205 components={{ strong: <strong style={{ color: COLORS[0] }} /> }} 206 /> 207 </Typography> 208 </Box> 209 )} 210 </Box> 211 )} 212 </Card> 213 </Grid> 214 </Grid> 215 </ContentBox> 216 </Container> 217 ); 218 }